From d6bf9edc91bfe05c3e0bdb1723afcfb98c8d696a Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Fri, 21 Nov 2025 01:22:30 +0000 Subject: [PATCH 1/5] Update offline_lora_inference.py to use tpu_inference.envs Signed-off-by: Xing Liu --- examples/offline_lora_inference.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/offline_lora_inference.py b/examples/offline_lora_inference.py index 386c74e5e..dd8324f66 100644 --- a/examples/offline_lora_inference.py +++ b/examples/offline_lora_inference.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import time -import vllm.envs as envs +import tpu_inference.envs as envs +import vllm.envs as vllm_envs from vllm import LLM, EngineArgs from vllm.lora.request import LoRARequest from vllm.utils.argparse_utils import FlexibleArgumentParser @@ -55,13 +55,13 @@ def main(args: dict): "lora_adapter_3", 3, "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_3_adapter") - if envs.VLLM_TORCH_PROFILER_DIR is not None: + if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None: llm.start_profile() start = time.perf_counter() outputs = llm.generate(prompt, sampling_params=sampling_params, lora_request=lora_request) - if envs.VLLM_TORCH_PROFILER_DIR is not None: + if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None: llm.stop_profile() # Print the outputs. @@ -77,7 +77,7 @@ def main(args: dict): if __name__ == "__main__": # Skip long warmup for local simple test. - os.environ['SKIP_JAX_PRECOMPILE'] = '1' + envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True parser = create_parser() args: dict = vars(parser.parse_args()) From e818a02c6548cc9f1e866d17c5fb51d3ba189d7c Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Fri, 21 Nov 2025 05:28:28 +0000 Subject: [PATCH 2/5] Refactor environment variable handling to use EnvironmentVariables class Signed-off-by: Xing Liu --- examples/offline_inference.py | 11 +++++------ examples/offline_safety_model_inference.py | 11 +++++------ tests/e2e/test_data_parallel.py | 5 +++-- tests/e2e/test_multi_modal_inference.py | 7 +++---- tests/runner/test_tpu_runner_mesh.py | 9 +++++---- tests/test_envs.py | 15 +++++++++++++++ tests/worker/tpu_worker_test.py | 5 +++-- tpu_inference/envs.py | 8 ++++++++ .../executors/ray_distributed_executor.py | 12 +++++++----- tpu_inference/layers/common/sharding.py | 6 +++--- tpu_inference/runner/compilation_manager.py | 11 +++++------ tpu_inference/runner/tpu_runner.py | 12 +++++++----- tpu_inference/runner/utils.py | 4 ++-- tpu_inference/worker/tpu_worker.py | 2 +- 14 files changed, 72 insertions(+), 46 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 293560767..b80f985b8 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os - -import vllm.envs as envs +import tpu_inference.envs as envs +import vllm.envs as vllm_envs from vllm import LLM, EngineArgs from vllm.utils.argparse_utils import FlexibleArgumentParser @@ -87,10 +86,10 @@ def main(args: dict): 'Who wrote the novel "Pride and Prejudice"?', ] - if envs.VLLM_TORCH_PROFILER_DIR is not None: + if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None: llm.start_profile() outputs = llm.generate(prompts, sampling_params) - if envs.VLLM_TORCH_PROFILER_DIR is not None: + if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None: llm.stop_profile() # Print the outputs. @@ -104,7 +103,7 @@ def main(args: dict): if __name__ == "__main__": # Skip long warmup for local simple test. - os.environ['SKIP_JAX_PRECOMPILE'] = '1' + envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True parser = create_parser() args: dict = vars(parser.parse_args()) diff --git a/examples/offline_safety_model_inference.py b/examples/offline_safety_model_inference.py index ebf736148..5147a6645 100644 --- a/examples/offline_safety_model_inference.py +++ b/examples/offline_safety_model_inference.py @@ -18,9 +18,8 @@ --max-num_batched_tokens=4096 """ -import os - -import vllm.envs as envs +import tpu_inference.envs as envs +import vllm.envs as vllm_envs from vllm import LLM, EngineArgs from vllm.utils.argparse_utils import FlexibleArgumentParser @@ -170,7 +169,7 @@ def main(args: dict): prompts.append(TokensPrompt(prompt_token_ids=tokenized_prompt)) - if envs.VLLM_TORCH_PROFILER_DIR is not None: + if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None: llm.start_profile() outputs = llm.generate( @@ -179,7 +178,7 @@ def main(args: dict): use_tqdm=True, ) - if envs.VLLM_TORCH_PROFILER_DIR is not None: + if vllm_envs.VLLM_TORCH_PROFILER_DIR is not None: llm.stop_profile() passed_tests = 0 @@ -220,7 +219,7 @@ def main(args: dict): if __name__ == "__main__": # Skip long warmup for local simple test. - os.environ['SKIP_JAX_PRECOMPILE'] = '1' + envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True parser = create_parser() args: dict = vars(parser.parse_args()) diff --git a/tests/e2e/test_data_parallel.py b/tests/e2e/test_data_parallel.py index 9d794df29..455773524 100644 --- a/tests/e2e/test_data_parallel.py +++ b/tests/e2e/test_data_parallel.py @@ -6,6 +6,7 @@ from dataclasses import asdict import pytest +import tpu_inference.envs as envs from vllm import LLM, EngineArgs, SamplingParams @@ -173,8 +174,8 @@ def test_data_parallelism_correctness( This test compares outputs from a single-device run with data parallel runs to ensure correctness, including log probabilities. """ - os.environ['SKIP_JAX_PRECOMPILE'] = '1' - os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0' + envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True + envs.environment_variables['VLLM_XLA_CHECK_RECOMPILATION'] = lambda: False model_name = "Qwen/Qwen2.5-1.5B-Instruct" # Use a smaller subset of prompts for correctness testing small_prompts = test_prompts[:10] diff --git a/tests/e2e/test_multi_modal_inference.py b/tests/e2e/test_multi_modal_inference.py index c1d2bda77..46c14d90f 100644 --- a/tests/e2e/test_multi_modal_inference.py +++ b/tests/e2e/test_multi_modal_inference.py @@ -4,9 +4,9 @@ # This script is a self-contained test that runs a single prompt and # compares the output to a known-good output. -import os from dataclasses import asdict +import tpu_inference.envs as envs from vllm import LLM, EngineArgs, SamplingParams from vllm.assets.image import ImageAsset from vllm.multimodal.image import convert_image_mode @@ -24,9 +24,8 @@ def test_multi_modal_inference(monkeypatch): """ Runs multi-modal inference and verifies the output. """ - os.environ['SKIP_JAX_PRECOMPILE'] = '1' # Skip warmup to save time. - os.environ[ - 'VLLM_XLA_CHECK_RECOMPILATION'] = '0' # Allow compilation during execution. + envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True # Skip warmup to save time. + envs.environment_variables['VLLM_XLA_CHECK_RECOMPILATION'] = lambda: False # Allow compilation during execution. monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") diff --git a/tests/runner/test_tpu_runner_mesh.py b/tests/runner/test_tpu_runner_mesh.py index cace9531d..c2477c0c2 100644 --- a/tests/runner/test_tpu_runner_mesh.py +++ b/tests/runner/test_tpu_runner_mesh.py @@ -4,6 +4,7 @@ import pytest +import tpu_inference.envs as envs from tpu_inference.runner.tpu_runner import TPUModelRunner @@ -53,7 +54,7 @@ def runner_instance(self, mock_vllm_config, mock_devices): def test_init_mesh_2d_model_without_device_order(self, runner_instance, mock_vllm_config): """Test 2d mesh creation without enforced device order.""" - with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \ + with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: False}), \ patch('tpu_inference.runner.tpu_runner.make_optimized_mesh') as mock_make_mesh, \ patch('tpu_inference.runner.tpu_runner.logger'): @@ -79,7 +80,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance, """Test 2d mesh creation with enforced device order.""" mock_vllm_config.sharding_config.device_indexes = [0, 1, 2, 3] - with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \ + with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: False}), \ patch('jax.make_mesh') as mock_jax_mesh, \ patch('tpu_inference.runner.tpu_runner.logger'): @@ -103,7 +104,7 @@ def test_init_mesh_2d_model_with_device_order(self, runner_instance, def test_init_mesh_new_model_single_slice(self, runner_instance, mock_vllm_config): """Test new model mesh creation with single slice.""" - with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': '1'}), \ + with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: True, 'NUM_SLICES': lambda: 1}), \ patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \ patch('jax.sharding.Mesh') as mock_jax_mesh, \ patch('tpu_inference.runner.tpu_runner.logger'): @@ -134,7 +135,7 @@ def test_init_mesh_new_model_multi_slice(self, runner_instance, mock_vllm_config): """Test new model mesh creation with multiple slices.""" num_slices = 2 - with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': str(num_slices)}), \ + with patch.dict(envs.environment_variables, {'NEW_MODEL_DESIGN': lambda: True, 'NUM_SLICES': lambda: num_slices}), \ patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \ patch('jax.sharding.Mesh') as mock_jax_mesh, \ patch('tpu_inference.runner.tpu_runner.logger'): diff --git a/tests/test_envs.py b/tests/test_envs.py index f707c1d6f..a58cdaf6c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -63,6 +63,13 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0") assert envs.SKIP_JAX_PRECOMPILE is False + # Test VLLM_XLA_CHECK_RECOMPILATION (default False) + assert envs.VLLM_XLA_CHECK_RECOMPILATION is False + monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1") + assert envs.VLLM_XLA_CHECK_RECOMPILATION is True + monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0") + assert envs.VLLM_XLA_CHECK_RECOMPILATION is False + # Test NEW_MODEL_DESIGN (default False) assert envs.NEW_MODEL_DESIGN is False monkeypatch.setenv("NEW_MODEL_DESIGN", "1") @@ -81,6 +88,13 @@ def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0") assert envs.PYTHON_TRACER_LEVEL == 0 + # Test NUM_SLICES (default 1) + assert envs.NUM_SLICES == 1 + monkeypatch.setenv("NUM_SLICES", "2") + assert envs.NUM_SLICES == 2 + monkeypatch.setenv("NUM_SLICES", "4") + assert envs.NUM_SLICES == 4 + def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC") @@ -134,6 +148,7 @@ def test_dir_returns_all_env_vars(): assert "JAX_PLATFORMS" in env_vars assert "TPU_NAME" in env_vars assert "SKIP_JAX_PRECOMPILE" in env_vars + assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars assert "MODEL_IMPL_TYPE" in env_vars diff --git a/tests/worker/tpu_worker_test.py b/tests/worker/tpu_worker_test.py index 4801c861a..11e2dec2b 100644 --- a/tests/worker/tpu_worker_test.py +++ b/tests/worker/tpu_worker_test.py @@ -6,6 +6,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import DraftTokenIds +import tpu_inference.envs as envs # The class we are testing from tpu_inference.worker.tpu_worker import TPUWorker @@ -280,7 +281,7 @@ def test_add_lora_not_implemented_lora_request(self, mock_vllm_config): # @patch('tpu_inference.worker.tpu_worker.jax') - @patch.dict('os.environ', {"PYTHON_TRACER_LEVEL": "1"}, clear=True) + @patch.dict(envs.environment_variables, {"PYTHON_TRACER_LEVEL": lambda: 1}) def test_profile_start(self, mock_jax, mock_vllm_config): """Tests starting the JAX profiler.""" worker = TPUWorker(vllm_config=mock_vllm_config, @@ -296,7 +297,7 @@ def test_profile_start(self, mock_jax, mock_vllm_config): args, kwargs = mock_jax.profiler.start_trace.call_args assert args[0] == "/tmp/profile_dir" # Verify options from env var were used - assert kwargs['profiler_options'].python_tracer_level == '1' + assert kwargs['profiler_options'].python_tracer_level == 1 @patch('tpu_inference.worker.tpu_worker.jax') def test_profile_stop(self, mock_jax, mock_vllm_config): diff --git a/tpu_inference/envs.py b/tpu_inference/envs.py index e97993204..82bf1f053 100644 --- a/tpu_inference/envs.py +++ b/tpu_inference/envs.py @@ -15,11 +15,13 @@ PREFILL_SLICES: str = "" DECODE_SLICES: str = "" SKIP_JAX_PRECOMPILE: bool = False + VLLM_XLA_CHECK_RECOMPILATION: bool = False MODEL_IMPL_TYPE: str = "flax_nnx" NEW_MODEL_DESIGN: bool = False PHASED_PROFILING_DIR: str = "" PYTHON_TRACER_LEVEL: int = 1 USE_MOE_EP_KERNEL: bool = False + NUM_SLICES: int = 1 RAY_USAGE_STATS_ENABLED: str = "0" VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm" @@ -48,6 +50,9 @@ # Skip JAX precompilation step during initialization "SKIP_JAX_PRECOMPILE": lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))), + # Check for XLA recompilation during execution + "VLLM_XLA_CHECK_RECOMPILATION": + lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))), # Model implementation type (e.g., "flax_nnx") "MODEL_IMPL_TYPE": lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(), @@ -63,6 +68,9 @@ # Use custom expert-parallel kernel for MoE (Mixture of Experts) "USE_MOE_EP_KERNEL": lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))), + # Number of TPU slices for multi-slice mesh + "NUM_SLICES": + lambda: int(os.getenv("NUM_SLICES", "1")), # Enable/disable Ray usage statistics collection "RAY_USAGE_STATS_ENABLED": lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"), diff --git a/tpu_inference/executors/ray_distributed_executor.py b/tpu_inference/executors/ray_distributed_executor.py index 1c411a939..0ef201efb 100644 --- a/tpu_inference/executors/ray_distributed_executor.py +++ b/tpu_inference/executors/ray_distributed_executor.py @@ -18,6 +18,7 @@ from vllm.v1.executor.ray_executor import RayWorkerMetaData from vllm.v1.executor.ray_utils import RayWorkerWrapper, _wait_until_pg_ready +import tpu_inference.envs as tpu_envs from tpu_inference.logger import init_logger try: @@ -72,7 +73,8 @@ class RayDistributedExecutor(RayDistributedExecutorV1): def _init_executor(self) -> None: self.forward_dag: Optional[ray.dag.CompiledDAG] = None - os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm" + # Ensure Ray compiled DAG channel type is set for vLLM + os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = tpu_envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE # Currently, this requires USE_RAY_SPMD_WORKER=True. self.use_ray_compiled_dag = True @@ -86,10 +88,10 @@ def _init_executor(self) -> None: self._initialize_ray_cluster() placement_group = self.parallel_config.placement_group - # Disable Ray usage stats collection. - ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") - if ray_usage != "1": - os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + # Ensure Ray usage stats collection setting is propagated to Ray workers. + # Ray workers inherit environment variables, so we explicitly set this + # based on our configuration (defaults to "0" to disable stats). + os.environ["RAY_USAGE_STATS_ENABLED"] = tpu_envs.RAY_USAGE_STATS_ENABLED # Create the parallel GPU workers. self._init_workers_ray(placement_group) diff --git a/tpu_inference/layers/common/sharding.py b/tpu_inference/layers/common/sharding.py index 1a1a8d169..a25bca52e 100644 --- a/tpu_inference/layers/common/sharding.py +++ b/tpu_inference/layers/common/sharding.py @@ -8,7 +8,7 @@ import numpy as np from jax.sharding import Mesh -from tpu_inference import utils +from tpu_inference import envs, utils if TYPE_CHECKING: from vllm.v1.configs.vllm_config import VllmConfig @@ -48,7 +48,7 @@ class ShardingAxisName2D: try: - _use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False) + _use_base_sharding = envs.NEW_MODEL_DESIGN if _use_base_sharding: ShardingAxisName = ShardingAxisNameBase else: @@ -166,7 +166,7 @@ def validate(cls, vllm_config, sharding_strategy): f"LoRA is not supported with data parallelism " f"(DP size: {total_dp_size}). Please disable LoRA or " f"set data parallelism to 1.") - if not os.environ.get("NEW_MODEL_DESIGN", False): + if not envs.NEW_MODEL_DESIGN: raise ValueError( "Must run DP with NEW_MODEL_DESIGN enabled. Please set the " "NEW_MODEL_DESIGN=True.") diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 42b9b199d..92465d1c5 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -1,13 +1,13 @@ -import os import time from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple import jax import jax.numpy as jnp import numpy as np -import vllm.envs as envs +import vllm.envs as vllm_envs from jax.sharding import NamedSharding, PartitionSpec +import tpu_inference.envs as envs from tpu_inference.core.disagg_utils import is_disagg_enabled from tpu_inference.layers.common.attention_metadata import AttentionMetadata from tpu_inference.layers.common.sharding import ShardingAxisName @@ -30,10 +30,10 @@ class CompilationManager: def __init__(self, runner: "TPUModelRunner"): self.runner = runner - if not envs.VLLM_DISABLE_COMPILE_CACHE: + if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE: logger.info("Enabling JAX compile cache.") jax.config.update("jax_compilation_cache_dir", - envs.VLLM_XLA_CACHE_PATH) + vllm_envs.VLLM_XLA_CACHE_PATH) def _create_dummy_tensor(self, shape: Tuple[int, ...], @@ -67,8 +67,7 @@ def _run_compilation(self, name: str, fn: Callable, *args, logger.info("Compilation finished in %.2f [secs].", end - start) def capture_model(self) -> None: - if os.getenv("SKIP_JAX_PRECOMPILE", - False) or self.runner.model_config.enforce_eager: + if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager: return logger.info("Precompile all the subgraphs with possible input shapes.") diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index e76b9056b..921b5825e 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -11,8 +11,10 @@ import jaxtyping import numpy as np import torch -import vllm.envs as envs +import vllm.envs as vllm_envs from flax import nnx + +import tpu_inference.envs as envs from jax.experimental import mesh_utils from jax.sharding import NamedSharding, PartitionSpec from torchax.ops.mappings import j2t_dtype @@ -292,7 +294,7 @@ def _init_random(self): self.rng_key = jax.random.key(self.model_config.seed) def _init_mesh(self) -> None: - if os.getenv("NEW_MODEL_DESIGN", False): + if envs.NEW_MODEL_DESIGN: self.mesh = self._create_new_model_mesh() else: # NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need @@ -303,7 +305,7 @@ def _init_mesh(self) -> None: logger.info(f"Init mesh | mesh={self.mesh}") def _create_new_model_mesh(self) -> jax.sharding.Mesh: - num_slices = int(os.environ.get('NUM_SLICES', 1)) + num_slices = envs.NUM_SLICES logger.info(f"Creating new model mesh | devices={len(self.devices)}, " f"num_slices={num_slices}") @@ -372,7 +374,7 @@ def _create_2d_mesh(self) -> jax.sharding.Mesh: devices=self.devices) def _init_phased_profiling(self) -> None: - self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "") + self.phased_profiling_dir = envs.PHASED_PROFILING_DIR self.phase_based_profiler = None if self.phased_profiling_dir: self.phase_based_profiler = runner_utils.PhasedBasedProfiler( @@ -414,7 +416,7 @@ def _init_inputs(self) -> None: min_token_size=max(16, self.dp_size), max_token_size=scheduler_config.max_num_batched_tokens * self.dp_size, - padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP) self.num_tokens_paddings_per_dp = [ padding // self.dp_size for padding in self.num_tokens_paddings ] diff --git a/tpu_inference/runner/utils.py b/tpu_inference/runner/utils.py index a2d04527e..7b87989d2 100644 --- a/tpu_inference/runner/utils.py +++ b/tpu_inference/runner/utils.py @@ -15,6 +15,7 @@ from jax._src.interpreters import pxla from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput +from tpu_inference import envs from tpu_inference.logger import init_logger from tpu_inference.runner.input_batch import InputBatch @@ -306,8 +307,7 @@ def __init__(self, profile_dir: str): InferencePhase.BALANCED: False } self.default_profiling_options = jax.profiler.ProfileOptions() - self.default_profiling_options.python_tracer_level = os.getenv( - "PYTHON_TRACER_LEVEL", 0) + self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL self.current_phase: str = "" diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index efab89e07..d76f4f42f 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -347,7 +347,7 @@ def profile(self, is_start: bool = True): if is_start: options = jax.profiler.ProfileOptions() # default: https://docs.jax.dev/en/latest/profiling.html#general-options - options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0) + options.python_tracer_level = envs.PYTHON_TRACER_LEVEL options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1) jax.profiler.start_trace(self.profile_dir, profiler_options=options) From 08307158a1e50a4c718a04d273fdf649f12730f0 Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Fri, 21 Nov 2025 06:41:29 +0000 Subject: [PATCH 3/5] fix Signed-off-by: Xing Liu --- tests/e2e/test_data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/test_data_parallel.py b/tests/e2e/test_data_parallel.py index 455773524..a6c28ee93 100644 --- a/tests/e2e/test_data_parallel.py +++ b/tests/e2e/test_data_parallel.py @@ -13,7 +13,7 @@ @pytest.fixture(autouse=True) def setup_new_model_design(): """Automatically set NEW_MODEL_DESIGN=True for all tests.""" - os.environ['NEW_MODEL_DESIGN'] = 'True' + os.environ['NEW_MODEL_DESIGN'] = '1' @pytest.fixture From 286f414e98b37bfbbe4b5a05b251ee804a7048be Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Fri, 21 Nov 2025 06:49:37 +0000 Subject: [PATCH 4/5] Revert test file changes Signed-off-by: Xing Liu --- tests/e2e/test_data_parallel.py | 7 +++---- tests/e2e/test_multi_modal_inference.py | 7 ++++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/e2e/test_data_parallel.py b/tests/e2e/test_data_parallel.py index a6c28ee93..9d794df29 100644 --- a/tests/e2e/test_data_parallel.py +++ b/tests/e2e/test_data_parallel.py @@ -6,14 +6,13 @@ from dataclasses import asdict import pytest -import tpu_inference.envs as envs from vllm import LLM, EngineArgs, SamplingParams @pytest.fixture(autouse=True) def setup_new_model_design(): """Automatically set NEW_MODEL_DESIGN=True for all tests.""" - os.environ['NEW_MODEL_DESIGN'] = '1' + os.environ['NEW_MODEL_DESIGN'] = 'True' @pytest.fixture @@ -174,8 +173,8 @@ def test_data_parallelism_correctness( This test compares outputs from a single-device run with data parallel runs to ensure correctness, including log probabilities. """ - envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True - envs.environment_variables['VLLM_XLA_CHECK_RECOMPILATION'] = lambda: False + os.environ['SKIP_JAX_PRECOMPILE'] = '1' + os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0' model_name = "Qwen/Qwen2.5-1.5B-Instruct" # Use a smaller subset of prompts for correctness testing small_prompts = test_prompts[:10] diff --git a/tests/e2e/test_multi_modal_inference.py b/tests/e2e/test_multi_modal_inference.py index 46c14d90f..c1d2bda77 100644 --- a/tests/e2e/test_multi_modal_inference.py +++ b/tests/e2e/test_multi_modal_inference.py @@ -4,9 +4,9 @@ # This script is a self-contained test that runs a single prompt and # compares the output to a known-good output. +import os from dataclasses import asdict -import tpu_inference.envs as envs from vllm import LLM, EngineArgs, SamplingParams from vllm.assets.image import ImageAsset from vllm.multimodal.image import convert_image_mode @@ -24,8 +24,9 @@ def test_multi_modal_inference(monkeypatch): """ Runs multi-modal inference and verifies the output. """ - envs.environment_variables['SKIP_JAX_PRECOMPILE'] = lambda: True # Skip warmup to save time. - envs.environment_variables['VLLM_XLA_CHECK_RECOMPILATION'] = lambda: False # Allow compilation during execution. + os.environ['SKIP_JAX_PRECOMPILE'] = '1' # Skip warmup to save time. + os.environ[ + 'VLLM_XLA_CHECK_RECOMPILATION'] = '0' # Allow compilation during execution. monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") From 2c92d5d95909b9e355fba6b2b78b5d2f3a374993 Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Fri, 21 Nov 2025 07:02:51 +0000 Subject: [PATCH 5/5] Fix isort, yapf, and ruff formatting issues Signed-off-by: Xing Liu --- examples/offline_inference.py | 2 +- examples/offline_lora_inference.py | 3 ++- examples/offline_safety_model_inference.py | 2 +- tests/runner/test_tpu_runner_mesh.py | 1 - tpu_inference/executors/ray_distributed_executor.py | 6 ++++-- tpu_inference/layers/common/sharding.py | 1 - tpu_inference/runner/tpu_runner.py | 4 +--- 7 files changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index b80f985b8..c98d18d50 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import tpu_inference.envs as envs import vllm.envs as vllm_envs from vllm import LLM, EngineArgs from vllm.utils.argparse_utils import FlexibleArgumentParser +import tpu_inference.envs as envs from tpu_inference.core import disagg_utils diff --git a/examples/offline_lora_inference.py b/examples/offline_lora_inference.py index dd8324f66..6c2c3fe5e 100644 --- a/examples/offline_lora_inference.py +++ b/examples/offline_lora_inference.py @@ -3,12 +3,13 @@ import time -import tpu_inference.envs as envs import vllm.envs as vllm_envs from vllm import LLM, EngineArgs from vllm.lora.request import LoRARequest from vllm.utils.argparse_utils import FlexibleArgumentParser +import tpu_inference.envs as envs + def create_parser(): parser = FlexibleArgumentParser() diff --git a/examples/offline_safety_model_inference.py b/examples/offline_safety_model_inference.py index 5147a6645..9fd4b94ed 100644 --- a/examples/offline_safety_model_inference.py +++ b/examples/offline_safety_model_inference.py @@ -18,11 +18,11 @@ --max-num_batched_tokens=4096 """ -import tpu_inference.envs as envs import vllm.envs as vllm_envs from vllm import LLM, EngineArgs from vllm.utils.argparse_utils import FlexibleArgumentParser +import tpu_inference.envs as envs from tpu_inference.core import disagg_utils diff --git a/tests/runner/test_tpu_runner_mesh.py b/tests/runner/test_tpu_runner_mesh.py index c2477c0c2..8ab4c5dee 100644 --- a/tests/runner/test_tpu_runner_mesh.py +++ b/tests/runner/test_tpu_runner_mesh.py @@ -1,5 +1,4 @@ """Unit tests for TPUModelRunner mesh initialization.""" -import os from unittest.mock import Mock, patch import pytest diff --git a/tpu_inference/executors/ray_distributed_executor.py b/tpu_inference/executors/ray_distributed_executor.py index 0ef201efb..26b2f621f 100644 --- a/tpu_inference/executors/ray_distributed_executor.py +++ b/tpu_inference/executors/ray_distributed_executor.py @@ -74,7 +74,8 @@ def _init_executor(self) -> None: self.forward_dag: Optional[ray.dag.CompiledDAG] = None # Ensure Ray compiled DAG channel type is set for vLLM - os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = tpu_envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE + os.environ[ + "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = tpu_envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE # Currently, this requires USE_RAY_SPMD_WORKER=True. self.use_ray_compiled_dag = True @@ -91,7 +92,8 @@ def _init_executor(self) -> None: # Ensure Ray usage stats collection setting is propagated to Ray workers. # Ray workers inherit environment variables, so we explicitly set this # based on our configuration (defaults to "0" to disable stats). - os.environ["RAY_USAGE_STATS_ENABLED"] = tpu_envs.RAY_USAGE_STATS_ENABLED + os.environ[ + "RAY_USAGE_STATS_ENABLED"] = tpu_envs.RAY_USAGE_STATS_ENABLED # Create the parallel GPU workers. self._init_workers_ray(placement_group) diff --git a/tpu_inference/layers/common/sharding.py b/tpu_inference/layers/common/sharding.py index a25bca52e..817d7c76f 100644 --- a/tpu_inference/layers/common/sharding.py +++ b/tpu_inference/layers/common/sharding.py @@ -1,6 +1,5 @@ import json import math -import os from dataclasses import asdict, dataclass from typing import TYPE_CHECKING, List, Optional diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 921b5825e..be0d6af52 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -1,6 +1,5 @@ import copy import functools -import os import random from contextlib import nullcontext from dataclasses import dataclass @@ -13,8 +12,6 @@ import torch import vllm.envs as vllm_envs from flax import nnx - -import tpu_inference.envs as envs from jax.experimental import mesh_utils from jax.sharding import NamedSharding, PartitionSpec from torchax.ops.mappings import j2t_dtype @@ -37,6 +34,7 @@ KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +import tpu_inference.envs as envs from tpu_inference import utils as common_utils from tpu_inference.layers.common.attention_metadata import AttentionMetadata from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,