diff --git a/src/agents/items.py b/src/agents/items.py index c43e9f856..6ceaa52a8 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -1,7 +1,6 @@ from __future__ import annotations import abc -import copy from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union @@ -41,6 +40,7 @@ from .exceptions import AgentsException, ModelBehaviorError from .usage import Usage +from .util._safe_copy import safe_copy if TYPE_CHECKING: from .agent import Agent @@ -277,7 +277,7 @@ def input_to_new_input_list( "role": "user", } ] - return copy.deepcopy(input) + return safe_copy(input) @classmethod def text_message_outputs(cls, items: list[RunItem]) -> str: diff --git a/src/agents/run.py b/src/agents/run.py index 5f9ec10ac..82ced8f32 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import copy import inspect from dataclasses import dataclass, field from typing import Any, Callable, Generic, cast @@ -56,6 +55,7 @@ from .tracing.span_data import AgentSpanData from .usage import Usage from .util import _coro, _error_tracing +from .util._safe_copy import safe_copy from .util._types import MaybeAwaitable DEFAULT_MAX_TURNS = 10 @@ -387,7 +387,7 @@ async def run( disabled=run_config.tracing_disabled, ): current_turn = 0 - original_input: str | list[TResponseInputItem] = copy.deepcopy(prepared_input) + original_input: str | list[TResponseInputItem] = safe_copy(prepared_input) generated_items: list[RunItem] = [] model_responses: list[ModelResponse] = [] @@ -446,7 +446,7 @@ async def run( starting_agent, starting_agent.input_guardrails + (run_config.input_guardrails or []), - copy.deepcopy(prepared_input), + safe_copy(prepared_input), context_wrapper, ), self._run_single_turn( @@ -594,7 +594,7 @@ def run_streamed( ) streamed_result = RunResultStreaming( - input=copy.deepcopy(input), + input=safe_copy(input), new_items=[], current_agent=starting_agent, raw_responses=[], @@ -647,7 +647,7 @@ async def _maybe_filter_model_input( try: model_input = ModelInputData( - input=copy.deepcopy(effective_input), + input=safe_copy(effective_input), instructions=effective_instructions, ) filter_payload: CallModelData[TContext] = CallModelData( @@ -786,7 +786,7 @@ async def _start_streaming( cls._run_input_guardrails_with_queue( starting_agent, starting_agent.input_guardrails + (run_config.input_guardrails or []), - copy.deepcopy(ItemHelpers.input_to_new_input_list(prepared_input)), + safe_copy(ItemHelpers.input_to_new_input_list(prepared_input)), context_wrapper, streamed_result, current_span, diff --git a/src/agents/util/_safe_copy.py b/src/agents/util/_safe_copy.py new file mode 100644 index 000000000..9dba92585 --- /dev/null +++ b/src/agents/util/_safe_copy.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import Any, TypeVar + +T = TypeVar("T") + + +def safe_copy(obj: T) -> T: + """ + Craete a copy of the given object -- it can be either str or list/set/tuple of objects. + This avoids failures like: + TypeError: cannot pickle '...ValidatorIterator' object + because we never call deepcopy() on non-trivial objects. + """ + return _safe_copy_internal(obj) + + +def _safe_copy_internal(obj: T) -> T: + if isinstance(obj, list): + new_list: list[Any] = [] + new_list.extend(_safe_copy_internal(x) for x in obj) + return new_list # type: ignore [return-value] + + if isinstance(obj, tuple): + new_tuple = tuple(_safe_copy_internal(x) for x in obj) + return new_tuple # type: ignore [return-value] + + if isinstance(obj, set): + new_set: set[Any] = set() + for x in obj: + new_set.add(_safe_copy_internal(x)) + return new_set # type: ignore [return-value] + + if isinstance(obj, frozenset): + new_fset = frozenset(_safe_copy_internal(x) for x in obj) + return new_fset # type: ignore + + return obj diff --git a/tests/test_items_helpers.py b/tests/test_items_helpers.py index f711f21e1..a770d6573 100644 --- a/tests/test_items_helpers.py +++ b/tests/test_items_helpers.py @@ -1,5 +1,7 @@ from __future__ import annotations +import json + from openai.types.responses.response_computer_tool_call import ( ActionScreenshot, ResponseComputerToolCall, @@ -20,8 +22,10 @@ from openai.types.responses.response_output_message_param import ResponseOutputMessageParam from openai.types.responses.response_output_refusal import ResponseOutputRefusal from openai.types.responses.response_output_text import ResponseOutputText +from openai.types.responses.response_output_text_param import ResponseOutputTextParam from openai.types.responses.response_reasoning_item import ResponseReasoningItem, Summary from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam +from pydantic import TypeAdapter from agents import ( Agent, @@ -109,6 +113,37 @@ def test_input_to_new_input_list_deep_copies_lists() -> None: assert "content" in original[0] and original[0].get("content") == "abc" +def test_input_to_new_input_list_copies_the_ones_produced_by_pydantic() -> None: + # Given a list of message dictionaries, ensure the returned list is a deep copy. + original = ResponseOutputMessageParam( + id="a75654dc-7492-4d1c-bce0-89e8312fbdd7", + content=[ + ResponseOutputTextParam( + type="output_text", + text="Hey, what's up?", + annotations=[], + ) + ], + role="assistant", + status="completed", + type="message", + ) + original_json = json.dumps(original) + output_item = TypeAdapter(ResponseOutputMessageParam).validate_json(original_json) + new_list = ItemHelpers.input_to_new_input_list([output_item]) + assert len(new_list) == 1 + assert new_list[0]["id"] == original["id"] # type: ignore + size = 0 + for i, item in enumerate(original["content"]): + size += 1 # pydantic_core._pydantic_core.ValidatorIterator does not support len() + assert item["type"] == original["content"][i]["type"] # type: ignore + assert item["text"] == original["content"][i]["text"] # type: ignore + assert size == 1 + assert new_list[0]["role"] == original["role"] # type: ignore + assert new_list[0]["status"] == original["status"] # type: ignore + assert new_list[0]["type"] == original["type"] + + def test_text_message_output_concatenates_text_segments() -> None: # Build a message with both text and refusal segments, only text segments are concatenated. pieces: list[ResponseOutputText | ResponseOutputRefusal] = [] diff --git a/tests/utils/test_safe_copy.py b/tests/utils/test_safe_copy.py new file mode 100644 index 000000000..1de4281c2 --- /dev/null +++ b/tests/utils/test_safe_copy.py @@ -0,0 +1,137 @@ +# tests/test_safe_copy.py +import datetime as dt +import io +from decimal import Decimal +from fractions import Fraction +from uuid import UUID + +import pytest + +from agents.util._safe_copy import safe_copy + + +class BoomDeepcopy: + """Raises on deepcopy, but shallow copy is fine.""" + + def __init__(self, x=0): + self.x = x + + def __deepcopy__(self, memo): + raise TypeError("no deepcopy") + + def __copy__(self): + # canonical shallow behavior: return self (mutable identity preserved) + return self + + +class NoCopyEither: + """Raises on shallow copy; our safe_copy should return original object.""" + + def __copy__(self): + raise TypeError("no shallow copy") + + def __deepcopy__(self, memo): + raise TypeError("no deepcopy") + + +@pytest.mark.parametrize( + "value", + [ + None, + True, + 123, + 3.14, + complex(1, 2), + "hello", + b"bytes", + Decimal("1.23"), + Fraction(3, 7), + UUID(int=1), + dt.date(2020, 1, 2), + dt.datetime(2020, 1, 2, 3, 4, 5), + dt.time(12, 34, 56), + dt.timedelta(days=2), + range(5), + ], +) +def test_simple_atoms_roundtrip(value): + cpy = safe_copy(value) + assert cpy == value + + +def test_generator_is_preserved_and_not_consumed(): + gen = (i for i in range(3)) + data = {"g": gen} + cpy = safe_copy(data) + + # generator object is reused (no deepcopy attempt) + assert cpy["g"] is gen + + # ensure it hasn't been consumed by copying + assert next(gen) == 0 + assert next(gen) == 1 + + +def test_file_like_object_is_not_deepcopied(): + f = io.StringIO("hello") + data = {"f": f} + cpy = safe_copy(data) + assert cpy["f"] is f # shallow reuse + + +def test_frozenset_and_set_handling(): + class Marker: + pass + + m = Marker() + s = {1, 2, 3, m} + fs = frozenset({1, 2, 3, m}) + + s2 = safe_copy(s) + fs2 = safe_copy(fs) + + # containers are rebuilt + assert s2 is not s + assert fs2 is not fs + + # primitive members equal, complex leaf identity preserved + assert 1 in s2 and 1 in fs2 + assert any(x is m for x in s2) + assert any(x is m for x in fs2) + + # mutating original set doesn't affect the copy + s.add(99) + assert 99 not in s2 + + +def test_object_where_deepcopy_would_fail_is_handled_via_shallow_copy(): + b = BoomDeepcopy(7) + c = safe_copy(b) + # shallow copy path returns same instance per __copy__ implementation + assert c is b + assert c.x == 7 + + +def test_object_where_shallow_copy_also_fails_returns_original(): + o = NoCopyEither() + c = safe_copy(o) + # last-resort path: return original object, but do not raise + assert c is o + + +def test_tuple_container_is_rebuilt_and_nested_behavior_respected(): + class Box: + def __init__(self, v): + self.v = v + + box = Box(1) + orig = (1, [2, 3], box) + cpy = safe_copy(orig) + + assert cpy is not orig + assert cpy[0] == 1 + assert cpy[1] is not orig[1] # list rebuilt + assert cpy[2] is box # complex leaf shallow + + orig[1][0] = 999 + assert cpy[1][0] == 2