@@ -236,6 +236,22 @@ def schedule(self) -> SchedulerOutput:
236236 while req_index < len (self .running ) and token_budget > 0 :
237237 request = self .running [req_index ]
238238
239+ if (
240+ request .num_output_placeholders > 0
241+ # This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
242+ # Since output placeholders are also included in the computed tokens
243+ # count, we subtract (num_output_placeholders - 1) to remove any draft
244+ # tokens, so that we can be sure no further steps are needed even if
245+ # they are all rejected.
246+ and request .num_computed_tokens + 2 - request .num_output_placeholders
247+ >= request .num_prompt_tokens + request .max_tokens
248+ ):
249+ # Async scheduling: Avoid scheduling an extra step when we are sure that
250+ # the previous step has reached request.max_tokens. We don't schedule
251+ # partial draft tokens since this prevents uniform decode optimizations.
252+ req_index += 1
253+ continue
254+
239255 num_new_tokens = (
240256 request .num_tokens_with_spec
241257 + request .num_output_placeholders
@@ -245,18 +261,10 @@ def schedule(self) -> SchedulerOutput:
245261 num_new_tokens = self .scheduler_config .long_prefill_token_threshold
246262 num_new_tokens = min (num_new_tokens , token_budget )
247263
248- num_spec_placeholders = max (0 , request .num_output_placeholders - 1 )
249- max_total_tokens = min (
250- # Avoid scheduling tokens that we're sure won't will be needed based on
251- # request.max_tokens. For this calculation we assume placeholder
252- # speculated output tokens are rejected.
253- request .num_prompt_tokens + request .max_tokens + num_spec_placeholders ,
254- # Make sure the input position does not exceed the max model len.
255- # This is necessary when using spec decoding.
256- self .max_model_len ,
257- )
264+ # Make sure the input position does not exceed the max model len.
265+ # This is necessary when using spec decoding.
258266 num_new_tokens = min (
259- num_new_tokens , max_total_tokens - 1 - request .num_computed_tokens
267+ num_new_tokens , self . max_model_len - 1 - request .num_computed_tokens
260268 )
261269
262270 # Schedule encoder inputs.
@@ -799,15 +807,15 @@ def _make_cached_request_data(
799807 for idx , req in enumerate (itertools .chain (running_reqs , resumed_reqs )):
800808 req_id = req .request_id
801809 req_ids .append (req_id )
802- num_tokens = num_scheduled_tokens [req_id ] - len (
803- spec_decode_tokens .get (req_id , ())
804- )
805810 if self .use_pp :
806811 # When using PP, the scheduler sends the sampled tokens back,
807812 # because there's no direct communication between the first-
808813 # stage worker and the last-stage worker. Otherwise, we don't
809814 # need to send the sampled tokens back because the model runner
810815 # will cache them.
816+ num_tokens = num_scheduled_tokens [req_id ] - len (
817+ spec_decode_tokens .get (req_id , ())
818+ )
811819 token_ids = req .all_token_ids [
812820 req .num_computed_tokens : req .num_computed_tokens + num_tokens
813821 ]
0 commit comments