diff --git a/tests/e2e/test_data_parallel.py b/tests/e2e/test_data_parallel.py index 9d794df29..9567731c3 100644 --- a/tests/e2e/test_data_parallel.py +++ b/tests/e2e/test_data_parallel.py @@ -81,9 +81,11 @@ def _run_inference_with_config(model_name: str, time.sleep(5) +@pytest.mark.parametrize("model_impl_type", ["vllm", "flax_nnx"]) def test_model_data_parallelism( test_prompts: list, sampling_params: SamplingParams, + model_impl_type: str, ): """ Test model-wise data parallelism where data=2 in the mesh axis. @@ -95,6 +97,7 @@ def test_model_data_parallelism( """ # Use Llama 1B for this test test_model = "meta-llama/Llama-3.2-1B-Instruct" + os.environ['MODEL_IMPL_TYPE'] = model_impl_type # Test with data parallelism enabled outputs = _run_inference_with_config( @@ -103,6 +106,7 @@ def test_model_data_parallelism( sampling_params=sampling_params, tensor_parallel_size=1, data_parallel_size=2, + async_scheduling=True, ) # Verify we got outputs for all prompts @@ -175,7 +179,7 @@ def test_data_parallelism_correctness( """ os.environ['SKIP_JAX_PRECOMPILE'] = '1' os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0' - model_name = "Qwen/Qwen2.5-1.5B-Instruct" + model_name = "meta-llama/Llama-3.2-1B-Instruct" # Use a smaller subset of prompts for correctness testing small_prompts = test_prompts[:10] diff --git a/tpu_inference/layers/common/sharding.py b/tpu_inference/layers/common/sharding.py index 1a1a8d169..127a74fe6 100644 --- a/tpu_inference/layers/common/sharding.py +++ b/tpu_inference/layers/common/sharding.py @@ -166,9 +166,10 @@ def validate(cls, vllm_config, sharding_strategy): f"LoRA is not supported with data parallelism " f"(DP size: {total_dp_size}). Please disable LoRA or " f"set data parallelism to 1.") + if sharding_strategy.attention_data_parallelism > 1: if not os.environ.get("NEW_MODEL_DESIGN", False): raise ValueError( - "Must run DP with NEW_MODEL_DESIGN enabled. Please set the " + "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the " "NEW_MODEL_DESIGN=True.") @property diff --git a/tpu_inference/layers/vllm/fused_moe.py b/tpu_inference/layers/vllm/fused_moe.py index fa9a45288..b19128475 100644 --- a/tpu_inference/layers/vllm/fused_moe.py +++ b/tpu_inference/layers/vllm/fused_moe.py @@ -110,7 +110,8 @@ def tensor_sharded_gmm_merged_column_parallel( # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401 m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0] n = rhs.shape[1] if transpose_rhs else rhs.shape[2] - tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g) + tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n, + g) _gmm = functools.partial( gmm, @@ -123,14 +124,27 @@ def tensor_sharded_gmm_merged_column_parallel( gmm_result = shard_map( _gmm, mesh=mesh, - in_specs=(P(), P(None, "model", None), P()), - out_specs=(P(None, "model")), + in_specs=(P("data", None), P(None, "model", None), P("data")), + out_specs=(P("data", "model")), check_rep=False, )(lhs, rhs, group_sizes) if rhs_bias is not None: - rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m) - gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype) + + def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global): + rhs_bis = jnp.repeat(rhs_bias_local, + group_sizes_global, + 0, + total_repeat_length=m // mesh.shape["data"]) + return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype) + + gmm_result = shard_map( + _add_bias, + mesh=mesh, + in_specs=(P("data", "model"), P(None, "model"), P("data")), + out_specs=(P("data", "model")), + check_rep=False, + )(gmm_result, rhs_bias, group_sizes) n_shards = mesh.shape["model"] output_sizes = [intermediate_size, intermediate_size] @@ -150,7 +164,8 @@ def tensor_sharded_gmm_row_parallel( # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401 m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0] n = rhs.shape[1] if transpose_rhs else rhs.shape[2] - tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g) + tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n, + g) _gmm = functools.partial( gmm, @@ -167,14 +182,26 @@ def _gmm_all_reduce(lhs, rhs, group_sizes): gmm_result = shard_map( _gmm_all_reduce, mesh=mesh, - in_specs=(P(None, "model"), P(None, None, "model"), P()), - out_specs=(P()), + in_specs=(P("data", "model"), P(None, None, "model"), P("data")), + out_specs=(P("data")), check_rep=False, )(lhs, rhs, group_sizes) - if rhs_bias is not None: - rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m) - gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype) + + def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global): + rhs_bis = jnp.repeat(rhs_bias_local, + group_sizes_global, + 0, + total_repeat_length=m // mesh.shape["data"]) + return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype) + + gmm_result = shard_map( + _add_bias, + mesh=mesh, + in_specs=(P("data"), P(), P("data")), + out_specs=(P("data")), + check_rep=False, + )(gmm_result, rhs_bias, group_sizes) return gmm_result @@ -366,15 +393,27 @@ def fused_moe_func( topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True) topk_weights = topk_weights.astype(dtype) - topk_indices_flat = topk_indices.flatten() - topk_argsort_indices = jnp.argsort(topk_indices_flat) - topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices) - token_indices = jnp.arange(num_tokens, dtype=jnp.int32).repeat(topk) - token_indices_sorted = token_indices[topk_argsort_indices] - group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts) - - x = hidden_states[token_indices_sorted] - + def _process_tokens_locally(hidden_states_local, topk_indices_local): + num_tokens_local = hidden_states_local.shape[0] + topk_indices_flat = topk_indices_local.flatten() + topk_argsort_indices = jnp.argsort(topk_indices_flat) + topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices) + token_indices = jnp.arange(num_tokens_local, + dtype=jnp.int32).repeat(topk) + token_indices_sorted = token_indices[topk_argsort_indices] + group_sizes_local = jnp.bincount(topk_indices_flat, + length=global_num_experts) + + x = hidden_states_local[token_indices_sorted] + return x, group_sizes_local, topk_argsort_revert_indices + + x, group_sizes, topk_argsort_revert_indices = shard_map( + _process_tokens_locally, + mesh=mesh, + in_specs=(P("data", None), P("data", None)), + out_specs=(P("data", None), P("data"), P("data")), + check_rep=False, + )(hidden_states, topk_indices) if use_ep: x = expert_sharded_gmm( x, @@ -411,7 +450,7 @@ def fused_moe_func( ) else: x = jax.lax.with_sharding_constraint( - x, NamedSharding(mesh, P(None, "model"))) + x, NamedSharding(mesh, P("data", "model"))) x = tensor_sharded_gmm_row_parallel( x, w2, @@ -421,13 +460,25 @@ def fused_moe_func( mesh=mesh, ) - x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) - x = x * jnp.expand_dims(topk_weights, axis=-1) - x = x.sum(axis=-2) + def _finalize_output(x_local, topk_argsort_revert_indices_local, + topk_weights_local): + x_local = x_local[topk_argsort_revert_indices_local].reshape( + -1, topk, hidden_size) + x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1) + x_local = x_local.sum(axis=-2) + return x_local + + x = shard_map( + _finalize_output, + mesh=mesh, + in_specs=(P("data", None), P("data"), P("data", None)), + out_specs=(P("data", None)), + check_rep=False, + )(x, topk_argsort_revert_indices, topk_weights) x = x.reshape(orig_shape) if reduce_results: - x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P())) + x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P("data"))) return x diff --git a/tpu_inference/layers/vllm/quantization/common.py b/tpu_inference/layers/vllm/quantization/common.py index 381dce392..2b36a795e 100644 --- a/tpu_inference/layers/vllm/quantization/common.py +++ b/tpu_inference/layers/vllm/quantization/common.py @@ -61,7 +61,12 @@ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase): " bad performance.", type(layer)) self.bias_sharding = P(self.weight_sharding[0]) - self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1) + if isinstance(self.weight_sharding[0], tuple): + self.n_shards = 1 + for axis in self.weight_sharding[0]: + self.n_shards *= self.mesh.shape.get(axis, 1) + else: + self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1) def get_input_sharding(self, x: torchax.tensor.Tensor): if self.enable_sequence_parallelism: