Skip to content

Conversation

@dongbo910220
Copy link
Contributor

@dongbo910220 dongbo910220 commented Nov 29, 2025

Purpose

This PR optimizes the Medusa speculative decoding proposer by keeping draft token proposals on GPU instead of converting them to CPU lists, eliminating unnecessary GPU-CPU synchronization overhead.

Performance Impact

Tested with Mistral-7B-Instruct-v0.2 + Medusa heads on a single GPU:

  • Test configuration: batch_size=32, max_tokens=256, 64 diverse prompts
  • A/B comparison by git checkout:
    • Old version (.tolist() sync): 1184.2 tok/s
    • Optimized version (GPU tensor): 1195.0 tok/s
    • This optimization contributes ~1% throughput improvement by avoiding GPU-CPU synchronization

The optimization reduces latency by eliminating:

  • GPU-to-CPU memory transfer for each batch of draft tokens
  • Python list operations and transpose
  • Synchronization points between GPU and CPU

Draft tokens now remain on GPU throughout the speculative decoding pipeline, which is especially beneficial for larger batch sizes and higher token generation rates.

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable optimization to the Medusa speculative decoding proposer by avoiding GPU-CPU synchronization. The change to keep draft token proposals on the GPU as a tensor is well-implemented and the performance gains are clearly articulated. My review includes one suggestion to enhance the robustness of the new sanity check by replacing an assert statement with a standard exception, ensuring the check remains active in production environments.

Comment on lines 53 to 56
assert draft_tokens.shape == (
batch_size,
num_heads,
), f"Expected shape ({batch_size}, {num_heads}), got {draft_tokens.shape}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using assert for this sanity check means it will be disabled when Python is run with the -O (optimize) flag, which is a common practice in production environments. To ensure this important shape validation is always performed and to make the code more robust against unexpected issues, it's better to use a conditional check that raises a ValueError.

if draft_tokens.shape != (batch_size, num_heads):
    raise ValueError(
        f"Expected shape ({batch_size}, {num_heads}), got {draft_tokens.shape}"
    )

Signed-off-by: dongbo910220 <1275604947@qq.com>
Use explicit if/raise instead of assert to ensure shape validation
is not stripped when Python runs with -O flag in production.

Signed-off-by: dongbo910220 <1275604947@qq.com>
@dongbo910220 dongbo910220 force-pushed the optimize/medusa-gpu-tensor branch from f1aef53 to fc62f28 Compare November 29, 2025 17:16
@dongbo910220
Copy link
Contributor Author

Hi @DarkLight1337 , could you please help review this PR? Thanks

target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> list[list[int]]:
) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there no consequences to returning a different type from this function? If not, why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've thoroughly audited usages of MedusaProposer.propose() to ensure the type change is safe.

  • Single call site: only gpu_model_runner.py::propose_draft_token_ids() invokes it.
  • Downstream paths:
    1. GPU path in _prepare_input_ids : it asserts _draft_token_ids is a torch.Tensor and scatters it on GPU. With the old list return this code path couldn’t be taken; returning a tensor makes Medusa compatible with that optimization.
    2. Return path take_draft_token_ids : explicitly handles both tensor and list, converting tensor→list when needed. So external consumers still receive Python lists.

Type pattern stays consistent: GPU-based proposers (Medusa/Eagle) return torch.Tensor; CPU-based proposers (ngram/suffix) return list[list[int]]; the union type reflects that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants