From a62e29f08ba9522e534884f04d75624cf667c688 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Sat, 29 Nov 2025 18:06:57 +0800 Subject: [PATCH 1/2] Medusa: keep draft proposals on GPU Signed-off-by: dongbo910220 <1275604947@qq.com> --- vllm/v1/spec_decode/medusa.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 12b903ccaca9..15afe7bc3c84 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -38,16 +38,24 @@ def propose( self, target_hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> list[list[int]]: + ) -> torch.Tensor: # Generate blocks and compute logits blocks = self.model(target_hidden_states) logits = self.model.compute_logits(blocks) - # Get draft tokens and transpose the result - # TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU - # synchronization. - draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits] - return [list(row) for row in zip(*draft_tokens)] + # Compute argmax for each Medusa head and stack into a single tensor + # Shape: [batch_size, num_heads] + draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1) + + # Sanity check to catch any unexpected shape mismatch early + batch_size = target_hidden_states.shape[0] + num_heads = len(logits) + assert draft_tokens.shape == ( + batch_size, + num_heads, + ), f"Expected shape ({batch_size}, {num_heads}), got {draft_tokens.shape}" + + return draft_tokens def load_model(self, target_model: nn.Module) -> None: from vllm.compilation.backends import set_model_tag From fc62f28224e9351b79b89c378dead7e524b1d6b6 Mon Sep 17 00:00:00 2001 From: dongbo910220 <1275604947@qq.com> Date: Sat, 29 Nov 2025 18:32:29 +0800 Subject: [PATCH 2/2] Replace assert with explicit ValueError for shape validation Use explicit if/raise instead of assert to ensure shape validation is not stripped when Python runs with -O flag in production. Signed-off-by: dongbo910220 <1275604947@qq.com> --- vllm/v1/spec_decode/medusa.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 15afe7bc3c84..c565af753f46 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -47,13 +47,13 @@ def propose( # Shape: [batch_size, num_heads] draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1) - # Sanity check to catch any unexpected shape mismatch early + # Always validate shape (avoid assert being stripped by -O) batch_size = target_hidden_states.shape[0] num_heads = len(logits) - assert draft_tokens.shape == ( - batch_size, - num_heads, - ), f"Expected shape ({batch_size}, {num_heads}), got {draft_tokens.shape}" + if draft_tokens.shape != (batch_size, num_heads): + raise ValueError( + f"Expected shape ({batch_size}, {num_heads}), got {draft_tokens.shape}" + ) return draft_tokens