Skip to content

Commit 587c14f

Browse files
[DP] Functional DP for GPT-OSS (#1137)
Signed-off-by: wenxindongwork <wenxindong@google.com>
1 parent 0ba0c92 commit 587c14f

File tree

4 files changed

+87
-28
lines changed

4 files changed

+87
-28
lines changed

tests/e2e/test_data_parallel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,11 @@ def _run_inference_with_config(model_name: str,
8181
time.sleep(5)
8282

8383

84+
@pytest.mark.parametrize("model_impl_type", ["vllm", "flax_nnx"])
8485
def test_model_data_parallelism(
8586
test_prompts: list,
8687
sampling_params: SamplingParams,
88+
model_impl_type: str,
8789
):
8890
"""
8991
Test model-wise data parallelism where data=2 in the mesh axis.
@@ -95,6 +97,7 @@ def test_model_data_parallelism(
9597
"""
9698
# Use Llama 1B for this test
9799
test_model = "meta-llama/Llama-3.2-1B-Instruct"
100+
os.environ['MODEL_IMPL_TYPE'] = model_impl_type
98101

99102
# Test with data parallelism enabled
100103
outputs = _run_inference_with_config(
@@ -103,6 +106,7 @@ def test_model_data_parallelism(
103106
sampling_params=sampling_params,
104107
tensor_parallel_size=1,
105108
data_parallel_size=2,
109+
async_scheduling=True,
106110
)
107111

108112
# Verify we got outputs for all prompts
@@ -175,7 +179,7 @@ def test_data_parallelism_correctness(
175179
"""
176180
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
177181
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
178-
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
182+
model_name = "meta-llama/Llama-3.2-1B-Instruct"
179183
# Use a smaller subset of prompts for correctness testing
180184
small_prompts = test_prompts[:10]
181185

tpu_inference/layers/common/sharding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,10 @@ def validate(cls, vllm_config, sharding_strategy):
166166
f"LoRA is not supported with data parallelism "
167167
f"(DP size: {total_dp_size}). Please disable LoRA or "
168168
f"set data parallelism to 1.")
169+
if sharding_strategy.attention_data_parallelism > 1:
169170
if not os.environ.get("NEW_MODEL_DESIGN", False):
170171
raise ValueError(
171-
"Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
172+
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
172173
"NEW_MODEL_DESIGN=True.")
173174

174175
@property

tpu_inference/layers/vllm/fused_moe.py

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def tensor_sharded_gmm_merged_column_parallel(
110110
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
111111
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
112112
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
113-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
113+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
114+
g)
114115

115116
_gmm = functools.partial(
116117
gmm,
@@ -123,14 +124,26 @@ def tensor_sharded_gmm_merged_column_parallel(
123124
gmm_result = shard_map(
124125
_gmm,
125126
mesh=mesh,
126-
in_specs=(P(), P(None, "model", None), P()),
127-
out_specs=(P(None, "model")),
127+
in_specs=(P("data", None), P(None, "model", None), P("data")),
128+
out_specs=(P("data", "model")),
128129
check_rep=False,
129130
)(lhs, rhs, group_sizes)
130131

131132
if rhs_bias is not None:
132-
rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
133-
gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)
133+
134+
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
135+
rhs_bis = jnp.repeat(rhs_bias_local,
136+
group_sizes_global,
137+
0,
138+
total_repeat_length=m // mesh.shape["data"])
139+
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
140+
141+
gmm_result = shard_map(
142+
_add_bias,
143+
mesh=mesh,
144+
in_specs=(P("data", "model"), P(None, "model"), P("data")),
145+
out_specs=(P("data", "model")),
146+
)(gmm_result, rhs_bias, group_sizes)
134147

135148
n_shards = mesh.shape["model"]
136149
output_sizes = [intermediate_size, intermediate_size]
@@ -150,7 +163,8 @@ def tensor_sharded_gmm_row_parallel(
150163
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
151164
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
152165
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
153-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
166+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
167+
g)
154168

155169
_gmm = functools.partial(
156170
gmm,
@@ -167,14 +181,25 @@ def _gmm_all_reduce(lhs, rhs, group_sizes):
167181
gmm_result = shard_map(
168182
_gmm_all_reduce,
169183
mesh=mesh,
170-
in_specs=(P(None, "model"), P(None, None, "model"), P()),
171-
out_specs=(P()),
184+
in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
185+
out_specs=(P("data")),
172186
check_rep=False,
173187
)(lhs, rhs, group_sizes)
174-
175188
if rhs_bias is not None:
176-
rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
177-
gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
189+
190+
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
191+
rhs_bis = jnp.repeat(rhs_bias_local,
192+
group_sizes_global,
193+
0,
194+
total_repeat_length=m // mesh.shape["data"])
195+
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
196+
197+
gmm_result = shard_map(
198+
_add_bias,
199+
mesh=mesh,
200+
in_specs=(P("data"), P(), P("data")),
201+
out_specs=(P("data")),
202+
)(gmm_result, rhs_bias, group_sizes)
178203

179204
return gmm_result
180205

@@ -366,15 +391,27 @@ def fused_moe_func(
366391
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
367392
topk_weights = topk_weights.astype(dtype)
368393

369-
topk_indices_flat = topk_indices.flatten()
370-
topk_argsort_indices = jnp.argsort(topk_indices_flat)
371-
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
372-
token_indices = jnp.arange(num_tokens, dtype=jnp.int32).repeat(topk)
373-
token_indices_sorted = token_indices[topk_argsort_indices]
374-
group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts)
375-
376-
x = hidden_states[token_indices_sorted]
377-
394+
def _process_tokens_locally(hidden_states_local, topk_indices_local):
395+
num_tokens_local = hidden_states_local.shape[0]
396+
topk_indices_flat = topk_indices_local.flatten()
397+
topk_argsort_indices = jnp.argsort(topk_indices_flat)
398+
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
399+
token_indices = jnp.arange(num_tokens_local,
400+
dtype=jnp.int32).repeat(topk)
401+
token_indices_sorted = token_indices[topk_argsort_indices]
402+
group_sizes_local = jnp.bincount(topk_indices_flat,
403+
length=global_num_experts)
404+
405+
x = hidden_states_local[token_indices_sorted]
406+
return x, group_sizes_local, topk_argsort_revert_indices
407+
408+
x, group_sizes, topk_argsort_revert_indices = shard_map(
409+
_process_tokens_locally,
410+
mesh=mesh,
411+
in_specs=(P("data", None), P("data", None)),
412+
out_specs=(P("data", None), P("data"), P("data")),
413+
check_rep=False,
414+
)(hidden_states, topk_indices)
378415
if use_ep:
379416
x = expert_sharded_gmm(
380417
x,
@@ -411,7 +448,7 @@ def fused_moe_func(
411448
)
412449
else:
413450
x = jax.lax.with_sharding_constraint(
414-
x, NamedSharding(mesh, P(None, "model")))
451+
x, NamedSharding(mesh, P("data", "model")))
415452
x = tensor_sharded_gmm_row_parallel(
416453
x,
417454
w2,
@@ -421,13 +458,25 @@ def fused_moe_func(
421458
mesh=mesh,
422459
)
423460

424-
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
425-
x = x * jnp.expand_dims(topk_weights, axis=-1)
426-
x = x.sum(axis=-2)
461+
def _finalize_output(x_local, topk_argsort_revert_indices_local,
462+
topk_weights_local):
463+
x_local = x_local[topk_argsort_revert_indices_local].reshape(
464+
-1, topk, hidden_size)
465+
x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
466+
x_local = x_local.sum(axis=-2)
467+
return x_local
468+
469+
x = shard_map(
470+
_finalize_output,
471+
mesh=mesh,
472+
in_specs=(P("data", None), P("data"), P("data", None)),
473+
out_specs=(P("data", None)),
474+
check_rep=False,
475+
)(x, topk_argsort_revert_indices, topk_weights)
427476
x = x.reshape(orig_shape)
428477

429478
if reduce_results:
430-
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
479+
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P("data")))
431480
return x
432481

433482

tpu_inference/layers/vllm/quantization/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,12 @@ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
6161
" bad performance.", type(layer))
6262

6363
self.bias_sharding = P(self.weight_sharding[0])
64-
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
64+
if isinstance(self.weight_sharding[0], tuple):
65+
self.n_shards = 1
66+
for axis in self.weight_sharding[0]:
67+
self.n_shards *= self.mesh.shape.get(axis, 1)
68+
else:
69+
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
6570

6671
def get_input_sharding(self, x: torchax.tensor.Tensor):
6772
if self.enable_sequence_parallelism:

0 commit comments

Comments
 (0)