Skip to content

Commit 022a924

Browse files
wip
1 parent a98a8d6 commit 022a924

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tpu_inference/layers/vllm/fused_moe.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -360,10 +360,10 @@ def fused_moe_func(
360360
"The kernel requires num_tokens * topk to be a multiple of "
361361
f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
362362
hidden_states = jax.lax.with_sharding_constraint(
363-
hidden_states, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
363+
hidden_states, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA, None)))
364364

365365
gating_output = jax.lax.with_sharding_constraint(
366-
gating_output, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
366+
gating_output, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA, None)))
367367

368368
hidden_states = hidden_states.reshape(num_tokens, hidden_size)
369369
gating_output = gating_output.reshape(num_tokens, global_num_experts)
@@ -389,8 +389,8 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
389389
x, group_sizes, topk_argsort_revert_indices = shard_map(
390390
_process_tokens_locally,
391391
mesh=mesh,
392-
in_specs=(P(ShardingAxisName.MLP_DATA, None), P(ShardingAxisName.MLP_DATA, None)),
393-
out_specs=(P(ShardingAxisName.MLP_DATA, None), P(), P(ShardingAxisName.MLP_DATA)),
392+
in_specs=(P(ShardingAxisName.ATTN_DATA, None), P(ShardingAxisName.ATTN_DATA, None)),
393+
out_specs=(P(ShardingAxisName.ATTN_DATA, None), P(), P(ShardingAxisName.ATTN_DATA)),
394394
check_rep=False,
395395
)(hidden_states, topk_indices)
396396

@@ -449,15 +449,15 @@ def _finalize_output(x_local, topk_argsort_revert_indices_local, topk_weights_lo
449449
x = shard_map(
450450
_finalize_output,
451451
mesh=mesh,
452-
in_specs=(P(ShardingAxisName.MLP_DATA, None), P(ShardingAxisName.MLP_DATA), P(ShardingAxisName.MLP_DATA, None)),
453-
out_specs=(P(ShardingAxisName.MLP_DATA, None)),
452+
in_specs=(P(ShardingAxisName.ATTN_DATA, None), P(ShardingAxisName.ATTN_DATA), P(ShardingAxisName.ATTN_DATA, None)),
453+
out_specs=(P(ShardingAxisName.ATTN_DATA, None)),
454454
check_rep=False,
455455
)(x, topk_argsort_revert_indices, topk_weights)
456456

457457
x = x.reshape(orig_shape)
458458

459459
if reduce_results:
460-
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA)))
460+
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA)))
461461
return x
462462

463463

0 commit comments

Comments
 (0)