44from typing import TYPE_CHECKING , Optional , Tuple , Union , cast
55
66import jax .numpy as jnp
7- import vllm .envs as envs
7+ import vllm .envs as vllm_envs
88from torchax .ops .mappings import j2t_dtype
99from tpu_info import device
1010from vllm .inputs import ProcessorInputs , PromptType
1111from vllm .platforms .interface import Platform , PlatformEnum
1212from vllm .sampling_params import SamplingParams , SamplingType
1313
14+ from tpu_inference import envs
1415from tpu_inference .layers .jax .sharding import ShardingConfigManager
1516from tpu_inference .logger import init_logger
1617
@@ -71,7 +72,7 @@ def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
7172 @classmethod
7273 def get_device_name (cls , device_id : int = 0 ) -> str :
7374 try :
74- if envs .VLLM_TPU_USING_PATHWAYS :
75+ if vllm_envs .VLLM_TPU_USING_PATHWAYS :
7576 # Causes mutliprocess accessing IFRT when calling jax.devices()
7677 return "TPU v6 lite"
7778 else :
@@ -87,7 +88,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
8788
8889 @classmethod
8990 def is_async_output_supported (cls , enforce_eager : Optional [bool ]) -> bool :
90- return not envs .VLLM_USE_V1
91+ return not vllm_envs .VLLM_USE_V1
9192
9293 @classmethod
9394 def get_punica_wrapper (cls ) -> str :
@@ -118,11 +119,11 @@ def _initialize_sharding_config(cls, vllm_config: VllmConfig) -> None:
118119
119120 @classmethod
120121 def check_and_update_config (cls , vllm_config : VllmConfig ) -> None :
121- if not envs .VLLM_USE_V1 :
122+ if not vllm_envs .VLLM_USE_V1 :
122123 raise RuntimeError ("VLLM_USE_V1=1 must be set for JAX backend." )
123124
124- if envs .VLLM_TPU_USING_PATHWAYS :
125- assert not envs .VLLM_ENABLE_V1_MULTIPROCESSING , (
125+ if vllm_envs .VLLM_TPU_USING_PATHWAYS :
126+ assert not vllm_envs .VLLM_ENABLE_V1_MULTIPROCESSING , (
126127 "VLLM_ENABLE_V1_MULTIPROCESSING must be 0 when using Pathways(JAX_PLATFORMS=proxy)"
127128 )
128129 cls ._initialize_sharding_config (vllm_config )
@@ -144,7 +145,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
144145 compilation_config .backend = "openxla"
145146
146147 # If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
147- impl = os . getenv ( " MODEL_IMPL_TYPE" , "flax_nnx" ). lower ()
148+ impl = envs . MODEL_IMPL_TYPE
148149
149150 # NOTE(xiang): convert dtype to jnp.dtype
150151 # NOTE(wenlong): skip this logic for mm model preprocessing
@@ -164,7 +165,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
164165 vllm_config .model_config .dtype = j2t_dtype (
165166 vllm_config .model_config .dtype .dtype )
166167
167- if envs .VLLM_USE_V1 :
168+ if vllm_envs .VLLM_USE_V1 :
168169 # TODO(cuiq): remove this dependency.
169170 from vllm .v1 .attention .backends .pallas import \
170171 PallasAttentionBackend
@@ -250,7 +251,7 @@ def validate_request(
250251 """Raises if this request is unsupported on this platform"""
251252
252253 if isinstance (params , SamplingParams ):
253- if params .structured_outputs is not None and not envs .VLLM_USE_V1 :
254+ if params .structured_outputs is not None and not vllm_envs .VLLM_USE_V1 :
254255 raise ValueError ("Structured output is not supported on "
255256 f"{ cls .device_name } V0." )
256257 if params .sampling_type == SamplingType .RANDOM_SEED :
0 commit comments