diff --git a/tests/entrypoints/openai/test_responses_streaming_tool_calls.py b/tests/entrypoints/openai/test_responses_streaming_tool_calls.py new file mode 100644 index 000000000000..d2997babed88 --- /dev/null +++ b/tests/entrypoints/openai/test_responses_streaming_tool_calls.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import pytest_asyncio +from openai.types.responses import ( + ResponseFunctionCallArgumentsDeltaEvent, + ResponseFunctionCallArgumentsDoneEvent, + ResponseOutputItemAddedEvent, +) + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" + +INTEGRATION_TOOLS = [ + { + "type": "function", + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "country": {"type": "string"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["city", "country", "unit"], + }, + } +] + + +@pytest.fixture(scope="module") +def responses_server(): + args = [ + "--dtype", + "half", + "--enable-auto-tool-choice", + "--structured-outputs-config.backend", + "xgrammar", + "--tool-call-parser", + "hermes", + "--reasoning-parser", + "qwen3", + "--gpu-memory-utilization", + "0.3", + "--max-model-len", + "2048", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def responses_client(responses_server): + async with responses_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "tool_choice", + [ + "auto", + ], +) +async def test_responses_streaming_tool_calls_e2e(responses_client, tool_choice): + stream = await responses_client.responses.create( + model=MODEL_NAME, + input=( + "Use the weather tool to get the temperature for Berlin, Germany in " + "fahrenheit" + ), + tools=INTEGRATION_TOOLS, + tool_choice=tool_choice, + stream=True, + temperature=0, + ) + + added_event = None + arg_chunks: list[str] = [] + final_arguments: str | None = None + async for event in stream: + if ( + isinstance(event, ResponseOutputItemAddedEvent) + and getattr(event.item, "type", None) == "function_call" + ): + added_event = event + if isinstance(event, ResponseFunctionCallArgumentsDeltaEvent): + arg_chunks.append(event.delta) + if isinstance(event, ResponseFunctionCallArgumentsDoneEvent): + final_arguments = event.arguments + + assert added_event is not None + assert final_arguments is not None + assert final_arguments == "".join(arg_chunks) + args_text = final_arguments.lower() + assert "berlin" in args_text + assert "germany" in args_text + assert "fahrenheit" in args_text diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index f546dbda7fef..e2a4942fe5d5 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -3,6 +3,7 @@ import asyncio import json +import re import time import uuid from collections import deque @@ -13,6 +14,7 @@ from typing import Final import jinja2 +import partial_json_parser from fastapi import Request from openai.types.responses import ( ResponseCodeInterpreterCallCodeDeltaEvent, @@ -55,6 +57,7 @@ from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, ChatTemplateContentFormatOption, + make_tool_call_id, ) from vllm.entrypoints.context import ( ConversationContext, @@ -76,7 +79,9 @@ ) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, ErrorResponse, InputTokensDetails, OutputTokensDetails, @@ -1128,6 +1133,129 @@ def _make_store_not_supported_error(self) -> ErrorResponse: status_code=HTTPStatus.BAD_REQUEST, ) + @staticmethod + def _bracket_level(s: str, opening: str = "{", closing: str = "}") -> int: + """Calculate the current level of nested brackets in a given string.""" + level = 0 + for char in s: + if char == opening: + level += 1 + elif char == closing: + level -= 1 + return level + + @staticmethod + def _filter_delta_text(delta_text: str, previous_text: str) -> tuple[str, bool]: + """Trim tool call JSON to only emit the current in-progress block.""" + bracket_level = OpenAIServingResponses._bracket_level(previous_text) + updated_delta, passed_zero = "", False + for c in delta_text: + if c == "{": + bracket_level += 1 + passed_zero = bracket_level == 0 + elif c == "}": + bracket_level -= 1 + passed_zero = bracket_level == 0 + + if bracket_level != 0: + updated_delta += c + else: + if c == ",": + break + return updated_delta, passed_zero + + def _extract_tool_call_required_streaming( + self, + previous_text: str, + current_text: str | None, + delta_text: str, + function_name_returned: bool, + ) -> tuple[DeltaMessage | None, bool]: + """Parse streaming tool calls when tool_choice is forced/required.""" + if current_text is None or current_text == "": + return None, function_name_returned + try: + obj = partial_json_parser.loads(current_text) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug("not enough tokens to parse into JSON yet") + obj = None + + if obj is None or not isinstance(obj, list) or not len(obj) > 0: + return None, False + + _, finishes_previous_tool = self._filter_delta_text(delta_text, previous_text) + current_tool_call = obj[-1] + if not finishes_previous_tool and ( + "name" not in current_tool_call or "parameters" not in current_tool_call + ): + return None, False + + if not function_name_returned: + arguments = "" + # Extract the parameters field closest to the end of the string to + # avoid grabbing earlier tool calls when multiple are present. + param_pos = current_text.rfind('"parameters":') + if param_pos != -1: + param_section = current_text[param_pos:] + param_match = re.search( + r'"parameters"\s*:\s*(.*)', param_section, re.DOTALL + ) + if param_match: + arguments = param_match.group(1) + arguments, _ = self._filter_delta_text(arguments, previous_text) + if ( + finishes_previous_tool + and "parameters" not in current_tool_call + and len(obj) > 1 + ): + current_tool_call = obj[-2] + + function_name_returned = True + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + id=make_tool_call_id(func_name=current_tool_call["name"]), + function=DeltaFunctionCall( + name=current_tool_call["name"], arguments=arguments + ), + index=len(obj) - 1, + type="function", + ) + ] + ) + return delta_message, function_name_returned + + delta_text, _ = self._filter_delta_text(delta_text, previous_text) + if delta_text != "": + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + function=DeltaFunctionCall( + name=None, + arguments=delta_text, + ), + index=len(obj) - 1, + ) + ] + ) + return delta_message, function_name_returned + + return None, function_name_returned + + def _should_stream_with_auto_tool_parsing(self, request: ResponsesRequest) -> bool: + """ + Check if streamed tokens should go through the auto tool parser. + + We only do this when tools are provided, auto tool choice is enabled, + and a tool parser is configured. + """ + return bool( + request.tools + and self.tool_parser + and self.enable_auto_tools + and request.tool_choice in ["auto", None] + ) + async def _process_simple_streaming_events( self, request: ResponsesRequest, @@ -1152,29 +1280,165 @@ async def _process_simple_streaming_events( previous_token_ids: list[int] = [] first_delta_sent = False previous_delta_messages: list[DeltaMessage] = [] + function_name_returned = False + tool_choice_required = request.tool_choice == "required" + tool_choice_auto = self._should_stream_with_auto_tool_parsing(request) + if tool_choice_auto and self.tool_parser: + try: + tool_parser = self.tool_parser(tokenizer) + except Exception: + logger.exception("Error in tool parser creation.") + tool_choice_auto = False + tool_parser = None + else: + tool_parser = None + + tool_call_states: dict[int, dict[str, str | int]] = {} + tool_output_start_index: int | None = None + async for ctx in result_generator: assert isinstance(ctx, SimpleContext) if ctx.last_output is None: continue if ctx.last_output.outputs: output = ctx.last_output.outputs[0] + delta_text = output.text + delta_token_ids = output.token_ids or [] + prev_text = previous_text + prev_token_ids = previous_token_ids + current_text = prev_text + delta_text + current_token_ids = prev_token_ids + list(delta_token_ids) + + delta_message: DeltaMessage | None = None + reasoning_delta: DeltaMessage | None = None if reasoning_parser: - delta_message = reasoning_parser.extract_reasoning_streaming( - previous_text=previous_text, - current_text=previous_text + output.text, - delta_text=output.text, - previous_token_ids=previous_token_ids, - current_token_ids=previous_token_ids + output.token_ids, - delta_token_ids=output.token_ids, - ) - else: - delta_message = DeltaMessage( - content=output.text, + reasoning_delta = reasoning_parser.extract_reasoning_streaming( + previous_text=prev_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=prev_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, ) - previous_text += output.text - previous_token_ids += output.token_ids + if reasoning_delta and reasoning_delta.reasoning is not None: + delta_message = reasoning_delta + + if delta_message is None: + if tool_choice_required: + delta_message, function_name_returned = ( + self._extract_tool_call_required_streaming( + prev_text, + current_text, + delta_text, + function_name_returned, + ) + ) + elif tool_choice_auto and tool_parser: + parsed_tool_delta = tool_parser.extract_tool_calls_streaming( + previous_text=prev_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=prev_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, + ) + if parsed_tool_delta: + delta_message = parsed_tool_delta + elif reasoning_delta: + delta_message = reasoning_delta + elif reasoning_delta: + delta_message = reasoning_delta + else: + delta_message = DeltaMessage( + content=delta_text, + ) + + previous_text = current_text + previous_token_ids = current_token_ids if not delta_message: continue + if ( + delta_message.tool_calls in (None, []) + and delta_message.reasoning is None + and delta_message.content in (None, "") + ): + # Ignore empty content-only deltas that would create + # spurious empty message events. + continue + + if delta_message.tool_calls: + if tool_output_start_index is None: + tool_output_start_index = current_output_index + ( + 1 if first_delta_sent else 0 + ) + for tc in delta_message.tool_calls: + output_index = ( + tool_output_start_index + tc.index + if tool_output_start_index is not None + else current_output_index + ) + tool_state = tool_call_states.get(tc.index) + if tool_state is None: + item_id = tc.id or f"fc_{random_uuid()}" + tool_state = { + "item_id": item_id, + "call_id": f"call_{random_uuid()}", + "name": (tc.function.name or "") if tc.function else "", + "arguments": "", + "output_index": output_index, + } + tool_call_states[tc.index] = tool_state + yield _increment_sequence_number_and_return( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=output_index, + item=ResponseFunctionToolCall( + name=tool_state["name"], + type="function_call", + id=tool_state["item_id"], + call_id=tool_state["call_id"], + arguments="", + status="in_progress", + ), + ) + ) + if tc.function and tc.function.name and not tool_state["name"]: + tool_state["name"] = tc.function.name + if tc.function and tc.function.arguments: + tool_state["arguments"] += tc.function.arguments + yield _increment_sequence_number_and_return( + ResponseFunctionCallArgumentsDeltaEvent( + item_id=tool_state["item_id"], + delta=tc.function.arguments, + output_index=tool_state["output_index"], + sequence_number=-1, + type="response.function_call_arguments.delta", + ) + ) + if ( + delta_message.reasoning is None + and delta_message.content is None + ): + continue + + has_non_tool_content = ( + delta_message.reasoning is not None + or delta_message.content is not None + ) + if tool_call_states and has_non_tool_content: + if not first_delta_sent: + # If tool calls were streamed before any content, + # advance the output index so the next content item + # occupies the subsequent slot. + current_output_index = ( + tool_output_start_index or current_output_index + ) + len(tool_call_states) + for event in self._finalize_tool_calls(tool_call_states): + yield _increment_sequence_number_and_return(event) + tool_output_start_index = None + if not first_delta_sent: current_item_id = str(uuid.uuid4()) if delta_message.reasoning: @@ -1223,7 +1487,6 @@ async def _process_simple_streaming_events( ) current_content_index += 1 first_delta_sent = True - # todo(kebe7jun) tool call support # check delta message and previous delta message are # same as content or reasoning content @@ -1339,7 +1602,11 @@ async def _process_simple_streaming_events( ) current_content_index += 1 - previous_delta_messages.append(delta_message) + if ( + delta_message.reasoning is not None + or delta_message.content is not None + ): + previous_delta_messages.append(delta_message) if previous_delta_messages: if previous_delta_messages[-1].reasoning is not None: reason_content = "".join( @@ -1430,6 +1697,47 @@ async def _process_simple_streaming_events( item=item, ) ) + for event in self._finalize_tool_calls(tool_call_states): + yield _increment_sequence_number_and_return(event) + tool_output_start_index = None + + @staticmethod + def _finalize_tool_calls( + tool_call_states: dict[int, dict[str, str | int]], + ) -> list[StreamingResponsesResponse]: + """Emit done events for all pending tool calls and clear their state.""" + if not tool_call_states: + return [] + + events: list[StreamingResponsesResponse] = [] + for _, tool_state in sorted(tool_call_states.items()): + events.append( + ResponseFunctionCallArgumentsDoneEvent( + type="response.function_call_arguments.done", + arguments=str(tool_state.get("arguments", "")), + name=str(tool_state.get("name", "")), + item_id=str(tool_state.get("item_id", "")), + output_index=int(tool_state.get("output_index", 0)), + sequence_number=-1, + ) + ) + events.append( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=int(tool_state.get("output_index", 0)), + item=ResponseFunctionToolCall( + type="function_call", + id=str(tool_state.get("item_id", "")), + call_id=str(tool_state.get("call_id", "")), + name=str(tool_state.get("name", "")), + arguments=str(tool_state.get("arguments", "")), + status="completed", + ), + ) + ) + tool_call_states.clear() + return events async def _process_harmony_streaming_events( self, @@ -1448,28 +1756,33 @@ async def _process_harmony_streaming_events( current_content_index = -1 current_output_index = 0 current_item_id: str = "" + current_call_id: str | None = None sent_output_item_added = False is_first_function_call_delta = False async for ctx in result_generator: assert isinstance(ctx, StreamingHarmonyContext) if ctx.is_expecting_start(): - current_output_index += 1 sent_output_item_added = False is_first_function_call_delta = False - if len(ctx.parser.messages) > 0: - previous_item = ctx.parser.messages[-1] + previous_item = ( + ctx.parser.messages[-1] if len(ctx.parser.messages) > 0 else None + ) + if previous_item is not None: + finalized_output_index = current_output_index if previous_item.recipient is not None: # Deal with tool call if previous_item.recipient.startswith("functions."): function_name = previous_item.recipient[len("functions.") :] + item_id = current_item_id or f"fc_{random_uuid()}" + call_id = current_call_id or f"call_{random_uuid()}" yield _increment_sequence_number_and_return( ResponseFunctionCallArgumentsDoneEvent( type="response.function_call_arguments.done", arguments=previous_item.content[0].text, name=function_name, - item_id=current_item_id, - output_index=current_output_index, + item_id=item_id, + output_index=finalized_output_index, sequence_number=-1, ) ) @@ -1477,20 +1790,19 @@ async def _process_harmony_streaming_events( type="function_call", arguments=previous_item.content[0].text, name=function_name, - item_id=current_item_id, - output_index=current_output_index, - sequence_number=-1, - call_id=f"fc_{random_uuid()}", + id=item_id, + call_id=call_id, status="completed", ) yield _increment_sequence_number_and_return( ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, - output_index=current_output_index, + output_index=finalized_output_index, item=function_call_item, ) ) + current_call_id = None elif previous_item.channel == "analysis": content = ResponseReasoningTextContent( text=previous_item.content[0].text, @@ -1508,7 +1820,7 @@ async def _process_harmony_streaming_events( type="response.reasoning_text.done", item_id=current_item_id, sequence_number=-1, - output_index=current_output_index, + output_index=finalized_output_index, content_index=current_content_index, text=previous_item.content[0].text, ) @@ -1518,7 +1830,7 @@ async def _process_harmony_streaming_events( type="response.reasoning_part.done", sequence_number=-1, item_id=current_item_id, - output_index=current_output_index, + output_index=finalized_output_index, content_index=current_content_index, part=content, ) @@ -1527,7 +1839,7 @@ async def _process_harmony_streaming_events( ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, - output_index=current_output_index, + output_index=finalized_output_index, item=reasoning_item, ) ) @@ -1541,7 +1853,7 @@ async def _process_harmony_streaming_events( ResponseTextDoneEvent( type="response.output_text.done", sequence_number=-1, - output_index=current_output_index, + output_index=finalized_output_index, content_index=current_content_index, text=previous_item.content[0].text, logprobs=[], @@ -1553,7 +1865,7 @@ async def _process_harmony_streaming_events( type="response.content_part.done", sequence_number=-1, item_id=current_item_id, - output_index=current_output_index, + output_index=finalized_output_index, content_index=current_content_index, part=text_content, ) @@ -1562,7 +1874,7 @@ async def _process_harmony_streaming_events( ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, - output_index=current_output_index, + output_index=finalized_output_index, item=ResponseOutputMessage( id=current_item_id, type="message", @@ -1572,6 +1884,10 @@ async def _process_harmony_streaming_events( ), ) ) + current_output_index += 1 + current_content_index = -1 + current_item_id = "" + current_call_id = None # stream the output of a harmony message if ctx.parser.last_content_delta: @@ -1856,15 +2172,16 @@ async def _process_harmony_streaming_events( if is_first_function_call_delta is False: is_first_function_call_delta = True fc_name = ctx.parser.current_recipient[len("functions.") :] + current_item_id = f"fc_{random_uuid()}" + current_call_id = f"call_{random_uuid()}" tool_call_item = ResponseFunctionToolCall( name=fc_name, type="function_call", id=current_item_id, - call_id=f"call_{random_uuid()}", + call_id=current_call_id, arguments="", status="in_progress", ) - current_item_id = f"fc_{random_uuid()}" yield _increment_sequence_number_and_return( ResponseOutputItemAddedEvent( type="response.output_item.added",