-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[V1] implement tree sampler for draft token acceptance #22752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a tree-based sampler for speculative decoding, which is a significant feature for improving performance. The changes are well-structured, with new components like TreeDrafterParams and TreeRejectionSampler to handle the tree-based logic. The Eagle proposer is also updated to support tree drafting and to output draft probabilities. My review has identified a few critical issues regarding the handling of irregular token trees and the correctness of probability calculations, which should be addressed to ensure the feature works as expected.
9c59df6 to
3da3a66
Compare
3da3a66 to
66d7154
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
544ec15 to
aa5dd93
Compare
26e1e34 to
cd56f84
Compare
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
cd56f84 to
51c78b9
Compare
529d242 to
44a761a
Compare
Removing review; need green light from Woosuk
|
This pull request has merge conflicts that must be resolved before it can be |
82bad36 to
d0ad7b2
Compare
|
@LucasWilkinson , @WoosukKwon , i refactored so that there are even fewer changes in gpu model runner. Basically, i moved the logic of fetching bonus and target logits into |
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
d0ad7b2 to
41d9d55
Compare
| bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] | ||
| sampler_output = self.sampler( | ||
| logits=bonus_logits, | ||
| sampling_metadata=sampling_metadata, | ||
| ) | ||
| bonus_token_ids = sampler_output.sampled_token_ids | ||
|
|
||
| # Just like `bonus_logits`, `target_logits` is a new tensor with | ||
| # separate storage from the original `logits` tensor. Therefore, | ||
| # it is safe to update `target_logits` in place. | ||
| target_logits = logits[spec_decode_metadata.target_logits_indices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was moved into RejectionSampler.forward.
Purpose
Continuing with my work in this PR: #20401 , where I added support for drafting a tree of speculative tokens, that are then validated by the target model. In this PR, I add the class that performs rejection sampling for those draft tokens, so that they conform to the target model's output distribution. This class is based off of RejectionSampler, but with some key differences necessary to support a tree structure for drafted tokens. I added some tests for this new class to verify it's correctness.
In order to make it work with tree-drafted tokens, I also had to "remap" the KV cache immediately after rejection sampling. This is handled by the new
_remap_kv_cachemethod I added. It basically takes the KVs for theNaccepted branch of tokens, and copies them over into the physical KV cache memory corresponding to the firstNcontiguous paged KV slots for the draft request.In addition, I also made some refactors to the tree attention parameters to improve readability and performance. I created a new class called
TreeDrafterParamswhich is created during theSpeculativeConfiginitialization, and precomputes several properties from the spec token tree so that other tree-attention systems can use it (without re-computing themselves). Examples: attention mask, children per level, etc.With this PR, we now have a fully functional tree speculative decoding implementation in V1!
Test Plan
Automated Testing
New tree rejection sampler tests:
Eagle tree proposer test:
Spec decode e2e test:
Manual Testing
I tested manually with the following tree of draft tokens:
Server
Client
Response
The response is valid and aligns with what the LLM would output if running with standard decoding or spec decoding with FA3.
Benchmarks
I benchmarked using the philschmid/mt-bench and likaixin/InstructCoder datasets. Below are the commands I used:
Server
Client
Using both datasets and batch sizes 1 and 8, I compared the serving benchmark results for 2 spec tokens (baseline), and a 2 => 3 spec token tree that looks like this:
For the baseline, I decided to use the flash attention 3 backend, since it's the only other one that supports spec decoding in V1. Below is a short summary of the results:
mt-bench
In all cases, tree spec decoding increases the mean acceptance length. In the mt-bench dataset, tree spec decoding results in a higher output token throughput for batch sizes 1 and 8. As batch size gets larger, the benefit of more accepted tokens per decode is not enough to compensate for the added cost during forward attention, and tree spec decoding performs worse. It's important to note that the FA3 kernel is much faster than triton attention, which is used in the Tree Attention backend (with a query-on-query attention bias).
To isolate what we are gaining from tree drafts, I also compared chains vs tree draft structures using the same Tree Attention backend. This eliminates difference in kernel
mt-bench
Spec decoding with the tree of drafts performs better than with a chain of 2 drafts for batch sizes 1, 8, and 16. I did not test further with larger batch sizes, but I assume that gain becomes less pronounced.
A detailed breakdown of the benchmark results can be found in this spreadsheet.
Next Steps