Skip to content

Commit 3737895

Browse files
wip
1 parent fa7b2d2 commit 3737895

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tests/e2e/test_data_parallel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def setup_new_model_design():
1414
"""Automatically set NEW_MODEL_DESIGN=True for all tests."""
1515
os.environ['NEW_MODEL_DESIGN'] = 'True'
1616

17-
1817
@pytest.fixture
1918
def test_prompts():
2019
"""Simple test prompts for data parallelism testing."""
@@ -81,9 +80,11 @@ def _run_inference_with_config(model_name: str,
8180
time.sleep(5)
8281

8382

83+
@pytest.mark.parametrize("model_impl_type", ["vllm", "flax_nnx"])
8484
def test_model_data_parallelism(
8585
test_prompts: list,
8686
sampling_params: SamplingParams,
87+
model_impl_type: str,
8788
):
8889
"""
8990
Test model-wise data parallelism where data=2 in the mesh axis.
@@ -95,6 +96,7 @@ def test_model_data_parallelism(
9596
"""
9697
# Use Llama 1B for this test
9798
test_model = "meta-llama/Llama-3.2-1B-Instruct"
99+
os.environ['MODEL_IMPL_TYPE'] = model_impl_type
98100

99101
# Test with data parallelism enabled
100102
outputs = _run_inference_with_config(
@@ -103,6 +105,7 @@ def test_model_data_parallelism(
103105
sampling_params=sampling_params,
104106
tensor_parallel_size=1,
105107
data_parallel_size=2,
108+
async_scheduling=True,
106109
)
107110

108111
# Verify we got outputs for all prompts
@@ -175,7 +178,7 @@ def test_data_parallelism_correctness(
175178
"""
176179
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
177180
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
178-
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
181+
model_name = "meta-llama/Llama-3.2-1B-Instruct"
179182
# Use a smaller subset of prompts for correctness testing
180183
small_prompts = test_prompts[:10]
181184

0 commit comments

Comments
 (0)