diff --git a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py index 84a1c04012..8c0d773832 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py @@ -65,6 +65,9 @@ class ModelProfile: This is currently only used by `OpenAIChatModel`, `HuggingFaceModel`, and `GroqModel`. """ + supports_tool_search: bool = False + """Whether the model has native support for tool search and defer loading tools.""" + @classmethod def from_profile(cls, profile: ModelProfile | None) -> Self: """Build a ModelProfile subclass instance from a ModelProfile instance.""" diff --git a/pydantic_ai_slim/pydantic_ai/profiles/anthropic.py b/pydantic_ai_slim/pydantic_ai/profiles/anthropic.py index 6a59ab2dec..bc76b4d5a9 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/anthropic.py @@ -23,6 +23,7 @@ def anthropic_model_profile(model_name: str) -> ModelProfile | None: thinking_tags=('', ''), supports_json_schema_output=supports_json_schema_output, json_schema_transformer=AnthropicJsonSchemaTransformer, + supports_tool_search=True, ) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/searchable.py b/pydantic_ai_slim/pydantic_ai/toolsets/searchable.py new file mode 100644 index 0000000000..7702104688 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/searchable.py @@ -0,0 +1,128 @@ +import logging +import re +from collections.abc import Callable +from dataclasses import dataclass, field, replace +from typing import Any, TypedDict + +from pydantic import TypeAdapter +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from .abstract import AbstractToolset, SchemaValidatorProt, ToolsetTool + +_SEARCH_TOOL_NAME = 'load_tools' + + +class _SearchToolArgs(TypedDict): + regex: str + + +def _search_tool_def() -> ToolDefinition: + return ToolDefinition( + name=_SEARCH_TOOL_NAME, + description="""Search and load additional tools to make them available to the agent. + +DO call this to find and load more tools needed for a task. +NEVER ask the user if you should try loading tools, just try. +""", + parameters_json_schema={ + 'type': 'object', + 'properties': { + 'regex': { + 'type': 'string', + 'description': 'Regex pattern to search for relevant tools', + } + }, + 'required': ['regex'], + }, + ) + + +def _search_tool_validator() -> SchemaValidatorProt: + return TypeAdapter(_SearchToolArgs).validator + + +@dataclass +class _SearchTool(ToolsetTool[AgentDepsT]): + """A tool that searches for more relevant tools from a SearchableToolSet.""" + + tool_def: ToolDefinition = field(default_factory=_search_tool_def) + args_validator: SchemaValidatorProt = field(default_factory=_search_tool_validator) + + +@dataclass +class SearchableToolset(AbstractToolset[AgentDepsT]): + """A toolset that implements tool search and deferred tool loading.""" + + toolset: AbstractToolset[AgentDepsT] + _active_tool_names: set[str] = field(default_factory=set) + + @property + def id(self) -> str | None: + return None # pragma: no cover + + @property + def label(self) -> str: + return f'{self.__class__.__name__}({self.toolset.label})' # pragma: no cover + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + logging.debug("SearchableToolset.get_tools") + all_tools: dict[str, ToolsetTool[AgentDepsT]] = {} + all_tools[_SEARCH_TOOL_NAME] = _SearchTool( + toolset=self, + max_retries=1, + ) + + toolset_tools = await self.toolset.get_tools(ctx) + for tool_name, tool in toolset_tools.items(): + # TODO proper error handling + assert tool_name != _SEARCH_TOOL_NAME + + if tool_name in self._active_tool_names: + all_tools[tool_name] = tool + + logging.debug(f"SearchableToolset.get_tools ==> {[t for t in all_tools]}") + return all_tools + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + if isinstance(tool, _SearchTool): + adapter = TypeAdapter(_SearchToolArgs) + typed_args = adapter.validate_python(tool_args) + result = await self.call_search_tool(typed_args, ctx) + logging.debug(f"SearchableToolset.call_tool({name}, {tool_args}) ==> {result}") + return result + else: + result = await self.toolset.call_tool(name, tool_args, ctx, tool) + logging.debug(f"SearchableToolset.call_tool({name}, {tool_args}) ==> {result}") + return result + + async def call_search_tool(self, args: _SearchToolArgs, ctx: RunContext[AgentDepsT]) -> list[str]: + """Searches for tools matching the query, activates them and returns their names.""" + toolset_tools = await self.toolset.get_tools(ctx) + matching_tool_names: list[str] = [] + + for tool_name, tool in toolset_tools.items(): + rx = re.compile(args['regex']) + if rx.search(tool.tool_def.name) or rx.search(tool.tool_def.description): + matching_tool_names.append(tool.tool_def.name) + + self._active_tool_names.update(matching_tool_names) + return matching_tool_names + + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: + self.toolset.apply(visitor) + + def visit_and_replace( + self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]] + ) -> AbstractToolset[AgentDepsT]: + return replace(self, toolset=self.toolset.visit_and_replace(visitor)) + + async def __aenter__(self) -> Self: + await self.toolset.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> bool | None: + return await self.toolset.__aexit__(*args) diff --git a/test_searchable_example.py b/test_searchable_example.py new file mode 100644 index 0000000000..2f9a0fe96c --- /dev/null +++ b/test_searchable_example.py @@ -0,0 +1,136 @@ +"""Minimal example to test SearchableToolset functionality. + +Run with: uv run python test_searchable_example.py +Make sure you have ANTHROPIC_API_KEY set in your environment. +""" + +import asyncio +import logging +import sys + +# Configure logging to print to stdout +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + stream=sys.stdout +) + +# Silence noisy loggers +logging.getLogger('asyncio').setLevel(logging.WARNING) +logging.getLogger('httpx').setLevel(logging.WARNING) +logging.getLogger('httpcore.connection').setLevel(logging.WARNING) +logging.getLogger('httpcore.http11').setLevel(logging.WARNING) +logging.getLogger('anthropic._base_client').setLevel(logging.WARNING) + +from pydantic_ai import Agent +from pydantic_ai.toolsets import FunctionToolset +from pydantic_ai.toolsets.searchable import SearchableToolset + + +# Create a toolset with various tools +toolset = FunctionToolset() + + +@toolset.tool +def get_weather(city: str) -> str: + """Get the current weather for a given city. + + Args: + city: The name of the city to get weather for. + """ + return f"The weather in {city} is sunny and 72°F" + + +@toolset.tool +def calculate_sum(a: float, b: float) -> float: + """Add two numbers together. + + Args: + a: The first number. + b: The second number. + """ + return a + b + + +@toolset.tool +def calculate_product(a: float, b: float) -> float: + """Multiply two numbers together. + + Args: + a: The first number. + b: The second number. + """ + return a * b + + +@toolset.tool +def fetch_user_data(user_id: int) -> dict: + """Fetch user data from the database. + + Args: + user_id: The ID of the user to fetch. + """ + return {"id": user_id, "name": "John Doe", "email": "john@example.com"} + + +@toolset.tool +def send_email(recipient: str, subject: str, body: str) -> str: + """Send an email to a recipient. + + Args: + recipient: The email address of the recipient. + subject: The subject line of the email. + body: The body content of the email. + """ + return f"Email sent to {recipient} with subject '{subject}'" + + +@toolset.tool +def list_database_tables() -> list[str]: + """List all tables in the database.""" + return ["users", "orders", "products", "reviews"] + + +# Wrap the toolset with SearchableToolset +searchable_toolset = SearchableToolset(toolset=toolset) + +# Create an agent with the searchable toolset +agent = Agent( + 'anthropic:claude-sonnet-4-5', + toolsets=[searchable_toolset], + system_prompt=( + "You are a helpful assistant." + ), +) + + +async def main(): + print("=" * 60) + print("Testing SearchableToolset") + print("=" * 60) + print() + + # Test 1: Ask something that requires searching for calculation tools + print("Test 1: Calculation task") + print("-" * 60) + result = await agent.run("What is 123 multiplied by 456?") + print(f"Result: {result.output}") + print() + + # Test 2: Ask something that requires searching for database tools + print("\nTest 2: Database task") + print("-" * 60) + result = await agent.run("Can you list the database tables and then fetch user 42?") + print(f"Result: {result.output}") + print() + + # Test 3: Ask something that requires weather tool + print("\nTest 3: Weather task") + print("-" * 60) + result = await agent.run("What's the weather like in San Francisco?") + print(f"Result: {result.output}") + print() + + +if __name__ == "__main__": + asyncio.run(main())