Skip to content

Commit 3e8e3db

Browse files
committed
Use os.environ directly for setting environment variables
Signed-off-by: Xing Liu <xingliu14@gmail.com>
1 parent 8d13d94 commit 3e8e3db

File tree

4 files changed

+13
-12
lines changed

4 files changed

+13
-12
lines changed

examples/offline_inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import os
5+
46
import vllm.envs as vllm_envs
57
from vllm import LLM, EngineArgs
68
from vllm.utils.argparse_utils import FlexibleArgumentParser
79

8-
import tpu_inference.envs as envs
910
from tpu_inference.core import disagg_utils
1011

1112

@@ -103,7 +104,7 @@ def main(args: dict):
103104

104105
if __name__ == "__main__":
105106
# Skip long warmup for local simple test.
106-
envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True
107+
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
107108

108109
parser = create_parser()
109110
args: dict = vars(parser.parse_args())

examples/offline_lora_inference.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import os
45
import time
56

67
import vllm.envs as vllm_envs
78
from vllm import LLM, EngineArgs
89
from vllm.lora.request import LoRARequest
910
from vllm.utils.argparse_utils import FlexibleArgumentParser
1011

11-
import tpu_inference.envs as envs
12-
1312

1413
def create_parser():
1514
parser = FlexibleArgumentParser()
@@ -78,7 +77,7 @@ def main(args: dict):
7877

7978
if __name__ == "__main__":
8079
# Skip long warmup for local simple test.
81-
envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True
80+
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
8281

8382
parser = create_parser()
8483
args: dict = vars(parser.parse_args())

examples/offline_safety_model_inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
--max-num_batched_tokens=4096
1919
"""
2020

21+
import os
22+
2123
import vllm.envs as vllm_envs
2224
from vllm import LLM, EngineArgs
2325
from vllm.utils.argparse_utils import FlexibleArgumentParser
2426

25-
import tpu_inference.envs as envs
2627
from tpu_inference.core import disagg_utils
2728

2829

@@ -219,7 +220,7 @@ def main(args: dict):
219220

220221
if __name__ == "__main__":
221222
# Skip long warmup for local simple test.
222-
envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True
223+
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
223224

224225
parser = create_parser()
225226
args: dict = vars(parser.parse_args())

tests/runner/test_tpu_runner_mesh.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Unit tests for TPUModelRunner mesh initialization."""
2+
import os
23
from unittest.mock import Mock, patch
34

45
import pytest
56

6-
import tpu_inference.envs as envs
77
from tpu_inference.runner.tpu_runner import TPUModelRunner
88

99

@@ -53,7 +53,7 @@ def runner_instance(self, mock_vllm_config, mock_devices):
5353
def test_init_mesh_2d_model_without_device_order(self, runner_instance,
5454
mock_vllm_config):
5555
"""Test 2d mesh creation without enforced device order."""
56-
with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: False}), \
56+
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \
5757
patch('tpu_inference.runner.tpu_runner.make_optimized_mesh') as mock_make_mesh, \
5858
patch('tpu_inference.runner.tpu_runner.logger'):
5959

@@ -79,7 +79,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance,
7979
"""Test 2d mesh creation with enforced device order."""
8080
mock_vllm_config.sharding_config.device_indexes = [0, 1, 2, 3]
8181

82-
with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: False}), \
82+
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \
8383
patch('jax.make_mesh') as mock_jax_mesh, \
8484
patch('tpu_inference.runner.tpu_runner.logger'):
8585

@@ -103,7 +103,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance,
103103
def test_init_mesh_new_model_single_slice(self, runner_instance,
104104
mock_vllm_config):
105105
"""Test new model mesh creation with single slice."""
106-
with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: True, 'NUM_SLICES': lambda: 1}), \
106+
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': '1'}), \
107107
patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \
108108
patch('jax.sharding.Mesh') as mock_jax_mesh, \
109109
patch('tpu_inference.runner.tpu_runner.logger'):
@@ -134,7 +134,7 @@ def test_init_mesh_new_model_multi_slice(self, runner_instance,
134134
mock_vllm_config):
135135
"""Test new model mesh creation with multiple slices."""
136136
num_slices = 2
137-
with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: True, 'NUM_SLICES': lambda: num_slices}), \
137+
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': str(num_slices)}), \
138138
patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \
139139
patch('jax.sharding.Mesh') as mock_jax_mesh, \
140140
patch('tpu_inference.runner.tpu_runner.logger'):

0 commit comments

Comments
 (0)