Skip to content

Commit 85aa68d

Browse files
committed
fix: correct environment variable type conversions and improve comments
- Fix boolean environment variables: use bool(int(...)) instead of bool(os.getenv(..., False)) - Correctly converts '0' to False and '1' to True - Fixed JAX_RANDOM_WEIGHTS, SKIP_JAX_PRECOMPILE, NEW_MODEL_DESIGN - Standardize to use os.getenv() consistently across all variables - Improve docstrings: - Clarify __getattr__ comment: function is wrapped with functools.cache() - Clarify enable_envs_cache() comment: explains caching behavior and lifecycle - Type annotations in TYPE_CHECKING now accurately reflect runtime behavior Signed-off-by: Xing Liu <xingliu14@gmail.com>
1 parent cb1447f commit 85aa68d

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

tpu_inference/envs.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
3+
4+
import functools
5+
import os
6+
from collections.abc import Callable
7+
from typing import TYPE_CHECKING, Any
8+
9+
if TYPE_CHECKING:
10+
JAX_PLATFORMS: str = ""
11+
JAX_RANDOM_WEIGHTS: bool = False
12+
TPU_ACCELERATOR_TYPE: str | None = None
13+
TPU_NAME: str | None = None
14+
TPU_WORKER_ID: str | None = None
15+
TPU_MULTIHOST_BACKEND: str = ""
16+
PREFILL_SLICES: str = ""
17+
DECODE_SLICES: str = ""
18+
SKIP_JAX_PRECOMPILE: bool = False
19+
MODEL_IMPL_TYPE: str = "flax_nnx"
20+
NEW_MODEL_DESIGN: bool = False
21+
PHASED_PROFILING_DIR: str = ""
22+
PYTHON_TRACER_LEVEL: int = 1
23+
USE_MOE_EP_KERNEL: bool = False
24+
RAY_USAGE_STATS_ENABLED: str = "0"
25+
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
26+
27+
28+
environment_variables: dict[str, Callable[[], Any]] = {
29+
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
30+
"JAX_PLATFORMS": lambda: os.getenv("JAX_PLATFORMS", ""),
31+
# Initialize model weights randomly instead of loading from checkpoint
32+
"JAX_RANDOM_WEIGHTS": lambda: bool(int(os.getenv("JAX_RANDOM_WEIGHTS", "0"))),
33+
# TPU accelerator type (e.g., "v5litepod-16", "v4-8")
34+
"TPU_ACCELERATOR_TYPE": lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
35+
# Name of the TPU resource
36+
"TPU_NAME": lambda: os.getenv("TPU_NAME", None),
37+
# Worker ID for multi-host TPU setups
38+
"TPU_WORKER_ID": lambda: os.getenv("TPU_WORKER_ID", None),
39+
# Backend for multi-host communication on TPU
40+
"TPU_MULTIHOST_BACKEND": lambda: os.getenv("TPU_MULTIHOST_BACKEND", "").lower(),
41+
# Slice configuration for disaggregated prefill workers
42+
"PREFILL_SLICES": lambda: os.getenv("PREFILL_SLICES", ""),
43+
# Slice configuration for disaggregated decode workers
44+
"DECODE_SLICES": lambda: os.getenv("DECODE_SLICES", ""),
45+
# Skip JAX precompilation step during initialization
46+
"SKIP_JAX_PRECOMPILE": lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))),
47+
# Model implementation type (e.g., "flax_nnx")
48+
"MODEL_IMPL_TYPE": lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(),
49+
# Enable new experimental model design
50+
"NEW_MODEL_DESIGN": lambda: bool(int(os.getenv("NEW_MODEL_DESIGN", "0"))),
51+
# Directory to store phased profiling output
52+
"PHASED_PROFILING_DIR": lambda: os.getenv("PHASED_PROFILING_DIR", ""),
53+
# Python tracer level for profiling
54+
"PYTHON_TRACER_LEVEL": lambda: int(os.getenv("PYTHON_TRACER_LEVEL", "1")),
55+
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
56+
"USE_MOE_EP_KERNEL": lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))),
57+
# Enable/disable Ray usage statistics collection
58+
"RAY_USAGE_STATS_ENABLED": lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
59+
# Ray compiled DAG channel type for TPU
60+
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": lambda: os.getenv(
61+
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm"
62+
),
63+
}
64+
65+
def __getattr__(name: str) -> Any:
66+
"""
67+
Gets environment variables lazily.
68+
69+
NOTE: After enable_envs_cache() invocation (which triggered after service
70+
initialization), all environment variables will be cached.
71+
"""
72+
if name in environment_variables:
73+
return environment_variables[name]()
74+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
75+
76+
77+
def enable_envs_cache() -> None:
78+
"""
79+
Enables caching of environment variables by wrapping the module's __getattr__
80+
function with functools.cache(). This improves performance by avoiding
81+
repeated re-evaluation of environment variables.
82+
83+
NOTE: This should be called after service initialization. Once enabled,
84+
environment variable values are cached and will not reflect changes to
85+
os.environ until the process is restarted.
86+
"""
87+
# Tag __getattr__ with functools.cache
88+
global __getattr__
89+
__getattr__ = functools.cache(__getattr__)
90+
91+
# Cache all environment variables
92+
for key in environment_variables:
93+
__getattr__(key)
94+
95+
96+
def __dir__() -> list[str]:
97+
return list(environment_variables.keys())

0 commit comments

Comments
 (0)