@@ -14,7 +14,6 @@ def setup_new_model_design():
1414 """Automatically set NEW_MODEL_DESIGN=True for all tests."""
1515 os .environ ['NEW_MODEL_DESIGN' ] = 'True'
1616
17-
1817@pytest .fixture
1918def test_prompts ():
2019 """Simple test prompts for data parallelism testing."""
@@ -81,9 +80,11 @@ def _run_inference_with_config(model_name: str,
8180 time .sleep (5 )
8281
8382
83+ @pytest .mark .parametrize ("model_impl_type" , ["vllm" , "flax_nnx" ])
8484def test_model_data_parallelism (
8585 test_prompts : list ,
8686 sampling_params : SamplingParams ,
87+ model_impl_type : str ,
8788):
8889 """
8990 Test model-wise data parallelism where data=2 in the mesh axis.
@@ -95,6 +96,7 @@ def test_model_data_parallelism(
9596 """
9697 # Use Llama 1B for this test
9798 test_model = "meta-llama/Llama-3.2-1B-Instruct"
99+ os .environ ['MODEL_IMPL_TYPE' ] = model_impl_type
98100
99101 # Test with data parallelism enabled
100102 outputs = _run_inference_with_config (
@@ -103,6 +105,7 @@ def test_model_data_parallelism(
103105 sampling_params = sampling_params ,
104106 tensor_parallel_size = 1 ,
105107 data_parallel_size = 2 ,
108+ async_scheduling = True ,
106109 )
107110
108111 # Verify we got outputs for all prompts
@@ -175,7 +178,7 @@ def test_data_parallelism_correctness(
175178 """
176179 os .environ ['SKIP_JAX_PRECOMPILE' ] = '1'
177180 os .environ ['VLLM_XLA_CHECK_RECOMPILATION' ] = '0'
178- model_name = "Qwen/Qwen2.5-1.5B -Instruct"
181+ model_name = "meta-llama/Llama-3.2-1B -Instruct"
179182 # Use a smaller subset of prompts for correctness testing
180183 small_prompts = test_prompts [:10 ]
181184
0 commit comments