44from typing import TYPE_CHECKING , Optional , Tuple , Union , cast
55
66import jax .numpy as jnp
7- import vllm .envs as envs
7+ import vllm .envs as vllm_envs
88from torchax .ops .mappings import j2t_dtype
99from tpu_info import device
1010from vllm .inputs import ProcessorInputs , PromptType
1111from vllm .platforms .interface import Platform , PlatformEnum
1212from vllm .sampling_params import SamplingParams , SamplingType
1313
14+ from tpu_inference import envs
1415from tpu_inference .layers .jax .sharding import ShardingConfigManager
1516from tpu_inference .logger import init_logger
1617
@@ -72,7 +73,7 @@ def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
7273 @classmethod
7374 def get_device_name (cls , device_id : int = 0 ) -> str :
7475 try :
75- if envs .VLLM_TPU_USING_PATHWAYS :
76+ if vllm_envs .VLLM_TPU_USING_PATHWAYS :
7677 # Causes mutliprocess accessing IFRT when calling jax.devices()
7778 return "TPU v6 lite"
7879 else :
@@ -88,7 +89,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
8889
8990 @classmethod
9091 def is_async_output_supported (cls , enforce_eager : Optional [bool ]) -> bool :
91- return not envs .VLLM_USE_V1
92+ return not vllm_envs .VLLM_USE_V1
9293
9394 @classmethod
9495 def get_punica_wrapper (cls ) -> str :
@@ -119,11 +120,11 @@ def _initialize_sharding_config(cls, vllm_config: VllmConfig) -> None:
119120
120121 @classmethod
121122 def check_and_update_config (cls , vllm_config : VllmConfig ) -> None :
122- if not envs .VLLM_USE_V1 :
123+ if not vllm_envs .VLLM_USE_V1 :
123124 raise RuntimeError ("VLLM_USE_V1=1 must be set for JAX backend." )
124125
125- if envs .VLLM_TPU_USING_PATHWAYS :
126- assert not envs .VLLM_ENABLE_V1_MULTIPROCESSING , (
126+ if vllm_envs .VLLM_TPU_USING_PATHWAYS :
127+ assert not vllm_envs .VLLM_ENABLE_V1_MULTIPROCESSING , (
127128 "VLLM_ENABLE_V1_MULTIPROCESSING must be 0 when using Pathways(JAX_PLATFORMS=proxy)"
128129 )
129130 cls ._initialize_sharding_config (vllm_config )
@@ -145,7 +146,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
145146 compilation_config .backend = "openxla"
146147
147148 # If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
148- impl = os . getenv ( " MODEL_IMPL_TYPE" , "flax_nnx" ). lower ()
149+ impl = envs . MODEL_IMPL_TYPE
149150
150151 # NOTE(xiang): convert dtype to jnp.dtype
151152 # NOTE(wenlong): skip this logic for mm model preprocessing
@@ -165,7 +166,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
165166 vllm_config .model_config .dtype = j2t_dtype (
166167 vllm_config .model_config .dtype .dtype )
167168
168- if envs .VLLM_USE_V1 :
169+ if vllm_envs .VLLM_USE_V1 :
169170 # TODO(cuiq): remove this dependency.
170171 from vllm .v1 .attention .backends .pallas import \
171172 PallasAttentionBackend
@@ -251,7 +252,7 @@ def validate_request(
251252 """Raises if this request is unsupported on this platform"""
252253
253254 if isinstance (params , SamplingParams ):
254- if params .structured_outputs is not None and not envs .VLLM_USE_V1 :
255+ if params .structured_outputs is not None and not vllm_envs .VLLM_USE_V1 :
255256 raise ValueError ("Structured output is not supported on "
256257 f"{ cls .device_name } V0." )
257258 if params .sampling_type == SamplingType .RANDOM_SEED :
0 commit comments