@@ -360,10 +360,10 @@ def fused_moe_func(
360360 "The kernel requires num_tokens * topk to be a multiple of "
361361 f"16 but got { num_tokens } *{ topk } ={ num_tokens * topk } " )
362362 hidden_states = jax .lax .with_sharding_constraint (
363- hidden_states , NamedSharding (mesh , P (ShardingAxisName .MLP_DATA , None )))
363+ hidden_states , NamedSharding (mesh , P (ShardingAxisName .ATTN_DATA , None )))
364364
365365 gating_output = jax .lax .with_sharding_constraint (
366- gating_output , NamedSharding (mesh , P (ShardingAxisName .MLP_DATA , None )))
366+ gating_output , NamedSharding (mesh , P (ShardingAxisName .ATTN_DATA , None )))
367367
368368 hidden_states = hidden_states .reshape (num_tokens , hidden_size )
369369 gating_output = gating_output .reshape (num_tokens , global_num_experts )
@@ -389,8 +389,8 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
389389 x , group_sizes , topk_argsort_revert_indices = shard_map (
390390 _process_tokens_locally ,
391391 mesh = mesh ,
392- in_specs = (P (ShardingAxisName .MLP_DATA , None ), P (ShardingAxisName .MLP_DATA , None )),
393- out_specs = (P (ShardingAxisName .MLP_DATA , None ), P (), P (ShardingAxisName .MLP_DATA )),
392+ in_specs = (P (ShardingAxisName .ATTN_DATA , None ), P (ShardingAxisName .ATTN_DATA , None )),
393+ out_specs = (P (ShardingAxisName .ATTN_DATA , None ), P (), P (ShardingAxisName .ATTN_DATA )),
394394 check_rep = False ,
395395 )(hidden_states , topk_indices )
396396
@@ -449,15 +449,15 @@ def _finalize_output(x_local, topk_argsort_revert_indices_local, topk_weights_lo
449449 x = shard_map (
450450 _finalize_output ,
451451 mesh = mesh ,
452- in_specs = (P (ShardingAxisName .MLP_DATA , None ), P (ShardingAxisName .MLP_DATA ), P (ShardingAxisName .MLP_DATA , None )),
453- out_specs = (P (ShardingAxisName .MLP_DATA , None )),
452+ in_specs = (P (ShardingAxisName .ATTN_DATA , None ), P (ShardingAxisName .ATTN_DATA ), P (ShardingAxisName .ATTN_DATA , None )),
453+ out_specs = (P (ShardingAxisName .ATTN_DATA , None )),
454454 check_rep = False ,
455455 )(x , topk_argsort_revert_indices , topk_weights )
456456
457457 x = x .reshape (orig_shape )
458458
459459 if reduce_results :
460- x = jax .lax .with_sharding_constraint (x , NamedSharding (mesh , P (ShardingAxisName .MLP_DATA )))
460+ x = jax .lax .with_sharding_constraint (x , NamedSharding (mesh , P (ShardingAxisName .ATTN_DATA )))
461461 return x
462462
463463
0 commit comments