@@ -175,7 +175,7 @@ def _gmm_all_reduce(lhs, rhs, group_sizes):
175175 out_specs = (P (ShardingAxisName .MLP_DATA )),
176176 check_rep = False ,
177177 )(lhs , rhs , group_sizes )
178- jax .debug .print ("gmm_result before bias {} {}" , gmm_result .sum (), gmm_result .ravel ()[:10 ])
178+ # jax.debug.print("gmm_result before bias {} {}", gmm_result.sum(), gmm_result.ravel()[:10])
179179 if rhs_bias is not None :
180180 rhs_bias = jnp .repeat (rhs_bias , group_sizes_global , 0 , total_repeat_length = m )
181181 gmm_result = (gmm_result + rhs_bias ).astype (gmm_result .dtype )
@@ -365,7 +365,7 @@ def fused_moe_func(
365365 gating_output = jax .lax .with_sharding_constraint (
366366 gating_output , NamedSharding (mesh , P (ShardingAxisName .ATTN_DATA , None )))
367367
368- jax .debug .print ("hidden_state before MoE {} {}" , hidden_states .sum (), hidden_states .ravel ()[:10 ])
368+ # jax.debug.print("hidden_state before MoE {} {}", hidden_states.sum(), hidden_states.ravel()[:10])
369369 hidden_states = hidden_states .reshape (num_tokens , hidden_size )
370370 gating_output = gating_output .reshape (num_tokens , global_num_experts )
371371
@@ -386,7 +386,7 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
386386
387387 # Reduce group_sizes once across data parallel shards to get global counts
388388 # This is needed for bias addition and should be done only once for efficiency
389- group_sizes_global = jax .lax .psum (group_sizes_local , axis_name = ShardingAxisName .MLP_DATA )
389+ group_sizes_global = jax .lax .psum (group_sizes_local , axis_name = ShardingAxisName .ATTN_DATA )
390390
391391 x = hidden_states_local [token_indices_sorted ]
392392 return x , group_sizes_local , group_sizes_global , topk_argsort_revert_indices
@@ -399,8 +399,8 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
399399 check_rep = False ,
400400 )(hidden_states , topk_indices )
401401
402- jax .debug .print ("hidden_state before gmm {} {}" , x .sum (), x .ravel ()[:10 ])
403- jax .debug .print ("group_sizes {} {}" , group_sizes .sum (), group_sizes )
402+ # jax.debug.print("hidden_state before gmm {} {}", x.sum(), x.ravel()[:10])
403+ # jax.debug.print("group_sizes {} {}", group_sizes.sum(), group_sizes)
404404 if use_ep :
405405 x = expert_sharded_gmm (
406406 x ,
@@ -423,11 +423,11 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
423423 mesh = mesh ,
424424 intermediate_size = intermediate_size ,
425425 )
426- jax .debug .print ("hidden_state after first gmm x1 {} {}" , x1 .sum (), x1 .ravel ()[:10 ])
427- jax .debug .print ("hidden_state after first gmm x2 {} {}" , x2 .sum (), x2 .ravel ()[:10 ])
426+ # jax.debug.print("hidden_state after first gmm x1 {} {}", x1.sum(), x1.ravel()[:10])
427+ # jax.debug.print("hidden_state after first gmm x2 {} {}", x2.sum(), x2.ravel()[:10])
428428
429429 x = activation_fn (activation , x1 , x2 )
430- jax .debug .print ("hidden_state after activation {} {}" , x .sum (), x .ravel ()[:10 ])
430+ # jax.debug.print("hidden_state after activation {} {}", x.sum(), x.ravel()[:10])
431431 if use_ep :
432432 x = expert_sharded_gmm (
433433 x ,
@@ -450,7 +450,7 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
450450 transpose_rhs = True ,
451451 mesh = mesh ,
452452 )
453- jax .debug .print ("hidden_state after second gmm {} {}" , x .sum (), x .ravel ()[:10 ])
453+ # jax.debug.print("hidden_state after second gmm {} {}", x.sum(), x.ravel()[:10])
454454
455455 def _finalize_output (x_local , topk_argsort_revert_indices_local , topk_weights_local ):
456456 x_local = x_local [topk_argsort_revert_indices_local ].reshape (- 1 , topk , hidden_size )
@@ -465,12 +465,12 @@ def _finalize_output(x_local, topk_argsort_revert_indices_local, topk_weights_lo
465465 out_specs = (P (ShardingAxisName .ATTN_DATA , None )),
466466 check_rep = False ,
467467 )(x , topk_argsort_revert_indices , topk_weights )
468- jax .debug .print ("hidden_state after finalize output {} {}" , x .sum (), x .ravel ()[:10 ])
468+ # jax.debug.print("hidden_state after finalize output {} {}", x.sum(), x.ravel()[:10])
469469 x = x .reshape (orig_shape )
470470
471471 if reduce_results :
472472 x = jax .lax .with_sharding_constraint (x , NamedSharding (mesh , P (ShardingAxisName .ATTN_DATA )))
473- jax .debug .print ("hidden_state after reducing result {} {}" , x .sum (), x .ravel ()[:10 ])
473+ # jax.debug.print("hidden_state after reducing result {} {}", x.sum(), x.ravel()[:10])
474474 return x
475475
476476
0 commit comments