Skip to content

Commit 2e7ef98

Browse files
authored
ggml-cuda: add stricter checking for fusion (#17568)
* ggml-cuda: make conditions for fusion more explicit * ggml-cuda: remove size check as std::equal already does it
1 parent ddf9f94 commit 2e7ef98

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3050,7 +3050,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30503050
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
30513051
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
30523052

3053-
if (ops.size() == topk_moe_ops_with_norm.size() &&
3053+
const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
3054+
const std::initializer_list<enum ggml_op> & list2) {
3055+
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
3056+
};
3057+
3058+
if (is_equal(topk_moe_ops_with_norm, ops) &&
30543059
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
30553060
ggml_tensor * softmax = cgraph->nodes[node_idx];
30563061
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
@@ -3060,16 +3065,15 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30603065
}
30613066
}
30623067

3063-
if (ops.size() == topk_moe_ops.size() &&
3064-
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
3068+
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
30653069
ggml_tensor * softmax = cgraph->nodes[node_idx];
30663070
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
30673071
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
30683072
return true;
30693073
}
30703074
}
30713075

3072-
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
3076+
if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
30733077
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
30743078
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
30753079
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
@@ -3085,9 +3089,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30853089
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
30863090
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
30873091

3088-
if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
3089-
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {
3090-
3092+
if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&
3093+
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {
30913094
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
30923095
const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
30933096
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
@@ -3099,9 +3102,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30993102
}
31003103
}
31013104

3102-
if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) ||
3103-
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) {
3104-
3105+
if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&
3106+
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
31053107
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
31063108
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
31073109
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
@@ -3111,7 +3113,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
31113113
}
31123114
}
31133115

3114-
if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
3116+
std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
3117+
3118+
if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
31153119
const ggml_tensor * rope = cgraph->nodes[node_idx];
31163120
const ggml_tensor * view = cgraph->nodes[node_idx + 1];
31173121
const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];

0 commit comments

Comments
 (0)