Skip to content

Commit f63e270

Browse files
committed
Apply graph reduction changes
1 parent 5ecbe6e commit f63e270

File tree

1 file changed

+70
-81
lines changed

1 file changed

+70
-81
lines changed

src/llama-model.cpp

Lines changed: 70 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -20462,19 +20462,19 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2046220462
ggml_build_forward_expand(gf, cur);
2046320463
}
2046420464

20465-
struct ggml_tensor * delta_net_unified(struct ggml_context * ctx,
20466-
struct ggml_tensor * q,
20467-
struct ggml_tensor * k,
20468-
struct ggml_tensor * v,
20469-
struct ggml_tensor * g,
20470-
struct ggml_tensor * beta,
20471-
struct ggml_tensor * state,
20472-
struct ggml_tensor * causal_mask,
20473-
struct ggml_tensor * identity,
20474-
bool use_qk_l2norm,
20475-
float eps_norm,
20476-
int il
20477-
) {
20465+
ggml_tensor * delta_net_unified(
20466+
ggml_context * ctx,
20467+
ggml_tensor * q,
20468+
ggml_tensor * k,
20469+
ggml_tensor * v,
20470+
ggml_tensor * g,
20471+
ggml_tensor * beta,
20472+
ggml_tensor * state,
20473+
ggml_tensor * causal_mask,
20474+
ggml_tensor * identity,
20475+
bool use_qk_l2norm,
20476+
float eps_norm,
20477+
int il) {
2047820478
GGML_ASSERT(ggml_is_contiguous(q));
2047920479
GGML_ASSERT(ggml_is_contiguous(k));
2048020480
GGML_ASSERT(ggml_is_contiguous(v));
@@ -20511,19 +20511,20 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2051120511

2051220512
beta = ggml_sigmoid(ctx, beta);
2051320513

20514-
struct ggml_tensor * causal_diag_mask = ggml_add(ctx, causal_mask, identity);
20514+
ggml_tensor * causal_diag_mask = ggml_add(ctx, causal_mask, identity);
2051520515

2051620516
cb(q, "q_in", il);
2051720517
cb(k, "k_in", il);
2051820518
cb(v, "v_in", il);
2051920519
cb(beta, "beta_in", il);
2052020520
cb(g, "g_in", il);
2052120521

20522-
q = ggml_cont_4d(ctx, ggml_permute(ctx, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
20523-
k = ggml_cont_4d(ctx, ggml_permute(ctx, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
20524-
v = ggml_cont_4d(ctx, ggml_permute(ctx, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
20522+
q = ggml_cont_4d(ctx, ggml_permute(ctx, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
20523+
k = ggml_cont_4d(ctx, ggml_permute(ctx, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
20524+
v = ggml_cont_4d(ctx, ggml_permute(ctx, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
20525+
g = ggml_cont_4d(ctx, ggml_permute(ctx, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
20526+
2052520527
beta = ggml_cont(ctx, ggml_permute(ctx, beta, 2, 0, 1, 3));
20526-
g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
2052720528
state = ggml_reshape_4d(ctx, state, S_v, S_v, H_v, n_seqs);
2052820529

2052920530
cb(q, "q_perm", il);
@@ -20536,39 +20537,32 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2053620537
GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
2053720538
GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
2053820539
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
20539-
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 &&
20540-
beta->ne[3] == n_seqs);
20541-
GGML_ASSERT(g->ne[0] == n_tokens && g->ne[2] == H_k && g->ne[1] == 1 && g->ne[3] == n_seqs);
20542-
20543-
struct ggml_tensor * v_beta = ggml_mul(ctx, v, beta);
20544-
v_beta = ggml_reshape_4d(ctx, v_beta, S_v, n_tokens, H_k, n_seqs);
20545-
struct ggml_tensor * k_beta = ggml_mul(ctx, k, beta);
20546-
k_beta = ggml_reshape_4d(ctx, k_beta, S_v, n_tokens, H_k, n_seqs);
20547-
k = ggml_reshape_4d(ctx, k, S_v, n_tokens, H_k, n_seqs);
20548-
q = ggml_reshape_4d(ctx, q, S_v, n_tokens, H_k, n_seqs);
20549-
v = ggml_reshape_4d(ctx, v, S_v, n_tokens, H_v, n_seqs);
20550-
g = ggml_reshape_4d(ctx, g, n_tokens, 1, H_k, n_seqs);
20551-
struct ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
20540+
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
20541+
20542+
ggml_tensor * v_beta = ggml_mul(ctx, v, beta);
20543+
ggml_tensor * k_beta = ggml_mul(ctx, k, beta);
20544+
20545+
ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
2055220546

2055320547
cb(k_beta, "k_beta", il);
2055420548
cb(v_beta, "v_beta", il);
2055520549
cb(g_cumsum, "g_cumsum", il);
2055620550

20557-
struct ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, n_tokens, 1, H_v,
20551+
ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, n_tokens, 1, H_v,
2055820552
n_seqs); // [chunk_size, 1, n_tokens, n_seqs]
20559-
struct ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, n_tokens, H_v,
20553+
ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, n_tokens, H_v,
2056020554
n_seqs); // [1, chunk_size, n_tokens, n_seqs]
2056120555

2056220556
// Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
20563-
// struct ggml_tensor * gcs_i_broadcast =
20557+
// ggml_tensor * gcs_i_broadcast =
2056420558
// ggml_repeat_4d(ctx, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v,
2056520559
// n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
2056620560
// Don't need this, this one will get auto-broadcast
20567-
struct ggml_tensor * gcs_j_broadcast =
20561+
ggml_tensor * gcs_j_broadcast =
2056820562
ggml_repeat_4d(ctx, gcs_j, n_tokens, n_tokens, H_v,
2056920563
n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
2057020564

20571-
struct ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i);
20565+
ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i);
2057220566

2057320567
// Apply lower triangular mask to ensure attention is causal (only past tokens influence current)
2057420568
decay_mask = ggml_mul(ctx, decay_mask, causal_diag_mask);
@@ -20580,12 +20574,12 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2058020574
cb(decay_mask, "decay_mask", il);
2058120575

2058220576
// attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
20583-
struct ggml_tensor * kmulkbeta = ggml_mul_mat(ctx, ggml_cont(ctx, k), ggml_cont(ctx, k_beta));
20577+
ggml_tensor * kmulkbeta = ggml_mul_mat(ctx, k, k_beta);
2058420578

2058520579
cb(kmulkbeta, "kmulkbeta", il);
2058620580

20587-
struct ggml_tensor * k_decay = ggml_mul(ctx, kmulkbeta, decay_mask);
20588-
struct ggml_tensor * attn = ggml_neg(ctx, ggml_mul(ctx, k_decay, causal_mask));
20581+
ggml_tensor * k_decay = ggml_mul(ctx, kmulkbeta, decay_mask);
20582+
ggml_tensor * attn = ggml_neg(ctx, ggml_mul(ctx, k_decay, causal_mask));
2058920583

2059020584
cb(attn, "attn_pre_rec", il);
2059120585

@@ -20597,29 +20591,28 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2059720591
//
2059820592
// We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
2059920593
ggml_tensor * attn_lower = ggml_mul(ctx, attn, causal_mask);
20600-
struct ggml_tensor * lhs =
20601-
ggml_sub(ctx, ggml_repeat_4d(ctx, identity, identity->ne[0], identity->ne[1], attn_lower->ne[2], attn_lower->ne[3]), attn_lower);
20594+
ggml_tensor * lhs = ggml_sub(ctx, ggml_repeat(ctx, identity, attn_lower), attn_lower);
2060220595

20603-
struct ggml_tensor * lin_solve = ggml_solve_tri(ctx, lhs, attn);
20596+
ggml_tensor * lin_solve = ggml_solve_tri(ctx, lhs, attn);
2060420597
attn = ggml_mul(ctx, lin_solve, causal_mask);
20605-
attn = ggml_cont(ctx, ggml_add(ctx, attn, identity));
20598+
attn = ggml_add(ctx, attn, identity);
2060620599

2060720600
// value = attn @ v_beta
20608-
v = ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, attn, ggml_cont(ctx, ggml_transpose(ctx0, v_beta)))));
20601+
v = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx0, v_beta)), attn);
2060920602

2061020603
cb(v, "value_beta", il);
2061120604

2061220605
// k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
20613-
struct ggml_tensor * g_cumsum_t = ggml_cont(ctx, ggml_transpose(ctx, g_cumsum));
20614-
struct ggml_tensor * gexp = ggml_exp(ctx, g_cumsum_t);
20606+
ggml_tensor * g_cumsum_t = ggml_cont(ctx, ggml_transpose(ctx, g_cumsum));
20607+
ggml_tensor * gexp = ggml_exp(ctx, g_cumsum_t);
2061520608

2061620609
cb(gexp, "g_cum_exp", il);
2061720610

20618-
struct ggml_tensor * kbeta_gexp = ggml_mul(ctx, ggml_cont(ctx, k_beta), gexp);
20611+
ggml_tensor * kbeta_gexp = ggml_mul(ctx, k_beta, gexp);
2061920612

2062020613
cb(kbeta_gexp, "kbeta_gexp", il);
2062120614

20622-
struct ggml_tensor * k_cumdecay =
20615+
ggml_tensor * k_cumdecay =
2062320616
ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, attn, ggml_cont(ctx, ggml_transpose(ctx, kbeta_gexp)))));
2062420617

2062520618
cb(k_cumdecay, "k_cumdecay", il);
@@ -20631,28 +20624,32 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2063120624

2063220625
cb(attn, "attn_decay_key", il);
2063320626

20627+
ggml_tensor * state_t = ggml_cont(ctx, ggml_transpose(ctx, state));
20628+
2063420629
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
20635-
struct ggml_tensor * v_prime = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, state)), k_cumdecay);
20630+
ggml_tensor * v_prime = ggml_mul_mat(ctx, state_t, k_cumdecay);
2063620631

2063720632
cb(v_prime, "v_prime", il);
2063820633

2063920634
// v_new = v_i - v_prime
20640-
struct ggml_tensor * v_new = ggml_sub(ctx, ggml_repeat_4d(ctx, v, v_prime->ne[0], v_prime->ne[1], v_prime->ne[2], v_prime->ne[3]), v_prime);
20635+
ggml_tensor * v_new = ggml_sub(ctx, ggml_repeat(ctx, v, v_prime), v_prime);
20636+
20637+
ggml_tensor * v_new_t = ggml_cont(ctx, ggml_transpose(ctx, v_new));
2064120638

2064220639
cb(v_new, "v_new", il);
2064320640

2064420641
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
20645-
struct ggml_tensor * q_g_exp = ggml_mul(ctx, q, gexp);
20646-
struct ggml_tensor * attn_inter = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, state)), q_g_exp);
20642+
ggml_tensor * q_g_exp = ggml_mul(ctx, q, gexp);
20643+
ggml_tensor * attn_inter = ggml_mul_mat(ctx, state_t, q_g_exp);
2064720644

2064820645
cb(attn_inter, "attn_inter", il);
2064920646

2065020647
// core_attn_out[:, :, i] = attn_inter + attn @ v_new
20651-
struct ggml_tensor * v_attn = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, v_new)), attn);
20648+
ggml_tensor * v_attn = ggml_mul_mat(ctx, v_new_t, attn);
2065220649

2065320650
cb(v_attn, "v_attn", il);
2065420651

20655-
struct ggml_tensor * core_attn_out = ggml_add(ctx, attn_inter, v_attn);
20652+
ggml_tensor * core_attn_out = ggml_add(ctx, attn_inter, v_attn);
2065620653

2065720654
cb(core_attn_out, "core_attn_out", il);
2065820655

@@ -20662,22 +20659,20 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2066220659
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
2066320660
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
2066420661

20665-
gexp = ggml_cont(ctx, gexp);
20666-
2066720662
ggml_tensor * g_cum_last = ggml_cont(ctx, ggml_view_4d(ctx, g_cumsum_t, g_cumsum_t->ne[0], 1, g_cumsum_t->ne[2], g_cumsum_t->ne[3], g_cumsum_t->nb[1],
2066820663
g_cumsum_t->nb[2], g_cumsum_t->nb[3], g_cumsum_t->nb[0] * (g_cumsum_t->ne[1] - 1)));
2066920664

2067020665
cb(g_cum_last, "g_cum_last", il);
2067120666

20672-
ggml_tensor * gexp_last = ggml_cont_4d(ctx, ggml_exp(ctx, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
20667+
ggml_tensor * gexp_last = ggml_reshape_4d(ctx, ggml_exp(ctx, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
2067320668

2067420669
cb(g_cum_last, "gexp_last", il);
2067520670

20676-
ggml_tensor * g_cum_last_3d = ggml_cont_3d(ctx, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
20671+
ggml_tensor * g_cum_last_3d = ggml_reshape_3d(ctx, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
2067720672

2067820673
cb(g_cum_last, "g_cum_last_3d", il);
2067920674

20680-
ggml_tensor * g_cumsum_3d = ggml_cont_3d(ctx, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]);
20675+
ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]);
2068120676

2068220677
cb(g_cum_last, "g_cumsum_3d", il);
2068320678

@@ -20689,24 +20684,22 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2068920684

2069020685
cb(g_cum_last, "g_diff_exp", il);
2069120686

20692-
ggml_tensor * key_gdiff = ggml_mul(ctx, k, ggml_cont_4d(ctx, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], g_diff_exp->ne[2] * g_diff_exp->ne[3]));
20687+
ggml_tensor * key_gdiff = ggml_mul(ctx, k, ggml_reshape_4d(ctx, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], g_diff_exp->ne[2] * g_diff_exp->ne[3]));
2069320688

2069420689
cb(g_cum_last, "key_gdiff", il);
2069520690

20696-
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_cont(ctx, ggml_transpose(ctx, v_new))),
20691+
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx, v_new_t,
2069720692
ggml_cont(ctx, ggml_transpose(ctx, key_gdiff)));
2069820693

2069920694
cb(kgdmulvnew, "kgdmulvnew", il);
2070020695

20701-
struct ggml_tensor * new_state =
20702-
ggml_add(ctx, ggml_mul(ctx, state, ggml_cont_4d(ctx, gexp_last, 1, 1, H_v, ggml_nelements(gexp_last) / H_v)),
20703-
kgdmulvnew);
20696+
ggml_tensor * new_state = ggml_add(ctx, ggml_mul(ctx, state, gexp_last), kgdmulvnew);
2070420697

2070520698
cb(new_state, "new_state", il);
2070620699

2070720700
// flatten output
20708-
struct ggml_tensor * flat_output = ggml_cont_1d(ctx, ggml_permute(ctx, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
20709-
struct ggml_tensor * flat_state = ggml_cont_1d(ctx, new_state, S_v * S_v * H_v * n_seqs);
20701+
ggml_tensor * flat_output = ggml_cont_1d(ctx, ggml_permute(ctx, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
20702+
ggml_tensor * flat_state = ggml_cont_1d(ctx, new_state, S_v * S_v * H_v * n_seqs);
2071020703

2071120704
return ggml_concat(ctx, flat_output, flat_state, 0);
2071220705
}
@@ -20799,15 +20792,14 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2079920792
return cur;
2080020793
}
2080120794

20802-
20803-
20804-
ggml_tensor * build_qwen3next_linear_attn_layer(llm_graph_input_rs * inp,
20805-
ggml_tensor * cur,
20806-
const llama_model & model,
20807-
const llama_ubatch & ubatch,
20808-
ggml_tensor * causal_mask,
20809-
ggml_tensor * identity,
20810-
int il) {
20795+
ggml_tensor * build_qwen3next_linear_attn_layer(
20796+
llm_graph_input_rs * inp,
20797+
ggml_tensor * cur,
20798+
const llama_model & model,
20799+
const llama_ubatch & ubatch,
20800+
ggml_tensor * causal_mask,
20801+
ggml_tensor * identity,
20802+
int il) {
2081120803
const auto * mctx_cur = inp->mctx;
2081220804

2081320805
const int64_t d_inner = hparams.ssm_d_inner;
@@ -21050,27 +21042,24 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2105021042

2105121043
// Reshape both attn_out_final and z to 2D tensors for normalization
2105221044
// attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
21053-
ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
21045+
ggml_tensor * attn_out_2d_final = ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
2105421046

2105521047
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
2105621048
ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
2105721049

2105821050
// Apply gated normalization: self.norm(core_attn_out, z)
2105921051
ggml_tensor * attn_out_norm = build_q3n_gated_norm(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
2106021052

21061-
// Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
21062-
ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
21063-
2106421053
// Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
21065-
ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
21054+
ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
2106621055
cb(final_output, "final_output", il);
2106721056

2106821057
// Output projection
2106921058
cur = build_lora_mm(model.layers[il].ssm_out, final_output);
2107021059
cb(cur, "linear_attn_out", il);
2107121060

2107221061
// Reshape back to original dimensions
21073-
cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs));
21062+
cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
2107421063
return cur;
2107521064
}
2107621065

0 commit comments

Comments
 (0)