From ee9007b959980416047c170a3bba0b12f3978fc2 Mon Sep 17 00:00:00 2001 From: Lihao Ran Date: Tue, 28 Oct 2025 06:05:58 +0000 Subject: [PATCH] [Spec Decoding] Reduce TPU <-> CPU data transfer Signed-off-by: Lihao Ran --- .../runner/speculative_decoding_manager.py | 17 +++--------- tpu_inference/runner/tpu_jax_runner.py | 26 +++++++++++++++---- tpu_inference/spec_decode/jax/eagle3.py | 22 +++++++--------- 3 files changed, 34 insertions(+), 31 deletions(-) diff --git a/tpu_inference/runner/speculative_decoding_manager.py b/tpu_inference/runner/speculative_decoding_manager.py index 9dbf060d7..5ce18d9a1 100644 --- a/tpu_inference/runner/speculative_decoding_manager.py +++ b/tpu_inference/runner/speculative_decoding_manager.py @@ -11,7 +11,6 @@ from tpu_inference.runner import utils as runner_utils from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer -from tpu_inference.utils import device_array if TYPE_CHECKING: from tpu_inference.layers.common.attention_metadata import \ @@ -109,8 +108,7 @@ def propose_eagle3_draft_token_ids( assert pad_len >= 0 next_token_ids += [0] * pad_len - next_token_ids = device_array( - self.runner.mesh, np.array(next_token_ids, dtype=jnp.int32)) + next_token_ids = np.array(next_token_ids, dtype=jnp.int32) if spec_decode_metadata is None: num_rejected_tokens = None @@ -123,9 +121,8 @@ def propose_eagle3_draft_token_ids( pad_len = self.runner.max_num_reqs - len(num_rejected_tokens) num_rejected_tokens += [0] * pad_len - num_rejected_tokens = device_array( - self.runner.mesh, np.array(num_rejected_tokens, - dtype=jnp.int32)) + num_rejected_tokens = np.array(num_rejected_tokens, + dtype=jnp.int32) target_hidden_states, input_ids, last_token_indices, attn_metadata = self.runner.drafter.prepare_inputs( attn_metadata, @@ -228,14 +225,6 @@ def get_spec_decode_metadata( ]) padded_num_draft_tokens_cpu = padded_num_draft_tokens - # CPU -> TPU copy. - (padded_num_draft_tokens, padded_draft_token_ids, - padded_logits_indices, padded_target_logits_indices, - padded_bonus_logits_indices) = device_array( - self.runner.mesh, - (padded_num_draft_tokens, padded_draft_token_ids, - padded_logits_indices, padded_target_logits_indices, - padded_bonus_logits_indices)) metadata = SpecDecodeMetadata( draft_token_ids=padded_draft_token_ids, diff --git a/tpu_inference/runner/tpu_jax_runner.py b/tpu_inference/runner/tpu_jax_runner.py index 988e61920..95afb1d8c 100644 --- a/tpu_inference/runner/tpu_jax_runner.py +++ b/tpu_inference/runner/tpu_jax_runner.py @@ -683,7 +683,6 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"): spec_decode_metadata = self.speculative_decoding_manager.get_spec_decode_metadata( num_draft_tokens, self.query_start_loc_cpu[1:num_reqs + 1], padded_num_reqs) - logits_indices = spec_decode_metadata.final_logits_indices # Put to device sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( @@ -696,10 +695,27 @@ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"): query_start_loc_cpu = query_start_loc seq_lens_cpu = seq_lens - (input_ids, positions, block_tables, query_start_loc, seq_lens, - logits_indices, request_distribution) = device_array( - self.mesh, (input_ids, positions, block_tables, query_start_loc, - seq_lens, logits_indices, request_distribution)) + if not spec_decode_metadata: + (input_ids, positions, block_tables, query_start_loc, seq_lens, + logits_indices, request_distribution) = device_array( + self.mesh, + (input_ids, positions, block_tables, query_start_loc, + seq_lens, logits_indices, request_distribution)) + else: + (input_ids, positions, block_tables, query_start_loc, seq_lens, + request_distribution, spec_decode_metadata.draft_token_ids, + spec_decode_metadata.draft_lengths, + spec_decode_metadata.target_logits_indices, + spec_decode_metadata.bonus_logits_indices, + spec_decode_metadata.final_logits_indices) = device_array( + self.mesh, (input_ids, positions, block_tables, + query_start_loc, seq_lens, request_distribution, + spec_decode_metadata.draft_token_ids, + spec_decode_metadata.draft_lengths, + spec_decode_metadata.target_logits_indices, + spec_decode_metadata.bonus_logits_indices, + spec_decode_metadata.final_logits_indices)) + logits_indices = spec_decode_metadata.final_logits_indices if self.lora_config is not None: self.lora_utils.set_active_loras( diff --git a/tpu_inference/spec_decode/jax/eagle3.py b/tpu_inference/spec_decode/jax/eagle3.py index 86a460943..72995b8b2 100644 --- a/tpu_inference/spec_decode/jax/eagle3.py +++ b/tpu_inference/spec_decode/jax/eagle3.py @@ -142,8 +142,8 @@ def prepare_inputs( attn_metadata: AttentionMetadata, input_ids: jax.Array, aux_hidden_states: tuple[jax.Array, ...], - next_token_ids: jax.Array, - num_rejected_tokens: Optional[jax.Array] = None, + next_token_ids: np.array, + num_rejected_tokens: Optional[np.array] = None, ) -> tuple[jax.Array, jax.Array, jax.Array, AttentionMetadata]: """Prepare drafter inputs based on target forward outputs. @@ -169,8 +169,9 @@ def prepare_inputs( num_reqs = self.runner.input_batch.num_reqs if num_rejected_tokens is None: - num_reqs = device_array(self.mesh, - np.asarray([num_reqs], dtype=jnp.int32)) + num_reqs, next_token_ids = device_array( + self.mesh, + (np.asarray([num_reqs], dtype=jnp.int32), next_token_ids)) # block_tables = device_array(self.mesh, block_tables) attn_metadata = replace(attn_metadata, block_tables=device_array( @@ -185,10 +186,6 @@ def prepare_inputs( seq_lens_cpu = attn_metadata.seq_lens_cpu assert query_start_loc_cpu is not None and seq_lens_cpu is not None - # Rejection-aware path: compute new per-request lengths and token indices. - # Convert to host numpy for efficient prefix-sum and repeat ops. - nrt_cpu = jax.device_get(num_rejected_tokens).astype("int32") - # query_len_per_req = [q1, q2, ...] query_len_per_req = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) @@ -197,7 +194,7 @@ def prepare_inputs( # For padded requests, the query length should be 0. query_len_per_req[num_reqs:] = 1 # num_tokens_per_req = [q1 - n1, q2 - n2, ...] - num_tokens_per_req = (query_len_per_req - nrt_cpu) + num_tokens_per_req = (query_len_per_req - num_rejected_tokens) assert (num_tokens_per_req >= 0).all(), ("num_tokens_per_req must be non-negative") @@ -232,12 +229,13 @@ def prepare_inputs( "constant", constant_values=0) # Update seq_lens for active requests only: new_seq_lens = s - n. - new_seq_lens_cpu = seq_lens_cpu - nrt_cpu + new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens - query_start_loc, seq_lens, token_indices, num_reqs, block_tables = device_array( + query_start_loc, seq_lens, token_indices, num_reqs, block_tables, next_token_ids = device_array( self.mesh, (new_query_start_loc_cpu, new_seq_lens_cpu, token_indices_cpu, - np.asarray([num_reqs], dtype=jnp.int32), block_tables)) + np.asarray([num_reqs], + dtype=jnp.int32), block_tables, next_token_ids)) attn_metadata = replace(attn_metadata, block_tables=block_tables) return self._filter_token_and_prepare_initial_inputs(