@@ -1350,21 +1350,11 @@ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
13501350 logits_indices_cpu = logits_indices
13511351 seq_lens_cpu = seq_lens
13521352
1353- # First, put arrays on a single device.
1354- # JAX will then handle efficient device-to-device transfer.
1355- input_tuple_single_device = jax .device_put (
1356- (input_ids , positions , block_tables , query_start_loc , seq_lens ,
1357- logits_indices , request_distribution ),
1358- device = self .devices [0 ],
1359- )
1360-
1361- print (f'{ input_tuple_single_device = } ' )
1362-
1363- # Then, distribute from that single device to all devices in the mesh.
13641353 (input_ids , positions , block_tables , query_start_loc , seq_lens , logits_indices ,
13651354 request_distribution ) = device_array (
13661355 self .mesh ,
1367- input_tuple_single_device ,
1356+ (input_ids , positions , block_tables , query_start_loc , seq_lens ,
1357+ logits_indices , request_distribution ),
13681358 sharding = data_parallel_attn_sharding ,
13691359 )
13701360 # Async scheduling: substitute placeholder tokens for DP
@@ -1553,16 +1543,10 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
15531543 seq_lens_cpu = seq_lens
15541544
15551545
1556- logger .info (f"{ self .devices = } " )
1557- logger .info (f"{ jax .local_devices ()= } " )
1558- input_tuple_single_device = jax .device_put (
1559- (input_ids , positions , block_tables , query_start_loc , seq_lens ,
1560- logits_indices , request_distribution ),
1561- device = jax .local_devices ()[0 ],
1562- )
15631546 (input_ids , positions , block_tables , query_start_loc , seq_lens ,
15641547 logits_indices , request_distribution ) = device_array (
1565- self .mesh , input_tuple_single_device )
1548+ self .mesh , (input_ids , positions , block_tables , query_start_loc , seq_lens ,
1549+ logits_indices , request_distribution ))
15661550
15671551 if self .scheduler_config .async_scheduling and len (
15681552 token_in_tpu_cur_input_indices ) > 0 :
0 commit comments