Skip to content

Commit ca8596d

Browse files
author
bzgoogle
committed
bug fix after rebase
1 parent 364a9f2 commit ca8596d

File tree

2 files changed

+5
-8
lines changed

2 files changed

+5
-8
lines changed

tpu_inference/layers/jax/moe/deepseek_v3_moe.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
manually_quantize_qwix_activation, manually_quantize_qwix_weight)
2020

2121
modeling_flax_utils = FlaxUtils()
22-
22+
jax.config.update("jax_ragged_dot_use_ragged_dot_instruction", True),
2323

2424
@dataclass
2525
class DeepSeekV3Router(nnx.Module):
@@ -329,8 +329,9 @@ def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array,
329329
with jax.named_scope("unpermute"):
330330
unsorted_tokens_tD = self._sort_activations(
331331
processed_tokens, jnp.argsort(sort_indices))
332+
D = unsorted_tokens_tD.shape[-1]
332333
reshaped_tokens_TXD = unsorted_tokens_tD.reshape(
333-
-1, self.num_experts_per_tok, self.hidden_size)
334+
-1, self.num_experts_per_tok, D)
334335
with jax.named_scope("combine_weights"):
335336
output_TD = jnp.einsum(
336337
"TXD,TX -> TD",
@@ -394,10 +395,10 @@ def _distributed_sparse_moe_fwd(
394395

395396
# TODO: update to 'expert' after we enable expert parallelism, currently experts are sharded along model axis
396397
# or we sould derive it from the model init
397-
expert_shard_id = jax.lax.axis_index(self.expert_axis_name)
398-
local_expert_size = self.num_local_experts // self.num_expert_parallelism
399398

400399
if self.num_expert_parallelism > 1:
400+
expert_shard_id = jax.lax.axis_index(self.expert_axis_name)
401+
local_expert_size = self.num_local_experts // self.num_expert_parallelism
401402
if self.is_batch_sharded_by_expert:
402403
# When token sharded in devices
403404
# In this path, we assume the data(tokens) are fully sharded on expert, namely data_axis_name == expert_axis_name

tpu_inference/models/jax/deepseek_v3.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,6 @@ def __init__(self, vllm_config: VllmConfig, num_layers, hidden_size,
363363
"is_verbose", None) is not None
364364
self.num_routed_experts = num_local_experts
365365
self.model_dtype = model_dtype
366-
<<<<<<< HEAD:tpu_inference/models/jax/deepseek_v3.py
367-
368-
=======
369-
>>>>>>> 641cb6d4 (Fix dtype bug of weight_loading. It not only occurs for sparsematmul, but densematmul):tpu_commons/models/jax/deepseek_v3.py
370366
self._transpose_map = {
371367
# dense mlp
372368
r"mlp\.down_proj": (1, 0),

0 commit comments

Comments
 (0)