|
6 | 6 | from collections.abc import AsyncIterable, Callable |
7 | 7 | from dataclasses import dataclass, replace |
8 | 8 | from datetime import timezone |
9 | | -from typing import Any, Literal, Union |
| 9 | +from typing import Any, Generic, Literal, TypeVar, Union |
10 | 10 |
|
11 | 11 | import httpx |
12 | 12 | import pytest |
@@ -96,6 +96,21 @@ class Person(BaseModel): |
96 | 96 | name: str |
97 | 97 |
|
98 | 98 |
|
| 99 | +# Generic classes for testing tool name sanitization with generic types |
| 100 | +T = TypeVar('T') |
| 101 | + |
| 102 | + |
| 103 | +class ResultGeneric(BaseModel, Generic[T]): |
| 104 | + """A generic result class.""" |
| 105 | + |
| 106 | + value: T |
| 107 | + success: bool |
| 108 | + |
| 109 | + |
| 110 | +class StringData(BaseModel): |
| 111 | + text: str |
| 112 | + |
| 113 | + |
99 | 114 | def test_result_list_of_models_with_stringified_response(): |
100 | 115 | def return_list(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: |
101 | 116 | assert info.output_tools is not None |
@@ -635,6 +650,24 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: |
635 | 650 | assert got_tool_call_name == snapshot('final_result_Bar') |
636 | 651 |
|
637 | 652 |
|
| 653 | +def test_output_type_generic_class_name_sanitization(): |
| 654 | + """Test that generic class names with brackets are properly sanitized.""" |
| 655 | + # This will have a name like "ResultGeneric[StringData]" which needs sanitization |
| 656 | + output_type = [ResultGeneric[StringData], ResultGeneric[int]] |
| 657 | + |
| 658 | + m = TestModel() |
| 659 | + agent = Agent(m, output_type=output_type) |
| 660 | + agent.run_sync('Hello') |
| 661 | + |
| 662 | + # The sanitizer should remove brackets from the generic type name |
| 663 | + assert m.last_model_request_parameters is not None |
| 664 | + assert m.last_model_request_parameters.output_tools is not None |
| 665 | + assert len(m.last_model_request_parameters.output_tools) == 2 |
| 666 | + |
| 667 | + tool_names = [tool.name for tool in m.last_model_request_parameters.output_tools] |
| 668 | + assert tool_names == snapshot(['final_result_ResultGenericStringData', 'final_result_ResultGenericint']) |
| 669 | + |
| 670 | + |
638 | 671 | def test_output_type_with_two_descriptions(): |
639 | 672 | class MyOutput(BaseModel): |
640 | 673 | """Description from docstring""" |
|
0 commit comments