Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3640,11 +3640,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
n_fuse++;

if (n_fuse > 1) {
ggml_tensor fused_add_node;
memcpy(&fused_add_node, node, sizeof(ggml_tensor));
Comment thread
ORippler marked this conversation as resolved.
for (int j = 0; j < n_fuse - 1; ++j) {
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
}
cgraph->nodes[i + n_fuse - 1]->data = node->data;
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
i += n_fuse - 1;

continue;
Expand Down
Loading