Skip to content

Commit 8794e13

Browse files
committed
Clean up code, add missing hybrid qualifier
1 parent bb5e624 commit 8794e13

File tree

3 files changed

+37
-25
lines changed

3 files changed

+37
-25
lines changed

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2745,6 +2745,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
27452745
case LLM_ARCH_LFM2:
27462746
case LLM_ARCH_LFM2MOE:
27472747
case LLM_ARCH_NEMOTRON_H:
2748+
case LLM_ARCH_QWEN3NEXT:
27482749
return true;
27492750
default:
27502751
return false;

src/llama-model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6761,6 +6761,7 @@ void llama_model::print_info() const {
67616761
arch == LLM_ARCH_FALCON_H1 ||
67626762
arch == LLM_ARCH_PLAMO2 ||
67636763
arch == LLM_ARCH_GRANITE_HYBRID ||
6764+
arch == LLM_ARCH_QWEN3NEXT ||
67646765
arch == LLM_ARCH_NEMOTRON_H) {
67656766
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
67666767
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);

src/models/qwen3next.cpp

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
2020

2121
ggml_tensor * causal_mask =
2222
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens, ubatch.n_seq_tokens), 1.0f),
23-
GGML_TRI_TYPE_LOWER);
24-
ggml_tensor * identity = ggml_diag(
25-
ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 1.0f));
23+
GGML_TRI_TYPE_LOWER);
24+
ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 1.0f));
2625

2726
ggml_build_forward_expand(gf, causal_mask);
2827
ggml_build_forward_expand(gf, identity);
@@ -170,19 +169,16 @@ ggml_tensor * llm_build_qwen3next::delta_net_unified(ggml_context * ctx,
170169
cb(v_beta, "v_beta", il);
171170
cb(g_cumsum, "g_cumsum", il);
172171

173-
ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, n_tokens, 1, H_v,
174-
n_seqs); // [chunk_size, 1, n_tokens, n_seqs]
175-
ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, n_tokens, H_v,
176-
n_seqs); // [1, chunk_size, n_tokens, n_seqs]
172+
ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, n_tokens, 1, H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs]
173+
ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, n_tokens, H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs]
177174

178175
// Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
179176
// ggml_tensor * gcs_i_broadcast =
180177
// ggml_repeat_4d(ctx, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v,
181178
// n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
182179
// Don't need this, this one will get auto-broadcast
183180
ggml_tensor * gcs_j_broadcast =
184-
ggml_repeat_4d(ctx, gcs_j, n_tokens, n_tokens, H_v,
185-
n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
181+
ggml_repeat_4d(ctx, gcs_j, n_tokens, n_tokens, H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
186182

187183
ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i);
188184

@@ -215,9 +211,9 @@ ggml_tensor * llm_build_qwen3next::delta_net_unified(ggml_context * ctx,
215211
ggml_tensor * attn_lower = ggml_mul(ctx, attn, causal_mask);
216212
ggml_tensor * lhs = ggml_sub(ctx, ggml_repeat(ctx, identity, attn_lower), attn_lower);
217213

218-
ggml_tensor * lin_solve = ggml_solve_tri(ctx, lhs, attn, true, true, false);
219-
attn = ggml_mul(ctx, lin_solve, causal_mask);
220-
attn = ggml_add(ctx, attn, identity);
214+
ggml_tensor * lin_solve = ggml_solve_tri(ctx, lhs, attn, true, true, false);
215+
attn = ggml_mul(ctx, lin_solve, causal_mask);
216+
attn = ggml_add(ctx, attn, identity);
221217

222218
// value = attn @ v_beta
223219
v = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx0, v_beta)), attn);
@@ -361,11 +357,11 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor *
361357
Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
362358
// Split Q projection into query and gate
363359
// The split should be along dimension 0 (the feature dimension)
364-
struct ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1],
365-
Qcur_full->nb[2], Qcur_full->nb[3], 0);
360+
struct ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
361+
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
366362
struct ggml_tensor * gate =
367-
ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2],
368-
Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
363+
ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
364+
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
369365
cb(Qcur, "Qcur", il);
370366
cb(gate, "gate", il);
371367

@@ -395,11 +391,15 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor *
395391
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
396392

397393
// Apply RoPE
398-
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
399-
attn_factor, beta_fast, beta_slow);
394+
Qcur = ggml_rope_ext(
395+
ctx0, Qcur, inp_pos, nullptr,
396+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
397+
ext_factor, attn_factor, beta_fast, beta_slow);
400398

401-
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
402-
attn_factor, beta_fast, beta_slow);
399+
Kcur = ggml_rope_ext(
400+
ctx0, Kcur, inp_pos, nullptr,
401+
n_rot, rope_type, n_ctx_orig, freq_base,
402+
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
403403

404404
cb(Qcur, "Qcur", il);
405405
cb(Kcur, "Kcur", il);
@@ -408,7 +408,9 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor *
408408
// Attention computation
409409
const float kq_scale =
410410
hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
411-
cur = build_attn(inp_attn, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
411+
cur = build_attn(inp_attn,
412+
nullptr, nullptr,
413+
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
412414
cb(cur, "attn_pregate", il);
413415

414416
struct ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
@@ -716,8 +718,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llam
716718
// Add shared experts if present - following Qwen3Next reference implementation
717719
if (model.layers[il].ffn_up_shexp != nullptr) {
718720
ggml_tensor * ffn_shexp =
719-
build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL,
720-
model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
721+
build_ffn(cur,
722+
model.layers[il].ffn_up_shexp, NULL, NULL,
723+
model.layers[il].ffn_gate_shexp, NULL, NULL,
724+
model.layers[il].ffn_down_shexp, NULL, NULL,
725+
NULL,
726+
LLM_FFN_SILU, LLM_FFN_PAR, il);
721727
cb(ffn_shexp, "ffn_shexp", il);
722728

723729
// Apply shared expert gating as in the reference implementation
@@ -747,8 +753,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llam
747753
}
748754
} else {
749755
// Dense FFN branch (not currently used I believe)
750-
cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
751-
model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
756+
cur = build_ffn(cur,
757+
model.layers[il].ffn_up, NULL, NULL,
758+
model.layers[il].ffn_gate, NULL, NULL,
759+
model.layers[il].ffn_down, NULL, NULL,
760+
NULL,
761+
LLM_FFN_SILU, LLM_FFN_PAR, il);
752762
cb(cur, "ffn_out", il);
753763
}
754764
return cur;

0 commit comments

Comments
 (0)