@@ -20462,19 +20462,19 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2046220462 ggml_build_forward_expand(gf, cur);
2046320463 }
2046420464
20465- struct ggml_tensor * delta_net_unified(struct ggml_context * ctx,
20466- struct ggml_tensor * q ,
20467- struct ggml_tensor * k ,
20468- struct ggml_tensor * v ,
20469- struct ggml_tensor * g ,
20470- struct ggml_tensor * beta ,
20471- struct ggml_tensor * state ,
20472- struct ggml_tensor * causal_mask ,
20473- struct ggml_tensor * identity ,
20474- bool use_qk_l2norm ,
20475- float eps_norm ,
20476- int il
20477- ) {
20465+ ggml_tensor * delta_net_unified(
20466+ ggml_context * ctx ,
20467+ ggml_tensor * q ,
20468+ ggml_tensor * k ,
20469+ ggml_tensor * v ,
20470+ ggml_tensor * g ,
20471+ ggml_tensor * beta ,
20472+ ggml_tensor * state ,
20473+ ggml_tensor * causal_mask ,
20474+ ggml_tensor * identity ,
20475+ bool use_qk_l2norm ,
20476+ float eps_norm,
20477+ int il ) {
2047820478 GGML_ASSERT(ggml_is_contiguous(q));
2047920479 GGML_ASSERT(ggml_is_contiguous(k));
2048020480 GGML_ASSERT(ggml_is_contiguous(v));
@@ -20511,19 +20511,20 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2051120511
2051220512 beta = ggml_sigmoid(ctx, beta);
2051320513
20514- struct ggml_tensor * causal_diag_mask = ggml_add(ctx, causal_mask, identity);
20514+ ggml_tensor * causal_diag_mask = ggml_add(ctx, causal_mask, identity);
2051520515
2051620516 cb(q, "q_in", il);
2051720517 cb(k, "k_in", il);
2051820518 cb(v, "v_in", il);
2051920519 cb(beta, "beta_in", il);
2052020520 cb(g, "g_in", il);
2052120521
20522- q = ggml_cont_4d(ctx, ggml_permute(ctx, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
20523- k = ggml_cont_4d(ctx, ggml_permute(ctx, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
20524- v = ggml_cont_4d(ctx, ggml_permute(ctx, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
20522+ q = ggml_cont_4d(ctx, ggml_permute(ctx, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
20523+ k = ggml_cont_4d(ctx, ggml_permute(ctx, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
20524+ v = ggml_cont_4d(ctx, ggml_permute(ctx, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
20525+ g = ggml_cont_4d(ctx, ggml_permute(ctx, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
20526+
2052520527 beta = ggml_cont(ctx, ggml_permute(ctx, beta, 2, 0, 1, 3));
20526- g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
2052720528 state = ggml_reshape_4d(ctx, state, S_v, S_v, H_v, n_seqs);
2052820529
2052920530 cb(q, "q_perm", il);
@@ -20536,39 +20537,32 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2053620537 GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
2053720538 GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
2053820539 GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
20539- GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 &&
20540- beta->ne[3] == n_seqs);
20541- GGML_ASSERT(g->ne[0] == n_tokens && g->ne[2] == H_k && g->ne[1] == 1 && g->ne[3] == n_seqs);
20542-
20543- struct ggml_tensor * v_beta = ggml_mul(ctx, v, beta);
20544- v_beta = ggml_reshape_4d(ctx, v_beta, S_v, n_tokens, H_k, n_seqs);
20545- struct ggml_tensor * k_beta = ggml_mul(ctx, k, beta);
20546- k_beta = ggml_reshape_4d(ctx, k_beta, S_v, n_tokens, H_k, n_seqs);
20547- k = ggml_reshape_4d(ctx, k, S_v, n_tokens, H_k, n_seqs);
20548- q = ggml_reshape_4d(ctx, q, S_v, n_tokens, H_k, n_seqs);
20549- v = ggml_reshape_4d(ctx, v, S_v, n_tokens, H_v, n_seqs);
20550- g = ggml_reshape_4d(ctx, g, n_tokens, 1, H_k, n_seqs);
20551- struct ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
20540+ GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
20541+
20542+ ggml_tensor * v_beta = ggml_mul(ctx, v, beta);
20543+ ggml_tensor * k_beta = ggml_mul(ctx, k, beta);
20544+
20545+ ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
2055220546
2055320547 cb(k_beta, "k_beta", il);
2055420548 cb(v_beta, "v_beta", il);
2055520549 cb(g_cumsum, "g_cumsum", il);
2055620550
20557- struct ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, n_tokens, 1, H_v,
20551+ ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, n_tokens, 1, H_v,
2055820552 n_seqs); // [chunk_size, 1, n_tokens, n_seqs]
20559- struct ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, n_tokens, H_v,
20553+ ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, n_tokens, H_v,
2056020554 n_seqs); // [1, chunk_size, n_tokens, n_seqs]
2056120555
2056220556 // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
20563- // struct ggml_tensor * gcs_i_broadcast =
20557+ // ggml_tensor * gcs_i_broadcast =
2056420558 // ggml_repeat_4d(ctx, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v,
2056520559 // n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
2056620560 // Don't need this, this one will get auto-broadcast
20567- struct ggml_tensor * gcs_j_broadcast =
20561+ ggml_tensor * gcs_j_broadcast =
2056820562 ggml_repeat_4d(ctx, gcs_j, n_tokens, n_tokens, H_v,
2056920563 n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
2057020564
20571- struct ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i);
20565+ ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i);
2057220566
2057320567 // Apply lower triangular mask to ensure attention is causal (only past tokens influence current)
2057420568 decay_mask = ggml_mul(ctx, decay_mask, causal_diag_mask);
@@ -20580,12 +20574,12 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2058020574 cb(decay_mask, "decay_mask", il);
2058120575
2058220576 // attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
20583- struct ggml_tensor * kmulkbeta = ggml_mul_mat(ctx, ggml_cont(ctx, k), ggml_cont(ctx, k_beta) );
20577+ ggml_tensor * kmulkbeta = ggml_mul_mat(ctx, k, k_beta);
2058420578
2058520579 cb(kmulkbeta, "kmulkbeta", il);
2058620580
20587- struct ggml_tensor * k_decay = ggml_mul(ctx, kmulkbeta, decay_mask);
20588- struct ggml_tensor * attn = ggml_neg(ctx, ggml_mul(ctx, k_decay, causal_mask));
20581+ ggml_tensor * k_decay = ggml_mul(ctx, kmulkbeta, decay_mask);
20582+ ggml_tensor * attn = ggml_neg(ctx, ggml_mul(ctx, k_decay, causal_mask));
2058920583
2059020584 cb(attn, "attn_pre_rec", il);
2059120585
@@ -20597,29 +20591,28 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2059720591 //
2059820592 // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
2059920593 ggml_tensor * attn_lower = ggml_mul(ctx, attn, causal_mask);
20600- struct ggml_tensor * lhs =
20601- ggml_sub(ctx, ggml_repeat_4d(ctx, identity, identity->ne[0], identity->ne[1], attn_lower->ne[2], attn_lower->ne[3]), attn_lower);
20594+ ggml_tensor * lhs = ggml_sub(ctx, ggml_repeat(ctx, identity, attn_lower), attn_lower);
2060220595
20603- struct ggml_tensor * lin_solve = ggml_solve_tri(ctx, lhs, attn);
20596+ ggml_tensor * lin_solve = ggml_solve_tri(ctx, lhs, attn);
2060420597 attn = ggml_mul(ctx, lin_solve, causal_mask);
20605- attn = ggml_cont(ctx, ggml_add(ctx, attn, identity) );
20598+ attn = ggml_add(ctx, attn, identity);
2060620599
2060720600 // value = attn @ v_beta
20608- v = ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, attn, ggml_cont(ctx, ggml_transpose(ctx0, v_beta)))) );
20601+ v = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx0, v_beta)), attn );
2060920602
2061020603 cb(v, "value_beta", il);
2061120604
2061220605 // k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
20613- struct ggml_tensor * g_cumsum_t = ggml_cont(ctx, ggml_transpose(ctx, g_cumsum));
20614- struct ggml_tensor * gexp = ggml_exp(ctx, g_cumsum_t);
20606+ ggml_tensor * g_cumsum_t = ggml_cont(ctx, ggml_transpose(ctx, g_cumsum));
20607+ ggml_tensor * gexp = ggml_exp(ctx, g_cumsum_t);
2061520608
2061620609 cb(gexp, "g_cum_exp", il);
2061720610
20618- struct ggml_tensor * kbeta_gexp = ggml_mul(ctx, ggml_cont(ctx, k_beta) , gexp);
20611+ ggml_tensor * kbeta_gexp = ggml_mul(ctx, k_beta, gexp);
2061920612
2062020613 cb(kbeta_gexp, "kbeta_gexp", il);
2062120614
20622- struct ggml_tensor * k_cumdecay =
20615+ ggml_tensor * k_cumdecay =
2062320616 ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, attn, ggml_cont(ctx, ggml_transpose(ctx, kbeta_gexp)))));
2062420617
2062520618 cb(k_cumdecay, "k_cumdecay", il);
@@ -20631,28 +20624,32 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2063120624
2063220625 cb(attn, "attn_decay_key", il);
2063320626
20627+ ggml_tensor * state_t = ggml_cont(ctx, ggml_transpose(ctx, state));
20628+
2063420629 // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
20635- struct ggml_tensor * v_prime = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, state)) , k_cumdecay);
20630+ ggml_tensor * v_prime = ggml_mul_mat(ctx, state_t , k_cumdecay);
2063620631
2063720632 cb(v_prime, "v_prime", il);
2063820633
2063920634 // v_new = v_i - v_prime
20640- struct ggml_tensor * v_new = ggml_sub(ctx, ggml_repeat_4d(ctx, v, v_prime->ne[0], v_prime->ne[1], v_prime->ne[2], v_prime->ne[3]), v_prime);
20635+ ggml_tensor * v_new = ggml_sub(ctx, ggml_repeat(ctx, v, v_prime), v_prime);
20636+
20637+ ggml_tensor * v_new_t = ggml_cont(ctx, ggml_transpose(ctx, v_new));
2064120638
2064220639 cb(v_new, "v_new", il);
2064320640
2064420641 // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
20645- struct ggml_tensor * q_g_exp = ggml_mul(ctx, q, gexp);
20646- struct ggml_tensor * attn_inter = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, state)) , q_g_exp);
20642+ ggml_tensor * q_g_exp = ggml_mul(ctx, q, gexp);
20643+ ggml_tensor * attn_inter = ggml_mul_mat(ctx, state_t , q_g_exp);
2064720644
2064820645 cb(attn_inter, "attn_inter", il);
2064920646
2065020647 // core_attn_out[:, :, i] = attn_inter + attn @ v_new
20651- struct ggml_tensor * v_attn = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, v_new)) , attn);
20648+ ggml_tensor * v_attn = ggml_mul_mat(ctx, v_new_t , attn);
2065220649
2065320650 cb(v_attn, "v_attn", il);
2065420651
20655- struct ggml_tensor * core_attn_out = ggml_add(ctx, attn_inter, v_attn);
20652+ ggml_tensor * core_attn_out = ggml_add(ctx, attn_inter, v_attn);
2065620653
2065720654 cb(core_attn_out, "core_attn_out", il);
2065820655
@@ -20662,22 +20659,20 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2066220659 // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
2066320660 // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
2066420661
20665- gexp = ggml_cont(ctx, gexp);
20666-
2066720662 ggml_tensor * g_cum_last = ggml_cont(ctx, ggml_view_4d(ctx, g_cumsum_t, g_cumsum_t->ne[0], 1, g_cumsum_t->ne[2], g_cumsum_t->ne[3], g_cumsum_t->nb[1],
2066820663 g_cumsum_t->nb[2], g_cumsum_t->nb[3], g_cumsum_t->nb[0] * (g_cumsum_t->ne[1] - 1)));
2066920664
2067020665 cb(g_cum_last, "g_cum_last", il);
2067120666
20672- ggml_tensor * gexp_last = ggml_cont_4d (ctx, ggml_exp(ctx, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
20667+ ggml_tensor * gexp_last = ggml_reshape_4d (ctx, ggml_exp(ctx, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
2067320668
2067420669 cb(g_cum_last, "gexp_last", il);
2067520670
20676- ggml_tensor * g_cum_last_3d = ggml_cont_3d (ctx, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
20671+ ggml_tensor * g_cum_last_3d = ggml_reshape_3d (ctx, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
2067720672
2067820673 cb(g_cum_last, "g_cum_last_3d", il);
2067920674
20680- ggml_tensor * g_cumsum_3d = ggml_cont_3d (ctx, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]);
20675+ ggml_tensor * g_cumsum_3d = ggml_reshape_3d (ctx, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]);
2068120676
2068220677 cb(g_cum_last, "g_cumsum_3d", il);
2068320678
@@ -20689,24 +20684,22 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2068920684
2069020685 cb(g_cum_last, "g_diff_exp", il);
2069120686
20692- ggml_tensor * key_gdiff = ggml_mul(ctx, k, ggml_cont_4d (ctx, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], g_diff_exp->ne[2] * g_diff_exp->ne[3]));
20687+ ggml_tensor * key_gdiff = ggml_mul(ctx, k, ggml_reshape_4d (ctx, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], g_diff_exp->ne[2] * g_diff_exp->ne[3]));
2069320688
2069420689 cb(g_cum_last, "key_gdiff", il);
2069520690
20696- ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_cont(ctx, ggml_transpose(ctx, v_new))) ,
20691+ ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx, v_new_t ,
2069720692 ggml_cont(ctx, ggml_transpose(ctx, key_gdiff)));
2069820693
2069920694 cb(kgdmulvnew, "kgdmulvnew", il);
2070020695
20701- struct ggml_tensor * new_state =
20702- ggml_add(ctx, ggml_mul(ctx, state, ggml_cont_4d(ctx, gexp_last, 1, 1, H_v, ggml_nelements(gexp_last) / H_v)),
20703- kgdmulvnew);
20696+ ggml_tensor * new_state = ggml_add(ctx, ggml_mul(ctx, state, gexp_last), kgdmulvnew);
2070420697
2070520698 cb(new_state, "new_state", il);
2070620699
2070720700 // flatten output
20708- struct ggml_tensor * flat_output = ggml_cont_1d(ctx, ggml_permute(ctx, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
20709- struct ggml_tensor * flat_state = ggml_cont_1d(ctx, new_state, S_v * S_v * H_v * n_seqs);
20701+ ggml_tensor * flat_output = ggml_cont_1d(ctx, ggml_permute(ctx, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
20702+ ggml_tensor * flat_state = ggml_cont_1d(ctx, new_state, S_v * S_v * H_v * n_seqs);
2071020703
2071120704 return ggml_concat(ctx, flat_output, flat_state, 0);
2071220705 }
@@ -20799,15 +20792,14 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2079920792 return cur;
2080020793 }
2080120794
20802-
20803-
20804- ggml_tensor * build_qwen3next_linear_attn_layer(llm_graph_input_rs * inp,
20805- ggml_tensor * cur,
20806- const llama_model & model,
20807- const llama_ubatch & ubatch,
20808- ggml_tensor * causal_mask,
20809- ggml_tensor * identity,
20810- int il) {
20795+ ggml_tensor * build_qwen3next_linear_attn_layer(
20796+ llm_graph_input_rs * inp,
20797+ ggml_tensor * cur,
20798+ const llama_model & model,
20799+ const llama_ubatch & ubatch,
20800+ ggml_tensor * causal_mask,
20801+ ggml_tensor * identity,
20802+ int il) {
2081120803 const auto * mctx_cur = inp->mctx;
2081220804
2081320805 const int64_t d_inner = hparams.ssm_d_inner;
@@ -21050,27 +21042,24 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2105021042
2105121043 // Reshape both attn_out_final and z to 2D tensors for normalization
2105221044 // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
21053- ggml_tensor * attn_out_2d_final = ggml_reshape_2d (ctx0, ggml_cont(ctx0, attn_out_final) , head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
21045+ ggml_tensor * attn_out_2d_final = ggml_cont_2d (ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
2105421046
2105521047 // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
2105621048 ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
2105721049
2105821050 // Apply gated normalization: self.norm(core_attn_out, z)
2105921051 ggml_tensor * attn_out_norm = build_q3n_gated_norm(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
2106021052
21061- // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
21062- ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
21063-
2106421053 // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
21065- ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d , head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
21054+ ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm , head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
2106621055 cb(final_output, "final_output", il);
2106721056
2106821057 // Output projection
2106921058 cur = build_lora_mm(model.layers[il].ssm_out, final_output);
2107021059 cb(cur, "linear_attn_out", il);
2107121060
2107221061 // Reshape back to original dimensions
21073- cur = ggml_cont (ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs) );
21062+ cur = ggml_cont_2d (ctx0, cur, n_embd, n_seq_tokens * n_seqs);
2107421063 return cur;
2107521064 }
2107621065
0 commit comments