Skip to content

Commit b45ad8c

Browse files
committed
experimental multi-host device_put
1 parent 4d85f58 commit b45ad8c

File tree

2 files changed

+5
-21
lines changed

2 files changed

+5
-21
lines changed

tpu_inference/runner/tpu_runner.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,21 +1350,11 @@ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
13501350
logits_indices_cpu = logits_indices
13511351
seq_lens_cpu = seq_lens
13521352

1353-
# First, put arrays on a single device.
1354-
# JAX will then handle efficient device-to-device transfer.
1355-
input_tuple_single_device = jax.device_put(
1356-
(input_ids, positions, block_tables, query_start_loc, seq_lens,
1357-
logits_indices, request_distribution),
1358-
device=self.devices[0],
1359-
)
1360-
1361-
print(f'{input_tuple_single_device=}')
1362-
1363-
# Then, distribute from that single device to all devices in the mesh.
13641353
(input_ids, positions, block_tables, query_start_loc, seq_lens, logits_indices,
13651354
request_distribution) = device_array(
13661355
self.mesh,
1367-
input_tuple_single_device,
1356+
(input_ids, positions, block_tables, query_start_loc, seq_lens,
1357+
logits_indices, request_distribution),
13681358
sharding=data_parallel_attn_sharding,
13691359
)
13701360
# Async scheduling: substitute placeholder tokens for DP
@@ -1553,16 +1543,10 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
15531543
seq_lens_cpu = seq_lens
15541544

15551545

1556-
logger.info(f"{self.devices=}")
1557-
logger.info(f"{jax.local_devices()=}")
1558-
input_tuple_single_device = jax.device_put(
1559-
(input_ids, positions, block_tables, query_start_loc, seq_lens,
1560-
logits_indices, request_distribution),
1561-
device=jax.local_devices()[0],
1562-
)
15631546
(input_ids, positions, block_tables, query_start_loc, seq_lens,
15641547
logits_indices, request_distribution) = device_array(
1565-
self.mesh, input_tuple_single_device)
1548+
self.mesh, (input_ids, positions, block_tables, query_start_loc, seq_lens,
1549+
logits_indices, request_distribution))
15661550

15671551
if self.scheduler_config.async_scheduling and len(
15681552
token_in_tpu_cur_input_indices) > 0:

tpu_inference/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
243243
"""
244244
if sharding is None:
245245
sharding = NamedSharding(mesh, PartitionSpec(None))
246-
return jax.device_put(*args, device=sharding, **kwargs)
246+
return jax.make_array_from_process_local_data(sharding=sharding, *args, **kwargs)
247247

248248

249249
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:

0 commit comments

Comments
 (0)