diff --git a/examples/basic/stream_function_call_args.py b/examples/basic/stream_function_call_args.py index 46e72896c..3c3538772 100644 --- a/examples/basic/stream_function_call_args.py +++ b/examples/basic/stream_function_call_args.py @@ -35,7 +35,7 @@ async def main(): result = Runner.run_streamed( agent, - input="Create a Python web project called 'my-app' with FastAPI. Version 1.0.0, dependencies: fastapi, uvicorn" + input="Create a Python web project called 'my-app' with FastAPI. Version 1.0.0, dependencies: fastapi, uvicorn", ) # Track function calls for detailed output @@ -50,10 +50,7 @@ async def main(): function_name = getattr(event.data.item, "name", "unknown") call_id = getattr(event.data.item, "call_id", "unknown") - function_calls[call_id] = { - 'name': function_name, - 'arguments': "" - } + function_calls[call_id] = {"name": function_name, "arguments": ""} current_active_call_id = call_id print(f"\n📞 Function call streaming started: {function_name}()") print("📝 Arguments building...") @@ -61,12 +58,12 @@ async def main(): # Real-time argument streaming elif isinstance(event.data, ResponseFunctionCallArgumentsDeltaEvent): if current_active_call_id and current_active_call_id in function_calls: - function_calls[current_active_call_id]['arguments'] += event.data.delta + function_calls[current_active_call_id]["arguments"] += event.data.delta print(event.data.delta, end="", flush=True) # Function call completed elif event.data.type == "response.output_item.done": - if hasattr(event.data.item, 'call_id'): + if hasattr(event.data.item, "call_id"): call_id = getattr(event.data.item, "call_id", "unknown") if call_id in function_calls: function_info = function_calls[call_id] diff --git a/examples/customer_service/main.py b/examples/customer_service/main.py index 8ed218536..266a7e611 100644 --- a/examples/customer_service/main.py +++ b/examples/customer_service/main.py @@ -40,7 +40,10 @@ class AirlineAgentContext(BaseModel): ) async def faq_lookup_tool(question: str) -> str: question_lower = question.lower() - if any(keyword in question_lower for keyword in ["bag", "baggage", "luggage", "carry-on", "hand luggage", "hand carry"]): + if any( + keyword in question_lower + for keyword in ["bag", "baggage", "luggage", "carry-on", "hand luggage", "hand carry"] + ): return ( "You are allowed to bring one bag on the plane. " "It must be under 50 pounds and 22 inches x 14 inches x 9 inches." @@ -52,7 +55,10 @@ async def faq_lookup_tool(question: str) -> str: "Exit rows are rows 4 and 16. " "Rows 5-8 are Economy Plus, with extra legroom. " ) - elif any(keyword in question_lower for keyword in ["wifi", "internet", "wireless", "connectivity", "network", "online"]): + elif any( + keyword in question_lower + for keyword in ["wifi", "internet", "wireless", "connectivity", "network", "online"] + ): return "We have free wifi on the plane, join Airline-Wifi" return "I'm sorry, I don't know the answer to that question." diff --git a/examples/realtime/app/server.py b/examples/realtime/app/server.py index 73fcf3e56..04f3def43 100644 --- a/examples/realtime/app/server.py +++ b/examples/realtime/app/server.py @@ -4,11 +4,12 @@ import logging import struct from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, Any, assert_never +from typing import TYPE_CHECKING, Any from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles +from typing_extensions import assert_never from agents.realtime import RealtimeRunner, RealtimeSession, RealtimeSessionEvent diff --git a/src/agents/handoffs.py b/src/agents/handoffs.py index 4d70f6058..2c52737ad 100644 --- a/src/agents/handoffs.py +++ b/src/agents/handoffs.py @@ -119,9 +119,9 @@ class Handoff(Generic[TContext, TAgent]): True, as it increases the likelihood of correct JSON input. """ - is_enabled: bool | Callable[ - [RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool] - ] = True + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = ( + True + ) """Whether the handoff is enabled. Either a bool or a Callable that takes the run context and agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable a handoff based on your context/state.""" diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 267f320c1..47161dd18 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -55,7 +55,6 @@ class MCPToolChoice: ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None] - @dataclass class ModelSettings: """Settings to use when calling an LLM. diff --git a/src/agents/run.py b/src/agents/run.py index 3945e5131..cefdd12ee 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -3,9 +3,12 @@ import asyncio import inspect from dataclasses import dataclass, field -from typing import Any, Callable, Generic, cast +from typing import Any, Callable, Generic, cast, get_args -from openai.types.responses import ResponseCompletedEvent +from openai.types.responses import ( + ResponseCompletedEvent, + ResponseOutputItemAddedEvent, +) from openai.types.responses.response_prompt_param import ( ResponsePromptParam, ) @@ -40,7 +43,14 @@ OutputGuardrailResult, ) from .handoffs import Handoff, HandoffInputFilter, handoff -from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem +from .items import ( + ItemHelpers, + ModelResponse, + RunItem, + ToolCallItem, + ToolCallItemTypes, + TResponseInputItem, +) from .lifecycle import RunHooks from .logger import logger from .memory import Session @@ -49,7 +59,7 @@ from .models.multi_provider import MultiProvider from .result import RunResult, RunResultStreaming from .run_context import RunContextWrapper, TContext -from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent +from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunItemStreamEvent from .tool import Tool from .tracing import Span, SpanError, agent_span, get_current_trace, trace from .tracing.span_data import AgentSpanData @@ -905,6 +915,8 @@ async def _run_single_turn_streamed( all_tools: list[Tool], previous_response_id: str | None, ) -> SingleStepResult: + emitted_tool_call_ids: set[str] = set() + if should_run_agent_start_hooks: await asyncio.gather( hooks.on_agent_start(context_wrapper, agent), @@ -977,6 +989,25 @@ async def _run_single_turn_streamed( ) context_wrapper.usage.add(usage) + if isinstance(event, ResponseOutputItemAddedEvent): + output_item = event.item + + if isinstance(output_item, _TOOL_CALL_TYPES): + call_id: str | None = getattr( + output_item, "call_id", getattr(output_item, "id", None) + ) + + if call_id and call_id not in emitted_tool_call_ids: + emitted_tool_call_ids.add(call_id) + + tool_item = ToolCallItem( + raw_item=cast(ToolCallItemTypes, output_item), + agent=agent, + ) + streamed_result._event_queue.put_nowait( + RunItemStreamEvent(item=tool_item, name="tool_called") + ) + streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) # 2. At this point, the streaming is complete for this turn of the agent loop. @@ -984,9 +1015,10 @@ async def _run_single_turn_streamed( raise ModelBehaviorError("Model did not produce a final response!") # 3. Now, we can process the turn as we do in the non-streaming case - return await cls._get_single_step_result_from_streamed_response( + single_step_result = await cls._get_single_step_result_from_response( agent=agent, - streamed_result=streamed_result, + original_input=streamed_result.input, + pre_step_items=streamed_result.new_items, new_response=final_response, output_schema=output_schema, all_tools=all_tools, @@ -997,6 +1029,34 @@ async def _run_single_turn_streamed( tool_use_tracker=tool_use_tracker, ) + if emitted_tool_call_ids: + import dataclasses as _dc + + filtered_items = [ + item + for item in single_step_result.new_step_items + if not ( + isinstance(item, ToolCallItem) + and ( + call_id := getattr( + item.raw_item, "call_id", getattr(item.raw_item, "id", None) + ) + ) + and call_id in emitted_tool_call_ids + ) + ] + + single_step_result_filtered = _dc.replace( + single_step_result, new_step_items=filtered_items + ) + + RunImpl.stream_step_result_to_queue( + single_step_result_filtered, streamed_result._event_queue + ) + else: + RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue) + return single_step_result + @classmethod async def _run_single_turn( cls, @@ -1375,9 +1435,11 @@ async def _save_result_to_session( DEFAULT_AGENT_RUNNER = AgentRunner() +_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes) def _copy_str_or_list(input: str | list[TResponseInputItem]) -> str | list[TResponseInputItem]: if isinstance(input, str): return input return input.copy() + diff --git a/src/agents/tracing/processors.py b/src/agents/tracing/processors.py index 32fd290ec..126c71498 100644 --- a/src/agents/tracing/processors.py +++ b/src/agents/tracing/processors.py @@ -70,8 +70,8 @@ def set_api_key(self, api_key: str): client. """ # Clear the cached property if it exists - if 'api_key' in self.__dict__: - del self.__dict__['api_key'] + if "api_key" in self.__dict__: + del self.__dict__["api_key"] # Update the private attribute self._api_key = api_key diff --git a/tests/test_agent_clone_shallow_copy.py b/tests/test_agent_clone_shallow_copy.py index fdf9e0247..44b41bd3d 100644 --- a/tests/test_agent_clone_shallow_copy.py +++ b/tests/test_agent_clone_shallow_copy.py @@ -5,6 +5,7 @@ def greet(name: str) -> str: return f"Hello, {name}!" + def test_agent_clone_shallow_copy(): """Test that clone creates shallow copy with tools.copy() workaround""" target_agent = Agent(name="Target") @@ -16,9 +17,7 @@ def test_agent_clone_shallow_copy(): ) cloned = original.clone( - name="Cloned", - tools=original.tools.copy(), - handoffs=original.handoffs.copy() + name="Cloned", tools=original.tools.copy(), handoffs=original.handoffs.copy() ) # Basic assertions diff --git a/tests/test_stream_events.py b/tests/test_stream_events.py index 11feb9fe0..0f85b63f8 100644 --- a/tests/test_stream_events.py +++ b/tests/test_stream_events.py @@ -14,6 +14,7 @@ async def foo() -> str: await asyncio.sleep(3) return "success!" + @pytest.mark.asyncio async def test_stream_events_main(): model = FakeModel()