Skip to content

Commit 5fe5fad

Browse files
authored
Enable Pipeline Parallelism on jax worker (#1043)
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent 98e7d81 commit 5fe5fad

File tree

3 files changed

+174
-29
lines changed

3 files changed

+174
-29
lines changed

tests/worker/tpu_worker_test.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def mock_vllm_config():
2525
mock_parallel_conf = MagicMock()
2626
mock_parallel_conf.tensor_parallel_size = 2
2727
mock_parallel_conf.data_parallel_size = 1
28+
mock_parallel_conf.pipeline_parallel_size = 1
2829
mock_parallel_conf.nnodes = 1
2930
mock_parallel_conf.nnodes_within_dp = 1
3031

@@ -118,8 +119,14 @@ def test_init_device_with_provided_devices(
118119

119120
worker.init_device()
120121

121-
mock_jax.devices.assert_not_called()
122-
mock_runner_cls.assert_called_once_with(mock_vllm_config, mock_devices)
122+
mock_jax.local_devices.assert_not_called()
123+
expected_rank = 0
124+
expected_is_first_rank = True
125+
expected_is_last_rank = True
126+
mock_runner_cls.assert_called_once_with(mock_vllm_config, mock_devices,
127+
expected_rank,
128+
expected_is_first_rank,
129+
expected_is_last_rank)
123130
assert isinstance(worker.model_runner, MagicMock)
124131

125132
@patch('tpu_inference.worker.tpu_worker.TPUModelRunner')
@@ -137,15 +144,24 @@ def test_init_device_autodetects_devices(
137144
distributed_init_method="test_method",
138145
devices=[] # No devices provided, should trigger auto-detection
139146
)
140-
mock_jax.devices.return_value = ['tpu:0', 'tpu:1', 'tpu:2', 'tpu:3']
147+
mock_jax.local_device_count.return_value = 4
148+
mock_jax.local_devices.return_value = [
149+
'tpu:0', 'tpu:1', 'tpu:2', 'tpu:3'
150+
]
141151

142152
worker.init_device()
143153

144-
mock_jax.devices.assert_called_once()
154+
mock_jax.local_devices.assert_called_once()
145155
expected_devices = ['tpu:0', 'tpu:1'] # Sliced by tensor_parallel_size
146156
assert worker.devices == expected_devices
157+
expected_rank = 0
158+
expected_is_first_rank = True
159+
expected_is_last_rank = True
147160
mock_runner_cls.assert_called_once_with(mock_vllm_config,
148-
expected_devices)
161+
expected_devices,
162+
expected_rank,
163+
expected_is_first_rank,
164+
expected_is_last_rank)
149165

150166
@patch('tpu_inference.worker.tpu_worker.utils')
151167
def test_determine_available_memory(self, mock_utils, mock_vllm_config):
@@ -194,7 +210,7 @@ def test_execute_model(self, mock_runner_cls, mock_vllm_config):
194210

195211
# Assert the runner was called with the scheduler output directly
196212
worker.model_runner.execute_model.assert_called_once_with(
197-
mock_scheduler_input)
213+
mock_scheduler_input, None)
198214
# Assert the final result is the concrete model output
199215
assert result == mock_model_output
200216

tpu_inference/runner/tpu_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ def __init__(
223223
self,
224224
vllm_config: VllmConfig,
225225
devices: List[Any],
226+
rank: int = 0,
227+
is_first_rank: bool = True,
228+
is_last_rank: bool = True,
226229
):
227230
self.vllm_config = vllm_config
228231
self.model_config = vllm_config.model_config

tpu_inference/worker/tpu_worker.py

Lines changed: 149 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import tempfile
5+
from dataclasses import dataclass, field
56
from typing import Callable, Dict, Optional, Tuple
67

78
import jax
@@ -10,6 +11,7 @@
1011
import jaxtyping
1112
import vllm.envs as vllm_envs
1213
from vllm.config import VllmConfig, set_current_vllm_config
14+
from vllm.distributed import get_pp_group
1315
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
1416
has_kv_transfer_group)
1517
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
@@ -23,10 +25,13 @@
2325
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
2426

2527
from tpu_inference import envs, utils
28+
from tpu_inference.distributed import jax_parallel_state
2629
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
2730
get_node_id)
2831
from tpu_inference.layers.common.sharding import ShardingConfigManager
2932
from tpu_inference.logger import init_logger
33+
from tpu_inference.models.jax.jax_intermediate_tensor import \
34+
JaxIntermediateTensors
3035
from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
3136
from tpu_inference.runner.tpu_runner import TPUModelRunner
3237

@@ -39,15 +44,39 @@
3944
}
4045

4146

47+
@dataclass
48+
class PPConfig:
49+
rank: int
50+
ip: str
51+
prev_worker_ip: str
52+
pp_world_size: int
53+
54+
# default env vars for
55+
# TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
56+
# if PP is used in single host.
57+
default_tpu_process_bounds: str = field(init=False)
58+
default_tpu_chips_per_process_bounds: str = field(init=False)
59+
default_tpu_visible_chips: str = field(init=False)
60+
61+
def __post_init__(self):
62+
self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
63+
self.default_tpu_chips_per_process_bounds = "1,1,1"
64+
self.default_tpu_visible_chips = f"{self.rank}"
65+
66+
4267
class TPUWorker:
4368

44-
def __init__(self,
45-
vllm_config: VllmConfig,
46-
local_rank: int,
47-
rank: int,
48-
distributed_init_method: str,
49-
is_driver_worker: bool = False,
50-
devices=None):
69+
def __init__(
70+
self,
71+
vllm_config: VllmConfig,
72+
local_rank: int,
73+
rank: int,
74+
distributed_init_method: str,
75+
is_driver_worker: bool = False,
76+
devices=None,
77+
ip: str = "localhost",
78+
prev_worker_ip: str = "localhost",
79+
):
5180
# If we use vLLM's model implementation in PyTorch, we should set it
5281
# with torch version of the dtype.
5382
impl = envs.MODEL_IMPL_TYPE
@@ -74,6 +103,8 @@ def __init__(self,
74103
self.devices = devices if devices is not None else []
75104
self.device_ranks = set(device.id for device in self.devices
76105
if isinstance(device, jaxlib._jax.Device))
106+
self.pp_config = PPConfig(rank, ip, prev_worker_ip,
107+
self.parallel_config.pipeline_parallel_size)
77108

78109
if self.model_config.trust_remote_code:
79110
# note: lazy import to avoid importing torch before initializing
@@ -86,14 +117,22 @@ def __init__(self,
86117
# TPU Worker is initialized. The profiler server needs to start after
87118
# MP runtime is initialized.
88119
self.profile_dir = None
89-
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
120+
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1:
90121
if not self.devices or 0 in self.device_ranks:
91122
# For TPU, we can only have 1 active profiler session for 1 profiler
92123
# server. So we only profile on rank0.
93124
self.profile_dir = vllm_envs.VLLM_TORCH_PROFILER_DIR
94125
logger.info("Profiling enabled. Traces will be saved to: %s",
95126
self.profile_dir)
96127

128+
# For PP, we use MPMD so we want to profile every worker.
129+
if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
130+
self.profile_dir = os.path.join(
131+
vllm_envs.VLLM_TORCH_PROFILER_DIR,
132+
f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
133+
)
134+
os.makedirs(self.profile_dir, exist_ok=True)
135+
97136
use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
98137
# Only one instance of profiler is allowed
99138
if use_jax_profiler_server and self.rank < 1:
@@ -105,32 +144,78 @@ def __init__(self,
105144
)
106145
jax.profiler.start_server(jax_profiler_server_port)
107146

147+
# step_counter is used to calculate uuid to transfer intermediate tensors.
148+
self.step_counter = 0
149+
108150
def initialize_cache(self, num_gpu_blocks: int,
109151
num_cpu_blocks: int) -> None:
110152
self.cache_config.num_gpu_blocks = num_gpu_blocks
111153
self.cache_config.num_cpu_blocks = num_cpu_blocks
112154

113-
def init_device(self):
155+
def init_device(self,
156+
tpu_process_bounds="",
157+
tpu_chips_per_process_bounds="",
158+
tpu_visible_chips=""):
159+
# set tpu visible devices for Jax runtime in single host PP.
160+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
161+
if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
162+
tpu_ports = [
163+
jax_parallel_state.BASE_JAX_PORT + i
164+
for i in range(self.pp_config.pp_world_size)
165+
]
166+
os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
167+
[f"localhost:{port}" for port in tpu_ports])
168+
os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
169+
os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
170+
171+
# Note: Below is the setting for v6e8 host (8 chips of v6e)
172+
# Replace with your own topology.
173+
# There are 2 ways of subslicing a v6e
174+
# 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
175+
# TPU_PROCESS_BOUNDS = "1,1,1"
176+
# TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
177+
# TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
178+
# 2) 1 chip for each subslice, with at most 8 subslices,
179+
# we can do TP=1, PP=1/2/3/4/5/6/7/8
180+
os.environ[
181+
"TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
182+
if tpu_process_bounds \
183+
else self.pp_config.default_tpu_process_bounds
184+
os.environ[
185+
"TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
186+
if tpu_chips_per_process_bounds \
187+
else self.pp_config.default_tpu_chips_per_process_bounds
188+
os.environ[
189+
"TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
190+
if tpu_visible_chips \
191+
else self.pp_config.default_tpu_visible_chips
192+
114193
if not self.devices:
115194
sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
116195
device_indexes = sharding_config.device_indexes
117196
if device_indexes is not None and len(device_indexes) > 0:
118197
# Enforcing the devices sequence to be consistent with the specified device indexes
119-
all_devices = jax.devices()
120-
device_dict = {device.id: device for device in all_devices}
198+
all_local_devices = jax.local_devices()
199+
device_dict = {
200+
device.id: device
201+
for device in all_local_devices
202+
}
121203
self.devices = []
122204
for device_index in device_indexes:
123205
device = device_dict[device_index]
124206
if device is None:
125207
raise KeyError(
126208
f"Device index {device_index} not found in "
127-
f"jax.devices() with IDs {list(device_dict.keys())}!"
209+
f"jax.local_devices() with IDs {list(device_dict.keys())}!"
128210
)
129211
self.devices.append(device)
212+
assert len(self.devices) >= sharding_config.total_devices
130213
self.devices = self.devices[:sharding_config.total_devices]
131214
else:
132-
self.devices = jax.devices()[:sharding_config.total_devices]
133-
215+
assert jax.local_device_count(
216+
) >= sharding_config.total_devices
217+
self.devices = jax.local_devices()[:sharding_config.
218+
total_devices]
134219
# Initialize the vLLM distribution layer as a single chip environment,
135220
# we'll swap the model's parallel modules with TPU SPMD equivalents.
136221
with set_current_vllm_config(self.vllm_config):
@@ -146,15 +231,31 @@ def init_device(self):
146231
tensor_model_parallel_size=1,
147232
pipeline_model_parallel_size=1,
148233
)
234+
235+
jax_parallel_state.init_pp_distributed_environment(
236+
self.pp_config.ip,
237+
self.rank,
238+
self.parallel_config.pipeline_parallel_size,
239+
self.devices[0],
240+
need_pp=self.parallel_config.pipeline_parallel_size > 1)
241+
149242
ensure_kv_transfer_initialized(self.vllm_config)
150-
self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
243+
self.model_runner = TPUModelRunner(
244+
self.vllm_config, self.devices, self.rank, self.rank == 0,
245+
self.rank == self.pp_config.pp_world_size - 1)
151246
logger.info(f"Init worker | "
152247
f"rank={self.rank} | "
153248
f"node_id={get_node_id()} | "
154249
f"is_driver_worker={self.is_driver_worker} | "
155250
f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
156251
vllm_utils.report_usage_stats(self.vllm_config)
157252

253+
def initialize_pp_transfer_connect(self):
254+
if self.rank == 0:
255+
return
256+
jax_parallel_state.connect(self.pp_config.prev_worker_ip,
257+
self.rank - 1)
258+
158259
def determine_available_memory(self) -> int:
159260
gpu_memory_utilization = self.cache_config.gpu_memory_utilization
160261
hbm_usage = utils.hbm_usage_bytes(self.devices)
@@ -194,14 +295,39 @@ def execute_model(
194295
# deliberate, temporary compromise for the same reasons outlined in
195296
# the `get_kv_cache_spec` method.
196297

197-
output = self.model_runner.execute_model(scheduler_output)
198-
199-
# With a connector, the scheduler expects output from all workers
200-
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
201-
if has_kv_transfer_group():
202-
return output
203-
204-
return output if self.is_driver_worker else None
298+
if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
299+
intermediate_tensors = None
300+
else:
301+
# receive intermediate tensors
302+
uuid = self.model_runner.get_uuid_for_jax_transfer(
303+
scheduler_output, self.rank - 1, self.step_counter)
304+
# TODO: this method might only works for vllm model, not sure about jax models.
305+
tensor_spec = self.model_runner.get_intermediate_tensor_spec(
306+
scheduler_output.total_num_scheduled_tokens)
307+
intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
308+
uuid, tensor_spec)
309+
intermediate_tensors = JaxIntermediateTensors(
310+
intermediate_tensors_dict)
311+
312+
output = self.model_runner.execute_model(scheduler_output,
313+
intermediate_tensors)
314+
315+
if isinstance(output, JaxIntermediateTensors):
316+
assert self.parallel_config.pipeline_parallel_size > 1
317+
assert not get_pp_group().is_last_rank
318+
# send intermediate tensors
319+
uuid = self.model_runner.get_uuid_for_jax_transfer(
320+
scheduler_output, self.rank, self.step_counter)
321+
get_pp_group().send_tensor_dict(uuid, output.tensors)
322+
self.step_counter += 1
323+
return None
324+
else:
325+
self.step_counter += 1
326+
# With a connector, the scheduler expects output from all workers
327+
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
328+
if has_kv_transfer_group():
329+
return output
330+
return output if self.is_driver_worker else None
205331

206332
def sample_tokens(self,
207333
grammar_output: GrammarOutput) -> ModelRunnerOutput:

0 commit comments

Comments
 (0)