From 06f88f68489cf8dddd9dcb5d9af486db12ab800d Mon Sep 17 00:00:00 2001 From: Chenyaaang Date: Sat, 8 Nov 2025 01:21:21 +0000 Subject: [PATCH] enable pp on runner Signed-off-by: Chenyaaang --- tpu_inference/runner/compilation_manager.py | 93 +++++++++++++++---- .../runner/persistent_batch_manager.py | 41 +++++++- tpu_inference/runner/tpu_jax_runner.py | 69 ++++++++++++-- 3 files changed, 174 insertions(+), 29 deletions(-) diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index c50c4bc86..f102d06d3 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -15,6 +15,8 @@ TPUSupportedSamplingMetadata from tpu_inference.layers.jax.sharding import ShardingAxisName from tpu_inference.logger import init_logger +from tpu_inference.models.jax.jax_intermediate_tensor import \ + JaxIntermediateTensors from tpu_inference.utils import device_array if TYPE_CHECKING: @@ -81,6 +83,8 @@ def capture_model(self) -> None: self._precompile_backbone_with_inputs_embeds() if self.runner.scheduler_config.async_scheduling: self._precompile_substitute_placeholder_token() + if not self.runner.is_last_rank: + return self._precompile_select_from_array() self._precompile_compute_logits() self._precompile_disagg_utils() @@ -120,8 +124,15 @@ def _precompile_input_embeddings_merger(self) -> None: num_tokens=num_tokens, ) - def _precompile_backbone_helper(self, name, *, input_ids, positions, - inputs_embeds) -> None: + def _precompile_backbone_helper(self, + name, + *, + input_ids, + positions, + inputs_embeds, + intermediate_tensors=None, + is_first_rank=True, + is_last_rank=True) -> None: num_tokens = None if input_ids is not None: num_tokens = input_ids.shape[0] @@ -168,10 +179,14 @@ def model_fn_wrapper( inputs_embeds, layer_name_to_kvcache_index, lora_metadata, + intermediate_tensors, + is_first_rank, + is_last_rank, ): kv_caches, hidden_states, _ = self.runner.model_fn( state, kv_caches, input_ids, attention_metadata, inputs_embeds, - layer_name_to_kvcache_index, lora_metadata) + layer_name_to_kvcache_index, lora_metadata, + intermediate_tensors, is_first_rank, is_last_rank) self.runner.kv_caches = kv_caches return hidden_states @@ -189,6 +204,9 @@ def model_fn_wrapper( inputs_embeds, tuple(self.runner.layer_name_to_kvcache_index.items()), lora_metadata, + intermediate_tensors, + is_first_rank, + is_last_rank, num_tokens=num_tokens, ) @@ -238,6 +256,7 @@ def _precompile_substitute_placeholder_token(self) -> None: ) def _precompile_backbone_text_only(self) -> None: + hidden_size = self.runner.model_config.get_hidden_size() for num_tokens in self.runner.num_tokens_paddings: dp_sharding = NamedSharding( self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, ) @@ -247,10 +266,28 @@ def _precompile_backbone_text_only(self) -> None: dp_sharding) positions = self._create_dummy_tensor((num_tokens, ), jnp.int32, dp_sharding) - self._precompile_backbone_helper("backbone", - input_ids=input_ids, - positions=positions, - inputs_embeds=None) + is_first_rank = self.runner.is_first_rank + is_last_rank = self.runner.is_last_rank + if not is_first_rank: + hidden_states = self._create_dummy_tensor( + (num_tokens, hidden_size), jnp.bfloat16) + residual = self._create_dummy_tensor((num_tokens, hidden_size), + jnp.bfloat16) + intermediate_tensors = JaxIntermediateTensors( + tensors={ + "hidden_states": hidden_states, + "residual": residual + }) + else: + intermediate_tensors = None + self._precompile_backbone_helper( + f"worker{self.runner.rank} backbone", + input_ids=input_ids, + positions=positions, + inputs_embeds=None, + intermediate_tensors=intermediate_tensors, + is_first_rank=is_first_rank, + is_last_rank=is_last_rank) def _precompile_backbone_with_inputs_embeds(self) -> None: hidden_size = self.runner.model_config.get_hidden_size() @@ -264,10 +301,28 @@ def _precompile_backbone_with_inputs_embeds(self) -> None: else: positions = self._create_dummy_tensor((num_tokens, ), jnp.int32) - self._precompile_backbone_helper("backbone with embeds", - input_ids=None, - positions=positions, - inputs_embeds=inputs_embeds) + is_first_rank = self.runner.is_first_rank + is_last_rank = self.runner.is_last_rank + if not is_first_rank: + hidden_states = self._create_dummy_tensor( + (num_tokens, hidden_size), jnp.bfloat16) + residual = self._create_dummy_tensor((num_tokens, hidden_size), + jnp.bfloat16) + intermediate_tensors = JaxIntermediateTensors( + tensors={ + "hidden_states": hidden_states, + "residual": residual + }) + else: + intermediate_tensors = None + self._precompile_backbone_helper( + f"worker{self.runner.rank} backbone with embeds", + input_ids=None, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + is_first_rank=is_first_rank, + is_last_rank=is_last_rank) def _precompile_select_from_array_helper( self, @@ -333,7 +388,7 @@ def _precompile_select_from_array(self) -> None: PartitionSpec(ShardingAxisName.ATTN_DATA)) dp_size = self.runner.vllm_config.sharding_config.total_dp_size self._precompile_select_from_array_helper( - name="select all logits", + name=f"worker{self.runner.rank} select all logits", source_paddings=self.runner.num_tokens_paddings, indices_paddings=index_paddings, hidden_dim=hsize, @@ -344,7 +399,8 @@ def _precompile_select_from_array(self) -> None: if self.runner.speculative_config: vocab_size = self.runner.model_config.get_vocab_size() self._precompile_select_from_array_helper( - name="select bonus tokens for spec decoding", + name= + f"worker{self.runner.rank} select bonus tokens for spec decoding", source_paddings=self.runner.num_logits_paddings, indices_paddings=self.runner.num_reqs_paddings, hidden_dim=vocab_size, @@ -352,7 +408,8 @@ def _precompile_select_from_array(self) -> None: PartitionSpec(None, "model")), ) self._precompile_select_from_array_helper( - name="select target tokens for spec decoding", + name= + f"worker{self.runner.rank} select target tokens for spec decoding", source_paddings=self.runner.num_logits_paddings, indices_paddings=self.runner.num_logits_paddings, hidden_dim=vocab_size, @@ -375,7 +432,7 @@ def _precompile_compute_logits(self) -> None: np.array([num_reqs], dtype=np.int32)): lora_metadata = self.runner.lora_utils.extract_lora_metadata() self._run_compilation( - "compute_logits", + f"worker{self.runner.rank} compute_logits", self.runner.compute_logits_fn, self.runner.state, hidden_states, @@ -417,7 +474,7 @@ def _precompile_sampling(self) -> None: do_sampling=do_sampling, ) self._run_compilation( - "sample", + f"worker{self.runner.rank} sample", sample, self.runner.rng_params_for_sampling, self.runner.mesh, @@ -458,7 +515,7 @@ def _precompile_gather_logprobs(self) -> None: logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16) token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32) self._run_compilation( - "gather_logprobs", + f"worker{self.runner.rank} gather_logprobs", self.runner._compute_and_gather_logprobs, logits, token_ids, @@ -510,7 +567,7 @@ def _precompile_rejection_sampler(self) -> None: do_sampling=do_sampling) self._run_compilation( - compilation_name, + f"worker{self.runner.rank} {compilation_name}", self.runner.rejection_sampler, draft_token_ids, num_draft_tokens, diff --git a/tpu_inference/runner/persistent_batch_manager.py b/tpu_inference/runner/persistent_batch_manager.py index c325c772a..ddcb34fdf 100644 --- a/tpu_inference/runner/persistent_batch_manager.py +++ b/tpu_inference/runner/persistent_batch_manager.py @@ -1,6 +1,7 @@ from typing import Dict import jax +from vllm.distributed import get_pp_group from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput from tpu_inference.logger import init_logger @@ -173,15 +174,42 @@ def update_states(self, scheduler_output: "VllmSchedulerOutput", ) # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] + num_output_tokens = req_data.num_output_tokens[i] # Update the cached states. req_state.num_computed_tokens = num_computed_tokens + req_index = self.input_batch.req_id_to_index.get(req_id) + + if not is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = (num_computed_tokens + len(new_token_ids) - + req_state.num_tokens) + if num_new_tokens == 1: + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:]) + elif num_output_tokens < len(req_state.output_token_ids): + del req_state.output_token_ids[num_output_tokens:] + if req_index is not None: + end_idx = (self.input_batch.num_prompt_tokens[req_index] + + num_output_tokens) + self.input_batch.num_tokens[req_index] = end_idx + self.input_batch.num_tokens_no_spec[req_index] = end_idx + + # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: # Append the new blocks to the existing block IDs. @@ -194,7 +222,6 @@ def update_states(self, scheduler_output: "VllmSchedulerOutput", # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids - req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not @@ -209,6 +236,18 @@ def update_states(self, scheduler_output: "VllmSchedulerOutput", self.input_batch.block_table.append_row( new_block_ids, req_index) + # For the last rank, we don't need to update the token_ids_cpu + # because the sampled tokens are already cached. + if not is_last_rank: + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, + start_token_index:end_token_index] = new_token_ids + self.input_batch.num_tokens_no_spec[ + req_index] = end_token_index + self.input_batch.num_tokens[req_index] = end_token_index + # Add spec_token_ids to token_ids_cpu. spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( req_id, ()) diff --git a/tpu_inference/runner/tpu_jax_runner.py b/tpu_inference/runner/tpu_jax_runner.py index cfb4da82a..1b035501a 100644 --- a/tpu_inference/runner/tpu_jax_runner.py +++ b/tpu_inference/runner/tpu_jax_runner.py @@ -14,12 +14,12 @@ import vllm.envs as envs from flax import nnx from jax.sharding import NamedSharding, PartitionSpec -from torchax.ops.mappings import j2t_dtype +from torchax.ops.mappings import j2t_dtype, t2j_dtype from vllm.config import VllmConfig +from vllm.distributed import get_pp_group from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.forward_context import set_forward_context -from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask from vllm.utils.math_utils import cdiv from vllm.v1.core.sched.output import GrammarOutput @@ -45,6 +45,8 @@ ShardingConfigManager) from tpu_inference.logger import init_logger from tpu_inference.models.common.model_loader import get_model +from tpu_inference.models.jax.jax_intermediate_tensor import \ + JaxIntermediateTensors from tpu_inference.models.jax.utils.weight_utils import ( shard_put, transfer_state_with_mappings) from tpu_inference.runner import utils as runner_utils @@ -193,6 +195,9 @@ def __init__( self, vllm_config: VllmConfig, devices: List[Any], + rank: int, + is_first_rank: bool = True, + is_last_rank: bool = True, ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -211,6 +216,9 @@ def __init__( self.maybe_forbid_compile = runner_utils.ForbidCompile( ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext() self.dp_size = self.vllm_config.sharding_config.total_dp_size + self.rank = rank + self.is_first_rank = is_first_rank + self.is_last_rank = is_last_rank self._init_random() self._init_mesh() @@ -221,8 +229,10 @@ def __init__( # Delegate functions to specific manager classes. self.compilation_manager = CompilationManager(self) - self.speculative_decoding_manager = SpeculativeDecodingManager(self) - self.structured_decoding_manager = StructuredDecodingManager(self) + if self.is_last_rank: + self.speculative_decoding_manager = SpeculativeDecodingManager( + self) + self.structured_decoding_manager = StructuredDecodingManager(self) self.kv_cache_manager = KVCacheManager(self) self.mm_manager = MultiModalManager(self) self.persistent_batch_manager = PersistentBatchManager( @@ -462,12 +472,12 @@ def capture_model(self) -> None: def execute_model( self, scheduler_output: "VllmSchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> ModelRunnerOutput | None: + intermediate_tensors: Optional[JaxIntermediateTensors] = None, + ) -> ModelRunnerOutput | JaxIntermediateTensors | None: if self.execute_model_state is not None: raise RuntimeError("State error: sample_tokens() must be called " "after execute_model() returns None.") - _, output = self._execute_model(scheduler_output) + _, output = self._execute_model(scheduler_output, intermediate_tensors) return output def sample_tokens( @@ -594,7 +604,9 @@ def _update_placeholder(self, def _execute_model( self, scheduler_output: "VllmSchedulerOutput", - ) -> tuple[AttentionMetadata, ModelRunnerOutput | None]: + intermediate_tensors: Optional[JaxIntermediateTensors] = None, + ) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput + | None]: self.persistent_batch_manager.update_states( scheduler_output, self.get_mrope_input_positions_fn) if not scheduler_output.total_num_scheduled_tokens: @@ -655,7 +667,6 @@ def _execute_model( scheduler_output) as kv_connector_output: # NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`, # but one of them would be `None` - (self.kv_caches, hidden_states, aux_hidden_states) = self.model_fn( self.state, @@ -665,8 +676,14 @@ def _execute_model( inputs_embeds, tuple(self.layer_name_to_kvcache_index.items()), lora_metadata, + intermediate_tensors, + self.is_first_rank, + self.is_last_rank, ) - + if not get_pp_group().is_last_rank: + assert isinstance(hidden_states, JaxIntermediateTensors) + hidden_states.kv_connector_output = kv_connector_output + return attn_metadata, hidden_states hidden_states = self._select_from_array_fn(hidden_states, logits_indices) logits = self.compute_logits_fn( @@ -1539,3 +1556,35 @@ def _sync_weights( mappings=mappings, transpose_keys=transpose_keys, shard=shard) + + def get_intermediate_tensor_spec(self, num_tokens: int): + impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower() + jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype + num_padded_tokens = runner_utils.get_padded_token_len( + self.num_tokens_paddings, num_tokens) + sharding = NamedSharding(self.mesh, PartitionSpec()) + hidden_size = self.model_config.get_hidden_size() + spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size), + dtype=jax_dtype, + sharding=sharding) + tensor_spec = {"hidden_states": spec, "residual": spec} + return tensor_spec + + def get_uuid_for_jax_transfer(self, + scheduler_output: "VllmSchedulerOutput", + rank: int, step: int) -> int: + ''' + Get a uuid for jax.transfer, here we use the hash of + scheduler_output + counter_step + sender's rank + ''' + scheduler_output_str = "" + if not scheduler_output.num_scheduled_tokens: + scheduler_output_str = "empty_batch" + else: + scheduler_output_str = str( + sorted(scheduler_output.num_scheduled_tokens.items())) + unique_str = f'{scheduler_output_str} {step} {rank}' + import hashlib + hasher = hashlib.sha1() + hasher.update(unique_str.encode('utf-8')) + return int.from_bytes(hasher.digest()[:8], 'big')