22
33import os
44import tempfile
5+ from dataclasses import dataclass , field
56from typing import Callable , Dict , Optional , Tuple
67
78import jax
1011import jaxtyping
1112import vllm .envs as vllm_envs
1213from vllm .config import VllmConfig , set_current_vllm_config
14+ from vllm .distributed import get_pp_group
1315from vllm .distributed .kv_transfer import (ensure_kv_transfer_initialized ,
1416 has_kv_transfer_group )
1517from vllm .distributed .parallel_state import (ensure_model_parallel_initialized ,
2325from vllm .v1 .outputs import DraftTokenIds , ModelRunnerOutput
2426
2527from tpu_inference import envs , utils
28+ from tpu_inference .distributed import jax_parallel_state
2629from tpu_inference .distributed .utils import (get_host_ip , get_kv_transfer_port ,
2730 get_node_id )
2831from tpu_inference .layers .common .sharding import ShardingConfigManager
2932from tpu_inference .logger import init_logger
33+ from tpu_inference .models .jax .jax_intermediate_tensor import \
34+ JaxIntermediateTensors
3035from tpu_inference .runner .kv_cache import get_rpa_page_size_bytes
3136from tpu_inference .runner .tpu_runner import TPUModelRunner
3237
3944}
4045
4146
47+ @dataclass
48+ class PPConfig :
49+ rank : int
50+ ip : str
51+ prev_worker_ip : str
52+ pp_world_size : int
53+
54+ # default env vars for
55+ # TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
56+ # if PP is used in single host.
57+ default_tpu_process_bounds : str = field (init = False )
58+ default_tpu_chips_per_process_bounds : str = field (init = False )
59+ default_tpu_visible_chips : str = field (init = False )
60+
61+ def __post_init__ (self ):
62+ self .default_tpu_process_bounds = f"1,{ self .pp_world_size } ,1"
63+ self .default_tpu_chips_per_process_bounds = "1,1,1"
64+ self .default_tpu_visible_chips = f"{ self .rank } "
65+
66+
4267class TPUWorker :
4368
44- def __init__ (self ,
45- vllm_config : VllmConfig ,
46- local_rank : int ,
47- rank : int ,
48- distributed_init_method : str ,
49- is_driver_worker : bool = False ,
50- devices = None ):
69+ def __init__ (
70+ self ,
71+ vllm_config : VllmConfig ,
72+ local_rank : int ,
73+ rank : int ,
74+ distributed_init_method : str ,
75+ is_driver_worker : bool = False ,
76+ devices = None ,
77+ ip : str = "localhost" ,
78+ prev_worker_ip : str = "localhost" ,
79+ ):
5180 # If we use vLLM's model implementation in PyTorch, we should set it
5281 # with torch version of the dtype.
5382 impl = envs .MODEL_IMPL_TYPE
@@ -74,6 +103,8 @@ def __init__(self,
74103 self .devices = devices if devices is not None else []
75104 self .device_ranks = set (device .id for device in self .devices
76105 if isinstance (device , jaxlib ._jax .Device ))
106+ self .pp_config = PPConfig (rank , ip , prev_worker_ip ,
107+ self .parallel_config .pipeline_parallel_size )
77108
78109 if self .model_config .trust_remote_code :
79110 # note: lazy import to avoid importing torch before initializing
@@ -86,14 +117,22 @@ def __init__(self,
86117 # TPU Worker is initialized. The profiler server needs to start after
87118 # MP runtime is initialized.
88119 self .profile_dir = None
89- if vllm_envs .VLLM_TORCH_PROFILER_DIR and self .rank < 1 :
120+ if vllm_envs .VLLM_TORCH_PROFILER_DIR and self .rank < 1 and self . pp_config . pp_world_size == 1 :
90121 if not self .devices or 0 in self .device_ranks :
91122 # For TPU, we can only have 1 active profiler session for 1 profiler
92123 # server. So we only profile on rank0.
93124 self .profile_dir = vllm_envs .VLLM_TORCH_PROFILER_DIR
94125 logger .info ("Profiling enabled. Traces will be saved to: %s" ,
95126 self .profile_dir )
96127
128+ # For PP, we use MPMD so we want to profile every worker.
129+ if self .pp_config .pp_world_size > 1 and vllm_envs .VLLM_TORCH_PROFILER_DIR :
130+ self .profile_dir = os .path .join (
131+ vllm_envs .VLLM_TORCH_PROFILER_DIR ,
132+ f"pprank_{ self .rank } _ppworldsize_{ self .pp_config .pp_world_size } "
133+ )
134+ os .makedirs (self .profile_dir , exist_ok = True )
135+
97136 use_jax_profiler_server = os .getenv ("USE_JAX_PROFILER_SERVER" , False )
98137 # Only one instance of profiler is allowed
99138 if use_jax_profiler_server and self .rank < 1 :
@@ -105,32 +144,78 @@ def __init__(self,
105144 )
106145 jax .profiler .start_server (jax_profiler_server_port )
107146
147+ # step_counter is used to calculate uuid to transfer intermediate tensors.
148+ self .step_counter = 0
149+
108150 def initialize_cache (self , num_gpu_blocks : int ,
109151 num_cpu_blocks : int ) -> None :
110152 self .cache_config .num_gpu_blocks = num_gpu_blocks
111153 self .cache_config .num_cpu_blocks = num_cpu_blocks
112154
113- def init_device (self ):
155+ def init_device (self ,
156+ tpu_process_bounds = "" ,
157+ tpu_chips_per_process_bounds = "" ,
158+ tpu_visible_chips = "" ):
159+ # set tpu visible devices for Jax runtime in single host PP.
160+ multihost_backend = os .environ .get ("TPU_MULTIHOST_BACKEND" , "" ).lower ()
161+ if multihost_backend != "ray" and self .parallel_config .pipeline_parallel_size > 1 :
162+ tpu_ports = [
163+ jax_parallel_state .BASE_JAX_PORT + i
164+ for i in range (self .pp_config .pp_world_size )
165+ ]
166+ os .environ ["TPU_PROCESS_ADDRESSES" ] = "," .join (
167+ [f"localhost:{ port } " for port in tpu_ports ])
168+ os .environ ["TPU_PROCESS_PORT" ] = f"{ tpu_ports [self .rank ]} "
169+ os .environ ["CLOUD_TPU_TASK_ID" ] = f"{ self .rank } "
170+
171+ # Note: Below is the setting for v6e8 host (8 chips of v6e)
172+ # Replace with your own topology.
173+ # There are 2 ways of subslicing a v6e
174+ # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
175+ # TPU_PROCESS_BOUNDS = "1,1,1"
176+ # TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
177+ # TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
178+ # 2) 1 chip for each subslice, with at most 8 subslices,
179+ # we can do TP=1, PP=1/2/3/4/5/6/7/8
180+ os .environ [
181+ "TPU_PROCESS_BOUNDS" ] = tpu_process_bounds \
182+ if tpu_process_bounds \
183+ else self .pp_config .default_tpu_process_bounds
184+ os .environ [
185+ "TPU_CHIPS_PER_PROCESS_BOUNDS" ] = tpu_chips_per_process_bounds \
186+ if tpu_chips_per_process_bounds \
187+ else self .pp_config .default_tpu_chips_per_process_bounds
188+ os .environ [
189+ "TPU_VISIBLE_CHIPS" ] = tpu_visible_chips \
190+ if tpu_visible_chips \
191+ else self .pp_config .default_tpu_visible_chips
192+
114193 if not self .devices :
115194 sharding_config : ShardingConfigManager = self .vllm_config .sharding_config
116195 device_indexes = sharding_config .device_indexes
117196 if device_indexes is not None and len (device_indexes ) > 0 :
118197 # Enforcing the devices sequence to be consistent with the specified device indexes
119- all_devices = jax .devices ()
120- device_dict = {device .id : device for device in all_devices }
198+ all_local_devices = jax .local_devices ()
199+ device_dict = {
200+ device .id : device
201+ for device in all_local_devices
202+ }
121203 self .devices = []
122204 for device_index in device_indexes :
123205 device = device_dict [device_index ]
124206 if device is None :
125207 raise KeyError (
126208 f"Device index { device_index } not found in "
127- f"jax.devices () with IDs { list (device_dict .keys ())} !"
209+ f"jax.local_devices () with IDs { list (device_dict .keys ())} !"
128210 )
129211 self .devices .append (device )
212+ assert len (self .devices ) >= sharding_config .total_devices
130213 self .devices = self .devices [:sharding_config .total_devices ]
131214 else :
132- self .devices = jax .devices ()[:sharding_config .total_devices ]
133-
215+ assert jax .local_device_count (
216+ ) >= sharding_config .total_devices
217+ self .devices = jax .local_devices ()[:sharding_config .
218+ total_devices ]
134219 # Initialize the vLLM distribution layer as a single chip environment,
135220 # we'll swap the model's parallel modules with TPU SPMD equivalents.
136221 with set_current_vllm_config (self .vllm_config ):
@@ -146,15 +231,31 @@ def init_device(self):
146231 tensor_model_parallel_size = 1 ,
147232 pipeline_model_parallel_size = 1 ,
148233 )
234+
235+ jax_parallel_state .init_pp_distributed_environment (
236+ self .pp_config .ip ,
237+ self .rank ,
238+ self .parallel_config .pipeline_parallel_size ,
239+ self .devices [0 ],
240+ need_pp = self .parallel_config .pipeline_parallel_size > 1 )
241+
149242 ensure_kv_transfer_initialized (self .vllm_config )
150- self .model_runner = TPUModelRunner (self .vllm_config , self .devices )
243+ self .model_runner = TPUModelRunner (
244+ self .vllm_config , self .devices , self .rank , self .rank == 0 ,
245+ self .rank == self .pp_config .pp_world_size - 1 )
151246 logger .info (f"Init worker | "
152247 f"rank={ self .rank } | "
153248 f"node_id={ get_node_id ()} | "
154249 f"is_driver_worker={ self .is_driver_worker } | "
155250 f"hbm={ utils .hbm_usage_gb (self .devices )} GiB" )
156251 vllm_utils .report_usage_stats (self .vllm_config )
157252
253+ def initialize_pp_transfer_connect (self ):
254+ if self .rank == 0 :
255+ return
256+ jax_parallel_state .connect (self .pp_config .prev_worker_ip ,
257+ self .rank - 1 )
258+
158259 def determine_available_memory (self ) -> int :
159260 gpu_memory_utilization = self .cache_config .gpu_memory_utilization
160261 hbm_usage = utils .hbm_usage_bytes (self .devices )
@@ -194,14 +295,39 @@ def execute_model(
194295 # deliberate, temporary compromise for the same reasons outlined in
195296 # the `get_kv_cache_spec` method.
196297
197- output = self .model_runner .execute_model (scheduler_output )
198-
199- # With a connector, the scheduler expects output from all workers
200- # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
201- if has_kv_transfer_group ():
202- return output
203-
204- return output if self .is_driver_worker else None
298+ if self .parallel_config .pipeline_parallel_size == 1 or self .rank == 0 :
299+ intermediate_tensors = None
300+ else :
301+ # receive intermediate tensors
302+ uuid = self .model_runner .get_uuid_for_jax_transfer (
303+ scheduler_output , self .rank - 1 , self .step_counter )
304+ # TODO: this method might only works for vllm model, not sure about jax models.
305+ tensor_spec = self .model_runner .get_intermediate_tensor_spec (
306+ scheduler_output .total_num_scheduled_tokens )
307+ intermediate_tensors_dict = get_pp_group ().recv_tensor_dict (
308+ uuid , tensor_spec )
309+ intermediate_tensors = JaxIntermediateTensors (
310+ intermediate_tensors_dict )
311+
312+ output = self .model_runner .execute_model (scheduler_output ,
313+ intermediate_tensors )
314+
315+ if isinstance (output , JaxIntermediateTensors ):
316+ assert self .parallel_config .pipeline_parallel_size > 1
317+ assert not get_pp_group ().is_last_rank
318+ # send intermediate tensors
319+ uuid = self .model_runner .get_uuid_for_jax_transfer (
320+ scheduler_output , self .rank , self .step_counter )
321+ get_pp_group ().send_tensor_dict (uuid , output .tensors )
322+ self .step_counter += 1
323+ return None
324+ else :
325+ self .step_counter += 1
326+ # With a connector, the scheduler expects output from all workers
327+ # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
328+ if has_kv_transfer_group ():
329+ return output
330+ return output if self .is_driver_worker else None
205331
206332 def sample_tokens (self ,
207333 grammar_output : GrammarOutput ) -> ModelRunnerOutput :
0 commit comments