Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tests/e2e/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ def _run_inference_with_config(model_name: str,
time.sleep(5)


@pytest.mark.parametrize("model_impl_type", ["vllm", "flax_nnx"])
def test_model_data_parallelism(
test_prompts: list,
sampling_params: SamplingParams,
model_impl_type: str,
):
"""
Test model-wise data parallelism where data=2 in the mesh axis.
Expand All @@ -95,6 +97,7 @@ def test_model_data_parallelism(
"""
# Use Llama 1B for this test
test_model = "meta-llama/Llama-3.2-1B-Instruct"
os.environ['MODEL_IMPL_TYPE'] = model_impl_type

# Test with data parallelism enabled
outputs = _run_inference_with_config(
Expand All @@ -103,6 +106,7 @@ def test_model_data_parallelism(
sampling_params=sampling_params,
tensor_parallel_size=1,
data_parallel_size=2,
async_scheduling=True,
)

# Verify we got outputs for all prompts
Expand Down Expand Up @@ -175,7 +179,7 @@ def test_data_parallelism_correctness(
"""
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
model_name = "meta-llama/Llama-3.2-1B-Instruct"
# Use a smaller subset of prompts for correctness testing
small_prompts = test_prompts[:10]

Expand Down
3 changes: 2 additions & 1 deletion tpu_inference/layers/common/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,10 @@ def validate(cls, vllm_config, sharding_strategy):
f"LoRA is not supported with data parallelism "
f"(DP size: {total_dp_size}). Please disable LoRA or "
f"set data parallelism to 1.")
if sharding_strategy.attention_data_parallelism > 1:
if not os.environ.get("NEW_MODEL_DESIGN", False):
raise ValueError(
"Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
"NEW_MODEL_DESIGN=True.")

@property
Expand Down
101 changes: 76 additions & 25 deletions tpu_inference/layers/vllm/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def tensor_sharded_gmm_merged_column_parallel(
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
g)

_gmm = functools.partial(
gmm,
Expand All @@ -123,14 +124,27 @@ def tensor_sharded_gmm_merged_column_parallel(
gmm_result = shard_map(
_gmm,
mesh=mesh,
in_specs=(P(), P(None, "model", None), P()),
out_specs=(P(None, "model")),
in_specs=(P("data", None), P(None, "model", None), P("data")),
out_specs=(P("data", "model")),
check_rep=False,
)(lhs, rhs, group_sizes)

if rhs_bias is not None:
rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)

def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
rhs_bis = jnp.repeat(rhs_bias_local,
group_sizes_global,
0,
total_repeat_length=m // mesh.shape["data"])
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)

gmm_result = shard_map(
_add_bias,
mesh=mesh,
in_specs=(P("data", "model"), P(None, "model"), P("data")),
out_specs=(P("data", "model")),
check_rep=False,
)(gmm_result, rhs_bias, group_sizes)

n_shards = mesh.shape["model"]
output_sizes = [intermediate_size, intermediate_size]
Expand All @@ -150,7 +164,8 @@ def tensor_sharded_gmm_row_parallel(
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
g)

_gmm = functools.partial(
gmm,
Expand All @@ -167,14 +182,26 @@ def _gmm_all_reduce(lhs, rhs, group_sizes):
gmm_result = shard_map(
_gmm_all_reduce,
mesh=mesh,
in_specs=(P(None, "model"), P(None, None, "model"), P()),
out_specs=(P()),
in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
out_specs=(P("data")),
check_rep=False,
)(lhs, rhs, group_sizes)

if rhs_bias is not None:
rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)

def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
rhs_bis = jnp.repeat(rhs_bias_local,
group_sizes_global,
0,
total_repeat_length=m // mesh.shape["data"])
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)

gmm_result = shard_map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally, I prefer not using shard_map if possible as it is really prone to numeric error when not properly used. When using check_rep=False, unlike other ops, there isn't any safety feature that guarantees that all the numerics of a tensor adheres to a proper SPMD / sharding annotation.

I prefer using it only when it's really necessary (like using kernel).

Please modify this code not use shard map and you can refer to this PR where I replaced existing using of shard_map to a regular jax function: #590

Copy link
Collaborator

@kyuyeunk kyuyeunk Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more context, when using shard map with check_rep=False, when using something like 'out_spec=P(None)', it only annotates the tensor as having that sharding but shard map does not introduce any collective to ensure it.

Meaning, it is possible that the output tenor numeric is not replicated along multiple devices and all devices have different numeric because shard_map does not provide any guarantees - which makes it really painful to debug when there's a numeric issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to not use a shard_map but did not figure out a way that doesn't involve a for loop.

The complexity here is that jnp.repeat(bias, group_size, 0) expects bias and group_size to share the same size on dimension 0, but group_size[0]= DP*num_experts whereas bias[0] = num_experts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if that's the case, can you do something like this?

# convert (experts, model_dim) to (experts * dp_size, model_dim)
bias = jnp.repeat(bias, dp_size, 0) 

# (optional. may or may not needed) match bias's sharding with group_size's sharding
bias = jax.lax.with_sharding_constraint(bias, P("data", "model"))

# Now the bias.shape[0] and group_size.shape[0] matches
rhs_bias = jnp.repeat(bias, group_size, 0)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that if complier optimization works correctly, the first 2 code (jnp.repeat & sharding constraint) will be a no-op. because the data is already present in each dp rank & we are just telling it to treat them differently starting from now.

Copy link
Collaborator Author

@wenxindongwork wenxindongwork Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Kyuyeun for this suggestio, For correctness, I have to use jnp.tile instead of jnp.repeat. however, I am noticing performance drop (7575.50 vs 7781.92) if I do this instead of shard_map. Maybe due to jnp.tile not being a no-op?

        rhs_bias = jnp.tile(rhs_bias, (mesh.shape["data"], 1)) 
        # adding the sharding constraint does not make a difference 
        rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
        gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm in-theory it should be no-op.

because bias is already replicated along TPUs in dp axis, and combining tile/repeat with sharding constraint just tells TPU to treat them like a separate non-replicated tensor.

I'll do some test locally and get back to you asap.

_add_bias,
mesh=mesh,
in_specs=(P("data"), P(), P("data")),
out_specs=(P("data")),
check_rep=False,
)(gmm_result, rhs_bias, group_sizes)

return gmm_result

Expand Down Expand Up @@ -366,15 +393,27 @@ def fused_moe_func(
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
topk_weights = topk_weights.astype(dtype)

topk_indices_flat = topk_indices.flatten()
topk_argsort_indices = jnp.argsort(topk_indices_flat)
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
token_indices = jnp.arange(num_tokens, dtype=jnp.int32).repeat(topk)
token_indices_sorted = token_indices[topk_argsort_indices]
group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts)

x = hidden_states[token_indices_sorted]

def _process_tokens_locally(hidden_states_local, topk_indices_local):
num_tokens_local = hidden_states_local.shape[0]
topk_indices_flat = topk_indices_local.flatten()
topk_argsort_indices = jnp.argsort(topk_indices_flat)
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
token_indices = jnp.arange(num_tokens_local,
dtype=jnp.int32).repeat(topk)
token_indices_sorted = token_indices[topk_argsort_indices]
group_sizes_local = jnp.bincount(topk_indices_flat,
length=global_num_experts)

x = hidden_states_local[token_indices_sorted]
return x, group_sizes_local, topk_argsort_revert_indices

x, group_sizes, topk_argsort_revert_indices = shard_map(
_process_tokens_locally,
mesh=mesh,
in_specs=(P("data", None), P("data", None)),
out_specs=(P("data", None), P("data"), P("data")),
check_rep=False,
)(hidden_states, topk_indices)
if use_ep:
x = expert_sharded_gmm(
x,
Expand Down Expand Up @@ -411,7 +450,7 @@ def fused_moe_func(
)
else:
x = jax.lax.with_sharding_constraint(
x, NamedSharding(mesh, P(None, "model")))
x, NamedSharding(mesh, P("data", "model")))
x = tensor_sharded_gmm_row_parallel(
x,
w2,
Expand All @@ -421,13 +460,25 @@ def fused_moe_func(
mesh=mesh,
)

x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
x = x * jnp.expand_dims(topk_weights, axis=-1)
x = x.sum(axis=-2)
def _finalize_output(x_local, topk_argsort_revert_indices_local,
topk_weights_local):
x_local = x_local[topk_argsort_revert_indices_local].reshape(
-1, topk, hidden_size)
x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
x_local = x_local.sum(axis=-2)
return x_local

x = shard_map(
_finalize_output,
mesh=mesh,
in_specs=(P("data", None), P("data"), P("data", None)),
out_specs=(P("data", None)),
check_rep=False,
)(x, topk_argsort_revert_indices, topk_weights)
x = x.reshape(orig_shape)

if reduce_results:
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P("data")))
return x


Expand Down
7 changes: 6 additions & 1 deletion tpu_inference/layers/vllm/quantization/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
" bad performance.", type(layer))

self.bias_sharding = P(self.weight_sharding[0])
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
if isinstance(self.weight_sharding[0], tuple):
self.n_shards = 1
for axis in self.weight_sharding[0]:
self.n_shards *= self.mesh.shape.get(axis, 1)
else:
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)

def get_input_sharding(self, x: torchax.tensor.Tensor):
if self.enable_sequence_parallelism:
Expand Down