diff --git a/qwen3/qwen3_jax/chkpt_utils.py b/qwen3/qwen3_jax/chkpt_utils.py index 3d2e658..7fcaecb 100644 --- a/qwen3/qwen3_jax/chkpt_utils.py +++ b/qwen3/qwen3_jax/chkpt_utils.py @@ -207,6 +207,12 @@ def convert_model_or_layer( } new_params = {k: None for k in layer_params.keys()} + # some qwen3 checkpoints store both 'embed_tokens' and 'lm_head' even if the embeddings are tied + # in this case, we check that the weights are identical and delete 'lm_head' + if cfg.tie_embed and 'lm_head.weight' in torch_params: + torch.testing.assert_close(torch_params['lm_head.weight'], torch_params['model.embed_tokens.weight']) + del torch_params['lm_head.weight'] + def convert_weight_thread(tkey, tweight): with jax.default_device(device): jweight = convert_weight(tkey, _map_weight(tkey, tweight, custom_transform_map=custom_transform_map), cfg) diff --git a/qwen3/qwen3_jax/model.py b/qwen3/qwen3_jax/model.py index 742dc44..0aa500a 100644 --- a/qwen3/qwen3_jax/model.py +++ b/qwen3/qwen3_jax/model.py @@ -20,7 +20,7 @@ from pathlib import Path import math from functools import partial -from typing import Callable, Any +from typing import Callable, Any, Optional import jax import jax.numpy as jnp @@ -134,6 +134,7 @@ class Config: head_dim: int vocab_size: int max_seq_len: int + tie_embed: bool # Attention causal: bool # MoE @@ -176,6 +177,7 @@ def hf_to_jax_config(hf_config: Any | dict[str, Any]) -> "Config": num_layers=_get(hf_config, "num_hidden_layers"), head_dim=_get(hf_config, "head_dim"), vocab_size=_get(hf_config, "vocab_size"), + tie_embed=_get(hf_config, "tie_word_embeddings"), norm_eps=_get(hf_config, "rms_norm_eps"), moe_experts_per_tok=_get(hf_config, "num_experts_per_tok"), moe_num_experts=_get(hf_config, "num_experts"), @@ -497,7 +499,7 @@ class Weights(_Init): layers: list[Layer] embedding: jax.Array | ArrayInfo gamma_final: jax.Array | ArrayInfo - lm_head: jax.Array | ArrayInfo + lm_head: Optional[jax.Array | ArrayInfo] @classmethod def abstract(cls, cfg: Config): @@ -507,7 +509,7 @@ def abstract(cls, cfg: Config): layers=layers, embedding=ArrayInfo((cfg.vocab_size, cfg.embed), cfg.dtype, ("vocab_in", "vocab_in"), init(0, 1)), gamma_final=ArrayInfo((cfg.embed,), cfg.dtype, ("act_embed",), jax.nn.initializers.constant(1.0)), - lm_head=ArrayInfo((cfg.embed, cfg.vocab_size), cfg.dtype, ("vocab_in", "vocab_out"), init(1, 0)), + lm_head=None if cfg.tie_embed else ArrayInfo((cfg.embed, cfg.vocab_size), cfg.dtype, ("vocab_in", "vocab_out"), init(1, 0)), ) @@ -1082,7 +1084,8 @@ def forward(x: jax.Array, segment_ids: jax.Array, weights: Weights, cfg: Config, all_cache_updates.append(cache_updates) x = rms_norm(x, weights.gamma_final, cfg.norm_eps) # Final layer norm. - logits = einsum("btd,dv->btv", x, weights.lm_head) # Project to vocabulary size + head_weights = weights.embedding.T if cfg.tie_embed else weights.lm_head + logits = einsum("btd,dv->btv", x, head_weights) # Project to vocabulary size if is_type(cache, KVCache): cache.k, cache.v = [[z[i] for z in all_cache_updates] for i in range(2)] additional_tokens = jnp.max(_length_minus_right_padding(segment_ids))