Skip to content

Commit f29cdf9

Browse files
wip
1 parent b028301 commit f29cdf9

File tree

8 files changed

+9151
-11
lines changed

8 files changed

+9151
-11
lines changed

debug_attn_dp.txt

Lines changed: 233 additions & 0 deletions
Large diffs are not rendered by default.

debugging.txt

Lines changed: 423 additions & 0 deletions
Large diffs are not rendered by default.

debugging_baseline.txt

Lines changed: 732 additions & 0 deletions
Large diffs are not rendered by default.

oss_log_98.txt

Lines changed: 2508 additions & 0 deletions
Large diffs are not rendered by default.

oss_log_attn_dp.txt

Lines changed: 2564 additions & 0 deletions
Large diffs are not rendered by default.

oss_log_baseline.txt

Lines changed: 2410 additions & 0 deletions
Large diffs are not rendered by default.

oss_log_numerics.txt

Lines changed: 270 additions & 0 deletions
Large diffs are not rendered by default.

tpu_inference/layers/vllm/fused_moe.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def _gmm_all_reduce(lhs, rhs, group_sizes):
175175
out_specs=(P(ShardingAxisName.MLP_DATA)),
176176
check_rep=False,
177177
)(lhs, rhs, group_sizes)
178-
jax.debug.print("gmm_result before bias {} {}", gmm_result.sum(), gmm_result.ravel()[:10])
178+
# jax.debug.print("gmm_result before bias {} {}", gmm_result.sum(), gmm_result.ravel()[:10])
179179
if rhs_bias is not None:
180180
rhs_bias = jnp.repeat(rhs_bias, group_sizes_global, 0, total_repeat_length=m)
181181
gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
@@ -365,7 +365,7 @@ def fused_moe_func(
365365
gating_output = jax.lax.with_sharding_constraint(
366366
gating_output, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA, None)))
367367

368-
jax.debug.print("hidden_state before MoE {} {}", hidden_states.sum(), hidden_states.ravel()[:10])
368+
# jax.debug.print("hidden_state before MoE {} {}", hidden_states.sum(), hidden_states.ravel()[:10])
369369
hidden_states = hidden_states.reshape(num_tokens, hidden_size)
370370
gating_output = gating_output.reshape(num_tokens, global_num_experts)
371371

@@ -386,7 +386,7 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
386386

387387
# Reduce group_sizes once across data parallel shards to get global counts
388388
# This is needed for bias addition and should be done only once for efficiency
389-
group_sizes_global = jax.lax.psum(group_sizes_local, axis_name=ShardingAxisName.MLP_DATA)
389+
group_sizes_global = jax.lax.psum(group_sizes_local, axis_name=ShardingAxisName.ATTN_DATA)
390390

391391
x = hidden_states_local[token_indices_sorted]
392392
return x, group_sizes_local, group_sizes_global, topk_argsort_revert_indices
@@ -399,8 +399,8 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
399399
check_rep=False,
400400
)(hidden_states, topk_indices)
401401

402-
jax.debug.print("hidden_state before gmm {} {}", x.sum(), x.ravel()[:10])
403-
jax.debug.print("group_sizes {} {}", group_sizes.sum(), group_sizes)
402+
# jax.debug.print("hidden_state before gmm {} {}", x.sum(), x.ravel()[:10])
403+
# jax.debug.print("group_sizes {} {}", group_sizes.sum(), group_sizes)
404404
if use_ep:
405405
x = expert_sharded_gmm(
406406
x,
@@ -423,11 +423,11 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
423423
mesh=mesh,
424424
intermediate_size=intermediate_size,
425425
)
426-
jax.debug.print("hidden_state after first gmm x1 {} {}", x1.sum(), x1.ravel()[:10])
427-
jax.debug.print("hidden_state after first gmm x2 {} {}", x2.sum(), x2.ravel()[:10])
426+
# jax.debug.print("hidden_state after first gmm x1 {} {}", x1.sum(), x1.ravel()[:10])
427+
# jax.debug.print("hidden_state after first gmm x2 {} {}", x2.sum(), x2.ravel()[:10])
428428

429429
x = activation_fn(activation, x1, x2)
430-
jax.debug.print("hidden_state after activation {} {}", x.sum(), x.ravel()[:10])
430+
# jax.debug.print("hidden_state after activation {} {}", x.sum(), x.ravel()[:10])
431431
if use_ep:
432432
x = expert_sharded_gmm(
433433
x,
@@ -450,7 +450,7 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
450450
transpose_rhs=True,
451451
mesh=mesh,
452452
)
453-
jax.debug.print("hidden_state after second gmm {} {}", x.sum(), x.ravel()[:10])
453+
# jax.debug.print("hidden_state after second gmm {} {}", x.sum(), x.ravel()[:10])
454454

455455
def _finalize_output(x_local, topk_argsort_revert_indices_local, topk_weights_local):
456456
x_local = x_local[topk_argsort_revert_indices_local].reshape(-1, topk, hidden_size)
@@ -465,12 +465,12 @@ def _finalize_output(x_local, topk_argsort_revert_indices_local, topk_weights_lo
465465
out_specs=(P(ShardingAxisName.ATTN_DATA, None)),
466466
check_rep=False,
467467
)(x, topk_argsort_revert_indices, topk_weights)
468-
jax.debug.print("hidden_state after finalize output {} {}", x.sum(), x.ravel()[:10])
468+
# jax.debug.print("hidden_state after finalize output {} {}", x.sum(), x.ravel()[:10])
469469
x = x.reshape(orig_shape)
470470

471471
if reduce_results:
472472
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA)))
473-
jax.debug.print("hidden_state after reducing result {} {}", x.sum(), x.ravel()[:10])
473+
# jax.debug.print("hidden_state after reducing result {} {}", x.sum(), x.ravel()[:10])
474474
return x
475475

476476

0 commit comments

Comments
 (0)