-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[V1][Spec Decode] Optimize Medusa proposer to avoid GPU-CPU sync #29723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[V1][Spec Decode] Optimize Medusa proposer to avoid GPU-CPU sync #29723
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a valuable optimization to the Medusa speculative decoding proposer by avoiding GPU-CPU synchronization. The change to keep draft token proposals on the GPU as a tensor is well-implemented and the performance gains are clearly articulated. My review includes one suggestion to enhance the robustness of the new sanity check by replacing an assert statement with a standard exception, ensuring the check remains active in production environments.
vllm/v1/spec_decode/medusa.py
Outdated
| assert draft_tokens.shape == ( | ||
| batch_size, | ||
| num_heads, | ||
| ), f"Expected shape ({batch_size}, {num_heads}), got {draft_tokens.shape}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using assert for this sanity check means it will be disabled when Python is run with the -O (optimize) flag, which is a common practice in production environments. To ensure this important shape validation is always performed and to make the code more robust against unexpected issues, it's better to use a conditional check that raises a ValueError.
if draft_tokens.shape != (batch_size, num_heads):
raise ValueError(
f"Expected shape ({batch_size}, {num_heads}), got {draft_tokens.shape}"
)Signed-off-by: dongbo910220 <1275604947@qq.com>
Use explicit if/raise instead of assert to ensure shape validation is not stripped when Python runs with -O flag in production. Signed-off-by: dongbo910220 <1275604947@qq.com>
f1aef53 to
fc62f28
Compare
|
Hi @DarkLight1337 , could you please help review this PR? Thanks |
| target_hidden_states: torch.Tensor, | ||
| sampling_metadata: SamplingMetadata, | ||
| ) -> list[list[int]]: | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there no consequences to returning a different type from this function? If not, why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've thoroughly audited usages of MedusaProposer.propose() to ensure the type change is safe.
- Single call site: only gpu_model_runner.py::propose_draft_token_ids() invokes it.
- Downstream paths:
- GPU path in _prepare_input_ids : it asserts _draft_token_ids is a torch.Tensor and scatters it on GPU. With the old list return this code path couldn’t be taken; returning a tensor makes Medusa compatible with that optimization.
- Return path take_draft_token_ids : explicitly handles both tensor and list, converting tensor→list when needed. So external consumers still receive Python lists.
Type pattern stays consistent: GPU-based proposers (Medusa/Eagle) return torch.Tensor; CPU-based proposers (ngram/suffix) return list[list[int]]; the union type reflects that.
Purpose
This PR optimizes the Medusa speculative decoding proposer by keeping draft token proposals on GPU instead of converting them to CPU lists, eliminating unnecessary GPU-CPU synchronization overhead.
Performance Impact
Tested with Mistral-7B-Instruct-v0.2 + Medusa heads on a single GPU:
.tolist()sync): 1184.2 tok/sThe optimization reduces latency by eliminating:
Draft tokens now remain on GPU throughout the speculative decoding pipeline, which is especially beneficial for larger batch sizes and higher token generation rates.
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.