Skip to content

Commit 364a9f2

Browse files
bzgooglebzgoogle
authored andcommitted
Fix dtype bug of weight_loading. It not only occurs for sparsematmul, but densematmul
1 parent 1c36c7d commit 364a9f2

File tree

2 files changed

+23
-19
lines changed

2 files changed

+23
-19
lines changed

tpu_commons/models/jax/common/moe/deepseek_moe.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ class SparseMoE(MoE):
136136
# TODO: determine if we get it from external or extrat it in MoE class
137137
is_batch_sharded_by_expert: True if batch is sharded over 'expert' dim.
138138
"""
139-
def_sharding: Sharding
140-
fed_sharding: Sharding
139+
edf_sharding: Sharding
140+
efd_sharding: Sharding
141141
num_experts_per_tok: int
142142
#TODO: tile size is (tile_batch_seq, tile_activation_dim, tile_weight_dim,) from MaxText
143143
tile_size: tuple[int, int, int] = (128, 64, 128)
@@ -155,24 +155,24 @@ def __post_init__(self, rngs: nnx.Rngs):
155155
shape_up = (self.num_local_experts, D, F)
156156
shape_down = (self.num_local_experts, F, D)
157157

158-
self.kernel_gating_DEF = create_param(rngs,
158+
self.kernel_gating_EDF = create_param(rngs,
159159
shape=shape_gating,
160160
dtype=self.dtype,
161-
sharding=self.def_sharding,
161+
sharding=self.edf_sharding,
162162
random_init=self.random_init)
163-
self.kernel_up_proj_DEF = create_param(rngs,
163+
self.kernel_up_proj_EDF = create_param(rngs,
164164
shape=shape_up,
165165
dtype=self.dtype,
166-
sharding=self.def_sharding,
166+
sharding=self.edf_sharding,
167167
random_init=self.random_init)
168-
self.kernel_down_proj_FED = create_param(rngs,
168+
self.kernel_down_proj_EFD = create_param(rngs,
169169
shape=shape_down,
170170
dtype=self.dtype,
171-
sharding=self.fed_sharding,
171+
sharding=self.efd_sharding,
172172
random_init=self.random_init)
173173

174174
# Derive the expert sharding
175-
self.expert_axis_name = self.def_sharding[0]
175+
self.expert_axis_name = self.edf_sharding[0]
176176
if self.expert_axis_name is None:
177177
self.num_expert_parallelism = 1
178178
else:
@@ -597,10 +597,10 @@ def __call__(self, x_TD: Float):
597597
PartitionSpec(*self.activation_ffw_td), # Sharded x_TD
598598
PartitionSpec(), # Replicated router_weights_TX
599599
PartitionSpec(), # Replicated selected_experts_TX
600-
PartitionSpec(*self.def_sharding), # Sharded gating kernel
601-
PartitionSpec(*self.def_sharding), # Sharded up-projection kernel
600+
PartitionSpec(*self.edf_sharding), # Sharded gating kernel
601+
PartitionSpec(*self.edf_sharding), # Sharded up-projection kernel
602602
PartitionSpec(
603-
*self.fed_sharding), # Sharded down-projection kernel
603+
*self.efd_sharding), # Sharded down-projection kernel
604604
)
605605
out_specs = PartitionSpec(*self.activation_ffw_td)
606606

@@ -616,7 +616,7 @@ def __call__(self, x_TD: Float):
616616
x_TD,
617617
router_weights_TX,
618618
selected_experts_TX,
619-
self.kernel_gating_DEF.value,
620-
self.kernel_up_proj_DEF.value,
621-
self.kernel_down_proj_FED.value,
619+
self.kernel_gating_EDF.value,
620+
self.kernel_up_proj_EDF.value,
621+
self.kernel_down_proj_EFD.value,
622622
)

tpu_inference/models/jax/deepseek_v3.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ def _create_mla() -> MLA:
218218
random_init=self.random_init,
219219
activation_ffw_td=('data', 'model'),
220220
activation_ffw_ted=('data', None, 'model'),
221-
def_sharding=(None , 'model', 'expert'),
222-
fed_sharding=(None , 'expert', 'model'),
221+
edf_sharding=(None , 'model', 'expert'),
222+
efd_sharding=(None , 'expert', 'model'),
223223
router=router) if is_moe_layer else DenseFFW(
224224
dtype=dtype,
225225
hidden_act=hidden_act,
@@ -363,7 +363,10 @@ def __init__(self, vllm_config: VllmConfig, num_layers, hidden_size,
363363
"is_verbose", None) is not None
364364
self.num_routed_experts = num_local_experts
365365
self.model_dtype = model_dtype
366+
<<<<<<< HEAD:tpu_inference/models/jax/deepseek_v3.py
366367

368+
=======
369+
>>>>>>> 641cb6d4 (Fix dtype bug of weight_loading. It not only occurs for sparsematmul, but densematmul):tpu_commons/models/jax/deepseek_v3.py
367370
self._transpose_map = {
368371
# dense mlp
369372
r"mlp\.down_proj": (1, 0),
@@ -827,9 +830,10 @@ def load_weights(self, model_for_loading: nnx.Module):
827830

828831
def weights_dequant_cpu(x: torch.Tensor,
829832
s: torch.Tensor,
830-
output_dtype: jnp.dtype,
833+
output_dtype: torch.dtype,
831834
block_size: int = 128) -> torch.Tensor:
832835
assert x.dim() == 2 and s.dim() == 2, "Both x and s must be 2D tensors"
836+
torch_output_type = DTYPE_VIEW_MAP.get(jnp.dtype(output_dtype))
833837
M, N = x.shape
834838

835839
x = x.to(torch.float32)
@@ -863,4 +867,4 @@ def weights_dequant_cpu(x: torch.Tensor,
863867
scale = s[M // block_size, j // block_size]
864868
y[M_main:M, j:j + block_size] = block * scale
865869

866-
return y.to(j2t_dtype(jnp.dtype(output_dtype)))
870+
return y.to(torch_output_type)

0 commit comments

Comments
 (0)