@@ -3050,7 +3050,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30503050 std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
30513051 ggml_cuda_topk_moe_ops (/* with_norm=*/ false , /* delayed_softmax=*/ true );
30523052
3053- if (ops.size () == topk_moe_ops_with_norm.size () &&
3053+ const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
3054+ const std::initializer_list<enum ggml_op> & list2) {
3055+ return std::equal (list1.begin (), list1.end (), list2.begin (), list2.end ());
3056+ };
3057+
3058+ if (is_equal (topk_moe_ops_with_norm, ops) &&
30543059 ggml_can_fuse_subgraph (cgraph, node_idx, ops, { node_idx + 3 , node_idx + 9 })) {
30553060 ggml_tensor * softmax = cgraph->nodes [node_idx];
30563061 ggml_tensor * weights = cgraph->nodes [node_idx + 9 ];
@@ -3060,16 +3065,15 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30603065 }
30613066 }
30623067
3063- if (ops.size () == topk_moe_ops.size () &&
3064- ggml_can_fuse_subgraph (cgraph, node_idx, ops, { node_idx + 3 , node_idx + 4 })) {
3068+ if (is_equal (topk_moe_ops, ops) && ggml_can_fuse_subgraph (cgraph, node_idx, ops, { node_idx + 3 , node_idx + 4 })) {
30653069 ggml_tensor * softmax = cgraph->nodes [node_idx];
30663070 ggml_tensor * weights = cgraph->nodes [node_idx + 4 ];
30673071 if (ggml_cuda_should_use_topk_moe (softmax, weights)) {
30683072 return true ;
30693073 }
30703074 }
30713075
3072- if (ops. size () == topk_moe_ops_delayed_softmax. size ( ) &&
3076+ if (is_equal ( topk_moe_ops_delayed_softmax, ops ) &&
30733077 ggml_can_fuse_subgraph (cgraph, node_idx, ops, { node_idx + 1 , node_idx + 5 })) {
30743078 ggml_tensor * softmax = cgraph->nodes [node_idx + 4 ];
30753079 ggml_tensor * weights = cgraph->nodes [node_idx + 5 ];
@@ -3085,9 +3089,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30853089 std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
30863090 std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
30873091
3088- if (ops.size () == 5 && (ggml_can_fuse_subgraph (cgraph, node_idx, ops, {node_idx + 4 }) ||
3089- ggml_can_fuse_subgraph (cgraph, node_idx, ops, {node_idx + 4 }))) {
3090-
3092+ if ((is_equal (mul_mat_bias_glu_ops, ops) || is_equal (mul_mat_id_bias_glu_ops, ops)) &&
3093+ ggml_can_fuse_subgraph (cgraph, node_idx, ops, { node_idx + 4 })) {
30913094 const ggml_tensor * ffn_gate = cgraph->nodes [node_idx];
30923095 const ggml_tensor * ffn_gate_bias = cgraph->nodes [node_idx + 1 ];
30933096 const ggml_tensor * ffn_up = cgraph->nodes [node_idx + 2 ];
@@ -3099,9 +3102,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30993102 }
31003103 }
31013104
3102- if (ops.size () == 3 && (ggml_can_fuse_subgraph (cgraph, node_idx, ops, {node_idx + 2 }) ||
3103- ggml_can_fuse_subgraph (cgraph, node_idx, ops, {node_idx + 2 }))) {
3104-
3105+ if ((is_equal (mul_mat_id_glu_ops, ops) || is_equal (mul_mat_glu_ops, ops)) &&
3106+ ggml_can_fuse_subgraph (cgraph, node_idx, ops, { node_idx + 2 })) {
31053107 const ggml_tensor * ffn_gate = cgraph->nodes [node_idx];
31063108 const ggml_tensor * ffn_up = cgraph->nodes [node_idx + 1 ];
31073109 const ggml_tensor * glu = cgraph->nodes [node_idx + 2 ];
@@ -3111,7 +3113,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
31113113 }
31123114 }
31133115
3114- if (ops.size () == 3 && ggml_can_fuse_subgraph (cgraph, node_idx, ops, { node_idx + 2 })) {
3116+ std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
3117+
3118+ if (is_equal (rope_set_rows_ops, ops) && ggml_can_fuse_subgraph (cgraph, node_idx, ops, { node_idx + 2 })) {
31153119 const ggml_tensor * rope = cgraph->nodes [node_idx];
31163120 const ggml_tensor * view = cgraph->nodes [node_idx + 1 ];
31173121 const ggml_tensor * set_rows = cgraph->nodes [node_idx + 2 ];
0 commit comments