diff --git a/tests/e2e/test_pipeline_parallel.py b/tests/e2e/test_pipeline_parallel.py new file mode 100644 index 000000000..c358e2ebf --- /dev/null +++ b/tests/e2e/test_pipeline_parallel.py @@ -0,0 +1,265 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import time +from dataclasses import asdict + +import pytest +from vllm import LLM, EngineArgs, SamplingParams + + +@pytest.fixture +def model_name(): + """Choose LLama3 8b as the test model as it supports PP on jax model impl.""" + return "meta-llama/Llama-3.1-8B-Instruct" + + +@pytest.fixture +def test_prompts(): + """Simple test prompts for data parallelism testing.""" + return [ + "Hello, my name is", + "The capital of France is", + "The colors of the rainbow are", + "The future of AI is", + "The president of the United States is", + "How many players are on a standard soccer team?", + "In Greek mythology, who is the god of the sea?", + "What is the capital of Australia?", + "What is the largest planet in our solar system?", + "Who developed the theory of general relativity?", + ] + + +@pytest.fixture +def sampling_params(): + """Standard sampling parameters for testing.""" + return SamplingParams( + temperature=0.0, + max_tokens=32, + ignore_eos=True, + logprobs=1, + ) + + +def _run_inference_with_config(model_name: str, + test_prompts: list, + sampling_params: SamplingParams, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + additional_config: dict = {}, + kv_cache_dtype: str = "auto", + enable_prefix_caching: bool = False) -> list: + """Helper function to run inference with specified configuration.""" + + # Create LLM args using parser-based approach similar to offline_inference.py + engine_args = EngineArgs( + model=model_name, + max_model_len=128, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + gpu_memory_utilization=0.95, + max_num_batched_tokens=128, + max_num_seqs=16, + enable_prefix_caching=enable_prefix_caching, + additional_config=additional_config, + kv_cache_dtype=kv_cache_dtype, + ) + + engine_args_dict = asdict(engine_args) + llm = LLM(**engine_args_dict) + + try: + outputs = llm.generate(test_prompts, sampling_params) + return outputs + finally: + del llm + # Wait for TPUs to be released + time.sleep(5) + + +@pytest.mark.skip(reason="PP is not fully enabled.") +def test_pipeline_parallelism_jax_model( + model_name: str, + test_prompts: list, + sampling_params: SamplingParams, +): + """ + Test pipline parallelism works on Jax models + + Equivalent to: + python examples/offline_inference.py --tensor_parallel_size=1 --pipeline_parallel_size=2 + """ + # Test with data parallelism enabled + outputs = _run_inference_with_config( + model_name=model_name, + test_prompts=test_prompts, + sampling_params=sampling_params, + tensor_parallel_size=1, + pipeline_parallel_size=2, + ) + + # Verify we got outputs for all prompts + assert len(outputs) == len(test_prompts) + + # Verify each output has generated text + for output in outputs: + assert len(output.outputs) > 0 + assert len(output.outputs[0].text.strip()) > 0 + + print( + f"✓ Pipeline Parallelism Jax model test passed with {len(outputs)} outputs" + ) + + +@pytest.mark.skip(reason="PP is not fully enabled.") +def test_pipeline_parallelism_vllm_model( + model_name: str, + test_prompts: list, + sampling_params: SamplingParams, +): + """ + Test pipline parallelism works on vLLM models, and it also works with + with tensor parallelism. + + Equivalent to: + MODEL_IMPL_TYPE=vllm python examples/offline_inference.py --tensor_parallel_size=2 --pipeline_parallel_size=2 + """ + + os.environ['MODEL_IMPL_TYPE'] = 'vllm' + # Test with data parallelism enabled + outputs = _run_inference_with_config( + model_name=model_name, + test_prompts=test_prompts, + sampling_params=sampling_params, + tensor_parallel_size=1, + pipeline_parallel_size=2, + ) + + # Verify we got outputs for all prompts + assert len(outputs) == len(test_prompts) + + # Verify each output has generated text + for output in outputs: + assert len(output.outputs) > 0 + assert len(output.outputs[0].text.strip()) > 0 + + print( + f"✓ Pipeline Parallelism vLLM model test passed with {len(outputs)} outputs" + ) + + +@pytest.mark.skip(reason="PP is not fully enabled.") +def test_pipeline_parallelism_jax_model_correctness( + model_name: str, + test_prompts: list, + sampling_params: SamplingParams, +): + """ + Test that pipeline parallelism produces consistent results compared to a baseline. + This test compares outputs from a single-device run with pipeline parallel runs + to ensure correctness, including log probabilities. + """ + os.environ['SKIP_JAX_PRECOMPILE'] = '1' + os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0' + + # Use a smaller subset of prompts for correctness testing + small_prompts = test_prompts[:10] + + # Run baseline (no PP) + baseline_outputs = _run_inference_with_config( + model_name=model_name, + test_prompts=small_prompts, + sampling_params=sampling_params, + tensor_parallel_size=1, + pipeline_parallel_size=1, + ) + + # Run with model data parallelism and async scheduling + pp_outputs = _run_inference_with_config( + model_name=model_name, + test_prompts=small_prompts, + sampling_params=sampling_params, + tensor_parallel_size=1, + pipeline_parallel_size=2, + ) + + # Compare outputs - in theory they should be identical for greedy sampling + # in reality there may be some differences, but overall the outputs should + # be very similar. + + # an example: + # prompt: What is the capital of Australia? + # both answers should be acceptable. + # The capital of Australia is Canberra. It is located in the Australian Capital Territory (ACT) and is home to many + # Canberra is the capital of Australia. It is located in the Australian Capital Territory (ACT) and is home to + assert len(baseline_outputs) == len(pp_outputs) + + text_matches = 0 + text_mismatches = 0 + logprob_mismatches = 0 + max_logprob_diff = 0.0 + + for i, (baseline, pp_result) in enumerate(zip(baseline_outputs, + pp_outputs)): + baseline_text = baseline.outputs[0].text.strip() + pp_text = pp_result.outputs[0].text.strip() + + # Check text output + if baseline_text == pp_text: + text_matches += 1 + else: + text_mismatches += 1 + print(f"Text mismatch found in prompt {i}:") + print(f" Baseline: {baseline_text}") + print(f" Pipeline Parallel: {pp_text}") + + # Check log probabilities + baseline_logprobs = baseline.outputs[0].logprobs + pp_logprobs = pp_result.outputs[0].logprobs + if baseline_logprobs is not None and pp_logprobs is not None: + # Compare log probabilities for each token + assert len(baseline_logprobs) == len(pp_logprobs), \ + f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(pp_logprobs)}" + for token_idx, (base_lp, pp_lp) in enumerate( + zip(baseline_logprobs, pp_logprobs)): + # Get the top logprob value for the selected token + if base_lp and pp_lp: + # Get the top token's logprob from each + base_top_token = list(base_lp.keys())[0] + pp_top_token = list(pp_lp.keys())[0] + + base_logprob_val = base_lp[base_top_token].logprob + pp_logprob_val = pp_lp[pp_top_token].logprob + + # Calculate absolute difference + diff = abs(base_logprob_val - pp_logprob_val) + max_logprob_diff = max(max_logprob_diff, diff) + + # Allow small numerical differences (e.g., 1e-3) + if diff > 1e-3: + logprob_mismatches += 1 + print( + f"Logprob mismatch in prompt {i}, token {token_idx}:" + ) + print( + f" Baseline token: {base_top_token}, logprob: {base_logprob_val:.6f}" + ) + print( + f" PP token: {pp_top_token}, logprob: {pp_logprob_val:.6f}" + ) + print(f" Difference: {diff:.6f}") + + print("✓ Correctness test results:") + print(f" Text: {text_matches} matches, {text_mismatches} mismatches") + print(f" Max logprob difference: {max_logprob_diff:.6e}") + print(f" Significant logprob mismatches (>1e-3): {logprob_mismatches}") + + # Allow for some variance due to potential numerical differences + # but most outputs should match with greedy sampling + text_match_rate = text_matches / len(baseline_outputs) + assert text_match_rate >= 0.9, f"Text match rate {text_match_rate:.2%} is too low" + + # Log probabilities should be very close (allow small numerical errors) + assert max_logprob_diff < 1, f"Max logprob difference {max_logprob_diff} is too large" diff --git a/tests/runner/test_persistent_batch_manager.py b/tests/runner/test_persistent_batch_manager.py new file mode 100644 index 000000000..dd2d6454a --- /dev/null +++ b/tests/runner/test_persistent_batch_manager.py @@ -0,0 +1,70 @@ +import unittest +from unittest.mock import MagicMock + +import numpy as np + +from tpu_inference.runner.persistent_batch_manager import \ + PersistentBatchManager + + +class TestPersistentBatchManager(unittest.TestCase): + + def test_update_states_pp_non_last_rank(self): + """ + the current rank is not the last rank. + + This test verifies that when new tokens are received from the scheduler, + the internal state of the PersistentBatchManager (including request + states and the input batch) is correctly updated. + """ + + req_id = 101 + initial_output_tokens = [10, 20] + + req_state = MagicMock() + req_state.num_tokens = 2 + req_state.output_token_ids = list(initial_output_tokens) + + requests = {req_id: req_state} + + input_batch = MagicMock() + input_batch.req_id_to_index = {req_id: 0} + input_batch.num_prompt_tokens = np.array([2], dtype=np.int32) + input_batch.token_ids_cpu = np.zeros((1, 10), dtype=np.int32) + input_batch.num_tokens = np.array([2], dtype=np.int32) + input_batch.num_tokens_no_spec = np.array([2], dtype=np.int32) + input_batch.num_reqs = 1 + + encoder_cache = MagicMock() + model_config = MagicMock() + + manager = PersistentBatchManager(requests, + input_batch, + encoder_cache, + False, + model_config, + is_last_rank=False) + + scheduler_output = MagicMock() + req_data = MagicMock() + req_data.req_ids = [req_id] + req_data.num_computed_tokens = [2] + new_token_id = [30] + req_data.new_token_ids = [new_token_id] + req_data.new_block_ids = [None] + req_data.resumed_from_preemption = [False] + req_data.num_output_tokens = [len(initial_output_tokens) + 1] + scheduler_output.scheduled_cached_reqs = req_data + scheduler_output.scheduled_spec_decode_tokens = {} + + manager.update_states(scheduler_output, None) + + expected_output_token_ids = initial_output_tokens + new_token_id + self.assertEqual(req_state.output_token_ids, expected_output_token_ids) + + np.testing.assert_array_equal( + manager.input_batch.token_ids_cpu[0, 2:3], + np.array(new_token_id, dtype=np.int32)) + + self.assertEqual(manager.input_batch.num_tokens[0], 3) + self.assertEqual(manager.input_batch.num_tokens_no_spec[0], 3) diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index 238b123b3..4a95c8503 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -236,7 +236,9 @@ def get_flax_model( hidden_states_sharding, # aux hidden states ), donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache - static_argnums=7, #7 is layer_name_to_kvcache_index + static_argnums=( + 7, 10, 11 + ), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank ) def run_model(graphdef, state, *args): model = nnx.merge(graphdef, state) diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index 0f8d1cfde..3bb590366 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -15,6 +15,8 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \ TPUSupportedSamplingMetadata from tpu_inference.logger import init_logger +from tpu_inference.models.jax.jax_intermediate_tensor import \ + JaxIntermediateTensors from tpu_inference.utils import device_array if TYPE_CHECKING: @@ -81,6 +83,8 @@ def capture_model(self) -> None: self._precompile_backbone_with_inputs_embeds() if self.runner.scheduler_config.async_scheduling: self._precompile_substitute_placeholder_token() + if not self.runner.is_last_rank: + return self._precompile_select_from_array() self._precompile_compute_logits() self._precompile_disagg_utils() @@ -120,8 +124,15 @@ def _precompile_input_embeddings_merger(self) -> None: num_tokens=num_tokens, ) - def _precompile_backbone_helper(self, name, *, input_ids, positions, - inputs_embeds) -> None: + def _precompile_backbone_helper(self, + name, + *, + input_ids, + positions, + inputs_embeds, + intermediate_tensors=None, + is_first_rank=True, + is_last_rank=True) -> None: num_tokens = None if input_ids is not None: num_tokens = input_ids.shape[0] @@ -181,10 +192,14 @@ def model_fn_wrapper( inputs_embeds, layer_name_to_kvcache_index, lora_metadata, + intermediate_tensors, + is_first_rank, + is_last_rank, ): kv_caches, hidden_states, _ = self.runner.model_fn( state, kv_caches, input_ids, attention_metadata, inputs_embeds, - positions, layer_name_to_kvcache_index, lora_metadata) + positions, layer_name_to_kvcache_index, lora_metadata, + intermediate_tensors, is_first_rank, is_last_rank) self.runner.kv_caches = kv_caches return hidden_states @@ -207,6 +222,9 @@ def model_fn_wrapper( inputs_embeds, tuple(self.runner.layer_name_to_kvcache_index.items()), lora_metadata, + intermediate_tensors, + is_first_rank, + is_last_rank, num_tokens=num_tokens, ) @@ -257,6 +275,7 @@ def _precompile_substitute_placeholder_token(self) -> None: ) def _precompile_backbone_text_only(self) -> None: + hidden_size = self.runner.model_config.get_hidden_size() for num_tokens in self.runner.num_tokens_paddings: dp_sharding = NamedSharding( self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, ) @@ -266,10 +285,28 @@ def _precompile_backbone_text_only(self) -> None: dp_sharding) positions = self._create_dummy_tensor((num_tokens, ), jnp.int32, dp_sharding) - self._precompile_backbone_helper("backbone", - input_ids=input_ids, - positions=positions, - inputs_embeds=None) + is_first_rank = self.runner.is_first_rank + is_last_rank = self.runner.is_last_rank + if not is_first_rank: + hidden_states = self._create_dummy_tensor( + (num_tokens, hidden_size), jnp.bfloat16) + residual = self._create_dummy_tensor((num_tokens, hidden_size), + jnp.bfloat16) + intermediate_tensors = JaxIntermediateTensors( + tensors={ + "hidden_states": hidden_states, + "residual": residual + }) + else: + intermediate_tensors = None + self._precompile_backbone_helper( + f"worker{self.runner.rank} backbone", + input_ids=input_ids, + positions=positions, + inputs_embeds=None, + intermediate_tensors=intermediate_tensors, + is_first_rank=is_first_rank, + is_last_rank=is_last_rank) def _precompile_backbone_with_inputs_embeds(self) -> None: hidden_size = self.runner.model_config.get_hidden_size() @@ -283,10 +320,28 @@ def _precompile_backbone_with_inputs_embeds(self) -> None: else: positions = self._create_dummy_tensor((num_tokens, ), jnp.int32) - self._precompile_backbone_helper("backbone with embeds", - input_ids=None, - positions=positions, - inputs_embeds=inputs_embeds) + is_first_rank = self.runner.is_first_rank + is_last_rank = self.runner.is_last_rank + if not is_first_rank: + hidden_states = self._create_dummy_tensor( + (num_tokens, hidden_size), jnp.bfloat16) + residual = self._create_dummy_tensor((num_tokens, hidden_size), + jnp.bfloat16) + intermediate_tensors = JaxIntermediateTensors( + tensors={ + "hidden_states": hidden_states, + "residual": residual + }) + else: + intermediate_tensors = None + self._precompile_backbone_helper( + f"worker{self.runner.rank} backbone with embeds", + input_ids=None, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + is_first_rank=is_first_rank, + is_last_rank=is_last_rank) def _precompile_select_from_array_helper( self, @@ -354,7 +409,7 @@ def _precompile_select_from_array(self) -> None: self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None)) dp_size = self.runner.vllm_config.sharding_config.total_dp_size self._precompile_select_from_array_helper( - name="select all logits", + name=f"worker{self.runner.rank} select all logits", source_paddings=self.runner.num_tokens_paddings, indices_paddings=index_paddings, hidden_dim=hsize, @@ -365,7 +420,8 @@ def _precompile_select_from_array(self) -> None: if self.runner.speculative_config: vocab_size = self.runner.model_config.get_vocab_size() self._precompile_select_from_array_helper( - name="select bonus tokens for spec decoding", + name= + f"worker{self.runner.rank} select bonus tokens for spec decoding", source_paddings=self.runner.num_logits_paddings, indices_paddings=self.runner.num_reqs_paddings, hidden_dim=vocab_size, @@ -373,7 +429,8 @@ def _precompile_select_from_array(self) -> None: PartitionSpec(None, "model")), ) self._precompile_select_from_array_helper( - name="select target tokens for spec decoding", + name= + f"worker{self.runner.rank} select target tokens for spec decoding", source_paddings=self.runner.num_logits_paddings, indices_paddings=self.runner.num_logits_paddings, hidden_dim=vocab_size, @@ -396,7 +453,7 @@ def _precompile_compute_logits(self) -> None: np.array([num_reqs], dtype=np.int32)): lora_metadata = self.runner.lora_utils.extract_lora_metadata() self._run_compilation( - "compute_logits", + f"worker{self.runner.rank} compute_logits", self.runner.compute_logits_fn, self.runner.state, hidden_states, @@ -438,7 +495,7 @@ def _precompile_sampling(self) -> None: do_sampling=do_sampling, ) self._run_compilation( - "sample", + f"worker{self.runner.rank} sample", sample, self.runner.rng_params_for_sampling, self.runner.mesh, @@ -479,7 +536,7 @@ def _precompile_gather_logprobs(self) -> None: logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16) token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32) self._run_compilation( - "gather_logprobs", + f"worker{self.runner.rank} gather_logprobs", self.runner._compute_and_gather_logprobs, logits, token_ids, @@ -531,7 +588,7 @@ def _precompile_rejection_sampler(self) -> None: do_sampling=do_sampling) self._run_compilation( - compilation_name, + f"worker{self.runner.rank} {compilation_name}", self.runner.rejection_sampler, draft_token_ids, num_draft_tokens, diff --git a/tpu_inference/runner/persistent_batch_manager.py b/tpu_inference/runner/persistent_batch_manager.py index c8c093315..9d8be7048 100644 --- a/tpu_inference/runner/persistent_batch_manager.py +++ b/tpu_inference/runner/persistent_batch_manager.py @@ -14,12 +14,13 @@ class PersistentBatchManager: def __init__(self, requests: Dict[str, CachedRequestState], input_batch: InputBatch, encoder_cache: Dict[str, 'jax.Array'], - uses_mrope: bool, model_config): + uses_mrope: bool, model_config, is_last_rank: bool): self.requests = requests self.input_batch = input_batch self.encoder_cache = encoder_cache self.uses_mrope = uses_mrope self.model_config = model_config + self.is_last_rank = is_last_rank def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int: """ Reorder the sheduled requests to RPA kernel friendly distribution @@ -179,9 +180,35 @@ def update_states(self, scheduler_output: "VllmSchedulerOutput", num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] + num_output_tokens = req_data.num_output_tokens[i] # Update the cached states. req_state.num_computed_tokens = num_computed_tokens + req_index = self.input_batch.req_id_to_index.get(req_id) + + if not self.is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = (num_computed_tokens + len(new_token_ids) - + req_state.num_tokens) + if num_new_tokens == 1: + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:]) + elif num_output_tokens < len(req_state.output_token_ids): + del req_state.output_token_ids[num_output_tokens:] + if req_index is not None: + end_idx = (self.input_batch.num_prompt_tokens[req_index] + + num_output_tokens) + self.input_batch.num_tokens[req_index] = end_idx + self.input_batch.num_tokens_no_spec[req_index] = end_idx + + # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: # Append the new blocks to the existing block IDs. @@ -194,7 +221,6 @@ def update_states(self, scheduler_output: "VllmSchedulerOutput", # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids - req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not @@ -209,6 +235,18 @@ def update_states(self, scheduler_output: "VllmSchedulerOutput", self.input_batch.block_table.append_row( new_block_ids, req_index) + # For the last rank, we don't need to update the token_ids_cpu + # because the sampled tokens are already cached. + if not self.is_last_rank: + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, + start_token_index:end_token_index] = new_token_ids + self.input_batch.num_tokens_no_spec[ + req_index] = end_token_index + self.input_batch.num_tokens[req_index] = end_token_index + # Add spec_token_ids to token_ids_cpu. spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( req_id, ()) diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 256798093..60b3c5174 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -15,12 +15,12 @@ from flax import nnx from jax.experimental import mesh_utils from jax.sharding import NamedSharding, PartitionSpec -from torchax.ops.mappings import j2t_dtype +from torchax.ops.mappings import j2t_dtype, t2j_dtype from vllm.config import VllmConfig +from vllm.distributed import get_pp_group from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.forward_context import set_forward_context -from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask from vllm.utils.math_utils import cdiv from vllm.v1.core.sched.output import GrammarOutput @@ -48,6 +48,8 @@ TPUSupportedSamplingMetadata from tpu_inference.logger import init_logger from tpu_inference.models.common.model_loader import get_model +from tpu_inference.models.jax.jax_intermediate_tensor import \ + JaxIntermediateTensors from tpu_inference.models.jax.utils.weight_utils import ( shard_put, transfer_state_with_mappings) from tpu_inference.runner import utils as runner_utils @@ -243,6 +245,9 @@ def __init__( self.maybe_forbid_compile = runner_utils.ForbidCompile( ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext() self.dp_size = self.vllm_config.sharding_config.total_dp_size + self.rank = rank + self.is_first_rank = is_first_rank + self.is_last_rank = is_last_rank self._init_random() self._init_mesh() @@ -253,13 +258,15 @@ def __init__( # Delegate functions to specific manager classes. self.compilation_manager = CompilationManager(self) - self.speculative_decoding_manager = SpeculativeDecodingManager(self) - self.structured_decoding_manager = StructuredDecodingManager(self) + if self.is_last_rank: + self.speculative_decoding_manager = SpeculativeDecodingManager( + self) + self.structured_decoding_manager = StructuredDecodingManager(self) self.kv_cache_manager = KVCacheManager(self) self.mm_manager = MultiModalManager(self) self.persistent_batch_manager = PersistentBatchManager( self.requests, self.input_batch, self.encoder_cache, - self.uses_mrope, self.model_config) + self.uses_mrope, self.model_config, self.is_last_rank) self.lora_utils = LoraUtils(self) cache_config = self.cache_config @@ -555,12 +562,12 @@ def capture_model(self) -> None: def execute_model( self, scheduler_output: "VllmSchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> ModelRunnerOutput | None: + intermediate_tensors: Optional[JaxIntermediateTensors] = None, + ) -> ModelRunnerOutput | JaxIntermediateTensors | None: if self.execute_model_state is not None: raise RuntimeError("State error: sample_tokens() must be called " "after execute_model() returns None.") - _, output = self._execute_model(scheduler_output) + _, output = self._execute_model(scheduler_output, intermediate_tensors) return output def sample_tokens( @@ -686,7 +693,9 @@ def _update_placeholder(self, def _execute_model( self, scheduler_output: "VllmSchedulerOutput", - ) -> tuple[AttentionMetadata, ModelRunnerOutput | None]: + intermediate_tensors: Optional[JaxIntermediateTensors] = None, + ) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput + | None]: self.persistent_batch_manager.update_states( scheduler_output, self.get_mrope_input_positions_fn) if not scheduler_output.total_num_scheduled_tokens: @@ -764,7 +773,6 @@ def _execute_model( scheduler_output) as kv_connector_output: # NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`, # but one of them would be `None` - (self.kv_caches, hidden_states, aux_hidden_states) = self.model_fn( self.state, @@ -775,8 +783,14 @@ def _execute_model( input_positions, tuple(self.layer_name_to_kvcache_index.items()), lora_metadata, + intermediate_tensors, + self.is_first_rank, + self.is_last_rank, ) - + if not get_pp_group().is_last_rank: + assert isinstance(hidden_states, JaxIntermediateTensors) + hidden_states.kv_connector_output = kv_connector_output + return attn_metadata, hidden_states hidden_states = self._select_from_array_fn(hidden_states, logits_indices) logits = self.compute_logits_fn( @@ -1706,3 +1720,35 @@ def _sync_weights( mappings=mappings, transpose_keys=transpose_keys, shard=shard) + + def get_intermediate_tensor_spec(self, num_tokens: int): + impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower() + jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype + num_padded_tokens = runner_utils.get_padded_token_len( + self.num_tokens_paddings, num_tokens) + sharding = NamedSharding(self.mesh, PartitionSpec()) + hidden_size = self.model_config.get_hidden_size() + spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size), + dtype=jax_dtype, + sharding=sharding) + tensor_spec = {"hidden_states": spec, "residual": spec} + return tensor_spec + + def get_uuid_for_jax_transfer(self, + scheduler_output: "VllmSchedulerOutput", + rank: int, step: int) -> int: + ''' + Get a uuid for jax.transfer, here we use the hash of + scheduler_output + counter_step + sender's rank + ''' + scheduler_output_str = "" + if not scheduler_output.num_scheduled_tokens: + scheduler_output_str = "empty_batch" + else: + scheduler_output_str = str( + sorted(scheduler_output.num_scheduled_tokens.items())) + unique_str = f'{scheduler_output_str} {step} {rank}' + import hashlib + hasher = hashlib.sha1() + hasher.update(unique_str.encode('utf-8')) + return int.from_bytes(hasher.digest()[:8], 'big')