Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class ModelProfile:
This is currently only used by `OpenAIChatModel`, `HuggingFaceModel`, and `GroqModel`.
"""

supports_tool_search: bool = False
"""Whether the model has native support for tool search and defer loading tools."""

@classmethod
def from_profile(cls, profile: ModelProfile | None) -> Self:
"""Build a ModelProfile subclass instance from a ModelProfile instance."""
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/profiles/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def anthropic_model_profile(model_name: str) -> ModelProfile | None:
thinking_tags=('<thinking>', '</thinking>'),
supports_json_schema_output=supports_json_schema_output,
json_schema_transformer=AnthropicJsonSchemaTransformer,
supports_tool_search=True,
)


Expand Down
128 changes: 128 additions & 0 deletions pydantic_ai_slim/pydantic_ai/toolsets/searchable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import logging
import re
from collections.abc import Callable
from dataclasses import dataclass, field, replace
from typing import Any, TypedDict

from pydantic import TypeAdapter
from typing_extensions import Self

from .._run_context import AgentDepsT, RunContext
from ..tools import ToolDefinition
from .abstract import AbstractToolset, SchemaValidatorProt, ToolsetTool

_SEARCH_TOOL_NAME = 'load_tools'
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another curious bit is that when tool was called "more_tools", I hit a crash:



Traceback (most recent call last):
  File "/Users/anton/code/pydantic-ai/test_searchable_example.py", line 136, in <module>
    asyncio.run(main())
  File "/Users/anton/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/Users/anton/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anton/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/python3.12/asyncio/base_events.py", line 691, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/Users/anton/code/pydantic-ai/test_searchable_example.py", line 123, in main
    result = await agent.run("Can you list the database tables and then fetch user 42?")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anton/code/pydantic-ai/pydantic_ai_slim/pydantic_ai/agent/abstract.py", line 226, in run
    async with self.iter(
               ^^^^^^^^^^
  File "/Users/anton/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/python3.12/contextlib.py", line 231, in __aexit__
    await self.gen.athrow(value)
  File "/Users/anton/code/pydantic-ai/pydantic_ai_slim/pydantic_ai/agent/__init__.py", line 658, in iter
    async with graph.iter(
               ^^^^^^^^^^^
  File "/Users/anton/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/python3.12/contextlib.py", line 231, in __aexit__
    await self.gen.athrow(value)
  File "/Users/anton/code/pydantic-ai/pydantic_graph/pydantic_graph/beta/graph.py", line 270, in iter
    async with GraphRun[StateT, DepsT, OutputT](
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anton/code/pydantic-ai/pydantic_graph/pydantic_graph/beta/graph.py", line 423, in __aexit__
    await self._async_exit_stack.__aexit__(exc_type, exc_val, exc_tb)
  File "/Users/anton/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/python3.12/contextlib.py", line 754, in __aexit__
    raise exc_details[1]
  File "/Users/anton/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/python3.12/contextlib.py", line 735, in __aexit__
    cb_suppress = cb(*exc_details)
                  ^^^^^^^^^^^^^^^^
  File "/Users/anton/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/python3.12/contextlib.py", line 158, in __exit__
    self.gen.throw(value)
  File "/Users/anton/code/pydantic-ai/pydantic_graph/pydantic_graph/beta/graph.py", line 978, in _unwrap_exception_groups
    raise exception
  File "/Users/anton/code/pydantic-ai/pydantic_graph/pydantic_graph/beta/graph.py", line 750, in _run_tracked_task
    result = await self._run_task(t_)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anton/code/pydantic-ai/pydantic_graph/pydantic_graph/beta/graph.py", line 779, in _run_task
    output = await node.call(step_context)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anton/code/pydantic-ai/pydantic_graph/pydantic_graph/beta/step.py", line 253, in _call_node
    return await node.run(GraphRunContext(state=ctx.state, deps=ctx.deps))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anton/code/pydantic-ai/pydantic_ai_slim/pydantic_ai/_agent_graph.py", line 576, in run
    async with self.stream(ctx):
               ^^^^^^^^^^^^^^^^
  File "/Users/anton/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/python3.12/contextlib.py", line 217, in __aexit__
    await anext(self.gen)
  File "/Users/anton/code/pydantic-ai/pydantic_ai_slim/pydantic_ai/_agent_graph.py", line 590, in stream
    async for _event in stream:
  File "/Users/anton/code/pydantic-ai/pydantic_ai_slim/pydantic_ai/_agent_graph.py", line 716, in _run_stream
    async for event in self._events_iterator:
  File "/Users/anton/code/pydantic-ai/pydantic_ai_slim/pydantic_ai/_agent_graph.py", line 677, in _run_stream
    async for event in self._handle_tool_calls(ctx, tool_calls):
  File "/Users/anton/code/pydantic-ai/pydantic_ai_slim/pydantic_ai/_agent_graph.py", line 732, in _handle_tool_calls
    async for event in process_tool_calls(
  File "/Users/anton/code/pydantic-ai/pydantic_ai_slim/pydantic_ai/_agent_graph.py", line 925, in process_tool_calls
    ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings)
  File "/Users/anton/code/pydantic-ai/pydantic_ai_slim/pydantic_ai/_agent_graph.py", line 127, in increment_retries
    raise exceptions.UnexpectedModelBehavior(message)
pydantic_ai.exceptions.UnexpectedModelBehavior: Exceeded maximum retries (1) for output validation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, that suggests that the model was not calling it correctly (wrong args possibly). I suggest adding https://ai.pydantic.dev/logfire/ so you can easily see what's happening behind the scenes in an agent run.



class _SearchToolArgs(TypedDict):
regex: str


def _search_tool_def() -> ToolDefinition:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check out Tool.from_schema and the Tool constructor that takes a function (as used by FunctionToolset) for easier ways to construct a single tool. The function approach is the easiest by far

return ToolDefinition(
name=_SEARCH_TOOL_NAME,
description="""Search and load additional tools to make them available to the agent.

DO call this to find and load more tools needed for a task.
NEVER ask the user if you should try loading tools, just try.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I see you explained below that this was needed to pass the tests, even for Sonnet 4.5, but tokens are expensive so it'll be worth another iteration on this.

""",
parameters_json_schema={
'type': 'object',
'properties': {
'regex': {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like pattern slightly better as an argument name, as we may at some point support different ones. Although it is very helpful to the model in knowing what to put here, in case we remove/shorted the description.

'type': 'string',
'description': 'Regex pattern to search for relevant tools',
}
},
'required': ['regex'],
},
)


def _search_tool_validator() -> SchemaValidatorProt:
return TypeAdapter(_SearchToolArgs).validator


@dataclass
class _SearchTool(ToolsetTool[AgentDepsT]):
"""A tool that searches for more relevant tools from a SearchableToolSet."""

tool_def: ToolDefinition = field(default_factory=_search_tool_def)
args_validator: SchemaValidatorProt = field(default_factory=_search_tool_validator)


@dataclass
class SearchableToolset(AbstractToolset[AgentDepsT]):
"""A toolset that implements tool search and deferred tool loading."""

toolset: AbstractToolset[AgentDepsT]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a look at WrapperToolset which already handles this + properly forwards __aexit__ and __aenter__!

_active_tool_names: set[str] = field(default_factory=set)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that this has instance variables means that it can't be reused across multiple agent runs, even though the same instance is registered to an agent just once... We had a similar issue with DynamicToolset, I suggest having a look at how we handle it there


@property
def id(self) -> str | None:
return None # pragma: no cover

@property
def label(self) -> str:
return f'{self.__class__.__name__}({self.toolset.label})' # pragma: no cover

async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
logging.debug("SearchableToolset.get_tools")
all_tools: dict[str, ToolsetTool[AgentDepsT]] = {}
all_tools[_SEARCH_TOOL_NAME] = _SearchTool(
toolset=self,
max_retries=1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to increase this, to give the model a few chances to fix its regex, if it submitted an invalid one the first time

)

toolset_tools = await self.toolset.get_tools(ctx)
for tool_name, tool in toolset_tools.items():
# TODO proper error handling
assert tool_name != _SEARCH_TOOL_NAME

if tool_name in self._active_tool_names:
all_tools[tool_name] = tool

logging.debug(f"SearchableToolset.get_tools ==> {[t for t in all_tools]}")
return all_tools

async def call_tool(
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
) -> Any:
if isinstance(tool, _SearchTool):
adapter = TypeAdapter(_SearchToolArgs)
typed_args = adapter.validate_python(tool_args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguments will/should already have been validated by this point when used through ToolManager/Agent!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, was not obvious from the types, but sounds like I can just cast this. Thanks.

result = await self.call_search_tool(typed_args, ctx)
logging.debug(f"SearchableToolset.call_tool({name}, {tool_args}) ==> {result}")
return result
else:
result = await self.toolset.call_tool(name, tool_args, ctx, tool)
logging.debug(f"SearchableToolset.call_tool({name}, {tool_args}) ==> {result}")
return result

async def call_search_tool(self, args: _SearchToolArgs, ctx: RunContext[AgentDepsT]) -> list[str]:
"""Searches for tools matching the query, activates them and returns their names."""
toolset_tools = await self.toolset.get_tools(ctx)
matching_tool_names: list[str] = []

for tool_name, tool in toolset_tools.items():
rx = re.compile(args['regex'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This'll be more efficient one line up :)

if rx.search(tool.tool_def.name) or rx.search(tool.tool_def.description):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For error handling, check out the ModelRetry exception

matching_tool_names.append(tool.tool_def.name)

self._active_tool_names.update(matching_tool_names)
return matching_tool_names

def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None:
self.toolset.apply(visitor)

def visit_and_replace(
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
) -> AbstractToolset[AgentDepsT]:
return replace(self, toolset=self.toolset.visit_and_replace(visitor))

async def __aenter__(self) -> Self:
await self.toolset.__aenter__()
return self

async def __aexit__(self, *args: Any) -> bool | None:
return await self.toolset.__aexit__(*args)
136 changes: 136 additions & 0 deletions test_searchable_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Minimal example to test SearchableToolset functionality.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like proper tests need to go into:

  • test_toolsets.py has space for unit tests
  • somewhere there are VCR cassettes that record an interaction with an LLM could be useful here

I just wanted to get something quick to iterate with an actual LLM. This ended up working on Claude but took a few iterations on the prompt. The model seemed sensitive to how the "search tool" is called and the content of the description - it would either refuse to load it or start asking for user confirmation before loading it. It took some tweaking to get the current description to pass this simple test.

❯ uv run python test_searchable_example.py
============================================================
Testing SearchableToolset
============================================================

Test 1: Calculation task
------------------------------------------------------------
2025-12-11 07:20:48,189 - root - DEBUG - SearchableToolset.get_tools
2025-12-11 07:20:48,189 - root - DEBUG - SearchableToolset.get_tools ==> ['load_tools']
Result: I can calculate that for you directly.

123 multiplied by 456 equals **56,088**.


Test 2: Database task
------------------------------------------------------------
2025-12-11 07:20:50,983 - root - DEBUG - SearchableToolset.get_tools
2025-12-11 07:20:50,984 - root - DEBUG - SearchableToolset.get_tools ==> ['load_tools']
2025-12-11 07:20:54,254 - root - DEBUG - SearchableToolset.call_tool(load_tools, {'regex': 'database|sql|table|query'}) ==> ['fetch_user_data', 'list_database_tables']
2025-12-11 07:20:54,255 - root - DEBUG - SearchableToolset.get_tools
2025-12-11 07:20:54,255 - root - DEBUG - SearchableToolset.get_tools ==> ['load_tools', 'fetch_user_data', 'list_database_tables']
2025-12-11 07:20:57,735 - root - DEBUG - SearchableToolset.call_tool(list_database_tables, {}) ==> ['users', 'orders', 'products', 'reviews']
2025-12-11 07:20:57,735 - root - DEBUG - SearchableToolset.call_tool(fetch_user_data, {'user_id': 42}) ==> {'id': 42, 'name': 'John Doe', 'email': 'john@example.com'}
2025-12-11 07:20:57,735 - root - DEBUG - SearchableToolset.get_tools
2025-12-11 07:20:57,736 - root - DEBUG - SearchableToolset.get_tools ==> ['load_tools', 'fetch_user_data', 'list_database_tables']
Result: Perfect! Here are the results:

**Database Tables:**
- users
- orders
- products
- reviews

**User 42 Data:**
- ID: 42
- Name: John Doe
- Email: john@example.com


Test 3: Weather task
------------------------------------------------------------
2025-12-11 07:21:00,605 - root - DEBUG - SearchableToolset.get_tools
2025-12-11 07:21:00,607 - root - DEBUG - SearchableToolset.get_tools ==> ['load_tools', 'fetch_user_data', 'list_database_tables']
2025-12-11 07:21:04,597 - root - DEBUG - SearchableToolset.call_tool(load_tools, {'regex': 'weather'}) ==> ['get_weather']
2025-12-11 07:21:04,598 - root - DEBUG - SearchableToolset.get_tools
2025-12-11 07:21:04,599 - root - DEBUG - SearchableToolset.get_tools ==> ['load_tools', 'get_weather', 'fetch_user_data', 'list_database_tables']
2025-12-11 07:21:07,769 - root - DEBUG - SearchableToolset.call_tool(get_weather, {'city': 'San Francisco'}) ==> The weather in San Francisco is sunny and 72°F
2025-12-11 07:21:07,770 - root - DEBUG - SearchableToolset.get_tools
2025-12-11 07:21:07,771 - root - DEBUG - SearchableToolset.get_tools ==> ['load_tools', 'get_weather', 'fetch_user_data', 'list_database_tables']
Result: The weather in San Francisco is currently sunny and 72°F - a beautiful day!


Run with: uv run python test_searchable_example.py
Make sure you have ANTHROPIC_API_KEY set in your environment.
"""

import asyncio
import logging
import sys

# Configure logging to print to stdout
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout
)

# Silence noisy loggers
logging.getLogger('asyncio').setLevel(logging.WARNING)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('httpcore.connection').setLevel(logging.WARNING)
logging.getLogger('httpcore.http11').setLevel(logging.WARNING)
logging.getLogger('anthropic._base_client').setLevel(logging.WARNING)

from pydantic_ai import Agent
from pydantic_ai.toolsets import FunctionToolset
from pydantic_ai.toolsets.searchable import SearchableToolset


# Create a toolset with various tools
toolset = FunctionToolset()


@toolset.tool
def get_weather(city: str) -> str:
"""Get the current weather for a given city.

Args:
city: The name of the city to get weather for.
"""
return f"The weather in {city} is sunny and 72°F"


@toolset.tool
def calculate_sum(a: float, b: float) -> float:
"""Add two numbers together.

Args:
a: The first number.
b: The second number.
"""
return a + b


@toolset.tool
def calculate_product(a: float, b: float) -> float:
"""Multiply two numbers together.

Args:
a: The first number.
b: The second number.
"""
return a * b


@toolset.tool
def fetch_user_data(user_id: int) -> dict:
"""Fetch user data from the database.

Args:
user_id: The ID of the user to fetch.
"""
return {"id": user_id, "name": "John Doe", "email": "john@example.com"}


@toolset.tool
def send_email(recipient: str, subject: str, body: str) -> str:
"""Send an email to a recipient.

Args:
recipient: The email address of the recipient.
subject: The subject line of the email.
body: The body content of the email.
"""
return f"Email sent to {recipient} with subject '{subject}'"


@toolset.tool
def list_database_tables() -> list[str]:
"""List all tables in the database."""
return ["users", "orders", "products", "reviews"]


# Wrap the toolset with SearchableToolset
searchable_toolset = SearchableToolset(toolset=toolset)

# Create an agent with the searchable toolset
agent = Agent(
'anthropic:claude-sonnet-4-5',
toolsets=[searchable_toolset],
system_prompt=(
"You are a helpful assistant."
),
)


async def main():
print("=" * 60)
print("Testing SearchableToolset")
print("=" * 60)
print()

# Test 1: Ask something that requires searching for calculation tools
print("Test 1: Calculation task")
print("-" * 60)
result = await agent.run("What is 123 multiplied by 456?")
print(f"Result: {result.output}")
print()

# Test 2: Ask something that requires searching for database tools
print("\nTest 2: Database task")
print("-" * 60)
result = await agent.run("Can you list the database tables and then fetch user 42?")
print(f"Result: {result.output}")
print()

# Test 3: Ask something that requires weather tool
print("\nTest 3: Weather task")
print("-" * 60)
result = await agent.run("What's the weather like in San Francisco?")
print(f"Result: {result.output}")
print()


if __name__ == "__main__":
asyncio.run(main())