Skip to content

Conversation

@seven-mile
Copy link
Contributor

@seven-mile seven-mile commented Oct 4, 2025

Purpose

When using EagleProposer with DP, vllm hangs because of mismatch communication behavior between DP ranks. The communications for non-tree proposing have the pattern below:

class EagleProposer:
    def propose(...):
        with set_forward_context(...):   # First comm.
            ... = self.model(...)

        for token_index in range(self.num_speculative_tokens - 1):
            with set_forward_context(...):  # Next n-1 comms.
                ... = self.model(...)

where set_forward_context eventually leads to an all_reduce among the DP group.

num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = num_tokens
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
device=device,
dtype=torch.int32)
dist.all_reduce(num_tokens_tensor, group=group)
return num_tokens_tensor.cpu()

Therefore, the root cause is that EagleProposer implements dummy_run as a single forward invocation, but the actual runner does it self.num_speculative_tokens times. We should simply align its behavior with propose method here.


Since end-to-end tree sampler is still WIP #22752 , we don't eagerly integrate its logic into dummy_run now. An assersion about it is added. It should be easy to align with it similarily after its landing.

Test Plan

vllm serve \
    meta-llama/Meta-Llama-3-8B-Instruct \
    --host 0.0.0.0 \
    --port 7000 \
    --seed 42 \
    --disable-log-requests \
    --no-enable-prefix-caching \
    -dp 2 \
    --max-model-len 8192 \
    --max-num-seqs 64 \
    --gpu_memory_utilization 0.8 \
    --speculative-config '{"model":"yuhuili/EAGLE-LLaMA3-Instruct-8B","num_speculative_tokens":8,"max_model_len": 2048}'

Basic functionality

vllm bench serve \
  --backend vllm --model meta-llama/Meta-Llama-3-8B-Instruct \
  --dataset-name sharegpt \
  --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
  --num-prompts 200 \
  --host localhost --port 7000

Correctness

lm_eval --model local-completions \
  --tasks gsm8k \
  --model_args model=meta-llama/Meta-Llama-3-8B-Instruct,base_url=http://0.0.0.0:7000/v1/completions,num_concurrent=1,max_retries=3,tokenized_requests=False

Test Result

Basic functionality

============ Serving Benchmark Result ============
Successful requests:                     200       
Benchmark duration (s):                  12.97     
Total input tokens:                      42659     
Total generated tokens:                  40258     
Request throughput (req/s):              15.42     
Output token throughput (tok/s):         3104.04   
Peak output token throughput (tok/s):    2533.00   
Peak concurrent requests:                200.00    
Total Token throughput (tok/s):          6393.21   
---------------Time to First Token----------------
Mean TTFT (ms):                          1647.35   
Median TTFT (ms):                        1490.01   
P99 TTFT (ms):                           3865.45   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          31.44     
Median TPOT (ms):                        23.57     
P99 TPOT (ms):                           161.47    
---------------Inter-token Latency----------------
Mean ITL (ms):                           44.05     
Median ITL (ms):                         40.25     
P99 ITL (ms):                            87.76     
==================================================

No more hang.

Correctness

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.7551 ± 0.0118
strict-match 5 exact_match 0.7589 ± 0.0118

Reasonable results for LLaMA3 8B.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: seven-mile <i@7li.moe>
@github-actions
Copy link

github-actions bot commented Oct 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a hang issue in speculative decoding with data parallelism by ensuring the dummy_run method in EagleProposer performs the same number of communication operations as the propose method. The fix involves looping num_speculative_tokens times in dummy_run. The change is logical and well-supported by the PR description. I've identified a potential critical issue in the newly added _is_tree_attention helper function that could lead to an IndexError and have provided a suggestion to make the check more robust.

Comment on lines +589 to +591
if not hasattr(self.runner,
"attn_groups") or not self.runner.attn_groups:
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current check for self.runner.attn_groups is not sufficient to prevent a potential IndexError. If self.runner.attn_groups is [[]], the condition not self.runner.attn_groups will evaluate to False, but the subsequent access self.runner.attn_groups[0][0] on line 594 will raise an IndexError because the inner list is empty. This could lead to a server crash. To prevent this, you should also check if the inner list is empty.

Suggested change
if not hasattr(self.runner,
"attn_groups") or not self.runner.attn_groups:
return False
if (not hasattr(self.runner, "attn_groups") or
not self.runner.attn_groups or not self.runner.attn_groups[0]):
return False

@seven-mile
Copy link
Contributor Author

Already covered by #26086.

@seven-mile seven-mile closed this Oct 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant