@@ -103,7 +103,6 @@ def tensor_sharded_gmm_merged_column_parallel(
103103 rhs : jax .Array ,
104104 rhs_bias : jax .Array | None ,
105105 group_sizes : jax .Array ,
106- group_sizes_global : jax .Array ,
107106 transpose_rhs : bool ,
108107 mesh : Mesh ,
109108 intermediate_size : int ,
@@ -129,10 +128,18 @@ def tensor_sharded_gmm_merged_column_parallel(
129128 check_rep = False ,
130129 )(lhs , rhs , group_sizes )
131130
132-
133131 if rhs_bias is not None :
134- rhs_bis = jnp .repeat (rhs_bias , group_sizes_global , 0 , total_repeat_length = m )
135- gmm_result = (gmm_result + rhs_bis ).astype (gmm_result .dtype )
132+ def _add_bias (gmm_result_local , rhs_bias_local , group_sizes_global ):
133+ rhs_bis = jnp .repeat (rhs_bias_local , group_sizes_global , 0 , total_repeat_length = m // mesh .shape ["data" ])
134+ return (gmm_result_local + rhs_bis ).astype (gmm_result_local .dtype )
135+
136+ gmm_result = shard_map (
137+ _add_bias ,
138+ mesh = mesh ,
139+ in_specs = (P (ShardingAxisName .MLP_DATA , ShardingAxisName .MLP_TENSOR ), P (None , ShardingAxisName .MLP_TENSOR ), P (ShardingAxisName .MLP_DATA )),
140+ out_specs = (P (ShardingAxisName .MLP_DATA , ShardingAxisName .MLP_TENSOR )),
141+ check_rep = False ,
142+ )(gmm_result , rhs_bias , group_sizes )
136143
137144 n_shards = mesh .shape ['model' ] * mesh .shape .get ('attn_dp' , 1 )
138145 output_sizes = [intermediate_size , intermediate_size ]
@@ -146,7 +153,6 @@ def tensor_sharded_gmm_row_parallel(
146153 rhs : jax .Array ,
147154 rhs_bias : jax .Array | None ,
148155 group_sizes : jax .Array ,
149- group_sizes_global : jax .Array ,
150156 transpose_rhs : bool ,
151157 mesh : Mesh ,
152158) -> jax .Array :
@@ -177,8 +183,17 @@ def _gmm_all_reduce(lhs, rhs, group_sizes):
177183 )(lhs , rhs , group_sizes )
178184 # jax.debug.print("gmm_result before bias {} {}", gmm_result.sum(), gmm_result.ravel()[:10])
179185 if rhs_bias is not None :
180- rhs_bias = jnp .repeat (rhs_bias , group_sizes_global , 0 , total_repeat_length = m )
181- gmm_result = (gmm_result + rhs_bias ).astype (gmm_result .dtype )
186+ def _add_bias (gmm_result_local , rhs_bias_local , group_sizes_global ):
187+ rhs_bis = jnp .repeat (rhs_bias_local , group_sizes_global , 0 , total_repeat_length = m // mesh .shape ["data" ])
188+ return (gmm_result_local + rhs_bis ).astype (gmm_result_local .dtype )
189+
190+ gmm_result = shard_map (
191+ _add_bias ,
192+ mesh = mesh ,
193+ in_specs = (P (ShardingAxisName .MLP_DATA ), P (), P (ShardingAxisName .MLP_DATA )),
194+ out_specs = (P (ShardingAxisName .MLP_DATA )),
195+ check_rep = False ,
196+ )(gmm_result , rhs_bias , group_sizes )
182197
183198 return gmm_result
184199
@@ -359,13 +374,6 @@ def fused_moe_func(
359374 assert (num_tokens * topk ) % 16 == 0 , (
360375 "The kernel requires num_tokens * topk to be a multiple of "
361376 f"16 but got { num_tokens } *{ topk } ={ num_tokens * topk } " )
362- hidden_states = jax .lax .with_sharding_constraint (
363- hidden_states , NamedSharding (mesh , P (ShardingAxisName .ATTN_DATA , None )))
364-
365- gating_output = jax .lax .with_sharding_constraint (
366- gating_output , NamedSharding (mesh , P (ShardingAxisName .ATTN_DATA , None )))
367-
368- # jax.debug.print("hidden_state before MoE {} {}", hidden_states.sum(), hidden_states.ravel()[:10])
369377 hidden_states = hidden_states .reshape (num_tokens , hidden_size )
370378 gating_output = gating_output .reshape (num_tokens , global_num_experts )
371379
@@ -383,19 +391,15 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
383391 token_indices = jnp .arange (num_tokens_local , dtype = jnp .int32 ).repeat (topk )
384392 token_indices_sorted = token_indices [topk_argsort_indices ]
385393 group_sizes_local = jnp .bincount (topk_indices_flat , length = global_num_experts )
386-
387- # Reduce group_sizes once across data parallel shards to get global counts
388- # This is needed for bias addition and should be done only once for efficiency
389- group_sizes_global = jax .lax .psum (group_sizes_local , axis_name = ShardingAxisName .ATTN_DATA )
390-
394+
391395 x = hidden_states_local [token_indices_sorted ]
392- return x , group_sizes_local , group_sizes_global , topk_argsort_revert_indices
396+ return x , group_sizes_local , topk_argsort_revert_indices
393397
394- x , group_sizes , group_sizes_global , topk_argsort_revert_indices = shard_map (
398+ x , group_sizes , topk_argsort_revert_indices = shard_map (
395399 _process_tokens_locally ,
396400 mesh = mesh ,
397401 in_specs = (P (ShardingAxisName .ATTN_DATA , None ), P (ShardingAxisName .ATTN_DATA , None )),
398- out_specs = (P (ShardingAxisName .ATTN_DATA , None ), P (ShardingAxisName .ATTN_DATA ), P (), P ( ShardingAxisName .ATTN_DATA )),
402+ out_specs = (P (ShardingAxisName .ATTN_DATA , None ), P (ShardingAxisName .ATTN_DATA ), P (ShardingAxisName .ATTN_DATA )),
399403 check_rep = False ,
400404 )(hidden_states , topk_indices )
401405
@@ -418,7 +422,6 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
418422 w1 ,
419423 w1_bias ,
420424 group_sizes ,
421- group_sizes_global ,
422425 transpose_rhs = True ,
423426 mesh = mesh ,
424427 intermediate_size = intermediate_size ,
@@ -446,7 +449,6 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
446449 w2 ,
447450 w2_bias ,
448451 group_sizes ,
449- group_sizes_global ,
450452 transpose_rhs = True ,
451453 mesh = mesh ,
452454 )
0 commit comments