-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Toward SearchableToolSet and cross-model ToolSearch #3680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
df4dea0
980187b
35a65e9
364a58e
8ffdf17
0f754c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| import logging | ||
| import re | ||
| from collections.abc import Callable | ||
| from dataclasses import dataclass, field, replace | ||
| from typing import Any, TypedDict | ||
|
|
||
| from pydantic import TypeAdapter | ||
| from typing_extensions import Self | ||
|
|
||
| from .._run_context import AgentDepsT, RunContext | ||
| from ..tools import ToolDefinition | ||
| from .abstract import AbstractToolset, SchemaValidatorProt, ToolsetTool | ||
|
|
||
| _SEARCH_TOOL_NAME = 'load_tools' | ||
|
|
||
|
|
||
| class _SearchToolArgs(TypedDict): | ||
| regex: str | ||
|
|
||
|
|
||
| def _search_tool_def() -> ToolDefinition: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check out |
||
| return ToolDefinition( | ||
| name=_SEARCH_TOOL_NAME, | ||
| description="""Search and load additional tools to make them available to the agent. | ||
|
|
||
| DO call this to find and load more tools needed for a task. | ||
| NEVER ask the user if you should try loading tools, just try. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I see you explained below that this was needed to pass the tests, even for Sonnet 4.5, but tokens are expensive so it'll be worth another iteration on this. |
||
| """, | ||
| parameters_json_schema={ | ||
| 'type': 'object', | ||
| 'properties': { | ||
| 'regex': { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like |
||
| 'type': 'string', | ||
| 'description': 'Regex pattern to search for relevant tools', | ||
| } | ||
| }, | ||
| 'required': ['regex'], | ||
| }, | ||
| ) | ||
|
|
||
|
|
||
| def _search_tool_validator() -> SchemaValidatorProt: | ||
| return TypeAdapter(_SearchToolArgs).validator | ||
|
|
||
|
|
||
| @dataclass | ||
| class _SearchTool(ToolsetTool[AgentDepsT]): | ||
| """A tool that searches for more relevant tools from a SearchableToolSet.""" | ||
|
|
||
| tool_def: ToolDefinition = field(default_factory=_search_tool_def) | ||
| args_validator: SchemaValidatorProt = field(default_factory=_search_tool_validator) | ||
|
|
||
|
|
||
| @dataclass | ||
| class SearchableToolset(AbstractToolset[AgentDepsT]): | ||
| """A toolset that implements tool search and deferred tool loading.""" | ||
|
|
||
| toolset: AbstractToolset[AgentDepsT] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have a look at |
||
| _active_tool_names: set[str] = field(default_factory=set) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fact that this has instance variables means that it can't be reused across multiple agent runs, even though the same instance is registered to an agent just once... We had a similar issue with |
||
|
|
||
| @property | ||
| def id(self) -> str | None: | ||
| return None # pragma: no cover | ||
|
|
||
| @property | ||
| def label(self) -> str: | ||
| return f'{self.__class__.__name__}({self.toolset.label})' # pragma: no cover | ||
|
|
||
| async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: | ||
| logging.debug("SearchableToolset.get_tools") | ||
| all_tools: dict[str, ToolsetTool[AgentDepsT]] = {} | ||
| all_tools[_SEARCH_TOOL_NAME] = _SearchTool( | ||
| toolset=self, | ||
| max_retries=1, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may want to increase this, to give the model a few chances to fix its regex, if it submitted an invalid one the first time |
||
| ) | ||
|
|
||
| toolset_tools = await self.toolset.get_tools(ctx) | ||
| for tool_name, tool in toolset_tools.items(): | ||
| # TODO proper error handling | ||
| assert tool_name != _SEARCH_TOOL_NAME | ||
|
|
||
| if tool_name in self._active_tool_names: | ||
| all_tools[tool_name] = tool | ||
|
|
||
| logging.debug(f"SearchableToolset.get_tools ==> {[t for t in all_tools]}") | ||
| return all_tools | ||
|
|
||
| async def call_tool( | ||
| self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] | ||
| ) -> Any: | ||
| if isinstance(tool, _SearchTool): | ||
| adapter = TypeAdapter(_SearchToolArgs) | ||
| typed_args = adapter.validate_python(tool_args) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Arguments will/should already have been validated by this point when used through
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, was not obvious from the types, but sounds like I can just cast this. Thanks. |
||
| result = await self.call_search_tool(typed_args, ctx) | ||
| logging.debug(f"SearchableToolset.call_tool({name}, {tool_args}) ==> {result}") | ||
| return result | ||
| else: | ||
| result = await self.toolset.call_tool(name, tool_args, ctx, tool) | ||
| logging.debug(f"SearchableToolset.call_tool({name}, {tool_args}) ==> {result}") | ||
| return result | ||
|
|
||
| async def call_search_tool(self, args: _SearchToolArgs, ctx: RunContext[AgentDepsT]) -> list[str]: | ||
| """Searches for tools matching the query, activates them and returns their names.""" | ||
| toolset_tools = await self.toolset.get_tools(ctx) | ||
| matching_tool_names: list[str] = [] | ||
|
|
||
| for tool_name, tool in toolset_tools.items(): | ||
| rx = re.compile(args['regex']) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This'll be more efficient one line up :) |
||
| if rx.search(tool.tool_def.name) or rx.search(tool.tool_def.description): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For error handling, check out the |
||
| matching_tool_names.append(tool.tool_def.name) | ||
|
|
||
| self._active_tool_names.update(matching_tool_names) | ||
| return matching_tool_names | ||
|
|
||
| def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: | ||
| self.toolset.apply(visitor) | ||
|
|
||
| def visit_and_replace( | ||
| self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]] | ||
| ) -> AbstractToolset[AgentDepsT]: | ||
| return replace(self, toolset=self.toolset.visit_and_replace(visitor)) | ||
|
|
||
| async def __aenter__(self) -> Self: | ||
| await self.toolset.__aenter__() | ||
| return self | ||
|
|
||
| async def __aexit__(self, *args: Any) -> bool | None: | ||
| return await self.toolset.__aexit__(*args) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| """Minimal example to test SearchableToolset functionality. | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like proper tests need to go into:
I just wanted to get something quick to iterate with an actual LLM. This ended up working on Claude but took a few iterations on the prompt. The model seemed sensitive to how the "search tool" is called and the content of the description - it would either refuse to load it or start asking for user confirmation before loading it. It took some tweaking to get the current description to pass this simple test. |
||
|
|
||
| Run with: uv run python test_searchable_example.py | ||
| Make sure you have ANTHROPIC_API_KEY set in your environment. | ||
| """ | ||
|
|
||
| import asyncio | ||
| import logging | ||
| import sys | ||
|
|
||
| # Configure logging to print to stdout | ||
| logging.basicConfig( | ||
| level=logging.DEBUG, | ||
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | ||
| stream=sys.stdout | ||
| ) | ||
|
|
||
| # Silence noisy loggers | ||
| logging.getLogger('asyncio').setLevel(logging.WARNING) | ||
| logging.getLogger('httpx').setLevel(logging.WARNING) | ||
| logging.getLogger('httpcore.connection').setLevel(logging.WARNING) | ||
| logging.getLogger('httpcore.http11').setLevel(logging.WARNING) | ||
| logging.getLogger('anthropic._base_client').setLevel(logging.WARNING) | ||
|
|
||
| from pydantic_ai import Agent | ||
| from pydantic_ai.toolsets import FunctionToolset | ||
| from pydantic_ai.toolsets.searchable import SearchableToolset | ||
|
|
||
|
|
||
| # Create a toolset with various tools | ||
| toolset = FunctionToolset() | ||
|
|
||
|
|
||
| @toolset.tool | ||
| def get_weather(city: str) -> str: | ||
| """Get the current weather for a given city. | ||
|
|
||
| Args: | ||
| city: The name of the city to get weather for. | ||
| """ | ||
| return f"The weather in {city} is sunny and 72°F" | ||
|
|
||
|
|
||
| @toolset.tool | ||
| def calculate_sum(a: float, b: float) -> float: | ||
| """Add two numbers together. | ||
|
|
||
| Args: | ||
| a: The first number. | ||
| b: The second number. | ||
| """ | ||
| return a + b | ||
|
|
||
|
|
||
| @toolset.tool | ||
| def calculate_product(a: float, b: float) -> float: | ||
| """Multiply two numbers together. | ||
|
|
||
| Args: | ||
| a: The first number. | ||
| b: The second number. | ||
| """ | ||
| return a * b | ||
|
|
||
|
|
||
| @toolset.tool | ||
| def fetch_user_data(user_id: int) -> dict: | ||
| """Fetch user data from the database. | ||
|
|
||
| Args: | ||
| user_id: The ID of the user to fetch. | ||
| """ | ||
| return {"id": user_id, "name": "John Doe", "email": "john@example.com"} | ||
|
|
||
|
|
||
| @toolset.tool | ||
| def send_email(recipient: str, subject: str, body: str) -> str: | ||
| """Send an email to a recipient. | ||
|
|
||
| Args: | ||
| recipient: The email address of the recipient. | ||
| subject: The subject line of the email. | ||
| body: The body content of the email. | ||
| """ | ||
| return f"Email sent to {recipient} with subject '{subject}'" | ||
|
|
||
|
|
||
| @toolset.tool | ||
| def list_database_tables() -> list[str]: | ||
| """List all tables in the database.""" | ||
| return ["users", "orders", "products", "reviews"] | ||
|
|
||
|
|
||
| # Wrap the toolset with SearchableToolset | ||
| searchable_toolset = SearchableToolset(toolset=toolset) | ||
|
|
||
| # Create an agent with the searchable toolset | ||
| agent = Agent( | ||
| 'anthropic:claude-sonnet-4-5', | ||
| toolsets=[searchable_toolset], | ||
| system_prompt=( | ||
| "You are a helpful assistant." | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| async def main(): | ||
| print("=" * 60) | ||
| print("Testing SearchableToolset") | ||
| print("=" * 60) | ||
| print() | ||
|
|
||
| # Test 1: Ask something that requires searching for calculation tools | ||
| print("Test 1: Calculation task") | ||
| print("-" * 60) | ||
| result = await agent.run("What is 123 multiplied by 456?") | ||
| print(f"Result: {result.output}") | ||
| print() | ||
|
|
||
| # Test 2: Ask something that requires searching for database tools | ||
| print("\nTest 2: Database task") | ||
| print("-" * 60) | ||
| result = await agent.run("Can you list the database tables and then fetch user 42?") | ||
| print(f"Result: {result.output}") | ||
| print() | ||
|
|
||
| # Test 3: Ask something that requires weather tool | ||
| print("\nTest 3: Weather task") | ||
| print("-" * 60) | ||
| result = await agent.run("What's the weather like in San Francisco?") | ||
| print(f"Result: {result.output}") | ||
| print() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another curious bit is that when tool was called "more_tools", I hit a crash:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, that suggests that the model was not calling it correctly (wrong args possibly). I suggest adding https://ai.pydantic.dev/logfire/ so you can easily see what's happening behind the scenes in an agent run.