Skip to content

Commit b8fcb0b

Browse files
authored
Centralized environment variables (#1058)
1 parent 83f38a1 commit b8fcb0b

File tree

6 files changed

+128
-20
lines changed

6 files changed

+128
-20
lines changed

tests/worker/tpu_worker_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_init_success(self, mock_vllm_config):
6363
assert worker.profile_dir is None
6464
assert worker.devices == ['tpu:0']
6565

66-
@patch('tpu_inference.worker.tpu_worker.envs')
66+
@patch('tpu_inference.worker.tpu_worker.vllm_envs')
6767
def test_init_with_profiler_on_rank_zero(self, mock_envs,
6868
mock_vllm_config):
6969
"""Tests that the profiler directory is set correctly on rank 0."""
@@ -74,7 +74,7 @@ def test_init_with_profiler_on_rank_zero(self, mock_envs,
7474
distributed_init_method="test_method")
7575
assert worker.profile_dir == "/tmp/profiles"
7676

77-
@patch('tpu_inference.worker.tpu_worker.envs')
77+
@patch('tpu_inference.worker.tpu_worker.vllm_envs')
7878
def test_init_with_profiler_on_other_ranks(self, mock_envs,
7979
mock_vllm_config):
8080
"""Tests that the profiler directory is NOT set on non-rank 0 workers."""

tpu_inference/envs.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
3+
4+
import functools
5+
import os
6+
from collections.abc import Callable
7+
from typing import TYPE_CHECKING, Any
8+
9+
if TYPE_CHECKING:
10+
JAX_PLATFORMS: str = ""
11+
TPU_ACCELERATOR_TYPE: str | None = None
12+
TPU_NAME: str | None = None
13+
TPU_WORKER_ID: str | None = None
14+
TPU_MULTIHOST_BACKEND: str = ""
15+
PREFILL_SLICES: str = ""
16+
DECODE_SLICES: str = ""
17+
SKIP_JAX_PRECOMPILE: bool = False
18+
MODEL_IMPL_TYPE: str = "flax_nnx"
19+
NEW_MODEL_DESIGN: bool = False
20+
PHASED_PROFILING_DIR: str = ""
21+
PYTHON_TRACER_LEVEL: int = 1
22+
USE_MOE_EP_KERNEL: bool = False
23+
RAY_USAGE_STATS_ENABLED: str = "0"
24+
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
25+
26+
environment_variables: dict[str, Callable[[], Any]] = {
27+
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
28+
"JAX_PLATFORMS":
29+
lambda: os.getenv("JAX_PLATFORMS", ""),
30+
# TPU accelerator type (e.g., "v5litepod-16", "v4-8")
31+
"TPU_ACCELERATOR_TYPE":
32+
lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
33+
# Name of the TPU resource
34+
"TPU_NAME":
35+
lambda: os.getenv("TPU_NAME", None),
36+
# Worker ID for multi-host TPU setups
37+
"TPU_WORKER_ID":
38+
lambda: os.getenv("TPU_WORKER_ID", None),
39+
# Backend for multi-host communication on TPU
40+
"TPU_MULTIHOST_BACKEND":
41+
lambda: os.getenv("TPU_MULTIHOST_BACKEND", "").lower(),
42+
# Slice configuration for disaggregated prefill workers
43+
"PREFILL_SLICES":
44+
lambda: os.getenv("PREFILL_SLICES", ""),
45+
# Slice configuration for disaggregated decode workers
46+
"DECODE_SLICES":
47+
lambda: os.getenv("DECODE_SLICES", ""),
48+
# Skip JAX precompilation step during initialization
49+
"SKIP_JAX_PRECOMPILE":
50+
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))),
51+
# Model implementation type (e.g., "flax_nnx")
52+
"MODEL_IMPL_TYPE":
53+
lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(),
54+
# Enable new experimental model design
55+
"NEW_MODEL_DESIGN":
56+
lambda: bool(int(os.getenv("NEW_MODEL_DESIGN", "0"))),
57+
# Directory to store phased profiling output
58+
"PHASED_PROFILING_DIR":
59+
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
60+
# Python tracer level for profiling
61+
"PYTHON_TRACER_LEVEL":
62+
lambda: int(os.getenv("PYTHON_TRACER_LEVEL", "1")),
63+
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
64+
"USE_MOE_EP_KERNEL":
65+
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))),
66+
# Enable/disable Ray usage statistics collection
67+
"RAY_USAGE_STATS_ENABLED":
68+
lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
69+
# Ray compiled DAG channel type for TPU
70+
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
71+
lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm"),
72+
}
73+
74+
75+
def __getattr__(name: str) -> Any:
76+
"""
77+
Gets environment variables lazily.
78+
79+
NOTE: After enable_envs_cache() invocation (which triggered after service
80+
initialization), all environment variables will be cached.
81+
"""
82+
if name in environment_variables:
83+
return environment_variables[name]()
84+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
85+
86+
87+
def enable_envs_cache() -> None:
88+
"""
89+
Enables caching of environment variables by wrapping the module's __getattr__
90+
function with functools.cache(). This improves performance by avoiding
91+
repeated re-evaluation of environment variables.
92+
93+
NOTE: This should be called after service initialization. Once enabled,
94+
environment variable values are cached and will not reflect changes to
95+
os.environ until the process is restarted.
96+
"""
97+
# Tag __getattr__ with functools.cache
98+
global __getattr__
99+
__getattr__ = functools.cache(__getattr__)
100+
101+
# Cache all environment variables
102+
for key in environment_variables:
103+
__getattr__(key)
104+
105+
106+
def __dir__() -> list[str]:
107+
return list(environment_variables.keys())

tpu_inference/layers/vllm/quantization/unquantized.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
from typing import Any, Callable, Optional, Union
32

43
import jax
@@ -22,6 +21,7 @@
2221
from vllm.model_executor.layers.quantization.base_config import (
2322
QuantizationConfig, QuantizeMethodBase)
2423

24+
from tpu_inference import envs
2525
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
2626
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
2727
from tpu_inference.layers.vllm.linear_common import (
@@ -164,7 +164,7 @@ def __init__(self,
164164
ep_axis_name: str = 'model'):
165165
super().__init__(moe)
166166
self.mesh = mesh
167-
self.use_kernel = bool(int(os.getenv("USE_MOE_EP_KERNEL", "0")))
167+
self.use_kernel = envs.USE_MOE_EP_KERNEL
168168
self.ep_axis_name = ep_axis_name
169169
# TODO: Use autotune table once we have it.
170170
self.block_size = {

tpu_inference/models/common/model_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import functools
2-
import os
32
from typing import Any, Optional
43

54
import jax
@@ -11,6 +10,7 @@
1110
from vllm.config import VllmConfig
1211
from vllm.utils.func_utils import supports_kw
1312

13+
from tpu_inference import envs
1414
from tpu_inference.layers.jax.sharding import ShardingAxisName
1515
from tpu_inference.logger import init_logger
1616
from tpu_inference.models.jax.utils.quantization.quantization_utils import (
@@ -314,7 +314,7 @@ def get_model(
314314
mesh: Mesh,
315315
is_draft_model: bool = False,
316316
) -> Any:
317-
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
317+
impl = envs.MODEL_IMPL_TYPE
318318
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
319319

320320
if impl == "flax_nnx":

tpu_inference/platforms/tpu_platform.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
55

66
import jax.numpy as jnp
7-
import vllm.envs as envs
7+
import vllm.envs as vllm_envs
88
from torchax.ops.mappings import j2t_dtype
99
from tpu_info import device
1010
from vllm.inputs import ProcessorInputs, PromptType
1111
from vllm.platforms.interface import Platform, PlatformEnum
1212
from vllm.sampling_params import SamplingParams, SamplingType
1313

14+
from tpu_inference import envs
1415
from tpu_inference.layers.jax.sharding import ShardingConfigManager
1516
from tpu_inference.logger import init_logger
1617

@@ -71,7 +72,7 @@ def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
7172
@classmethod
7273
def get_device_name(cls, device_id: int = 0) -> str:
7374
try:
74-
if envs.VLLM_TPU_USING_PATHWAYS:
75+
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
7576
# Causes mutliprocess accessing IFRT when calling jax.devices()
7677
return "TPU v6 lite"
7778
else:
@@ -87,7 +88,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
8788

8889
@classmethod
8990
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
90-
return not envs.VLLM_USE_V1
91+
return not vllm_envs.VLLM_USE_V1
9192

9293
@classmethod
9394
def get_punica_wrapper(cls) -> str:
@@ -118,11 +119,11 @@ def _initialize_sharding_config(cls, vllm_config: VllmConfig) -> None:
118119

119120
@classmethod
120121
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
121-
if not envs.VLLM_USE_V1:
122+
if not vllm_envs.VLLM_USE_V1:
122123
raise RuntimeError("VLLM_USE_V1=1 must be set for JAX backend.")
123124

124-
if envs.VLLM_TPU_USING_PATHWAYS:
125-
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
125+
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
126+
assert not vllm_envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
126127
"VLLM_ENABLE_V1_MULTIPROCESSING must be 0 when using Pathways(JAX_PLATFORMS=proxy)"
127128
)
128129
cls._initialize_sharding_config(vllm_config)
@@ -144,7 +145,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
144145
compilation_config.backend = "openxla"
145146

146147
# If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
147-
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
148+
impl = envs.MODEL_IMPL_TYPE
148149

149150
# NOTE(xiang): convert dtype to jnp.dtype
150151
# NOTE(wenlong): skip this logic for mm model preprocessing
@@ -164,7 +165,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
164165
vllm_config.model_config.dtype = j2t_dtype(
165166
vllm_config.model_config.dtype.dtype)
166167

167-
if envs.VLLM_USE_V1:
168+
if vllm_envs.VLLM_USE_V1:
168169
# TODO(cuiq): remove this dependency.
169170
from vllm.v1.attention.backends.pallas import \
170171
PallasAttentionBackend
@@ -250,7 +251,7 @@ def validate_request(
250251
"""Raises if this request is unsupported on this platform"""
251252

252253
if isinstance(params, SamplingParams):
253-
if params.structured_outputs is not None and not envs.VLLM_USE_V1:
254+
if params.structured_outputs is not None and not vllm_envs.VLLM_USE_V1:
254255
raise ValueError("Structured output is not supported on "
255256
f"{cls.device_name} V0.")
256257
if params.sampling_type == SamplingType.RANDOM_SEED:

tpu_inference/worker/tpu_worker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import jax.numpy as jnp
99
import jaxlib
1010
import jaxtyping
11-
import vllm.envs as envs
11+
import vllm.envs as vllm_envs
1212
from vllm.config import VllmConfig, set_current_vllm_config
1313
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
1414
has_kv_transfer_group)
@@ -22,7 +22,7 @@
2222
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2323
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
2424

25-
from tpu_inference import utils
25+
from tpu_inference import envs, utils
2626
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
2727
get_node_id)
2828
from tpu_inference.layers.jax.sharding import ShardingConfigManager
@@ -50,7 +50,7 @@ def __init__(self,
5050
devices=None):
5151
# If we use vLLM's model implementation in PyTorch, we should set it
5252
# with torch version of the dtype.
53-
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
53+
impl = envs.MODEL_IMPL_TYPE
5454
if impl != "vllm": # vllm-pytorch implementation does not need this conversion
5555

5656
# NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
@@ -86,11 +86,11 @@ def __init__(self,
8686
# TPU Worker is initialized. The profiler server needs to start after
8787
# MP runtime is initialized.
8888
self.profile_dir = None
89-
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
89+
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
9090
if not self.devices or 0 in self.device_ranks:
9191
# For TPU, we can only have 1 active profiler session for 1 profiler
9292
# server. So we only profile on rank0.
93-
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
93+
self.profile_dir = vllm_envs.VLLM_TORCH_PROFILER_DIR
9494
logger.info("Profiling enabled. Traces will be saved to: %s",
9595
self.profile_dir)
9696

0 commit comments

Comments
 (0)