diff --git a/src/agents/items.py b/src/agents/items.py index 24defb22d..991a7f877 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -1,7 +1,8 @@ from __future__ import annotations import abc -from dataclasses import dataclass +import weakref +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union, cast import pydantic @@ -72,6 +73,9 @@ T = TypeVar("T", bound=Union[TResponseOutputItem, TResponseInputItem]) +# Distinguish a missing dict entry from an explicit None value. +_MISSING_ATTR_SENTINEL = object() + @dataclass class RunItemBase(Generic[T], abc.ABC): @@ -84,6 +88,49 @@ class RunItemBase(Generic[T], abc.ABC): (i.e. `openai.types.responses.ResponseInputItemParam`). """ + _agent_ref: weakref.ReferenceType[Agent[Any]] | None = field( + init=False, + repr=False, + default=None, + ) + + def __post_init__(self) -> None: + # Store a weak reference so we can release the strong reference later if desired. + self._agent_ref = weakref.ref(self.agent) + + def __getattribute__(self, name: str) -> Any: + if name == "agent": + return self._get_agent_via_weakref("agent", "_agent_ref") + return super().__getattribute__(name) + + def release_agent(self) -> None: + """Release the strong reference to the agent while keeping a weak reference.""" + if "agent" not in self.__dict__: + return + agent = self.__dict__["agent"] + if agent is None: + return + self._agent_ref = weakref.ref(agent) if agent is not None else None + # Set to None instead of deleting so dataclass repr/asdict keep working. + self.__dict__["agent"] = None + + def _get_agent_via_weakref(self, attr_name: str, ref_name: str) -> Any: + # Preserve the dataclass field so repr/asdict still read it, but lazily resolve the weakref + # when the stored value is None (meaning release_agent already dropped the strong ref). + # If the attribute was never overridden we fall back to the default descriptor chain. + data = object.__getattribute__(self, "__dict__") + value = data.get(attr_name, _MISSING_ATTR_SENTINEL) + if value is _MISSING_ATTR_SENTINEL: + return object.__getattribute__(self, attr_name) + if value is not None: + return value + ref = object.__getattribute__(self, ref_name) + if ref is not None: + agent = ref() + if agent is not None: + return agent + return None + def to_input_item(self) -> TResponseInputItem: """Converts this item into an input item suitable for passing to the model.""" if isinstance(self.raw_item, dict): @@ -131,6 +178,48 @@ class HandoffOutputItem(RunItemBase[TResponseInputItem]): type: Literal["handoff_output_item"] = "handoff_output_item" + _source_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field( + init=False, + repr=False, + default=None, + ) + _target_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field( + init=False, + repr=False, + default=None, + ) + + def __post_init__(self) -> None: + super().__post_init__() + # Maintain weak references so downstream code can release the strong references when safe. + self._source_agent_ref = weakref.ref(self.source_agent) + self._target_agent_ref = weakref.ref(self.target_agent) + + def __getattribute__(self, name: str) -> Any: + if name == "source_agent": + # Provide lazy weakref access like the base `agent` field so HandoffOutputItem + # callers keep seeing the original agent until GC occurs. + return self._get_agent_via_weakref("source_agent", "_source_agent_ref") + if name == "target_agent": + # Same as above but for the target of the handoff. + return self._get_agent_via_weakref("target_agent", "_target_agent_ref") + return super().__getattribute__(name) + + def release_agent(self) -> None: + super().release_agent() + if "source_agent" in self.__dict__: + source_agent = self.__dict__["source_agent"] + if source_agent is not None: + self._source_agent_ref = weakref.ref(source_agent) + # Preserve dataclass fields for repr/asdict while dropping strong refs. + self.__dict__["source_agent"] = None + if "target_agent" in self.__dict__: + target_agent = self.__dict__["target_agent"] + if target_agent is not None: + self._target_agent_ref = weakref.ref(target_agent) + # Preserve dataclass fields for repr/asdict while dropping strong refs. + self.__dict__["target_agent"] = None + ToolCallItemTypes: TypeAlias = Union[ ResponseFunctionToolCall, diff --git a/src/agents/result.py b/src/agents/result.py index 3fe20cfa5..438d53af2 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -2,6 +2,7 @@ import abc import asyncio +import weakref from collections.abc import AsyncIterator from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal, cast @@ -74,6 +75,35 @@ class RunResultBase(abc.ABC): def last_agent(self) -> Agent[Any]: """The last agent that was run.""" + def release_agents(self, *, release_new_items: bool = True) -> None: + """ + Release strong references to agents held by this result. After calling this method, + accessing `item.agent` or `last_agent` may return `None` if the agent has been garbage + collected. Callers can use this when they are done inspecting the result and want to + eagerly drop any associated agent graph. + """ + if release_new_items: + for item in self.new_items: + release = getattr(item, "release_agent", None) + if callable(release): + release() + self._release_last_agent_reference() + + def __del__(self) -> None: + try: + # Fall back to releasing agents automatically in case the caller never invoked + # `release_agents()` explicitly so GC of the RunResult drops the last strong reference. + # We pass `release_new_items=False` so RunItems that the user intentionally keeps + # continue exposing their originating agent until that agent itself is collected. + self.release_agents(release_new_items=False) + except Exception: + # Avoid raising from __del__. + pass + + @abc.abstractmethod + def _release_last_agent_reference(self) -> None: + """Release stored agent reference specific to the concrete result type.""" + def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) -> T: """A convenience method to cast the final output to a specific type. By default, the cast is only for the typechecker. If you set `raise_if_incorrect_type` to True, we'll raise a @@ -111,11 +141,34 @@ def last_response_id(self) -> str | None: @dataclass class RunResult(RunResultBase): _last_agent: Agent[Any] + _last_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field( + init=False, + repr=False, + default=None, + ) + + def __post_init__(self) -> None: + self._last_agent_ref = weakref.ref(self._last_agent) @property def last_agent(self) -> Agent[Any]: """The last agent that was run.""" - return self._last_agent + agent = cast("Agent[Any] | None", self.__dict__.get("_last_agent")) + if agent is not None: + return agent + if self._last_agent_ref: + agent = self._last_agent_ref() + if agent is not None: + return agent + raise AgentsException("Last agent reference is no longer available.") + + def _release_last_agent_reference(self) -> None: + agent = cast("Agent[Any] | None", self.__dict__.get("_last_agent")) + if agent is None: + return + self._last_agent_ref = weakref.ref(agent) + # Preserve dataclass field so repr/asdict continue to succeed. + self.__dict__["_last_agent"] = None def __str__(self) -> str: return pretty_print_result(self) @@ -150,6 +203,12 @@ class RunResultStreaming(RunResultBase): is_complete: bool = False """Whether the agent has finished running.""" + _current_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field( + init=False, + repr=False, + default=None, + ) + # Queues that the background run_loop writes to _event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field( default_factory=asyncio.Queue, repr=False @@ -167,12 +226,30 @@ class RunResultStreaming(RunResultBase): # Soft cancel state _cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False) + def __post_init__(self) -> None: + self._current_agent_ref = weakref.ref(self.current_agent) + @property def last_agent(self) -> Agent[Any]: """The last agent that was run. Updates as the agent run progresses, so the true last agent is only available after the agent run is complete. """ - return self.current_agent + agent = cast("Agent[Any] | None", self.__dict__.get("current_agent")) + if agent is not None: + return agent + if self._current_agent_ref: + agent = self._current_agent_ref() + if agent is not None: + return agent + raise AgentsException("Last agent reference is no longer available.") + + def _release_last_agent_reference(self) -> None: + agent = cast("Agent[Any] | None", self.__dict__.get("current_agent")) + if agent is None: + return + self._current_agent_ref = weakref.ref(agent) + # Preserve dataclass field so repr/asdict continue to succeed. + self.__dict__["current_agent"] = None def cancel(self, mode: Literal["immediate", "after_turn"] = "immediate") -> None: """Cancel the streaming run. diff --git a/src/agents/run.py b/src/agents/run.py index c14f13e3f..5ea1dbab1 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -648,51 +648,60 @@ async def run( tool_input_guardrail_results.extend(turn_result.tool_input_guardrail_results) tool_output_guardrail_results.extend(turn_result.tool_output_guardrail_results) - if isinstance(turn_result.next_step, NextStepFinalOutput): - output_guardrail_results = await self._run_output_guardrails( - current_agent.output_guardrails + (run_config.output_guardrails or []), - current_agent, - turn_result.next_step.output, - context_wrapper, - ) - result = RunResult( - input=original_input, - new_items=generated_items, - raw_responses=model_responses, - final_output=turn_result.next_step.output, - _last_agent=current_agent, - input_guardrail_results=input_guardrail_results, - output_guardrail_results=output_guardrail_results, - tool_input_guardrail_results=tool_input_guardrail_results, - tool_output_guardrail_results=tool_output_guardrail_results, - context_wrapper=context_wrapper, - ) - if not any( - guardrail_result.output.tripwire_triggered - for guardrail_result in input_guardrail_results - ): - await self._save_result_to_session( - session, [], turn_result.new_step_items + try: + if isinstance(turn_result.next_step, NextStepFinalOutput): + output_guardrail_results = await self._run_output_guardrails( + current_agent.output_guardrails + + (run_config.output_guardrails or []), + current_agent, + turn_result.next_step.output, + context_wrapper, + ) + result = RunResult( + input=original_input, + new_items=generated_items, + raw_responses=model_responses, + final_output=turn_result.next_step.output, + _last_agent=current_agent, + input_guardrail_results=input_guardrail_results, + output_guardrail_results=output_guardrail_results, + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, + context_wrapper=context_wrapper, ) + if not any( + guardrail_result.output.tripwire_triggered + for guardrail_result in input_guardrail_results + ): + await self._save_result_to_session( + session, [], turn_result.new_step_items + ) - return result - elif isinstance(turn_result.next_step, NextStepHandoff): - current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) - current_span.finish(reset_current=True) - current_span = None - should_run_agent_start_hooks = True - elif isinstance(turn_result.next_step, NextStepRunAgain): - if not any( - guardrail_result.output.tripwire_triggered - for guardrail_result in input_guardrail_results - ): - await self._save_result_to_session( - session, [], turn_result.new_step_items + return result + elif isinstance(turn_result.next_step, NextStepHandoff): + current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) + current_span.finish(reset_current=True) + current_span = None + should_run_agent_start_hooks = True + elif isinstance(turn_result.next_step, NextStepRunAgain): + if not any( + guardrail_result.output.tripwire_triggered + for guardrail_result in input_guardrail_results + ): + await self._save_result_to_session( + session, [], turn_result.new_step_items + ) + else: + raise AgentsException( + f"Unknown next step type: {type(turn_result.next_step)}" ) - else: - raise AgentsException( - f"Unknown next step type: {type(turn_result.next_step)}" - ) + finally: + # RunImpl.execute_tools_and_side_effects returns a SingleStepResult that + # stores direct references to the `pre_step_items` and `new_step_items` + # lists it manages internally. Clear them here so the next turn does not + # hold on to items from previous turns and to avoid leaking agent refs. + turn_result.pre_step_items.clear() + turn_result.new_step_items.clear() except AgentsException as exc: exc.run_data = RunErrorDetails( input=original_input, diff --git a/tests/test_agent_memory_leak.py b/tests/test_agent_memory_leak.py new file mode 100644 index 000000000..424aa399d --- /dev/null +++ b/tests/test_agent_memory_leak.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import gc +import weakref + +import pytest +from openai.types.responses import ResponseOutputMessage, ResponseOutputText + +from agents import Agent, Runner +from tests.fake_model import FakeModel + + +def _make_message(text: str) -> ResponseOutputMessage: + return ResponseOutputMessage( + id="msg-1", + content=[ResponseOutputText(annotations=[], text=text, type="output_text")], + role="assistant", + status="completed", + type="message", + ) + + +@pytest.mark.asyncio +async def test_agent_is_released_after_run() -> None: + fake_model = FakeModel(initial_output=[_make_message("Paris")]) + agent = Agent(name="leak-test-agent", instructions="Answer questions.", model=fake_model) + agent_ref = weakref.ref(agent) + + # Running the agent should not leave behind strong references once the result goes out of scope. + await Runner.run(agent, "What is the capital of France?") + + del agent + gc.collect() + + assert agent_ref() is None diff --git a/tests/test_items_helpers.py b/tests/test_items_helpers.py index 0bead32db..ad8da2266 100644 --- a/tests/test_items_helpers.py +++ b/tests/test_items_helpers.py @@ -1,6 +1,8 @@ from __future__ import annotations +import gc import json +import weakref from openai.types.responses.response_computer_tool_call import ( ActionScreenshot, @@ -29,6 +31,7 @@ from agents import ( Agent, + HandoffOutputItem, ItemHelpers, MessageOutputItem, ModelResponse, @@ -148,6 +151,64 @@ def test_text_message_outputs_across_list_of_runitems() -> None: assert ItemHelpers.text_message_outputs([item1, non_message_item, item2]) == "foobar" +def test_message_output_item_retains_agent_until_release() -> None: + # Construct the run item with an inline agent to ensure the run item keeps a strong reference. + message = make_message([ResponseOutputText(annotations=[], text="hello", type="output_text")]) + agent = Agent(name="inline") + item = MessageOutputItem(agent=agent, raw_item=message) + assert item.agent is agent + assert item.agent.name == "inline" + + # Releasing the agent should keep the weak reference alive while strong refs remain. + item.release_agent() + assert item.agent is agent + + agent_ref = weakref.ref(agent) + del agent + gc.collect() + + # Once the original agent is collected, the weak reference should drop. + assert agent_ref() is None + assert item.agent is None + + +def test_handoff_output_item_retains_agents_until_gc() -> None: + raw_item: TResponseInputItem = { + "call_id": "call1", + "output": "handoff", + "type": "function_call_output", + } + owner_agent = Agent(name="owner") + source_agent = Agent(name="source") + target_agent = Agent(name="target") + item = HandoffOutputItem( + agent=owner_agent, + raw_item=raw_item, + source_agent=source_agent, + target_agent=target_agent, + ) + + item.release_agent() + assert item.agent is owner_agent + assert item.source_agent is source_agent + assert item.target_agent is target_agent + + owner_ref = weakref.ref(owner_agent) + source_ref = weakref.ref(source_agent) + target_ref = weakref.ref(target_agent) + del owner_agent + del source_agent + del target_agent + gc.collect() + + assert owner_ref() is None + assert source_ref() is None + assert target_ref() is None + assert item.agent is None + assert item.source_agent is None + assert item.target_agent is None + + def test_tool_call_output_item_constructs_function_call_output_dict(): # Build a simple ResponseFunctionToolCall. call = ResponseFunctionToolCall( diff --git a/tests/test_result_cast.py b/tests/test_result_cast.py index 4ef1a293d..e919171ae 100644 --- a/tests/test_result_cast.py +++ b/tests/test_result_cast.py @@ -1,9 +1,14 @@ +import dataclasses +import gc +import weakref from typing import Any import pytest +from openai.types.responses import ResponseOutputMessage, ResponseOutputText from pydantic import BaseModel -from agents import Agent, RunContextWrapper, RunResult +from agents import Agent, MessageOutputItem, RunContextWrapper, RunResult, RunResultStreaming +from agents.exceptions import AgentsException def create_run_result(final_output: Any) -> RunResult: @@ -25,6 +30,16 @@ class Foo(BaseModel): bar: int +def _create_message(text: str) -> ResponseOutputMessage: + return ResponseOutputMessage( + id="msg", + content=[ResponseOutputText(annotations=[], text=text, type="output_text")], + role="assistant", + status="completed", + type="message", + ) + + def test_result_cast_typechecks(): """Correct casts should work fine.""" result = create_run_result(1) @@ -59,3 +74,203 @@ def test_bad_cast_with_param_raises(): result = create_run_result(Foo(bar=1)) with pytest.raises(TypeError): result.final_output_as(int, raise_if_incorrect_type=True) + + +def test_run_result_release_agents_breaks_strong_refs() -> None: + message = _create_message("hello") + agent = Agent(name="leak-test-agent") + item = MessageOutputItem(agent=agent, raw_item=message) + result = RunResult( + input="test", + new_items=[item], + raw_responses=[], + final_output=None, + input_guardrail_results=[], + output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], + _last_agent=agent, + context_wrapper=RunContextWrapper(context=None), + ) + assert item.agent is not None + assert item.agent.name == "leak-test-agent" + + agent_ref = weakref.ref(agent) + result.release_agents() + del agent + gc.collect() + + assert agent_ref() is None + assert item.agent is None + with pytest.raises(AgentsException): + _ = result.last_agent + + +def test_run_item_retains_agent_when_result_is_garbage_collected() -> None: + def build_item() -> tuple[MessageOutputItem, weakref.ReferenceType[RunResult]]: + message = _create_message("persist") + agent = Agent(name="persisted-agent") + item = MessageOutputItem(agent=agent, raw_item=message) + result = RunResult( + input="test", + new_items=[item], + raw_responses=[], + final_output=None, + input_guardrail_results=[], + output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], + _last_agent=agent, + context_wrapper=RunContextWrapper(context=None), + ) + return item, weakref.ref(result) + + item, result_ref = build_item() + gc.collect() + + assert result_ref() is None + assert item.agent is not None + assert item.agent.name == "persisted-agent" + + +def test_run_item_repr_and_asdict_after_release() -> None: + message = _create_message("repr") + agent = Agent(name="repr-agent") + item = MessageOutputItem(agent=agent, raw_item=message) + + item.release_agent() + assert item.agent is agent + + text = repr(item) + assert "MessageOutputItem" in text + + serialized = dataclasses.asdict(item) + assert isinstance(serialized["agent"], dict) + assert serialized["agent"]["name"] == "repr-agent" + + agent_ref = weakref.ref(agent) + del agent + gc.collect() + + assert agent_ref() is None + assert item.agent is None + + serialized_after_gc = dataclasses.asdict(item) + assert serialized_after_gc["agent"] is None + + +def test_run_result_repr_and_asdict_after_release_agents() -> None: + agent = Agent(name="repr-result-agent") + result = RunResult( + input="test", + new_items=[], + raw_responses=[], + final_output=None, + input_guardrail_results=[], + output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], + _last_agent=agent, + context_wrapper=RunContextWrapper(context=None), + ) + + result.release_agents() + + text = repr(result) + assert "RunResult" in text + + serialized = dataclasses.asdict(result) + assert serialized["_last_agent"] is None + + +def test_run_result_release_agents_without_releasing_new_items() -> None: + message = _create_message("keep") + item_agent = Agent(name="item-agent") + last_agent = Agent(name="last-agent") + item = MessageOutputItem(agent=item_agent, raw_item=message) + result = RunResult( + input="test", + new_items=[item], + raw_responses=[], + final_output=None, + input_guardrail_results=[], + output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], + _last_agent=last_agent, + context_wrapper=RunContextWrapper(context=None), + ) + + result.release_agents(release_new_items=False) + + assert item.agent is item_agent + + last_agent_ref = weakref.ref(last_agent) + del last_agent + gc.collect() + + assert last_agent_ref() is None + with pytest.raises(AgentsException): + _ = result.last_agent + + +def test_run_result_release_agents_is_idempotent() -> None: + message = _create_message("idempotent") + agent = Agent(name="idempotent-agent") + item = MessageOutputItem(agent=agent, raw_item=message) + result = RunResult( + input="test", + new_items=[item], + raw_responses=[], + final_output=None, + input_guardrail_results=[], + output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], + _last_agent=agent, + context_wrapper=RunContextWrapper(context=None), + ) + + result.release_agents() + result.release_agents() + + assert item.agent is agent + + agent_ref = weakref.ref(agent) + del agent + gc.collect() + + assert agent_ref() is None + assert item.agent is None + with pytest.raises(AgentsException): + _ = result.last_agent + + +def test_run_result_streaming_release_agents_releases_current_agent() -> None: + agent = Agent(name="streaming-agent") + streaming_result = RunResultStreaming( + input="stream", + new_items=[], + raw_responses=[], + final_output=None, + input_guardrail_results=[], + output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], + context_wrapper=RunContextWrapper(context=None), + current_agent=agent, + current_turn=0, + max_turns=1, + _current_agent_output_schema=None, + trace=None, + ) + + streaming_result.release_agents(release_new_items=False) + + agent_ref = weakref.ref(agent) + del agent + gc.collect() + + assert agent_ref() is None + with pytest.raises(AgentsException): + _ = streaming_result.last_agent