Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4183,6 +4183,36 @@ def set_vocab(self):
super().set_vocab()


@ModelBase.register("Qwen3NextForCausalLM")
class Qwen3NextModel(Qwen3MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3NEXT

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["linear_conv_kernel_dim"]))
self.gguf_writer.add_ssm_state_size(self.find_hparam(["linear_key_head_dim"]))
self.gguf_writer.add_ssm_group_count(self.find_hparam(["linear_num_key_heads"]))
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["linear_num_value_heads"]))
self.gguf_writer.add_ssm_inner_size(self.find_hparam(['linear_value_head_dim']) * self.find_hparam(['linear_num_value_heads']))
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("mtp"):
return [] # ignore MTP layers for now
if name.endswith(".A_log"):
data_torch = -torch.exp(data_torch)
elif name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
elif "conv1d" in name:
data_torch = data_torch.squeeze()
elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"):
data_torch = data_torch + 1

yield from Qwen2MoeModel.modify_tensors(self, data_torch, name, bid)


@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
class Qwen3VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ set -e

# First try command line argument, then environment variable, then file
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"

if [ -z "$MODEL_TESTING_PROMPT"]; then
MODEL_TESTING_PROMPT="Hello, my name is"
fi

# Final check if we have a model path
if [ -z "$CONVERTED_MODEL" ]; then
Expand All @@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then
fi

echo $CONVERTED_MODEL
echo $MODEL_TESTING_PROMPT

cmake --build ../../build --target llama-logits -j8

../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is"
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"
8 changes: 6 additions & 2 deletions examples/model-conversion/scripts/causal/run-org-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,12 @@ def fn(_m, input, output):
# of using AutoModelForCausalLM.
print(f"Model class: {model.__class__.__name__}")

prompt = "Hello, my name is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
device = next(model.parameters()).device
if os.getenv("MODEL_TESTING_PROMPT"):
prompt = os.getenv("MODEL_TESTING_PROMPT")
else:
prompt = "Hello, my name is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

print(f"Input tokens: {input_ids}")
print(f"Input text: {repr(prompt)}")
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ static bool ggml_is_view_op(enum ggml_op op) {
#endif

#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
#define GGML_SCHED_MAX_SPLIT_INPUTS 30
#define GGML_SCHED_MAX_SPLIT_INPUTS 60 // Qwen3 Next
#endif

#ifndef GGML_SCHED_MAX_COPIES
Expand Down
1 change: 0 additions & 1 deletion ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9698,7 +9698,6 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params
for (int64_t t = 0; t < i00; ++t) {
sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
}

const float diag = A_batch[i00 * n + i00];
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
Expand Down
33 changes: 33 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ class MODEL_ARCH(IntEnum):
QWEN2VL = auto()
QWEN3 = auto()
QWEN3MOE = auto()
QWEN3NEXT = auto()
QWEN3VL = auto()
QWEN3VLMOE = auto()
PHI2 = auto()
Expand Down Expand Up @@ -516,6 +517,7 @@ class MODEL_TENSOR(IntEnum):
SSM_D = auto()
SSM_NORM = auto()
SSM_OUT = auto()
SSM_BETA_ALPHA = auto() # qwen3next
TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto()
TIME_MIX_W2 = auto()
Expand Down Expand Up @@ -721,6 +723,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.QWEN2VL: "qwen2vl",
MODEL_ARCH.QWEN3: "qwen3",
MODEL_ARCH.QWEN3MOE: "qwen3moe",
MODEL_ARCH.QWEN3NEXT: "qwen3next",
MODEL_ARCH.QWEN3VL: "qwen3vl",
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
MODEL_ARCH.PHI2: "phi2",
Expand Down Expand Up @@ -884,6 +887,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
Expand Down Expand Up @@ -1553,6 +1557,35 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.QWEN3NEXT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_BETA_ALPHA,
MODEL_TENSOR.SSM_OUT
],
MODEL_ARCH.QWEN3VL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
22 changes: 16 additions & 6 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,17 +672,19 @@ class TensorNameMap:
),

MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj", # mamba-hf
"backbone.layers.{bid}.mixer.in_proj", # mamba
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
"model.layers.{bid}.in_proj", # mamba-hf
"backbone.layers.{bid}.mixer.in_proj", # mamba
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
"model.layers.{bid}.linear_attn.in_proj_qkvz", # qwen3next
),

MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d", # mamba-hf
"backbone.layers.{bid}.mixer.conv1d", # mamba
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.conv1d", # plamo2
"model.layers.{bid}.linear_attn.conv1d", # qwen3next
),

MODEL_TENSOR.SSM_X: (
Expand All @@ -697,6 +699,7 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.dt_proj", # mamba
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
"model.layers.{bid}.linear_attn.dt_proj", # qwen3next
),

MODEL_TENSOR.SSM_DT_NORM: (
Expand All @@ -709,6 +712,7 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.A_log", # mamba
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.A_log", # plamo2
"model.layers.{bid}.linear_attn.A_log", # qwen3next
),

MODEL_TENSOR.SSM_B_NORM: (
Expand All @@ -731,17 +735,23 @@ class TensorNameMap:
),

MODEL_TENSOR.SSM_NORM: (
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
"backbone.layers.{bid}.mixer.norm", # mamba2
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
"model.layers.{bid}.linear_attn.norm", # qwen3next
"backbone.layers.{bid}.mixer.norm", # mamba2
),

MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj", # mamba-hf
"backbone.layers.{bid}.mixer.out_proj", # mamba
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid
"model.layers.{bid}.linear_attn.out_proj", # qwen3next
"model.layers.layers.{bid}.mixer.out_proj", # plamo2
),

MODEL_TENSOR.SSM_BETA_ALPHA: (
"model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next
),

MODEL_TENSOR.TIME_MIX_W0: (
"model.layers.{bid}.attention.w0", # rwkv7
),
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ add_library(llama
models/qwen3vl.cpp
models/qwen3vl-moe.cpp
models/qwen3moe.cpp
models/qwen3next.cpp
models/refact.cpp
models/rwkv6-base.cpp
models/rwkv6.cpp
Expand Down
35 changes: 35 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_QWEN3, "qwen3" },
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_PHI2, "phi2" },
Expand Down Expand Up @@ -816,6 +817,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_QWEN3NEXT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_QWEN3VL,
{
Expand Down Expand Up @@ -2513,6 +2546,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Expand Down Expand Up @@ -2711,6 +2745,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
case LLM_ARCH_NEMOTRON_H:
case LLM_ARCH_QWEN3NEXT:
return true;
default:
return false;
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ enum llm_arch {
LLM_ARCH_QWEN2VL,
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_QWEN3NEXT,
LLM_ARCH_QWEN3VL,
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_PHI2,
Expand Down Expand Up @@ -368,6 +369,7 @@ enum llm_tensor {
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next
LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1,
LLM_TENSOR_TIME_MIX_W2,
Expand Down
4 changes: 4 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "llama-context.h"

#include "llama-arch.h"
#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-io.h"
Expand Down Expand Up @@ -1386,6 +1387,9 @@ void llama_context::output_reorder() {
//

uint32_t llama_context::graph_max_nodes() const {
if (model.arch == LLM_ARCH_QWEN3NEXT) {
return std::max<uint32_t>(8192u, 32u*model.n_tensors());
}
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
}

Expand Down
2 changes: 1 addition & 1 deletion src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

// bump if necessary
#define LLAMA_MAX_LAYERS 512
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next

enum llama_expert_gating_func_type {
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
Expand Down
Loading
Loading