Skip to content

Commit 746f9ee

Browse files
pwilkinCISC
andauthored
Override SSM_A op for Qwen3 Next to reduce splits (#17587)
* Override SSM_A op for Qwen3 Next to reduce splits * New tensor mapping SSM_A_NOSCAN for SSM_A used outside of OP_SSM_SCAN context. * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent 9810cb8 commit 746f9ee

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

src/llama-arch.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
855855
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
856856
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
857857
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
858-
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
858+
{ LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" },
859859
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
860860
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
861861
{ LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" },
@@ -2639,6 +2639,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
26392639
{LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
26402640
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
26412641
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
2642+
{LLM_TENSOR_SSM_A_NOSCAN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // a version of SSM_A used for MUL instead of SSM_SCAN
26422643
{LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
26432644
{LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
26442645
{LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ enum llm_tensor {
379379
LLM_TENSOR_SSM_DT,
380380
LLM_TENSOR_SSM_DT_NORM,
381381
LLM_TENSOR_SSM_A,
382+
LLM_TENSOR_SSM_A_NOSCAN, // qwen3next special case with MUL instead of SSM_SCAN
382383
LLM_TENSOR_SSM_B_NORM,
383384
LLM_TENSOR_SSM_C_NORM,
384385
LLM_TENSOR_SSM_D,

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6526,7 +6526,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
65266526
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0);
65276527
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
65286528
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
6529-
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0);
6529+
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0);
65306530
layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0);
65316531
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
65326532
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);

0 commit comments

Comments
 (0)