@@ -110,7 +110,8 @@ def tensor_sharded_gmm_merged_column_parallel(
110110 # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
111111 m , k , g = lhs .shape [0 ], lhs .shape [1 ], rhs .shape [0 ]
112112 n = rhs .shape [1 ] if transpose_rhs else rhs .shape [2 ]
113- tm , tk , tn = _get_tiling_size_for_gmm_kernel (m , k , n , g )
113+ tm , tk , tn = _get_tiling_size_for_gmm_kernel (m // mesh .shape ["data" ], k , n ,
114+ g )
114115
115116 _gmm = functools .partial (
116117 gmm ,
@@ -123,14 +124,26 @@ def tensor_sharded_gmm_merged_column_parallel(
123124 gmm_result = shard_map (
124125 _gmm ,
125126 mesh = mesh ,
126- in_specs = (P (), P (None , "model" , None ), P ()),
127- out_specs = (P (None , "model" )),
127+ in_specs = (P ("data" , None ), P (None , "model" , None ), P ("data" )),
128+ out_specs = (P ("data" , "model" )),
128129 check_rep = False ,
129130 )(lhs , rhs , group_sizes )
130131
131132 if rhs_bias is not None :
132- rhs_bis = jnp .repeat (rhs_bias , group_sizes , 0 , total_repeat_length = m )
133- gmm_result = (gmm_result + rhs_bis ).astype (gmm_result .dtype )
133+
134+ def _add_bias (gmm_result_local , rhs_bias_local , group_sizes_global ):
135+ rhs_bis = jnp .repeat (rhs_bias_local ,
136+ group_sizes_global ,
137+ 0 ,
138+ total_repeat_length = m // mesh .shape ["data" ])
139+ return (gmm_result_local + rhs_bis ).astype (gmm_result_local .dtype )
140+
141+ gmm_result = shard_map (
142+ _add_bias ,
143+ mesh = mesh ,
144+ in_specs = (P ("data" , "model" ), P (None , "model" ), P ("data" )),
145+ out_specs = (P ("data" , "model" )),
146+ )(gmm_result , rhs_bias , group_sizes )
134147
135148 n_shards = mesh .shape ["model" ]
136149 output_sizes = [intermediate_size , intermediate_size ]
@@ -150,7 +163,8 @@ def tensor_sharded_gmm_row_parallel(
150163 # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
151164 m , k , g = lhs .shape [0 ], lhs .shape [1 ], rhs .shape [0 ]
152165 n = rhs .shape [1 ] if transpose_rhs else rhs .shape [2 ]
153- tm , tk , tn = _get_tiling_size_for_gmm_kernel (m , k , n , g )
166+ tm , tk , tn = _get_tiling_size_for_gmm_kernel (m // mesh .shape ["data" ], k , n ,
167+ g )
154168
155169 _gmm = functools .partial (
156170 gmm ,
@@ -167,14 +181,25 @@ def _gmm_all_reduce(lhs, rhs, group_sizes):
167181 gmm_result = shard_map (
168182 _gmm_all_reduce ,
169183 mesh = mesh ,
170- in_specs = (P (None , "model" ), P (None , None , "model" ), P ()),
171- out_specs = (P ()),
184+ in_specs = (P ("data" , "model" ), P (None , None , "model" ), P ("data" )),
185+ out_specs = (P ("data" )),
172186 check_rep = False ,
173187 )(lhs , rhs , group_sizes )
174-
175188 if rhs_bias is not None :
176- rhs_bias = jnp .repeat (rhs_bias , group_sizes , 0 , total_repeat_length = m )
177- gmm_result = (gmm_result + rhs_bias ).astype (gmm_result .dtype )
189+
190+ def _add_bias (gmm_result_local , rhs_bias_local , group_sizes_global ):
191+ rhs_bis = jnp .repeat (rhs_bias_local ,
192+ group_sizes_global ,
193+ 0 ,
194+ total_repeat_length = m // mesh .shape ["data" ])
195+ return (gmm_result_local + rhs_bis ).astype (gmm_result_local .dtype )
196+
197+ gmm_result = shard_map (
198+ _add_bias ,
199+ mesh = mesh ,
200+ in_specs = (P ("data" ), P (), P ("data" )),
201+ out_specs = (P ("data" )),
202+ )(gmm_result , rhs_bias , group_sizes )
178203
179204 return gmm_result
180205
@@ -366,15 +391,27 @@ def fused_moe_func(
366391 topk_weights = topk_weights / topk_weights .sum (axis = - 1 , keepdims = True )
367392 topk_weights = topk_weights .astype (dtype )
368393
369- topk_indices_flat = topk_indices .flatten ()
370- topk_argsort_indices = jnp .argsort (topk_indices_flat )
371- topk_argsort_revert_indices = jnp .argsort (topk_argsort_indices )
372- token_indices = jnp .arange (num_tokens , dtype = jnp .int32 ).repeat (topk )
373- token_indices_sorted = token_indices [topk_argsort_indices ]
374- group_sizes = jnp .bincount (topk_indices_flat , length = global_num_experts )
375-
376- x = hidden_states [token_indices_sorted ]
377-
394+ def _process_tokens_locally (hidden_states_local , topk_indices_local ):
395+ num_tokens_local = hidden_states_local .shape [0 ]
396+ topk_indices_flat = topk_indices_local .flatten ()
397+ topk_argsort_indices = jnp .argsort (topk_indices_flat )
398+ topk_argsort_revert_indices = jnp .argsort (topk_argsort_indices )
399+ token_indices = jnp .arange (num_tokens_local ,
400+ dtype = jnp .int32 ).repeat (topk )
401+ token_indices_sorted = token_indices [topk_argsort_indices ]
402+ group_sizes_local = jnp .bincount (topk_indices_flat ,
403+ length = global_num_experts )
404+
405+ x = hidden_states_local [token_indices_sorted ]
406+ return x , group_sizes_local , topk_argsort_revert_indices
407+
408+ x , group_sizes , topk_argsort_revert_indices = shard_map (
409+ _process_tokens_locally ,
410+ mesh = mesh ,
411+ in_specs = (P ("data" , None ), P ("data" , None )),
412+ out_specs = (P ("data" , None ), P ("data" ), P ("data" )),
413+ check_rep = False ,
414+ )(hidden_states , topk_indices )
378415 if use_ep :
379416 x = expert_sharded_gmm (
380417 x ,
@@ -411,7 +448,7 @@ def fused_moe_func(
411448 )
412449 else :
413450 x = jax .lax .with_sharding_constraint (
414- x , NamedSharding (mesh , P (None , "model" )))
451+ x , NamedSharding (mesh , P ("data" , "model" )))
415452 x = tensor_sharded_gmm_row_parallel (
416453 x ,
417454 w2 ,
@@ -421,13 +458,25 @@ def fused_moe_func(
421458 mesh = mesh ,
422459 )
423460
424- x = x [topk_argsort_revert_indices ].reshape (- 1 , topk , hidden_size )
425- x = x * jnp .expand_dims (topk_weights , axis = - 1 )
426- x = x .sum (axis = - 2 )
461+ def _finalize_output (x_local , topk_argsort_revert_indices_local ,
462+ topk_weights_local ):
463+ x_local = x_local [topk_argsort_revert_indices_local ].reshape (
464+ - 1 , topk , hidden_size )
465+ x_local = x_local * jnp .expand_dims (topk_weights_local , axis = - 1 )
466+ x_local = x_local .sum (axis = - 2 )
467+ return x_local
468+
469+ x = shard_map (
470+ _finalize_output ,
471+ mesh = mesh ,
472+ in_specs = (P ("data" , None ), P ("data" ), P ("data" , None )),
473+ out_specs = (P ("data" , None )),
474+ check_rep = False ,
475+ )(x , topk_argsort_revert_indices , topk_weights )
427476 x = x .reshape (orig_shape )
428477
429478 if reduce_results :
430- x = jax .lax .with_sharding_constraint (x , NamedSharding (mesh , P ()))
479+ x = jax .lax .with_sharding_constraint (x , NamedSharding (mesh , P ("data" )))
431480 return x
432481
433482
0 commit comments