Skip to content

Conversation

@TheEpicDolphin
Copy link
Collaborator

@TheEpicDolphin TheEpicDolphin commented Aug 12, 2025

Purpose

Continuing with my work in this PR: #20401 , where I added support for drafting a tree of speculative tokens, that are then validated by the target model. In this PR, I add the class that performs rejection sampling for those draft tokens, so that they conform to the target model's output distribution. This class is based off of RejectionSampler, but with some key differences necessary to support a tree structure for drafted tokens. I added some tests for this new class to verify it's correctness.

In order to make it work with tree-drafted tokens, I also had to "remap" the KV cache immediately after rejection sampling. This is handled by the new _remap_kv_cache method I added. It basically takes the KVs for the N accepted branch of tokens, and copies them over into the physical KV cache memory corresponding to the first N contiguous paged KV slots for the draft request.

In addition, I also made some refactors to the tree attention parameters to improve readability and performance. I created a new class called TreeDrafterParams which is created during the SpeculativeConfig initialization, and precomputes several properties from the spec token tree so that other tree-attention systems can use it (without re-computing themselves). Examples: attention mask, children per level, etc.

With this PR, we now have a fully functional tree speculative decoding implementation in V1!

Test Plan

Automated Testing

New tree rejection sampler tests:

(py312conda) bash-5.1$ pytest tests/v1/sample/test_tree_rejection_sampler.py
=============================================================================== test session starts ===============================================================================
platform linux -- Python 3.12.9, pytest-8.4.1, pluggy-1.6.0
rootdir: /data/users/gdelfin/gitrepos/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0, asyncio-1.1.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 9 items                                                                                                                                                                 

tests/v1/sample/test_tree_rejection_sampler.py .........                                                                                                                    [100%]

================================================================================ 9 passed in 6.51s ================================================================================

Eagle tree proposer test:

(py312conda) bash-5.1$ pytest tests/v1/spec_decode/test_eagle.py -k test_propose_tree
=============================================================================== test session starts ===============================================================================
platform linux -- Python 3.12.9, pytest-8.4.1, pluggy-1.6.0
rootdir: /data/users/gdelfin/gitrepos/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0, asyncio-1.1.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 47 items / 43 deselected / 4 selected                                                                                                                                   

tests/v1/spec_decode/test_eagle.py ....                                                                                                                                     [100%]

======================================================================== 4 passed, 43 deselected in 10.82s ========================================================================

Spec decode e2e test:

(py312conda) bash-5.1$ pytest tests/v1/e2e/test_spec_decode.py -k test_tree_eagle_correctness
=============================================================================================================================== test session starts ===============================================================================================================================
platform linux -- Python 3.12.9, pytest-8.4.1, pluggy-1.6.0
rootdir: /data/users/gdelfin/gitrepos/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0, asyncio-1.1.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 21 items / 13 deselected / 8 selected                                                                                                                                                                                                                                   

tests/v1/e2e/test_spec_decode.py ........                                                                                                                                                                                                                                   [100%]

================================================================================================================================ warnings summary =================================================================================================================================
tests/v1/e2e/test_spec_decode.py: 16 warnings
  /home/gdelfin/.conda/envs/py312conda/lib/python3.12/multiprocessing/popen_fork.py:66: DeprecationWarning: This process (pid=2475023) is multi-threaded, use of fork() may lead to deadlocks in the child.
    self.pid = os.fork()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================================================================ 8 passed, 13 deselected, 16 warnings in 382.20s (0:06:22) ============================================================================================================

Manual Testing

I tested manually with the following tree of draft tokens:

ROOT
├── 0
│  ├── 0
│  └── 1
│  └── 2
└── 1
   ├── 0
   └── 1
   └── 2

Server

export VLLM_TORCH_PROFILER_DIR=~/traces/vllm
export LLAMA_MODEL=meta-llama/Llama-3.1-8B-Instruct
export DRAFT_MODEL=yuhuili/EAGLE-LLaMA3.1-Instruct-8B
export VLLM_USE_V1=1
export VLLM_ATTENTION_BACKEND=TREE_ATTN
export SPEC_DEC_CONFIG='{"method": "eagle", "model": "'$DRAFT_MODEL'", "num_speculative_tokens": 8, "draft_tensor_parallel_size": 1, "max_model_len": 2048, "speculative_token_tree": "[(0,), (1,), (0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]"}'
vllm serve $LLAMA_MODEL --disable-log-requests --tensor-parallel-size=1 --max-num-seqs=64 --max-model-len=32768 --no-enable-prefix-caching --speculative-config="$SPEC_DEC_CONFIG" 2>&1 | tee ~/server_logs/vllm_server.log

Client

from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
response = client.chat.completions.create(model="meta-llama/Llama-3.1-8B-Instruct", messages=[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Explain the theory of relativity in simple terms."}],temperature=0.0)
print(response.choices[0].message.content)

Response

The theory of relativity, developed by Albert Einstein, is a fundamental concept in physics that explains how space and time are connected. It's a bit complex, but I'll try to break it down in simple terms.

**Special Relativity (1905)**

Imagine you're on a train, and you throw a ball straight up in the air. From your perspective on the train, the ball goes up and comes down in a straight line. Now, imagine someone is standing outside the train, watching you throw the ball. From their perspective, the ball doesn't just go up and down – it also moves forward, because the train is moving really fast.

Einstein said that how we measure time and space depends on how fast we're moving and where we are. If you're on the train, time and space seem normal. But if you're standing outside, watching the train, time and space seem different. This is because time and space are connected, and they can appear to change depending on your frame of reference.

**Time Dilation**

Here's a key idea: time can appear to slow down or speed up depending on how fast you're moving. If you were to travel close to the speed of light, time would appear to slow down for you relative to someone who is standing still. This is called time dilation.

**General Relativity (1915)**

Imagine you're standing on a trampoline. If you put a heavy object, like a bowling ball, on the trampoline, it will warp and curve, creating a dent. That's kind of like what gravity does to space and time. According to general relativity, massive objects like planets and stars warp the fabric of space and time, creating gravity.

**Key Takeaways**

1. **Time and space are connected**: How we measure time and space depends on how fast we're moving and where we are.
2. **Time can appear to slow down or speed up**: Time dilation occurs when you're moving at high speeds or in strong gravitational fields.
3. **Gravity warps space and time**: Massive objects create curvatures in space and time, which we experience as gravity.

The theory of relativity revolutionized our understanding of the universe and has had a profound impact on modern physics and astronomy.

The response is valid and aligns with what the LLM would output if running with standard decoding or spec decoding with FA3.

Benchmarks

I benchmarked using the philschmid/mt-bench and likaixin/InstructCoder datasets. Below are the commands I used:
Server

export VLLM_TORCH_PROFILER_DIR=~/traces/vllm
export LLAMA_MODEL=meta-llama/Llama-3.1-8B-Instruct
export DRAFT_MODEL=yuhuili/EAGLE-LLaMA3.1-Instruct-8B
export VLLM_USE_V1=1
export VLLM_ATTENTION_BACKEND=TREE_ATTN
export SPEC_DEC_CONFIG='{"method": "eagle", "model": "'$DRAFT_MODEL'", "num_speculative_tokens": 8, "draft_tensor_parallel_size": 1, "max_model_len": 2048, "speculative_token_tree": "[(0,), (1,), (0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]"}'
vllm serve $LLAMA_MODEL --disable-log-requests --tensor-parallel-size=1 --max-num-seqs=64 --max-model-len=32768 --no-enable-prefix-caching --speculative-config="$SPEC_DEC_CONFIG" 2>&1 | tee ~/server_logs/vllm_server.log

Client

vllm bench serve --model $LLAMA_MODEL --tokenizer $LLAMA_MODEL --host 0.0.0.0 --dataset-name hf --dataset-path <dataset> --ignore-eos --request-rate inf --max-concurrency <batch_size>

Using both datasets and batch sizes 1 and 8, I compared the serving benchmark results for 2 spec tokens (baseline), and a 2 => 3 spec token tree that looks like this:

ROOT
├── 0
│  ├── 0
│  └── 1
│  └── 2
└── 1
   ├── 0
   └── 1
   └── 2

For the baseline, I decided to use the flash attention 3 backend, since it's the only other one that supports spec decoding in V1. Below is a short summary of the results:

mt-bench

Metric BS=1, FA3 BS=1, Tree BS=8, FA3 BS=8, Tree BS=16, FA3 BS=16, Tree
Duration (s) 1999.14 1990.13 269.42 265.15 118.45 131.05
Output Token Throughput (tok/s) 128.06 128.63 884.61 902.72 1813.36 1640.95
Mean Acceptance Length 2.10 2.37 2.03 2.37 2.04 2.38

In all cases, tree spec decoding increases the mean acceptance length. In the mt-bench dataset, tree spec decoding results in a higher output token throughput for batch sizes 1 and 8. As batch size gets larger, the benefit of more accepted tokens per decode is not enough to compensate for the added cost during forward attention, and tree spec decoding performs worse. It's important to note that the FA3 kernel is much faster than triton attention, which is used in the Tree Attention backend (with a query-on-query attention bias).

To isolate what we are gaining from tree drafts, I also compared chains vs tree draft structures using the same Tree Attention backend. This eliminates difference in kernel

mt-bench

Metric BS=1, Chain BS=1, Tree BS=8, Chain BS=8, Tree BS=16, Chain BS=16, Tree
Duration (s) 2240.21 1990.13 280.96 265.15 136.86 131.05
Output Token Throughput (tok/s) 114.27 128.63 847.38 902.72 1567.5 1640.95
Mean Acceptance Length 2.05 2.37 2.06 2.37 2.05 2.38

Spec decoding with the tree of drafts performs better than with a chain of 2 drafts for batch sizes 1, 8, and 16. I did not test further with larger batch sizes, but I assume that gain becomes less pronounced.

A detailed breakdown of the benchmark results can be found in this spreadsheet.

Next Steps

  • Use a different attention kernel because triton attention is slow compared to FA3. From benchmarking, each triton attention forward pass is ~35% slower than the FA3 forward pass. FA2+tree attention mask (feat: implement tree attention mask support for FlashAttention-2 flash-attention#81) is a good candidate!
  • As @sgrigory pointed out in the comments, this implementation is currently using greedy acceptance, which differs from the probabilistic approach used in EAGLE. The latter should increase the avg # of acceptances and improve performance even further.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify
Copy link

mergify bot commented Aug 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @TheEpicDolphin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a tree-based sampler for speculative decoding, which is a significant feature for improving performance. The changes are well-structured, with new components like TreeDrafterParams and TreeRejectionSampler to handle the tree-based logic. The Eagle proposer is also updated to support tree drafting and to output draft probabilities. My review has identified a few critical issues regarding the handling of irregular token trees and the correctness of probability calculations, which should be addressed to ensure the feature works as expected.

@TheEpicDolphin TheEpicDolphin force-pushed the tree_sampler_v1 branch 2 times, most recently from 9c59df6 to 3da3a66 Compare August 16, 2025 19:44
@mergify mergify bot added llama Related to Llama models and removed needs-rebase labels Aug 16, 2025
@mergify
Copy link

mergify bot commented Aug 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @TheEpicDolphin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 20, 2025
@TheEpicDolphin TheEpicDolphin force-pushed the tree_sampler_v1 branch 17 times, most recently from 544ec15 to aa5dd93 Compare August 25, 2025 22:25
@TheEpicDolphin TheEpicDolphin force-pushed the tree_sampler_v1 branch 4 times, most recently from 26e1e34 to cd56f84 Compare September 12, 2025 18:44
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 13, 2025
@TheEpicDolphin TheEpicDolphin force-pushed the tree_sampler_v1 branch 11 times, most recently from 529d242 to 44a761a Compare September 14, 2025 17:02
@LucasWilkinson LucasWilkinson dismissed their stale review September 16, 2025 00:09

Removing review; need green light from Woosuk

@mergify
Copy link

mergify bot commented Sep 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @TheEpicDolphin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 16, 2025
@TheEpicDolphin TheEpicDolphin force-pushed the tree_sampler_v1 branch 2 times, most recently from 82bad36 to d0ad7b2 Compare September 16, 2025 22:51
@TheEpicDolphin
Copy link
Collaborator Author

@LucasWilkinson , @WoosukKwon , i refactored so that there are even fewer changes in gpu model runner. Basically, i moved the logic of fetching bonus and target logits into RejectionSampler/TreeRejectionSampler. Please let me know what you think

Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
Comment on lines -1847 to -1857
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids

# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was moved into RejectionSampler.forward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models needs-rebase ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants