Skip to content

Commit b10487a

Browse files
formatting
1 parent a9d5154 commit b10487a

File tree

1 file changed

+30
-17
lines changed

1 file changed

+30
-17
lines changed

tpu_inference/layers/vllm/fused_moe.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def tensor_sharded_gmm_merged_column_parallel(
110110
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
111111
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
112112
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
113-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m//mesh.shape["data"], k, n, g)
113+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
114+
g)
114115

115116
_gmm = functools.partial(
116117
gmm,
@@ -129,10 +130,14 @@ def tensor_sharded_gmm_merged_column_parallel(
129130
)(lhs, rhs, group_sizes)
130131

131132
if rhs_bias is not None:
133+
132134
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
133-
rhs_bis = jnp.repeat(rhs_bias_local, group_sizes_global, 0, total_repeat_length=m//mesh.shape["data"])
135+
rhs_bis = jnp.repeat(rhs_bias_local,
136+
group_sizes_global,
137+
0,
138+
total_repeat_length=m // mesh.shape["data"])
134139
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
135-
140+
136141
gmm_result = shard_map(
137142
_add_bias,
138143
mesh=mesh,
@@ -159,7 +164,8 @@ def tensor_sharded_gmm_row_parallel(
159164
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
160165
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
161166
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
162-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m//mesh.shape["data"], k, n, g)
167+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
168+
g)
163169

164170
_gmm = functools.partial(
165171
gmm,
@@ -176,16 +182,19 @@ def _gmm_all_reduce(lhs, rhs, group_sizes):
176182
gmm_result = shard_map(
177183
_gmm_all_reduce,
178184
mesh=mesh,
179-
in_specs=(P("data", "model"),
180-
P(None, None, "model"), P("data")),
181-
out_specs=(P("data")),
182-
check_rep=False,
185+
in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
186+
out_specs=(P("data")),
187+
check_rep=False,
183188
)(lhs, rhs, group_sizes)
184189
if rhs_bias is not None:
190+
185191
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
186-
rhs_bis = jnp.repeat(rhs_bias_local, group_sizes_global, 0, total_repeat_length=m//mesh.shape["data"])
192+
rhs_bis = jnp.repeat(rhs_bias_local,
193+
group_sizes_global,
194+
0,
195+
total_repeat_length=m // mesh.shape["data"])
187196
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
188-
197+
189198
gmm_result = shard_map(
190199
_add_bias,
191200
mesh=mesh,
@@ -389,13 +398,15 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
389398
topk_indices_flat = topk_indices_local.flatten()
390399
topk_argsort_indices = jnp.argsort(topk_indices_flat)
391400
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
392-
token_indices = jnp.arange(num_tokens_local, dtype=jnp.int32).repeat(topk)
401+
token_indices = jnp.arange(num_tokens_local,
402+
dtype=jnp.int32).repeat(topk)
393403
token_indices_sorted = token_indices[topk_argsort_indices]
394-
group_sizes_local = jnp.bincount(topk_indices_flat, length=global_num_experts)
395-
404+
group_sizes_local = jnp.bincount(topk_indices_flat,
405+
length=global_num_experts)
406+
396407
x = hidden_states_local[token_indices_sorted]
397408
return x, group_sizes_local, topk_argsort_revert_indices
398-
409+
399410
x, group_sizes, topk_argsort_revert_indices = shard_map(
400411
_process_tokens_locally,
401412
mesh=mesh,
@@ -449,12 +460,14 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
449460
mesh=mesh,
450461
)
451462

452-
def _finalize_output(x_local, topk_argsort_revert_indices_local, topk_weights_local):
453-
x_local = x_local[topk_argsort_revert_indices_local].reshape(-1, topk, hidden_size)
463+
def _finalize_output(x_local, topk_argsort_revert_indices_local,
464+
topk_weights_local):
465+
x_local = x_local[topk_argsort_revert_indices_local].reshape(
466+
-1, topk, hidden_size)
454467
x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
455468
x_local = x_local.sum(axis=-2)
456469
return x_local
457-
470+
458471
x = shard_map(
459472
_finalize_output,
460473
mesh=mesh,

0 commit comments

Comments
 (0)