Skip to content

Commit e37a89b

Browse files
wip
1 parent f29cdf9 commit e37a89b

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

tpu_inference/layers/vllm/fused_moe.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def tensor_sharded_gmm_merged_column_parallel(
103103
rhs: jax.Array,
104104
rhs_bias: jax.Array | None,
105105
group_sizes: jax.Array,
106-
group_sizes_global: jax.Array,
107106
transpose_rhs: bool,
108107
mesh: Mesh,
109108
intermediate_size: int,
@@ -129,10 +128,18 @@ def tensor_sharded_gmm_merged_column_parallel(
129128
check_rep=False,
130129
)(lhs, rhs, group_sizes)
131130

132-
133131
if rhs_bias is not None:
134-
rhs_bis = jnp.repeat(rhs_bias, group_sizes_global, 0, total_repeat_length=m)
135-
gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)
132+
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
133+
rhs_bis = jnp.repeat(rhs_bias_local, group_sizes_global, 0, total_repeat_length=m//mesh.shape["data"])
134+
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
135+
136+
gmm_result = shard_map(
137+
_add_bias,
138+
mesh=mesh,
139+
in_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR), P(None, ShardingAxisName.MLP_TENSOR), P(ShardingAxisName.MLP_DATA)),
140+
out_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR)),
141+
check_rep=False,
142+
)(gmm_result, rhs_bias, group_sizes)
136143

137144
n_shards = mesh.shape['model'] * mesh.shape.get('attn_dp', 1)
138145
output_sizes = [intermediate_size, intermediate_size]
@@ -146,7 +153,6 @@ def tensor_sharded_gmm_row_parallel(
146153
rhs: jax.Array,
147154
rhs_bias: jax.Array | None,
148155
group_sizes: jax.Array,
149-
group_sizes_global: jax.Array,
150156
transpose_rhs: bool,
151157
mesh: Mesh,
152158
) -> jax.Array:
@@ -177,8 +183,17 @@ def _gmm_all_reduce(lhs, rhs, group_sizes):
177183
)(lhs, rhs, group_sizes)
178184
# jax.debug.print("gmm_result before bias {} {}", gmm_result.sum(), gmm_result.ravel()[:10])
179185
if rhs_bias is not None:
180-
rhs_bias = jnp.repeat(rhs_bias, group_sizes_global, 0, total_repeat_length=m)
181-
gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
186+
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
187+
rhs_bis = jnp.repeat(rhs_bias_local, group_sizes_global, 0, total_repeat_length=m//mesh.shape["data"])
188+
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
189+
190+
gmm_result = shard_map(
191+
_add_bias,
192+
mesh=mesh,
193+
in_specs=(P(ShardingAxisName.MLP_DATA), P(), P(ShardingAxisName.MLP_DATA)),
194+
out_specs=(P(ShardingAxisName.MLP_DATA)),
195+
check_rep=False,
196+
)(gmm_result, rhs_bias, group_sizes)
182197

183198
return gmm_result
184199

@@ -359,13 +374,6 @@ def fused_moe_func(
359374
assert (num_tokens * topk) % 16 == 0, (
360375
"The kernel requires num_tokens * topk to be a multiple of "
361376
f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
362-
hidden_states = jax.lax.with_sharding_constraint(
363-
hidden_states, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA, None)))
364-
365-
gating_output = jax.lax.with_sharding_constraint(
366-
gating_output, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA, None)))
367-
368-
# jax.debug.print("hidden_state before MoE {} {}", hidden_states.sum(), hidden_states.ravel()[:10])
369377
hidden_states = hidden_states.reshape(num_tokens, hidden_size)
370378
gating_output = gating_output.reshape(num_tokens, global_num_experts)
371379

@@ -383,19 +391,15 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
383391
token_indices = jnp.arange(num_tokens_local, dtype=jnp.int32).repeat(topk)
384392
token_indices_sorted = token_indices[topk_argsort_indices]
385393
group_sizes_local = jnp.bincount(topk_indices_flat, length=global_num_experts)
386-
387-
# Reduce group_sizes once across data parallel shards to get global counts
388-
# This is needed for bias addition and should be done only once for efficiency
389-
group_sizes_global = jax.lax.psum(group_sizes_local, axis_name=ShardingAxisName.ATTN_DATA)
390-
394+
391395
x = hidden_states_local[token_indices_sorted]
392-
return x, group_sizes_local, group_sizes_global, topk_argsort_revert_indices
396+
return x, group_sizes_local, topk_argsort_revert_indices
393397

394-
x, group_sizes, group_sizes_global, topk_argsort_revert_indices = shard_map(
398+
x, group_sizes, topk_argsort_revert_indices = shard_map(
395399
_process_tokens_locally,
396400
mesh=mesh,
397401
in_specs=(P(ShardingAxisName.ATTN_DATA, None), P(ShardingAxisName.ATTN_DATA, None)),
398-
out_specs=(P(ShardingAxisName.ATTN_DATA, None), P(ShardingAxisName.ATTN_DATA), P(), P(ShardingAxisName.ATTN_DATA)),
402+
out_specs=(P(ShardingAxisName.ATTN_DATA, None), P(ShardingAxisName.ATTN_DATA), P(ShardingAxisName.ATTN_DATA)),
399403
check_rep=False,
400404
)(hidden_states, topk_indices)
401405

@@ -418,7 +422,6 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
418422
w1,
419423
w1_bias,
420424
group_sizes,
421-
group_sizes_global,
422425
transpose_rhs=True,
423426
mesh=mesh,
424427
intermediate_size=intermediate_size,
@@ -446,7 +449,6 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
446449
w2,
447450
w2_bias,
448451
group_sizes,
449-
group_sizes_global,
450452
transpose_rhs=True,
451453
mesh=mesh,
452454
)

0 commit comments

Comments
 (0)