Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions qwen3/qwen3_jax/chkpt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ def convert_model_or_layer(
}
new_params = {k: None for k in layer_params.keys()}

# some qwen3 checkpoints store both 'embed_tokens' and 'lm_head' even if the embeddings are tied
# in this case, we check that the weights are identical and delete 'lm_head'
if cfg.tie_embed and 'lm_head.weight' in torch_params:
torch.testing.assert_close(torch_params['lm_head.weight'], torch_params['model.embed_tokens.weight'])
del torch_params['lm_head.weight']

def convert_weight_thread(tkey, tweight):
with jax.default_device(device):
jweight = convert_weight(tkey, _map_weight(tkey, tweight, custom_transform_map=custom_transform_map), cfg)
Expand Down
11 changes: 7 additions & 4 deletions qwen3/qwen3_jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pathlib import Path
import math
from functools import partial
from typing import Callable, Any
from typing import Callable, Any, Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -134,6 +134,7 @@ class Config:
head_dim: int
vocab_size: int
max_seq_len: int
tie_embed: bool
# Attention
causal: bool
# MoE
Expand Down Expand Up @@ -176,6 +177,7 @@ def hf_to_jax_config(hf_config: Any | dict[str, Any]) -> "Config":
num_layers=_get(hf_config, "num_hidden_layers"),
head_dim=_get(hf_config, "head_dim"),
vocab_size=_get(hf_config, "vocab_size"),
tie_embed=_get(hf_config, "tie_word_embeddings"),
norm_eps=_get(hf_config, "rms_norm_eps"),
moe_experts_per_tok=_get(hf_config, "num_experts_per_tok"),
moe_num_experts=_get(hf_config, "num_experts"),
Expand Down Expand Up @@ -497,7 +499,7 @@ class Weights(_Init):
layers: list[Layer]
embedding: jax.Array | ArrayInfo
gamma_final: jax.Array | ArrayInfo
lm_head: jax.Array | ArrayInfo
lm_head: Optional[jax.Array | ArrayInfo]

@classmethod
def abstract(cls, cfg: Config):
Expand All @@ -507,7 +509,7 @@ def abstract(cls, cfg: Config):
layers=layers,
embedding=ArrayInfo((cfg.vocab_size, cfg.embed), cfg.dtype, ("vocab_in", "vocab_in"), init(0, 1)),
gamma_final=ArrayInfo((cfg.embed,), cfg.dtype, ("act_embed",), jax.nn.initializers.constant(1.0)),
lm_head=ArrayInfo((cfg.embed, cfg.vocab_size), cfg.dtype, ("vocab_in", "vocab_out"), init(1, 0)),
lm_head=None if cfg.tie_embed else ArrayInfo((cfg.embed, cfg.vocab_size), cfg.dtype, ("vocab_in", "vocab_out"), init(1, 0)),
)


Expand Down Expand Up @@ -1082,7 +1084,8 @@ def forward(x: jax.Array, segment_ids: jax.Array, weights: Weights, cfg: Config,
all_cache_updates.append(cache_updates)

x = rms_norm(x, weights.gamma_final, cfg.norm_eps) # Final layer norm.
logits = einsum("btd,dv->btv", x, weights.lm_head) # Project to vocabulary size
head_weights = weights.embedding.T if cfg.tie_embed else weights.lm_head
logits = einsum("btd,dv->btv", x, head_weights) # Project to vocabulary size
if is_type(cache, KVCache):
cache.k, cache.v = [[z[i] for z in all_cache_updates] for i in range(2)]
additional_tokens = jnp.max(_length_minus_right_padding(segment_ids))
Expand Down