Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions vllm/v1/spec_decode/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,24 @@ def propose(
self,
target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> list[list[int]]:
) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there no consequences to returning a different type from this function? If not, why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've thoroughly audited usages of MedusaProposer.propose() to ensure the type change is safe.

  • Single call site: only gpu_model_runner.py::propose_draft_token_ids() invokes it.
  • Downstream paths:
    1. GPU path in _prepare_input_ids : it asserts _draft_token_ids is a torch.Tensor and scatters it on GPU. With the old list return this code path couldn’t be taken; returning a tensor makes Medusa compatible with that optimization.
    2. Return path take_draft_token_ids : explicitly handles both tensor and list, converting tensor→list when needed. So external consumers still receive Python lists.

Type pattern stays consistent: GPU-based proposers (Medusa/Eagle) return torch.Tensor; CPU-based proposers (ngram/suffix) return list[list[int]]; the union type reflects that.

# Generate blocks and compute logits
blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks)

# Get draft tokens and transpose the result
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
# synchronization.
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
return [list(row) for row in zip(*draft_tokens)]
# Compute argmax for each Medusa head and stack into a single tensor
# Shape: [batch_size, num_heads]
draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)

# Always validate shape (avoid assert being stripped by -O)
batch_size = target_hidden_states.shape[0]
num_heads = len(logits)
if draft_tokens.shape != (batch_size, num_heads):
raise ValueError(
f"Expected shape ({batch_size}, {num_heads}), got {draft_tokens.shape}"
)

return draft_tokens

def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag
Expand Down