Skip to content

Commit a62e29f

Browse files
committed
Medusa: keep draft proposals on GPU
Signed-off-by: dongbo910220 <1275604947@qq.com>
1 parent b9d0504 commit a62e29f

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

vllm/v1/spec_decode/medusa.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,24 @@ def propose(
3838
self,
3939
target_hidden_states: torch.Tensor,
4040
sampling_metadata: SamplingMetadata,
41-
) -> list[list[int]]:
41+
) -> torch.Tensor:
4242
# Generate blocks and compute logits
4343
blocks = self.model(target_hidden_states)
4444
logits = self.model.compute_logits(blocks)
4545

46-
# Get draft tokens and transpose the result
47-
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
48-
# synchronization.
49-
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
50-
return [list(row) for row in zip(*draft_tokens)]
46+
# Compute argmax for each Medusa head and stack into a single tensor
47+
# Shape: [batch_size, num_heads]
48+
draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)
49+
50+
# Sanity check to catch any unexpected shape mismatch early
51+
batch_size = target_hidden_states.shape[0]
52+
num_heads = len(logits)
53+
assert draft_tokens.shape == (
54+
batch_size,
55+
num_heads,
56+
), f"Expected shape ({batch_size}, {num_heads}), got {draft_tokens.shape}"
57+
58+
return draft_tokens
5159

5260
def load_model(self, target_model: nn.Module) -> None:
5361
from vllm.compilation.backends import set_model_tag

0 commit comments

Comments
 (0)