Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion src/agents/items.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import abc
from dataclasses import dataclass
import weakref
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union

import pydantic
Expand Down Expand Up @@ -84,6 +85,22 @@ class RunItemBase(Generic[T], abc.ABC):
(i.e. `openai.types.responses.ResponseInputItemParam`).
"""

_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field(
init=False,
repr=False,
default=None,
)

def __post_init__(self) -> None:
# Store the producing agent weakly to avoid keeping it alive after the run.
self._agent_ref = weakref.ref(self.agent)
object.__delattr__(self, "agent")

def __getattr__(self, name: str) -> Any:
if name == "agent":
return self._agent_ref() if self._agent_ref else None
raise AttributeError(name)

def to_input_item(self) -> TResponseInputItem:
"""Converts this item into an input item suitable for passing to the model."""
if isinstance(self.raw_item, dict):
Expand Down Expand Up @@ -131,6 +148,32 @@ class HandoffOutputItem(RunItemBase[TResponseInputItem]):

type: Literal["handoff_output_item"] = "handoff_output_item"

_source_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field(
init=False,
repr=False,
default=None,
)
_target_agent_ref: weakref.ReferenceType[Agent[Any]] | None = field(
init=False,
repr=False,
default=None,
)

def __post_init__(self) -> None:
super().__post_init__()
# Handoff metadata should not hold strong references to the agents either.
self._source_agent_ref = weakref.ref(self.source_agent)
self._target_agent_ref = weakref.ref(self.target_agent)
object.__delattr__(self, "source_agent")
object.__delattr__(self, "target_agent")

def __getattr__(self, name: str) -> Any:
if name == "source_agent":
return self._source_agent_ref() if self._source_agent_ref else None
if name == "target_agent":
return self._target_agent_ref() if self._target_agent_ref else None
return super().__getattr__(name)


ToolCallItemTypes: TypeAlias = Union[
ResponseFunctionToolCall,
Expand Down
34 changes: 34 additions & 0 deletions tests/test_agent_memory_leak.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

import gc
import weakref

import pytest
from openai.types.responses import ResponseOutputMessage, ResponseOutputText

from agents import Agent, Runner
from tests.fake_model import FakeModel


def _make_message(text: str) -> ResponseOutputMessage:
return ResponseOutputMessage(
id="msg-1",
content=[ResponseOutputText(annotations=[], text=text, type="output_text")],
role="assistant",
status="completed",
type="message",
)


@pytest.mark.asyncio
async def test_agent_is_released_after_run() -> None:
fake_model = FakeModel(initial_output=[_make_message("Paris")])
agent = Agent(name="leaker", instructions="Answer questions.", model=fake_model)
agent_ref = weakref.ref(agent)

await Runner.run(agent, "What is the capital of France?")

del agent
gc.collect()

assert agent_ref() is None