44
55import pytest
66
7+ import tpu_inference .envs as envs
78from tpu_inference .runner .tpu_runner import TPUModelRunner
89
910
@@ -53,7 +54,7 @@ def runner_instance(self, mock_vllm_config, mock_devices):
5354 def test_init_mesh_2d_model_without_device_order (self , runner_instance ,
5455 mock_vllm_config ):
5556 """Test 2d mesh creation without enforced device order."""
56- with patch .dict (os . environ , {'NEW_MODEL_DESIGN' : '' }), \
57+ with patch .dict (envs . environment_variables , {'NEW_MODEL_DESIGN' : lambda : False }), \
5758 patch ('tpu_inference.runner.tpu_runner.make_optimized_mesh' ) as mock_make_mesh , \
5859 patch ('tpu_inference.runner.tpu_runner.logger' ):
5960
@@ -79,7 +80,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance,
7980 """Test 2d mesh creation with enforced device order."""
8081 mock_vllm_config .sharding_config .device_indexes = [0 , 1 , 2 , 3 ]
8182
82- with patch .dict (os . environ , {'NEW_MODEL_DESIGN' : '' }), \
83+ with patch .dict (envs . environment_variables , {'NEW_MODEL_DESIGN' : lambda : False }), \
8384 patch ('jax.make_mesh' ) as mock_jax_mesh , \
8485 patch ('tpu_inference.runner.tpu_runner.logger' ):
8586
@@ -103,7 +104,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance,
103104 def test_init_mesh_new_model_single_slice (self , runner_instance ,
104105 mock_vllm_config ):
105106 """Test new model mesh creation with single slice."""
106- with patch .dict (os . environ , {'NEW_MODEL_DESIGN' : '1' , 'NUM_SLICES' : '1' }), \
107+ with patch .dict (envs . environment_variables , {'NEW_MODEL_DESIGN' : lambda : True , 'NUM_SLICES' : lambda : 1 }), \
107108 patch ('tpu_inference.runner.tpu_runner.mesh_utils' ) as mock_mesh_utils , \
108109 patch ('jax.sharding.Mesh' ) as mock_jax_mesh , \
109110 patch ('tpu_inference.runner.tpu_runner.logger' ):
@@ -134,7 +135,7 @@ def test_init_mesh_new_model_multi_slice(self, runner_instance,
134135 mock_vllm_config ):
135136 """Test new model mesh creation with multiple slices."""
136137 num_slices = 2
137- with patch .dict (os . environ , {'NEW_MODEL_DESIGN' : '1' , 'NUM_SLICES' : str ( num_slices ) }), \
138+ with patch .dict (envs . environment_variables , {'NEW_MODEL_DESIGN' : lambda : True , 'NUM_SLICES' : lambda : num_slices }), \
138139 patch ('tpu_inference.runner.tpu_runner.mesh_utils' ) as mock_mesh_utils , \
139140 patch ('jax.sharding.Mesh' ) as mock_jax_mesh , \
140141 patch ('tpu_inference.runner.tpu_runner.logger' ):
0 commit comments