Skip to content

Commit c861ffe

Browse files
committed
Refactor environment variable handling to use EnvironmentVariables class
Signed-off-by: Xing Liu <xingliu14@gmail.com>
1 parent af6afc9 commit c861ffe

File tree

14 files changed

+72
-46
lines changed

14 files changed

+72
-46
lines changed

examples/offline_inference.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import os
5-
6-
import vllm.envs as envs
4+
import tpu_inference.envs as envs
5+
import vllm.envs as vllm_envs
76
from vllm import LLM, EngineArgs
87
from vllm.utils.argparse_utils import FlexibleArgumentParser
98

@@ -87,10 +86,10 @@ def main(args: dict):
8786
'Who wrote the novel "Pride and Prejudice"?',
8887
]
8988

90-
if envs.VLLM_TORCH_PROFILER_DIR is not None:
89+
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
9190
llm.start_profile()
9291
outputs = llm.generate(prompts, sampling_params)
93-
if envs.VLLM_TORCH_PROFILER_DIR is not None:
92+
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
9493
llm.stop_profile()
9594

9695
# Print the outputs.
@@ -104,7 +103,7 @@ def main(args: dict):
104103

105104
if __name__ == "__main__":
106105
# Skip long warmup for local simple test.
107-
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
106+
envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True
108107

109108
parser = create_parser()
110109
args: dict = vars(parser.parse_args())

examples/offline_safety_model_inference.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
--max-num_batched_tokens=4096
1919
"""
2020

21-
import os
22-
23-
import vllm.envs as envs
21+
import tpu_inference.envs as envs
22+
import vllm.envs as vllm_envs
2423
from vllm import LLM, EngineArgs
2524
from vllm.utils.argparse_utils import FlexibleArgumentParser
2625

@@ -170,7 +169,7 @@ def main(args: dict):
170169

171170
prompts.append(TokensPrompt(prompt_token_ids=tokenized_prompt))
172171

173-
if envs.VLLM_TORCH_PROFILER_DIR is not None:
172+
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
174173
llm.start_profile()
175174

176175
outputs = llm.generate(
@@ -179,7 +178,7 @@ def main(args: dict):
179178
use_tqdm=True,
180179
)
181180

182-
if envs.VLLM_TORCH_PROFILER_DIR is not None:
181+
if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None:
183182
llm.stop_profile()
184183

185184
passed_tests = 0
@@ -220,7 +219,7 @@ def main(args: dict):
220219

221220
if __name__ == "__main__":
222221
# Skip long warmup for local simple test.
223-
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
222+
envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True
224223

225224
parser = create_parser()
226225
args: dict = vars(parser.parse_args())

tests/e2e/test_data_parallel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from dataclasses import asdict
77

88
import pytest
9+
import tpu_inference.envs as envs
910
from vllm import LLM, EngineArgs, SamplingParams
1011

1112

@@ -173,8 +174,8 @@ def test_data_parallelism_correctness(
173174
This test compares outputs from a single-device run with data parallel runs
174175
to ensure correctness, including log probabilities.
175176
"""
176-
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
177-
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
177+
envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True
178+
envs.environment_variables['VLLM_XLA_CHECK_RECOMPILATION'] = lambda: False
178179
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
179180
# Use a smaller subset of prompts for correctness testing
180181
small_prompts = test_prompts[:10]

tests/e2e/test_multi_modal_inference.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
# This script is a self-contained test that runs a single prompt and
55
# compares the output to a known-good output.
66

7-
import os
87
from dataclasses import asdict
98

9+
import tpu_inference.envs as envs
1010
from vllm import LLM, EngineArgs, SamplingParams
1111
from vllm.assets.image import ImageAsset
1212
from vllm.multimodal.image import convert_image_mode
@@ -24,9 +24,8 @@ def test_multi_modal_inference(monkeypatch):
2424
"""
2525
Runs multi-modal inference and verifies the output.
2626
"""
27-
os.environ['SKIP_JAX_PRECOMPILE'] = '1' # Skip warmup to save time.
28-
os.environ[
29-
'VLLM_XLA_CHECK_RECOMPILATION'] = '0' # Allow compilation during execution.
27+
envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True # Skip warmup to save time.
28+
envs.environment_variables['VLLM_XLA_CHECK_RECOMPILATION'] = lambda: False # Allow compilation during execution.
3029

3130
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
3231

tests/runner/test_tpu_runner_mesh.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66

7+
import tpu_inference.envs as envs
78
from tpu_inference.runner.tpu_runner import TPUModelRunner
89

910

@@ -53,7 +54,7 @@ def runner_instance(self, mock_vllm_config, mock_devices):
5354
def test_init_mesh_2d_model_without_device_order(self, runner_instance,
5455
mock_vllm_config):
5556
"""Test 2d mesh creation without enforced device order."""
56-
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \
57+
with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: False}), \
5758
patch('tpu_inference.runner.tpu_runner.make_optimized_mesh') as mock_make_mesh, \
5859
patch('tpu_inference.runner.tpu_runner.logger'):
5960

@@ -79,7 +80,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance,
7980
"""Test 2d mesh creation with enforced device order."""
8081
mock_vllm_config.sharding_config.device_indexes = [0, 1, 2, 3]
8182

82-
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \
83+
with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: False}), \
8384
patch('jax.make_mesh') as mock_jax_mesh, \
8485
patch('tpu_inference.runner.tpu_runner.logger'):
8586

@@ -103,7 +104,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance,
103104
def test_init_mesh_new_model_single_slice(self, runner_instance,
104105
mock_vllm_config):
105106
"""Test new model mesh creation with single slice."""
106-
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': '1'}), \
107+
with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: True, 'NUM_SLICES': lambda: 1}), \
107108
patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \
108109
patch('jax.sharding.Mesh') as mock_jax_mesh, \
109110
patch('tpu_inference.runner.tpu_runner.logger'):
@@ -134,7 +135,7 @@ def test_init_mesh_new_model_multi_slice(self, runner_instance,
134135
mock_vllm_config):
135136
"""Test new model mesh creation with multiple slices."""
136137
num_slices = 2
137-
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': str(num_slices)}), \
138+
with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: True, 'NUM_SLICES': lambda: num_slices}), \
138139
patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \
139140
patch('jax.sharding.Mesh') as mock_jax_mesh, \
140141
patch('tpu_inference.runner.tpu_runner.logger'):

tests/test_envs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
6363
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
6464
assert envs.SKIP_JAX_PRECOMPILE is False
6565

66+
# Test VLLM_XLA_CHECK_RECOMPILATION (default False)
67+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
68+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1")
69+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
70+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
71+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
72+
6673
# Test NEW_MODEL_DESIGN (default False)
6774
assert envs.NEW_MODEL_DESIGN is False
6875
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
@@ -81,6 +88,13 @@ def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
8188
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
8289
assert envs.PYTHON_TRACER_LEVEL == 0
8390

91+
# Test NUM_SLICES (default 1)
92+
assert envs.NUM_SLICES == 1
93+
monkeypatch.setenv("NUM_SLICES", "2")
94+
assert envs.NUM_SLICES == 2
95+
monkeypatch.setenv("NUM_SLICES", "4")
96+
assert envs.NUM_SLICES == 4
97+
8498

8599
def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
86100
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
@@ -134,6 +148,7 @@ def test_dir_returns_all_env_vars():
134148
assert "JAX_PLATFORMS" in env_vars
135149
assert "TPU_NAME" in env_vars
136150
assert "SKIP_JAX_PRECOMPILE" in env_vars
151+
assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars
137152
assert "MODEL_IMPL_TYPE" in env_vars
138153

139154

tests/worker/tpu_worker_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from vllm.v1.kv_cache_interface import KVCacheConfig
77
from vllm.v1.outputs import DraftTokenIds
88

9+
import tpu_inference.envs as envs
910
# The class we are testing
1011
from tpu_inference.worker.tpu_worker import TPUWorker
1112

@@ -280,7 +281,7 @@ def test_add_lora_not_implemented_lora_request(self, mock_vllm_config):
280281
#
281282

282283
@patch('tpu_inference.worker.tpu_worker.jax')
283-
@patch.dict('os.environ', {"PYTHON_TRACER_LEVEL": "1"}, clear=True)
284+
@patch.dict(envs.environment_variables, {"PYTHON_TRACER_LEVEL": lambda: 1})
284285
def test_profile_start(self, mock_jax, mock_vllm_config):
285286
"""Tests starting the JAX profiler."""
286287
worker = TPUWorker(vllm_config=mock_vllm_config,
@@ -296,7 +297,7 @@ def test_profile_start(self, mock_jax, mock_vllm_config):
296297
args, kwargs = mock_jax.profiler.start_trace.call_args
297298
assert args[0] == "/tmp/profile_dir"
298299
# Verify options from env var were used
299-
assert kwargs['profiler_options'].python_tracer_level == '1'
300+
assert kwargs['profiler_options'].python_tracer_level == 1
300301

301302
@patch('tpu_inference.worker.tpu_worker.jax')
302303
def test_profile_stop(self, mock_jax, mock_vllm_config):

tpu_inference/envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
PREFILL_SLICES: str = ""
1616
DECODE_SLICES: str = ""
1717
SKIP_JAX_PRECOMPILE: bool = False
18+
VLLM_XLA_CHECK_RECOMPILATION: bool = False
1819
MODEL_IMPL_TYPE: str = "flax_nnx"
1920
NEW_MODEL_DESIGN: bool = False
2021
PHASED_PROFILING_DIR: str = ""
2122
PYTHON_TRACER_LEVEL: int = 1
2223
USE_MOE_EP_KERNEL: bool = False
24+
NUM_SLICES: int = 1
2325
RAY_USAGE_STATS_ENABLED: str = "0"
2426
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
2527

@@ -48,6 +50,9 @@
4850
# Skip JAX precompilation step during initialization
4951
"SKIP_JAX_PRECOMPILE":
5052
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))),
53+
# Check for XLA recompilation during execution
54+
"VLLM_XLA_CHECK_RECOMPILATION":
55+
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))),
5156
# Model implementation type (e.g., "flax_nnx")
5257
"MODEL_IMPL_TYPE":
5358
lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(),
@@ -63,6 +68,9 @@
6368
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
6469
"USE_MOE_EP_KERNEL":
6570
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))),
71+
# Number of TPU slices for multi-slice mesh
72+
"NUM_SLICES":
73+
lambda: int(os.getenv("NUM_SLICES", "1")),
6674
# Enable/disable Ray usage statistics collection
6775
"RAY_USAGE_STATS_ENABLED":
6876
lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),

tpu_inference/executors/ray_distributed_executor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.v1.executor.ray_executor import RayWorkerMetaData
1919
from vllm.v1.executor.ray_utils import RayWorkerWrapper, _wait_until_pg_ready
2020

21+
import tpu_inference.envs as tpu_envs
2122
from tpu_inference.logger import init_logger
2223

2324
try:
@@ -72,7 +73,8 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
7273
def _init_executor(self) -> None:
7374
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
7475

75-
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
76+
# Ensure Ray compiled DAG channel type is set for vLLM
77+
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = tpu_envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
7678

7779
# Currently, this requires USE_RAY_SPMD_WORKER=True.
7880
self.use_ray_compiled_dag = True
@@ -86,10 +88,10 @@ def _init_executor(self) -> None:
8688
self._initialize_ray_cluster()
8789
placement_group = self.parallel_config.placement_group
8890

89-
# Disable Ray usage stats collection.
90-
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
91-
if ray_usage != "1":
92-
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
91+
# Ensure Ray usage stats collection setting is propagated to Ray workers.
92+
# Ray workers inherit environment variables, so we explicitly set this
93+
# based on our configuration (defaults to "0" to disable stats).
94+
os.environ["RAY_USAGE_STATS_ENABLED"] = tpu_envs.RAY_USAGE_STATS_ENABLED
9395

9496
# Create the parallel GPU workers.
9597
self._init_workers_ray(placement_group)

tpu_inference/layers/common/sharding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
from jax.sharding import Mesh
1010

11-
from tpu_inference import utils
11+
from tpu_inference import envs, utils
1212

1313
if TYPE_CHECKING:
1414
from vllm.v1.configs.vllm_config import VllmConfig
@@ -48,7 +48,7 @@ class ShardingAxisName2D:
4848

4949

5050
try:
51-
_use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False)
51+
_use_base_sharding = envs.NEW_MODEL_DESIGN
5252
if _use_base_sharding:
5353
ShardingAxisName = ShardingAxisNameBase
5454
else:
@@ -166,7 +166,7 @@ def validate(cls, vllm_config, sharding_strategy):
166166
f"LoRA is not supported with data parallelism "
167167
f"(DP size: {total_dp_size}). Please disable LoRA or "
168168
f"set data parallelism to 1.")
169-
if not os.environ.get("NEW_MODEL_DESIGN", False):
169+
if not envs.NEW_MODEL_DESIGN:
170170
raise ValueError(
171171
"Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
172172
"NEW_MODEL_DESIGN=True.")

0 commit comments

Comments
 (0)