Skip to content

Commit 44822d7

Browse files
authored
[BugFix] Preserve spec decoding uniform decode when scheduling (#29759)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 342c4f1 commit 44822d7

File tree

3 files changed

+25
-17
lines changed

3 files changed

+25
-17
lines changed

tests/v1/e2e/test_spec_decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ def test_suffix_decoding_acceptance(
191191
# Expect the acceptance rate to improve.
192192
assert first_accept_rate < last_accept_rate
193193

194-
# Heuristic: expect at least 85% acceptance rate at the end.
195-
assert last_accept_rate > 0.85
194+
# Heuristic: expect at least 82.5% acceptance rate at the end.
195+
assert last_accept_rate > 0.825
196196

197197
del spec_llm
198198
torch.cuda.empty_cache()

vllm/v1/core/sched/async_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _update_after_schedule(
3333
# in this scheduling step.
3434
request.num_output_placeholders += 1 + cur_num_spec_tokens
3535
# Add placeholders for the new tokens in spec_token_ids.
36-
# Wwe will update the actual spec token ids in the worker process.
36+
# We will update the actual spec token ids in the worker process.
3737
request.spec_token_ids = [-1] * self.num_spec_tokens
3838

3939
scheduler_output.pending_structured_output_tokens = (

vllm/v1/core/sched/scheduler.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,22 @@ def schedule(self) -> SchedulerOutput:
236236
while req_index < len(self.running) and token_budget > 0:
237237
request = self.running[req_index]
238238

239+
if (
240+
request.num_output_placeholders > 0
241+
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
242+
# Since output placeholders are also included in the computed tokens
243+
# count, we subtract (num_output_placeholders - 1) to remove any draft
244+
# tokens, so that we can be sure no further steps are needed even if
245+
# they are all rejected.
246+
and request.num_computed_tokens + 2 - request.num_output_placeholders
247+
>= request.num_prompt_tokens + request.max_tokens
248+
):
249+
# Async scheduling: Avoid scheduling an extra step when we are sure that
250+
# the previous step has reached request.max_tokens. We don't schedule
251+
# partial draft tokens since this prevents uniform decode optimizations.
252+
req_index += 1
253+
continue
254+
239255
num_new_tokens = (
240256
request.num_tokens_with_spec
241257
+ request.num_output_placeholders
@@ -245,18 +261,10 @@ def schedule(self) -> SchedulerOutput:
245261
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
246262
num_new_tokens = min(num_new_tokens, token_budget)
247263

248-
num_spec_placeholders = max(0, request.num_output_placeholders - 1)
249-
max_total_tokens = min(
250-
# Avoid scheduling tokens that we're sure won't will be needed based on
251-
# request.max_tokens. For this calculation we assume placeholder
252-
# speculated output tokens are rejected.
253-
request.num_prompt_tokens + request.max_tokens + num_spec_placeholders,
254-
# Make sure the input position does not exceed the max model len.
255-
# This is necessary when using spec decoding.
256-
self.max_model_len,
257-
)
264+
# Make sure the input position does not exceed the max model len.
265+
# This is necessary when using spec decoding.
258266
num_new_tokens = min(
259-
num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens
267+
num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens
260268
)
261269

262270
# Schedule encoder inputs.
@@ -799,15 +807,15 @@ def _make_cached_request_data(
799807
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
800808
req_id = req.request_id
801809
req_ids.append(req_id)
802-
num_tokens = num_scheduled_tokens[req_id] - len(
803-
spec_decode_tokens.get(req_id, ())
804-
)
805810
if self.use_pp:
806811
# When using PP, the scheduler sends the sampled tokens back,
807812
# because there's no direct communication between the first-
808813
# stage worker and the last-stage worker. Otherwise, we don't
809814
# need to send the sampled tokens back because the model runner
810815
# will cache them.
816+
num_tokens = num_scheduled_tokens[req_id] - len(
817+
spec_decode_tokens.get(req_id, ())
818+
)
811819
token_ids = req.all_token_ids[
812820
req.num_computed_tokens : req.num_computed_tokens + num_tokens
813821
]

0 commit comments

Comments
 (0)