diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 12b903ccaca9..c565af753f46 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -38,16 +38,24 @@ def propose( self, target_hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> list[list[int]]: + ) -> torch.Tensor: # Generate blocks and compute logits blocks = self.model(target_hidden_states) logits = self.model.compute_logits(blocks) - # Get draft tokens and transpose the result - # TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU - # synchronization. - draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits] - return [list(row) for row in zip(*draft_tokens)] + # Compute argmax for each Medusa head and stack into a single tensor + # Shape: [batch_size, num_heads] + draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1) + + # Always validate shape (avoid assert being stripped by -O) + batch_size = target_hidden_states.shape[0] + num_heads = len(logits) + if draft_tokens.shape != (batch_size, num_heads): + raise ValueError( + f"Expected shape ({batch_size}, {num_heads}), got {draft_tokens.shape}" + ) + + return draft_tokens def load_model(self, target_model: nn.Module) -> None: from vllm.compilation.backends import set_model_tag