@@ -103,6 +103,7 @@ def tensor_sharded_gmm_merged_column_parallel(
103103 rhs : jax .Array ,
104104 rhs_bias : jax .Array | None ,
105105 group_sizes : jax .Array ,
106+ group_sizes_global : jax .Array ,
106107 transpose_rhs : bool ,
107108 mesh : Mesh ,
108109 intermediate_size : int ,
@@ -128,9 +129,9 @@ def tensor_sharded_gmm_merged_column_parallel(
128129 check_rep = False ,
129130 )(lhs , rhs , group_sizes )
130131
132+
131133 if rhs_bias is not None :
132- rhs_bis = jnp .repeat (rhs_bias , group_sizes , 0 , total_repeat_length = m )
133- # Maybe need to add sharding constraint here
134+ rhs_bis = jnp .repeat (rhs_bias , group_sizes_global , 0 , total_repeat_length = m )
134135 gmm_result = (gmm_result + rhs_bis ).astype (gmm_result .dtype )
135136
136137 n_shards = mesh .shape ['model' ] * mesh .shape .get ('attn_dp' , 1 )
@@ -145,6 +146,7 @@ def tensor_sharded_gmm_row_parallel(
145146 rhs : jax .Array ,
146147 rhs_bias : jax .Array | None ,
147148 group_sizes : jax .Array ,
149+ group_sizes_global : jax .Array ,
148150 transpose_rhs : bool ,
149151 mesh : Mesh ,
150152) -> jax .Array :
@@ -173,11 +175,9 @@ def _gmm_all_reduce(lhs, rhs, group_sizes):
173175 out_specs = (P (ShardingAxisName .MLP_DATA )),
174176 check_rep = False ,
175177 )(lhs , rhs , group_sizes )
176-
178+ jax . debug . print ( "gmm_result before bias {} {}" , gmm_result . sum (), gmm_result . ravel ()[: 10 ])
177179 if rhs_bias is not None :
178-
179- rhs_bias = jnp .repeat (rhs_bias , group_sizes , 0 , total_repeat_length = m )
180- # wenxindong: Maybe need to add sharding constraint here
180+ rhs_bias = jnp .repeat (rhs_bias , group_sizes_global , 0 , total_repeat_length = m )
181181 gmm_result = (gmm_result + rhs_bias ).astype (gmm_result .dtype )
182182
183183 return gmm_result
@@ -365,6 +365,7 @@ def fused_moe_func(
365365 gating_output = jax .lax .with_sharding_constraint (
366366 gating_output , NamedSharding (mesh , P (ShardingAxisName .ATTN_DATA , None )))
367367
368+ jax .debug .print ("hidden_state before MoE {} {}" , hidden_states .sum (), hidden_states .ravel ()[:10 ])
368369 hidden_states = hidden_states .reshape (num_tokens , hidden_size )
369370 gating_output = gating_output .reshape (num_tokens , global_num_experts )
370371
@@ -381,19 +382,25 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
381382 topk_argsort_revert_indices = jnp .argsort (topk_argsort_indices )
382383 token_indices = jnp .arange (num_tokens_local , dtype = jnp .int32 ).repeat (topk )
383384 token_indices_sorted = token_indices [topk_argsort_indices ]
384- group_sizes = jnp .bincount (topk_indices_flat , length = global_num_experts )
385+ group_sizes_local = jnp .bincount (topk_indices_flat , length = global_num_experts )
386+
387+ # Reduce group_sizes once across data parallel shards to get global counts
388+ # This is needed for bias addition and should be done only once for efficiency
389+ group_sizes_global = jax .lax .psum (group_sizes_local , axis_name = ShardingAxisName .MLP_DATA )
385390
386391 x = hidden_states_local [token_indices_sorted ]
387- return x , group_sizes , topk_argsort_revert_indices
392+ return x , group_sizes_local , group_sizes_global , topk_argsort_revert_indices
388393
389- x , group_sizes , topk_argsort_revert_indices = shard_map (
394+ x , group_sizes , group_sizes_global , topk_argsort_revert_indices = shard_map (
390395 _process_tokens_locally ,
391396 mesh = mesh ,
392397 in_specs = (P (ShardingAxisName .ATTN_DATA , None ), P (ShardingAxisName .ATTN_DATA , None )),
393- out_specs = (P (ShardingAxisName .ATTN_DATA , None ), P (), P (ShardingAxisName .ATTN_DATA )),
398+ out_specs = (P (ShardingAxisName .ATTN_DATA , None ), P (ShardingAxisName . ATTN_DATA ), P ( ), P (ShardingAxisName .ATTN_DATA )),
394399 check_rep = False ,
395400 )(hidden_states , topk_indices )
396-
401+
402+ jax .debug .print ("hidden_state before gmm {} {}" , x .sum (), x .ravel ()[:10 ])
403+ jax .debug .print ("group_sizes {} {}" , group_sizes .sum (), group_sizes )
397404 if use_ep :
398405 x = expert_sharded_gmm (
399406 x ,
@@ -411,13 +418,16 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
411418 w1 ,
412419 w1_bias ,
413420 group_sizes ,
421+ group_sizes_global ,
414422 transpose_rhs = True ,
415423 mesh = mesh ,
416424 intermediate_size = intermediate_size ,
417425 )
426+ jax .debug .print ("hidden_state after first gmm x1 {} {}" , x1 .sum (), x1 .ravel ()[:10 ])
427+ jax .debug .print ("hidden_state after first gmm x2 {} {}" , x2 .sum (), x2 .ravel ()[:10 ])
418428
419429 x = activation_fn (activation , x1 , x2 )
420-
430+ jax . debug . print ( "hidden_state after activation {} {}" , x . sum (), x . ravel ()[: 10 ])
421431 if use_ep :
422432 x = expert_sharded_gmm (
423433 x ,
@@ -436,9 +446,11 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
436446 w2 ,
437447 w2_bias ,
438448 group_sizes ,
449+ group_sizes_global ,
439450 transpose_rhs = True ,
440451 mesh = mesh ,
441452 )
453+ jax .debug .print ("hidden_state after second gmm {} {}" , x .sum (), x .ravel ()[:10 ])
442454
443455 def _finalize_output (x_local , topk_argsort_revert_indices_local , topk_weights_local ):
444456 x_local = x_local [topk_argsort_revert_indices_local ].reshape (- 1 , topk , hidden_size )
@@ -453,11 +465,12 @@ def _finalize_output(x_local, topk_argsort_revert_indices_local, topk_weights_lo
453465 out_specs = (P (ShardingAxisName .ATTN_DATA , None )),
454466 check_rep = False ,
455467 )(x , topk_argsort_revert_indices , topk_weights )
456-
468+ jax . debug . print ( "hidden_state after finalize output {} {}" , x . sum (), x . ravel ()[: 10 ])
457469 x = x .reshape (orig_shape )
458470
459471 if reduce_results :
460472 x = jax .lax .with_sharding_constraint (x , NamedSharding (mesh , P (ShardingAxisName .ATTN_DATA )))
473+ jax .debug .print ("hidden_state after reducing result {} {}" , x .sum (), x .ravel ()[:10 ])
461474 return x
462475
463476
0 commit comments