Skip to content

Commit 4859a93

Browse files
committed
Simplify the code
1 parent 75c872e commit 4859a93

File tree

2 files changed

+7
-127
lines changed

2 files changed

+7
-127
lines changed

src/agents/util/_safe_copy.py

Lines changed: 7 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,38 @@
11
from __future__ import annotations
22

3-
import copy
4-
import datetime as _dt
5-
from decimal import Decimal
6-
from fractions import Fraction
7-
from pathlib import PurePath
83
from typing import Any, TypeVar
9-
from uuid import UUID
104

115
T = TypeVar("T")
126

137

148
def safe_copy(obj: T) -> T:
159
"""
16-
Copy 'obj' without triggering deepcopy on complex/fragile objects.
17-
18-
Rules:
19-
- Primitive/simple atoms (ints, strs, datetimes, etc.): deepcopy (cheap and safe).
20-
- Built-in containers (dict, list, tuple, set, frozenset): recurse element-wise.
21-
- Everything else (framework objects, iterators, models, file handles, etc.):
22-
shallow copy if possible; otherwise return as-is.
23-
10+
Craete a copy of the given object -- it can be either str or list/set/tuple of objects.
2411
This avoids failures like:
2512
TypeError: cannot pickle '...ValidatorIterator' object
2613
because we never call deepcopy() on non-trivial objects.
2714
"""
28-
memo: dict[int, Any] = {}
29-
return _safe_copy_internal(obj, memo)
30-
31-
32-
_SIMPLE_ATOMS = (
33-
# basics
34-
type(None),
35-
bool,
36-
int,
37-
float,
38-
complex,
39-
str,
40-
bytes,
41-
# small buffers/scalars
42-
bytearray,
43-
memoryview,
44-
range,
45-
# "value" types
46-
Decimal,
47-
Fraction,
48-
UUID,
49-
PurePath,
50-
_dt.date,
51-
_dt.datetime,
52-
_dt.time,
53-
_dt.timedelta,
54-
)
55-
56-
57-
def _is_simple_atom(o: Any) -> bool:
58-
return isinstance(o, _SIMPLE_ATOMS)
59-
60-
61-
def _safe_copy_internal(obj: T, memo: dict[int, Any]) -> T:
62-
oid = id(obj)
63-
if oid in memo:
64-
return memo[oid] # type: ignore [no-any-return]
65-
66-
# 1) Simple "atoms": safe to deepcopy (cheap, predictable).
67-
if _is_simple_atom(obj):
68-
return copy.deepcopy(obj)
15+
return _safe_copy_internal(obj)
6916

70-
# 2) Containers: rebuild and recurse.
71-
if isinstance(obj, dict):
72-
new_dict: dict[Any, Any] = {}
73-
memo[oid] = new_dict
74-
for k, v in obj.items():
75-
# preserve key identity/value, only copy the value
76-
new_dict[k] = _safe_copy_internal(v, memo)
77-
return new_dict # type: ignore [return-value]
7817

18+
def _safe_copy_internal(obj: T) -> T:
7919
if isinstance(obj, list):
8020
new_list: list[Any] = []
81-
memo[oid] = new_list
82-
new_list.extend(_safe_copy_internal(x, memo) for x in obj)
21+
new_list.extend(_safe_copy_internal(x) for x in obj)
8322
return new_list # type: ignore [return-value]
8423

8524
if isinstance(obj, tuple):
86-
new_tuple = tuple(_safe_copy_internal(x, memo) for x in obj)
87-
memo[oid] = new_tuple
25+
new_tuple = tuple(_safe_copy_internal(x) for x in obj)
8826
return new_tuple # type: ignore [return-value]
8927

9028
if isinstance(obj, set):
9129
new_set: set[Any] = set()
92-
memo[oid] = new_set
9330
for x in obj:
94-
new_set.add(_safe_copy_internal(x, memo))
31+
new_set.add(_safe_copy_internal(x))
9532
return new_set # type: ignore [return-value]
9633

9734
if isinstance(obj, frozenset):
98-
new_fset = frozenset(_safe_copy_internal(x, memo) for x in obj)
99-
memo[oid] = new_fset
35+
new_fset = frozenset(_safe_copy_internal(x) for x in obj)
10036
return new_fset # type: ignore
10137

102-
# 3) Unknown/complex leaf: return as-is (identity preserved).
103-
memo[oid] = obj
10438
return obj

tests/utils/test_safe_copy.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,6 @@ def __deepcopy__(self, memo):
3535
raise TypeError("no deepcopy")
3636

3737

38-
def test_primitives_are_copied_independently_for_mutable_bytes():
39-
orig = bytearray(b"abc")
40-
cpy = safe_copy(orig)
41-
assert bytes(cpy) == b"abc"
42-
orig[0] = ord("z")
43-
assert bytes(orig) == b"zbc"
44-
assert bytes(cpy) == b"abc" # unaffected
45-
46-
4738
@pytest.mark.parametrize(
4839
"value",
4940
[
@@ -69,37 +60,6 @@ def test_simple_atoms_roundtrip(value):
6960
assert cpy == value
7061

7162

72-
def test_deep_copy_for_nested_containers_of_primitives():
73-
orig = {"a": [1, 2, {"z": (3, 4)}]}
74-
cpy = safe_copy(orig)
75-
76-
# mutate original deeply
77-
orig["a"][2]["z"] = (99, 100) # type: ignore
78-
79-
assert cpy == {"a": [1, 2, {"z": (3, 4)}]} # unaffected
80-
81-
82-
def test_complex_leaf_is_only_shallow_copied():
83-
class Leaf:
84-
def __init__(self):
85-
self.val = 1
86-
87-
leaf = Leaf()
88-
obj = {"k": leaf, "arr": [1, 2, 3]}
89-
cpy = safe_copy(obj)
90-
91-
# container structure is new
92-
assert cpy is not obj
93-
assert cpy["arr"] is not obj["arr"]
94-
95-
# complex leaf is shallow: identity preserved
96-
assert cpy["k"] is leaf
97-
98-
# mutating the leaf reflects in the copied structure
99-
leaf.val = 42
100-
assert cpy["k"].val == 42 # type: ignore [attr-defined]
101-
102-
10363
def test_generator_is_preserved_and_not_consumed():
10464
gen = (i for i in range(3))
10565
data = {"g": gen}
@@ -145,20 +105,6 @@ class Marker:
145105
assert 99 not in s2
146106

147107

148-
def test_cycles_are_handled_without_recursion_error():
149-
# a -> (a,)
150-
a: list[Any] = []
151-
t = (a,)
152-
a.append(t)
153-
154-
c = safe_copy(a)
155-
# structure cloned:
156-
assert c is not a
157-
assert isinstance(c[0], tuple)
158-
# cycle preserved: the tuple's first element points back to the list
159-
assert c[0][0] is c
160-
161-
162108
def test_object_where_deepcopy_would_fail_is_handled_via_shallow_copy():
163109
b = BoomDeepcopy(7)
164110
c = safe_copy(b)

0 commit comments

Comments
 (0)