11"""Unit tests for TPUModelRunner mesh initialization."""
2+ import os
23from unittest .mock import Mock , patch
34
45import pytest
56
6- import tpu_inference .envs as envs
77from tpu_inference .runner .tpu_runner import TPUModelRunner
88
99
@@ -53,7 +53,7 @@ def runner_instance(self, mock_vllm_config, mock_devices):
5353 def test_init_mesh_2d_model_without_device_order (self , runner_instance ,
5454 mock_vllm_config ):
5555 """Test 2d mesh creation without enforced device order."""
56- with patch .dict (envs . environment_variables , {'NEW_MODEL_DESIGN' : lambda : False }), \
56+ with patch .dict (os . environ , {'NEW_MODEL_DESIGN' : '' }), \
5757 patch ('tpu_inference.runner.tpu_runner.make_optimized_mesh' ) as mock_make_mesh , \
5858 patch ('tpu_inference.runner.tpu_runner.logger' ):
5959
@@ -79,7 +79,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance,
7979 """Test 2d mesh creation with enforced device order."""
8080 mock_vllm_config .sharding_config .device_indexes = [0 , 1 , 2 , 3 ]
8181
82- with patch .dict (envs . environment_variables , {'NEW_MODEL_DESIGN' : lambda : False }), \
82+ with patch .dict (os . environ , {'NEW_MODEL_DESIGN' : '' }), \
8383 patch ('jax.make_mesh' ) as mock_jax_mesh , \
8484 patch ('tpu_inference.runner.tpu_runner.logger' ):
8585
@@ -103,7 +103,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance,
103103 def test_init_mesh_new_model_single_slice (self , runner_instance ,
104104 mock_vllm_config ):
105105 """Test new model mesh creation with single slice."""
106- with patch .dict (envs . environment_variables , {'NEW_MODEL_DESIGN' : lambda : True , 'NUM_SLICES' : lambda : 1 }), \
106+ with patch .dict (os . environ , {'NEW_MODEL_DESIGN' : '1' , 'NUM_SLICES' : '1' }), \
107107 patch ('tpu_inference.runner.tpu_runner.mesh_utils' ) as mock_mesh_utils , \
108108 patch ('jax.sharding.Mesh' ) as mock_jax_mesh , \
109109 patch ('tpu_inference.runner.tpu_runner.logger' ):
@@ -134,7 +134,7 @@ def test_init_mesh_new_model_multi_slice(self, runner_instance,
134134 mock_vllm_config ):
135135 """Test new model mesh creation with multiple slices."""
136136 num_slices = 2
137- with patch .dict (envs . environment_variables , {'NEW_MODEL_DESIGN' : lambda : True , 'NUM_SLICES' : lambda : num_slices }), \
137+ with patch .dict (os . environ , {'NEW_MODEL_DESIGN' : '1' , 'NUM_SLICES' : str ( num_slices ) }), \
138138 patch ('tpu_inference.runner.tpu_runner.mesh_utils' ) as mock_mesh_utils , \
139139 patch ('jax.sharding.Mesh' ) as mock_jax_mesh , \
140140 patch ('tpu_inference.runner.tpu_runner.logger' ):
0 commit comments