Skip to content

Commit efa1e26

Browse files
authored
Sanitize auto-generated output tool name to support generic types (#2979)
1 parent 1d3bb01 commit efa1e26

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import inspect
44
import json
5+
import re
56
from abc import ABC, abstractmethod
67
from collections.abc import Awaitable, Callable, Sequence
78
from dataclasses import dataclass, field
@@ -70,6 +71,7 @@
7071

7172
DEFAULT_OUTPUT_TOOL_NAME = 'final_result'
7273
DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
74+
OUTPUT_TOOL_NAME_SANITIZER = re.compile(r'[^a-zA-Z0-9-_]')
7375

7476

7577
async def execute_traced_output_function(
@@ -997,7 +999,9 @@ def build(
997999
if name is None:
9981000
name = default_name
9991001
if multiple:
1000-
name += f'_{object_def.name}'
1002+
# strip unsupported characters like "[" and "]" from generic class names
1003+
safe_name = OUTPUT_TOOL_NAME_SANITIZER.sub('', object_def.name or '')
1004+
name += f'_{safe_name}'
10011005

10021006
i = 1
10031007
original_name = name

tests/test_agent.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import AsyncIterable, Callable
77
from dataclasses import dataclass, replace
88
from datetime import timezone
9-
from typing import Any, Literal, Union
9+
from typing import Any, Generic, Literal, TypeVar, Union
1010

1111
import httpx
1212
import pytest
@@ -96,6 +96,21 @@ class Person(BaseModel):
9696
name: str
9797

9898

99+
# Generic classes for testing tool name sanitization with generic types
100+
T = TypeVar('T')
101+
102+
103+
class ResultGeneric(BaseModel, Generic[T]):
104+
"""A generic result class."""
105+
106+
value: T
107+
success: bool
108+
109+
110+
class StringData(BaseModel):
111+
text: str
112+
113+
99114
def test_result_list_of_models_with_stringified_response():
100115
def return_list(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
101116
assert info.output_tools is not None
@@ -635,6 +650,24 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any:
635650
assert got_tool_call_name == snapshot('final_result_Bar')
636651

637652

653+
def test_output_type_generic_class_name_sanitization():
654+
"""Test that generic class names with brackets are properly sanitized."""
655+
# This will have a name like "ResultGeneric[StringData]" which needs sanitization
656+
output_type = [ResultGeneric[StringData], ResultGeneric[int]]
657+
658+
m = TestModel()
659+
agent = Agent(m, output_type=output_type)
660+
agent.run_sync('Hello')
661+
662+
# The sanitizer should remove brackets from the generic type name
663+
assert m.last_model_request_parameters is not None
664+
assert m.last_model_request_parameters.output_tools is not None
665+
assert len(m.last_model_request_parameters.output_tools) == 2
666+
667+
tool_names = [tool.name for tool in m.last_model_request_parameters.output_tools]
668+
assert tool_names == snapshot(['final_result_ResultGenericStringData', 'final_result_ResultGenericint'])
669+
670+
638671
def test_output_type_with_two_descriptions():
639672
class MyOutput(BaseModel):
640673
"""Description from docstring"""

0 commit comments

Comments
 (0)