@@ -110,7 +110,8 @@ def tensor_sharded_gmm_merged_column_parallel(
110110 # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
111111 m , k , g = lhs .shape [0 ], lhs .shape [1 ], rhs .shape [0 ]
112112 n = rhs .shape [1 ] if transpose_rhs else rhs .shape [2 ]
113- tm , tk , tn = _get_tiling_size_for_gmm_kernel (m // mesh .shape ["data" ], k , n , g )
113+ tm , tk , tn = _get_tiling_size_for_gmm_kernel (m // mesh .shape ["data" ], k , n ,
114+ g )
114115
115116 _gmm = functools .partial (
116117 gmm ,
@@ -129,10 +130,14 @@ def tensor_sharded_gmm_merged_column_parallel(
129130 )(lhs , rhs , group_sizes )
130131
131132 if rhs_bias is not None :
133+
132134 def _add_bias (gmm_result_local , rhs_bias_local , group_sizes_global ):
133- rhs_bis = jnp .repeat (rhs_bias_local , group_sizes_global , 0 , total_repeat_length = m // mesh .shape ["data" ])
135+ rhs_bis = jnp .repeat (rhs_bias_local ,
136+ group_sizes_global ,
137+ 0 ,
138+ total_repeat_length = m // mesh .shape ["data" ])
134139 return (gmm_result_local + rhs_bis ).astype (gmm_result_local .dtype )
135-
140+
136141 gmm_result = shard_map (
137142 _add_bias ,
138143 mesh = mesh ,
@@ -159,7 +164,8 @@ def tensor_sharded_gmm_row_parallel(
159164 # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
160165 m , k , g = lhs .shape [0 ], lhs .shape [1 ], rhs .shape [0 ]
161166 n = rhs .shape [1 ] if transpose_rhs else rhs .shape [2 ]
162- tm , tk , tn = _get_tiling_size_for_gmm_kernel (m // mesh .shape ["data" ], k , n , g )
167+ tm , tk , tn = _get_tiling_size_for_gmm_kernel (m // mesh .shape ["data" ], k , n ,
168+ g )
163169
164170 _gmm = functools .partial (
165171 gmm ,
@@ -176,16 +182,19 @@ def _gmm_all_reduce(lhs, rhs, group_sizes):
176182 gmm_result = shard_map (
177183 _gmm_all_reduce ,
178184 mesh = mesh ,
179- in_specs = (P ("data" , "model" ),
180- P (None , None , "model" ), P ("data" )),
181- out_specs = (P ("data" )),
182- check_rep = False ,
185+ in_specs = (P ("data" , "model" ), P (None , None , "model" ), P ("data" )),
186+ out_specs = (P ("data" )),
187+ check_rep = False ,
183188 )(lhs , rhs , group_sizes )
184189 if rhs_bias is not None :
190+
185191 def _add_bias (gmm_result_local , rhs_bias_local , group_sizes_global ):
186- rhs_bis = jnp .repeat (rhs_bias_local , group_sizes_global , 0 , total_repeat_length = m // mesh .shape ["data" ])
192+ rhs_bis = jnp .repeat (rhs_bias_local ,
193+ group_sizes_global ,
194+ 0 ,
195+ total_repeat_length = m // mesh .shape ["data" ])
187196 return (gmm_result_local + rhs_bis ).astype (gmm_result_local .dtype )
188-
197+
189198 gmm_result = shard_map (
190199 _add_bias ,
191200 mesh = mesh ,
@@ -389,13 +398,15 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
389398 topk_indices_flat = topk_indices_local .flatten ()
390399 topk_argsort_indices = jnp .argsort (topk_indices_flat )
391400 topk_argsort_revert_indices = jnp .argsort (topk_argsort_indices )
392- token_indices = jnp .arange (num_tokens_local , dtype = jnp .int32 ).repeat (topk )
401+ token_indices = jnp .arange (num_tokens_local ,
402+ dtype = jnp .int32 ).repeat (topk )
393403 token_indices_sorted = token_indices [topk_argsort_indices ]
394- group_sizes_local = jnp .bincount (topk_indices_flat , length = global_num_experts )
395-
404+ group_sizes_local = jnp .bincount (topk_indices_flat ,
405+ length = global_num_experts )
406+
396407 x = hidden_states_local [token_indices_sorted ]
397408 return x , group_sizes_local , topk_argsort_revert_indices
398-
409+
399410 x , group_sizes , topk_argsort_revert_indices = shard_map (
400411 _process_tokens_locally ,
401412 mesh = mesh ,
@@ -449,12 +460,14 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
449460 mesh = mesh ,
450461 )
451462
452- def _finalize_output (x_local , topk_argsort_revert_indices_local , topk_weights_local ):
453- x_local = x_local [topk_argsort_revert_indices_local ].reshape (- 1 , topk , hidden_size )
463+ def _finalize_output (x_local , topk_argsort_revert_indices_local ,
464+ topk_weights_local ):
465+ x_local = x_local [topk_argsort_revert_indices_local ].reshape (
466+ - 1 , topk , hidden_size )
454467 x_local = x_local * jnp .expand_dims (topk_weights_local , axis = - 1 )
455468 x_local = x_local .sum (axis = - 2 )
456469 return x_local
457-
470+
458471 x = shard_map (
459472 _finalize_output ,
460473 mesh = mesh ,
0 commit comments