|
19 | 19 | manually_quantize_qwix_activation, manually_quantize_qwix_weight) |
20 | 20 |
|
21 | 21 | modeling_flax_utils = FlaxUtils() |
22 | | - |
| 22 | +jax.config.update("jax_ragged_dot_use_ragged_dot_instruction", True), |
23 | 23 |
|
24 | 24 | @dataclass |
25 | 25 | class DeepSeekV3Router(nnx.Module): |
@@ -329,8 +329,9 @@ def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array, |
329 | 329 | with jax.named_scope("unpermute"): |
330 | 330 | unsorted_tokens_tD = self._sort_activations( |
331 | 331 | processed_tokens, jnp.argsort(sort_indices)) |
| 332 | + D = unsorted_tokens_tD.shape[-1] |
332 | 333 | reshaped_tokens_TXD = unsorted_tokens_tD.reshape( |
333 | | - -1, self.num_experts_per_tok, self.hidden_size) |
| 334 | + -1, self.num_experts_per_tok, D) |
334 | 335 | with jax.named_scope("combine_weights"): |
335 | 336 | output_TD = jnp.einsum( |
336 | 337 | "TXD,TX -> TD", |
@@ -394,10 +395,10 @@ def _distributed_sparse_moe_fwd( |
394 | 395 |
|
395 | 396 | # TODO: update to 'expert' after we enable expert parallelism, currently experts are sharded along model axis |
396 | 397 | # or we sould derive it from the model init |
397 | | - expert_shard_id = jax.lax.axis_index(self.expert_axis_name) |
398 | | - local_expert_size = self.num_local_experts // self.num_expert_parallelism |
399 | 398 |
|
400 | 399 | if self.num_expert_parallelism > 1: |
| 400 | + expert_shard_id = jax.lax.axis_index(self.expert_axis_name) |
| 401 | + local_expert_size = self.num_local_experts // self.num_expert_parallelism |
401 | 402 | if self.is_batch_sharded_by_expert: |
402 | 403 | # When token sharded in devices |
403 | 404 | # In this path, we assume the data(tokens) are fully sharded on expert, namely data_axis_name == expert_axis_name |
|
0 commit comments