diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index 7a41c668d764..e907d6f8a6d8 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -214,6 +214,8 @@ def _update_num_reasoning_tokens(self): def append_output(self, output: RequestOutput) -> None: output_token_ids = output.outputs[0].token_ids + # Reset parser for each append_output call to handle multi-turn scenarios + # where the parser needs to start fresh for each assistant response self.parser = get_streamable_parser_for_assistant() for token_id in output_token_ids: self.parser.process(token_id) @@ -504,6 +506,8 @@ def __init__(self, *args, **kwargs): self.encoding = get_encoding() self.last_tok = None self.first_tok_of_message = True + # Track how many tokens have been processed to avoid buggy token search + self.processed_token_count = 0 @property def messages(self) -> list: @@ -519,8 +523,10 @@ def append_output(self, output: RequestOutput) -> None: # (finished=True), then the next token processed will mark the # beginning of a new message self.first_tok_of_message = output.finished - for tok in output.outputs[0].token_ids: + token_ids = output.outputs[0].token_ids + for tok in token_ids: self.parser.process(tok) + self.processed_token_count += 1 self._update_decode_token_usage(output) # For streaming, update previous turn when message is complete @@ -529,7 +535,9 @@ def append_output(self, output: RequestOutput) -> None: self.current_turn_metrics.reset() # Check if the current token is part of reasoning content self._update_num_reasoning_tokens() - self.last_tok = tok + # Only update last_tok if we actually processed tokens + if token_ids: + self.last_tok = tok if len(self._messages) - self.num_init_messages < len(self.parser.messages): self._messages.extend( self.parser.messages[len(self._messages) - self.num_init_messages :] @@ -546,8 +554,13 @@ def append_tool_output(self, output: list[Message]) -> None: toks = self.encoding.render(msg) for tok in toks: self.parser.process(tok) + self.processed_token_count += 1 self.last_tok = toks[-1] - # TODO: add tool_output messages to self._messages + # Add tool output messages from parser to self._messages + # (same pattern as append_output) + msg_count = len(self._messages) - self.num_init_messages + if msg_count < len(self.parser.messages): + self._messages.extend(self.parser.messages[msg_count:]) def is_expecting_start(self) -> bool: return self.parser.state == StreamState.EXPECT_START @@ -556,17 +569,15 @@ def is_assistant_action_turn(self) -> bool: return self.last_tok in self.encoding.stop_tokens_for_assistant_actions() def render_for_completion(self) -> list[int]: - # now this list of tokens as next turn's starting tokens - # `<|start|>assistant`, - # we need to process them in parser. + # Render all messages including the new turn start tokens + # e.g. [...previous tokens...] [<|start|>] [assistant] rendered_tokens = super().render_for_completion() - last_n = -1 - to_process = [] - while rendered_tokens[last_n] != self.last_tok: - to_process.append(rendered_tokens[last_n]) - last_n -= 1 - for tok in reversed(to_process): + # Process only the NEW tokens that we haven't seen before + # This avoids the buggy token search that could match at wrong positions + to_process = rendered_tokens[self.processed_token_count :] + for tok in to_process: self.parser.process(tok) + self.processed_token_count += 1 return rendered_tokens