Skip to content

Commit 0e82f10

Browse files
committed
Fix profiler directory env var reference and update tests
Signed-off-by: Xing Liu <xingliu14@gmail.com>
1 parent 3026166 commit 0e82f10

File tree

4 files changed

+21
-20
lines changed

4 files changed

+21
-20
lines changed

tests/worker/tpu_worker_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,22 @@ def test_init_success(self, mock_vllm_config):
6363
assert worker.profile_dir is None
6464
assert worker.devices == ['tpu:0']
6565

66-
@patch('tpu_inference.worker.tpu_worker.envs')
66+
@patch('tpu_inference.worker.tpu_worker_jax.vllm_envs')
6767
def test_init_with_profiler_on_rank_zero(self, mock_envs,
6868
mock_vllm_config):
6969
"""Tests that the profiler directory is set correctly on rank 0."""
70-
mock_envs.VLLM_TORCH_PROFILER_DIR = "/tmp/profiles"
70+
mock_vllm_envs.VLLM_TORCH_PROFILER_DIR = "/tmp/profiles"
7171
worker = TPUWorker(vllm_config=mock_vllm_config,
7272
local_rank=0,
7373
rank=0,
7474
distributed_init_method="test_method")
7575
assert worker.profile_dir == "/tmp/profiles"
7676

77-
@patch('tpu_inference.worker.tpu_worker.envs')
77+
@patch('tpu_inference.worker.tpu_worker_jax.vllm_envs')
7878
def test_init_with_profiler_on_other_ranks(self, mock_envs,
7979
mock_vllm_config):
8080
"""Tests that the profiler directory is NOT set on non-rank 0 workers."""
81-
mock_envs.VLLM_TORCH_PROFILER_DIR = "/tmp/profiles"
81+
mock_vllm_envs.VLLM_TORCH_PROFILER_DIR = "/tmp/profiles"
8282
worker = TPUWorker(vllm_config=mock_vllm_config,
8383
local_rank=1,
8484
rank=1,

tpu_inference/models/common/model_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import functools
2-
import os
32
from typing import Any, Optional
43

54
import jax
@@ -11,6 +10,7 @@
1110
from vllm.config import VllmConfig
1211
from vllm.utils.func_utils import supports_kw
1312

13+
from tpu_inference import envs
1414
from tpu_inference.layers.jax.sharding import ShardingAxisName
1515
from tpu_inference.logger import init_logger
1616
from tpu_inference.models.jax.utils.quantization.quantization_utils import (
@@ -314,7 +314,7 @@ def get_model(
314314
mesh: Mesh,
315315
is_draft_model: bool = False,
316316
) -> Any:
317-
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
317+
impl = envs.MODEL_IMPL_TYPE
318318
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
319319

320320
if impl == "flax_nnx":

tpu_inference/platforms/tpu_platform.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
55

66
import jax.numpy as jnp
7-
import vllm.envs as envs
7+
import vllm.envs as vllm_envs
88
from torchax.ops.mappings import j2t_dtype
99
from tpu_info import device
1010
from vllm.inputs import ProcessorInputs, PromptType
1111
from vllm.platforms.interface import Platform, PlatformEnum
1212
from vllm.sampling_params import SamplingParams, SamplingType
1313

14+
from tpu_inference import envs
1415
from tpu_inference.layers.jax.sharding import ShardingConfigManager
1516
from tpu_inference.logger import init_logger
1617

@@ -71,7 +72,7 @@ def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
7172
@classmethod
7273
def get_device_name(cls, device_id: int = 0) -> str:
7374
try:
74-
if envs.VLLM_TPU_USING_PATHWAYS:
75+
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
7576
# Causes mutliprocess accessing IFRT when calling jax.devices()
7677
return "TPU v6 lite"
7778
else:
@@ -87,7 +88,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
8788

8889
@classmethod
8990
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
90-
return not envs.VLLM_USE_V1
91+
return not vllm_envs.VLLM_USE_V1
9192

9293
@classmethod
9394
def get_punica_wrapper(cls) -> str:
@@ -118,11 +119,11 @@ def _initialize_sharding_config(cls, vllm_config: VllmConfig) -> None:
118119

119120
@classmethod
120121
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
121-
if not envs.VLLM_USE_V1:
122+
if not vllm_envs.VLLM_USE_V1:
122123
raise RuntimeError("VLLM_USE_V1=1 must be set for JAX backend.")
123124

124-
if envs.VLLM_TPU_USING_PATHWAYS:
125-
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
125+
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
126+
assert not vllm_envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
126127
"VLLM_ENABLE_V1_MULTIPROCESSING must be 0 when using Pathways(JAX_PLATFORMS=proxy)"
127128
)
128129
cls._initialize_sharding_config(vllm_config)
@@ -144,7 +145,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
144145
compilation_config.backend = "openxla"
145146

146147
# If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
147-
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
148+
impl = envs.MODEL_IMPL_TYPE
148149

149150
# NOTE(xiang): convert dtype to jnp.dtype
150151
# NOTE(wenlong): skip this logic for mm model preprocessing
@@ -164,7 +165,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
164165
vllm_config.model_config.dtype = j2t_dtype(
165166
vllm_config.model_config.dtype.dtype)
166167

167-
if envs.VLLM_USE_V1:
168+
if vllm_envs.VLLM_USE_V1:
168169
# TODO(cuiq): remove this dependency.
169170
from vllm.v1.attention.backends.pallas import \
170171
PallasAttentionBackend
@@ -250,7 +251,7 @@ def validate_request(
250251
"""Raises if this request is unsupported on this platform"""
251252

252253
if isinstance(params, SamplingParams):
253-
if params.structured_outputs is not None and not envs.VLLM_USE_V1:
254+
if params.structured_outputs is not None and not vllm_envs.VLLM_USE_V1:
254255
raise ValueError("Structured output is not supported on "
255256
f"{cls.device_name} V0.")
256257
if params.sampling_type == SamplingType.RANDOM_SEED:

tpu_inference/worker/tpu_worker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import jax.numpy as jnp
99
import jaxlib
1010
import jaxtyping
11-
import vllm.envs as envs
11+
import vllm.envs as vllm_envs
1212
from vllm.config import VllmConfig, set_current_vllm_config
1313
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
1414
has_kv_transfer_group)
@@ -22,7 +22,7 @@
2222
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2323
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
2424

25-
from tpu_inference import utils
25+
from tpu_inference import envs, utils
2626
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
2727
get_node_id)
2828
from tpu_inference.layers.jax.sharding import ShardingConfigManager
@@ -50,7 +50,7 @@ def __init__(self,
5050
devices=None):
5151
# If we use vLLM's model implementation in PyTorch, we should set it
5252
# with torch version of the dtype.
53-
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
53+
impl = envs.MODEL_IMPL_TYPE
5454
if impl != "vllm": # vllm-pytorch implementation does not need this conversion
5555

5656
# NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
@@ -86,11 +86,11 @@ def __init__(self,
8686
# TPU Worker is initialized. The profiler server needs to start after
8787
# MP runtime is initialized.
8888
self.profile_dir = None
89-
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
89+
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
9090
if not self.devices or 0 in self.device_ranks:
9191
# For TPU, we can only have 1 active profiler session for 1 profiler
9292
# server. So we only profile on rank0.
93-
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
93+
self.profile_dir = vllm_envs.VLLM_TORCH_PROFILER_DIR
9494
logger.info("Profiling enabled. Traces will be saved to: %s",
9595
self.profile_dir)
9696

0 commit comments

Comments
 (0)