Skip to content

Commit 46f5523

Browse files
wip
1 parent f0abc8b commit 46f5523

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

tpu_inference/layers/vllm/fused_moe.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def tensor_sharded_gmm_merged_column_parallel(
103103
rhs: jax.Array,
104104
rhs_bias: jax.Array | None,
105105
group_sizes: jax.Array,
106+
group_sizes_global: jax.Array,
106107
transpose_rhs: bool,
107108
mesh: Mesh,
108109
intermediate_size: int,
@@ -128,9 +129,9 @@ def tensor_sharded_gmm_merged_column_parallel(
128129
check_rep=False,
129130
)(lhs, rhs, group_sizes)
130131

132+
131133
if rhs_bias is not None:
132-
rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
133-
# Maybe need to add sharding constraint here
134+
rhs_bis = jnp.repeat(rhs_bias, group_sizes_global, 0, total_repeat_length=m)
134135
gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)
135136

136137
n_shards = mesh.shape['model'] * mesh.shape.get('attn_dp', 1)
@@ -145,6 +146,7 @@ def tensor_sharded_gmm_row_parallel(
145146
rhs: jax.Array,
146147
rhs_bias: jax.Array | None,
147148
group_sizes: jax.Array,
149+
group_sizes_global: jax.Array,
148150
transpose_rhs: bool,
149151
mesh: Mesh,
150152
) -> jax.Array:
@@ -173,11 +175,9 @@ def _gmm_all_reduce(lhs, rhs, group_sizes):
173175
out_specs=(P(ShardingAxisName.MLP_DATA)),
174176
check_rep=False,
175177
)(lhs, rhs, group_sizes)
176-
178+
jax.debug.print("gmm_result before bias {} {}", gmm_result.sum(), gmm_result.ravel()[:10])
177179
if rhs_bias is not None:
178-
179-
rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
180-
# wenxindong: Maybe need to add sharding constraint here
180+
rhs_bias = jnp.repeat(rhs_bias, group_sizes_global, 0, total_repeat_length=m)
181181
gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
182182

183183
return gmm_result
@@ -365,6 +365,7 @@ def fused_moe_func(
365365
gating_output = jax.lax.with_sharding_constraint(
366366
gating_output, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA, None)))
367367

368+
jax.debug.print("hidden_state before MoE {} {}", hidden_states.sum(), hidden_states.ravel()[:10])
368369
hidden_states = hidden_states.reshape(num_tokens, hidden_size)
369370
gating_output = gating_output.reshape(num_tokens, global_num_experts)
370371

@@ -381,19 +382,25 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
381382
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
382383
token_indices = jnp.arange(num_tokens_local, dtype=jnp.int32).repeat(topk)
383384
token_indices_sorted = token_indices[topk_argsort_indices]
384-
group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts)
385+
group_sizes_local = jnp.bincount(topk_indices_flat, length=global_num_experts)
386+
387+
# Reduce group_sizes once across data parallel shards to get global counts
388+
# This is needed for bias addition and should be done only once for efficiency
389+
group_sizes_global = jax.lax.psum(group_sizes_local, axis_name=ShardingAxisName.MLP_DATA)
385390

386391
x = hidden_states_local[token_indices_sorted]
387-
return x, group_sizes, topk_argsort_revert_indices
392+
return x, group_sizes_local, group_sizes_global, topk_argsort_revert_indices
388393

389-
x, group_sizes, topk_argsort_revert_indices = shard_map(
394+
x, group_sizes, group_sizes_global, topk_argsort_revert_indices = shard_map(
390395
_process_tokens_locally,
391396
mesh=mesh,
392397
in_specs=(P(ShardingAxisName.ATTN_DATA, None), P(ShardingAxisName.ATTN_DATA, None)),
393-
out_specs=(P(ShardingAxisName.ATTN_DATA, None), P(), P(ShardingAxisName.ATTN_DATA)),
398+
out_specs=(P(ShardingAxisName.ATTN_DATA, None), P(ShardingAxisName.ATTN_DATA), P(), P(ShardingAxisName.ATTN_DATA)),
394399
check_rep=False,
395400
)(hidden_states, topk_indices)
396-
401+
402+
jax.debug.print("hidden_state before gmm {} {}", x.sum(), x.ravel()[:10])
403+
jax.debug.print("group_sizes {} {}", group_sizes.sum(), group_sizes)
397404
if use_ep:
398405
x = expert_sharded_gmm(
399406
x,
@@ -411,13 +418,16 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
411418
w1,
412419
w1_bias,
413420
group_sizes,
421+
group_sizes_global,
414422
transpose_rhs=True,
415423
mesh=mesh,
416424
intermediate_size=intermediate_size,
417425
)
426+
jax.debug.print("hidden_state after first gmm x1 {} {}", x1.sum(), x1.ravel()[:10])
427+
jax.debug.print("hidden_state after first gmm x2 {} {}", x2.sum(), x2.ravel()[:10])
418428

419429
x = activation_fn(activation, x1, x2)
420-
430+
jax.debug.print("hidden_state after activation {} {}", x.sum(), x.ravel()[:10])
421431
if use_ep:
422432
x = expert_sharded_gmm(
423433
x,
@@ -436,9 +446,11 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
436446
w2,
437447
w2_bias,
438448
group_sizes,
449+
group_sizes_global,
439450
transpose_rhs=True,
440451
mesh=mesh,
441452
)
453+
jax.debug.print("hidden_state after second gmm {} {}", x.sum(), x.ravel()[:10])
442454

443455
def _finalize_output(x_local, topk_argsort_revert_indices_local, topk_weights_local):
444456
x_local = x_local[topk_argsort_revert_indices_local].reshape(-1, topk, hidden_size)
@@ -453,11 +465,12 @@ def _finalize_output(x_local, topk_argsort_revert_indices_local, topk_weights_lo
453465
out_specs=(P(ShardingAxisName.ATTN_DATA, None)),
454466
check_rep=False,
455467
)(x, topk_argsort_revert_indices, topk_weights)
456-
468+
jax.debug.print("hidden_state after finalize output {} {}", x.sum(), x.ravel()[:10])
457469
x = x.reshape(orig_shape)
458470

459471
if reduce_results:
460472
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA)))
473+
jax.debug.print("hidden_state after reducing result {} {}", x.sum(), x.ravel()[:10])
461474
return x
462475

463476

0 commit comments

Comments
 (0)