Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8d52d65
Clarify usage of agent factories
dsfaccini Nov 28, 2025
3128b4a
Merge branch 'pydantic:main' into main
dsfaccini Nov 30, 2025
6d942f5
implement tool choice resolution per model
dsfaccini Dec 1, 2025
0585347
- centralize logic in utility and add tests for all providers
dsfaccini Dec 2, 2025
96681ac
coverage?
dsfaccini Dec 2, 2025
e71dc86
coverage
dsfaccini Dec 2, 2025
5c387fd
imrpove tests
dsfaccini Dec 2, 2025
4dcfbe4
Merge branch 'main' into tool-choice
dsfaccini Dec 4, 2025
363c718
improvde code quality
dsfaccini Dec 5, 2025
31bb4e1
deduplicate openai logic
dsfaccini Dec 5, 2025
338a073
remove cast
dsfaccini Dec 5, 2025
07fcb6b
re-run existent cassettes and record new ones for new tool choice tests
dsfaccini Dec 8, 2025
51cada5
fix snapshots
dsfaccini Dec 8, 2025
6597e0b
fix tests
dsfaccini Dec 9, 2025
914748f
Merge branch 'main' into tool-choice
dsfaccini Dec 9, 2025
57ff6bf
upgrade to newer models
dsfaccini Dec 9, 2025
8a24d41
Merge upstream/main into tool-choice
dsfaccini Dec 9, 2025
1a46a7b
support tool choice callable to force tools on first request or arbit…
dsfaccini Dec 9, 2025
5e4cfb7
Merge branch 'main' into tool-choice
dsfaccini Dec 9, 2025
70dc917
skip tests
dsfaccini Dec 9, 2025
f924378
fix: skip lint/test for docstring examples with RunContext
dsfaccini Dec 9, 2025
80ece45
Merge branch 'main' into tool-choice
dsfaccini Dec 10, 2025
24962c0
add note about serialization obligation
dsfaccini Dec 9, 2025
dc97d4e
revert: remove callable tool_choice and force_first_request
dsfaccini Dec 9, 2025
88884f5
fix: align tool_choice tests with warning strategy and resolve merge …
dsfaccini Dec 10, 2025
b55bac9
covergae
dsfaccini Dec 11, 2025
50c9db1
simplify branches
dsfaccini Dec 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,90 @@ def prompted_output_instructions(self) -> str | None:
__repr__ = _utils.dataclasses_no_defaults_repr


@dataclass
class _ResolvedToolChoice:
"""Provider-agnostic resolved tool choice.

This is the result of validating and resolving the user's `tool_choice` setting.
Providers should map this to their API-specific format.
"""

mode: Literal['none', 'auto', 'required', 'specific']
"""The resolved tool choice mode."""

tool_names: list[str] = field(default_factory=list)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this class, we could pass around the Literal['none', 'required', 'auto'] | list[str] value that the user provides, and have filter_tools be a helper method right? I'm not sure a separate type is warranted here

"""For 'specific' mode, the list of tool names to force. Empty for other modes."""

def filter_tools(
self,
function_tools: list[ToolDefinition],
output_tools: list[ToolDefinition],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

People may also want to use tool_choice which builtin tools, which I know is supported by at least OpenAI with {type: 'web_search_tool'}. Can you check if Anthropic or Google also let you require a builtin tool to be used? Because in that case we should probably support type[AbstractBuiltinTool] in the list[str] as well

) -> list[ToolDefinition]:
"""Filter tools based on the resolved mode.

- 'none': only output_tools
- 'required': only function_tools
- 'specific': specified function_tools + output_tools
- 'auto': all tools
"""
if self.mode == 'none':
return list(output_tools)
elif self.mode == 'required':
return list(function_tools)
elif self.mode == 'specific':
allowed = set(self.tool_names)
return [t for t in function_tools if t.name in allowed] + list(output_tools)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should automatically include all output tools in this case, because the user is saying that they require one of the named tools to be called, so the model shouldn't be able to skip that and go straight for output.

That is, unless the user intentionally included an output tool name, which I don't see a strong reason not to support, so let's.

else: # 'auto'
return [*function_tools, *output_tools]


def _resolve_tool_choice( # pyright: ignore[reportUnusedFunction]
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> _ResolvedToolChoice | None:
"""Resolve and validate tool_choice from model settings.

This centralizes the common logic for handling tool_choice across all providers:
- Validates tool names in list[str] against available function_tools
- Returns a provider-agnostic _ResolvedToolChoice for the provider to map to their API format

Args:
model_settings: The model settings containing tool_choice.
model_request_parameters: The request parameters containing tool definitions.

Returns:
_ResolvedToolChoice if an explicit tool_choice was provided and validated,
None if tool_choice was not set (provider should use default behavior based on allow_text_output).

Raises:
UserError: If tool names in list[str] are invalid.
"""
user_tool_choice = (model_settings or {}).get('tool_choice')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method only uses this field of model_settings, so let's just pass in the tool choice value instead.


if user_tool_choice is None:
return None

if user_tool_choice == 'none':
return _ResolvedToolChoice(mode='none')

if user_tool_choice in ('auto', 'required'):
return _ResolvedToolChoice(mode=user_tool_choice)

if isinstance(user_tool_choice, list):
if not user_tool_choice:
return _ResolvedToolChoice(mode='none')
function_tool_names = {t.name for t in model_request_parameters.function_tools}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above; let's support output tool names here as well

invalid_names = set(user_tool_choice) - function_tool_names
if invalid_names:
raise UserError(
f'Invalid tool names in `tool_choice`: {invalid_names}. '
f'Available function tools: {function_tool_names or "none"}'
)
return _ResolvedToolChoice(mode='specific', tool_names=list(user_tool_choice))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, I don't think we really need the _ResolvedToolChoice type, so then this becomes a method that just validates the specified tool names in case of a list.


return None # pragma: no cover


class Model(ABC):
"""Abstract class for a model."""

Expand Down
123 changes: 89 additions & 34 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations as _annotations

import io
import warnings
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field, replace
Expand Down Expand Up @@ -42,7 +43,15 @@
from ..providers.anthropic import AsyncAnthropicClient
from ..settings import ModelSettings, merge_model_settings
from ..tools import ToolDefinition
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
from . import (
Model,
ModelRequestParameters,
StreamedResponse,
_resolve_tool_choice, # pyright: ignore[reportPrivateUsage]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I suppose this should be public since people may have custom model classes that need to use it

check_allow_model_requests,
download_item,
get_user_agent,
)

_FINISH_REASON_MAP: dict[BetaStopReason, FinishReason] = {
'end_turn': 'stop',
Expand Down Expand Up @@ -386,11 +395,9 @@ async def _messages_create(
This is the last step before sending the request to the API.
Most preprocessing has happened in `prepare_request()`.
"""
tools = self._get_tools(model_request_parameters, model_settings)
tools, tool_choice = self._infer_tool_choice(model_settings, model_request_parameters)
tools, mcp_servers, builtin_tool_betas = self._add_builtin_tools(tools, model_request_parameters)

tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters)

system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings)
self._limit_cache_points(system_prompt, anthropic_messages, tools)
output_format = self._native_output_format(model_request_parameters)
Expand Down Expand Up @@ -474,11 +481,9 @@ async def _messages_count_tokens(
raise UserError('AsyncAnthropicBedrock client does not support `count_tokens` api.')

# standalone function to make it easier to override
tools = self._get_tools(model_request_parameters, model_settings)
tools, tool_choice = self._infer_tool_choice(model_settings, model_request_parameters)
tools, mcp_servers, builtin_tool_betas = self._add_builtin_tools(tools, model_request_parameters)

tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters)

system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings)
self._limit_cache_points(system_prompt, anthropic_messages, tools)
output_format = self._native_output_format(model_request_parameters)
Expand Down Expand Up @@ -584,22 +589,6 @@ async def _process_streamed_response(
_provider_url=self._provider.base_url,
)

def _get_tools(
self, model_request_parameters: ModelRequestParameters, model_settings: AnthropicModelSettings
) -> list[BetaToolUnionParam]:
tools: list[BetaToolUnionParam] = [
self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()
]

# Add cache_control to the last tool if enabled
if tools and (cache_tool_defs := model_settings.get('anthropic_cache_tool_definitions')):
# If True, use '5m'; otherwise use the specified ttl value
ttl: Literal['5m', '1h'] = '5m' if cache_tool_defs is True else cache_tool_defs
last_tool = tools[-1]
last_tool['cache_control'] = self._build_cache_control(ttl)

return tools

def _add_builtin_tools(
self, tools: list[BetaToolUnionParam], model_request_parameters: ModelRequestParameters
) -> tuple[list[BetaToolUnionParam], list[BetaRequestMCPServerURLDefinitionParam], set[str]]:
Expand Down Expand Up @@ -663,26 +652,91 @@ def _add_builtin_tools(
)
return tools, mcp_servers, beta_features

def _infer_tool_choice(
def _infer_tool_choice( # noqa: C901
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that this method also returns the tools, it needs a new name for sure

self,
tools: list[BetaToolUnionParam],
model_settings: AnthropicModelSettings,
model_request_parameters: ModelRequestParameters,
) -> BetaToolChoiceParam | None:
if not tools:
return None
) -> tuple[list[BetaToolUnionParam], BetaToolChoiceParam | None]:
"""Determine which tools to send and the API tool_choice value.

Returns:
A tuple of (filtered_tools, tool_choice).
"""
thinking_enabled = model_settings.get('anthropic_thinking') is not None
function_tools = model_request_parameters.function_tools
output_tools = model_request_parameters.output_tools

resolved = _resolve_tool_choice(model_settings, model_request_parameters)

if resolved is None:
tool_defs_to_send = [*function_tools, *output_tools]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same thing that resolved.filter_tools returns for 'auto', so we can likely dedupe this by treating None and 'auto' the same at another level

else:
tool_choice: BetaToolChoiceParam
tool_defs_to_send = resolved.filter_tools(function_tools, output_tools)

# Map ToolDefinitions to Anthropic format
tools: list[BetaToolUnionParam] = [self._map_tool_definition(t) for t in tool_defs_to_send]

# Add cache_control to the last tool if enabled
if tools and (cache_tool_defs := model_settings.get('anthropic_cache_tool_definitions')):
ttl: Literal['5m', '1h'] = '5m' if cache_tool_defs is True else cache_tool_defs
last_tool = tools[-1]
last_tool['cache_control'] = self._build_cache_control(ttl)

if not tools:
return tools, None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move this immediately after tool_defs_to_send right?


tool_choice: BetaToolChoiceParam

if resolved is None:
if not model_request_parameters.allow_text_output:
tool_choice = {'type': 'any'}
else:
tool_choice = {'type': 'auto'}

if 'parallel_tool_calls' in model_settings:
tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls']
elif resolved.mode == 'auto':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch and the one above are identical :)

if not model_request_parameters.allow_text_output:
tool_choice = {'type': 'any'}
else:
tool_choice = {'type': 'auto'}

elif resolved.mode == 'required':
if thinking_enabled:
raise UserError(
"tool_choice='required' is not supported with Anthropic thinking mode. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All other user errors like this take the format of [Provider] does not support...; can we follow that here (and below) as well?

'Use `output_type=NativeOutput(...)` or `PromptedOutput(...)` instead.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing, how would changing the output type help when using tool_choice=required? Is there another recommendation we could make?

Also, let's double check if this thinking + {'type': 'any'} restriction still exists; I just did a (Gemini assisted) Google search and it told me this may have been lifted recently.

)
tool_choice = {'type': 'any'}

elif resolved.mode == 'none':
if len(output_tools) == 1:
tool_choice = {'type': 'tool', 'name': output_tools[0].name}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, when there's a specific tool that needs to be called, we're sending this arg AND filtering the tools, but it'd be much better not to do the filtering because it breaks the cache.

So we should only filter if we absolutely have to: if there are multiple specific tools to be called, and the API doesn't have a "require one of multiple tools" feature.

else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if there are 0 output tools? Then I think we should send {type: 'none'}, so that the model is forced to generate text. And we should assert model_request_parameters.allow_text_output just to be sure (I don't think it's possible to get into a situation with 0 output tools and NOT allowing text)

warnings.warn(
"Anthropic only supports forcing a single tool. Falling back to 'auto' for multiple output tools."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we limit the tool defs we send + tool_choice = {'type': 'any'}?

)
tool_choice = {'type': 'auto'}

elif resolved.mode == 'specific':
if thinking_enabled:
raise UserError(
'Forcing specific tools is not supported with Anthropic thinking mode. '
'Use `output_type=NativeOutput(...)` or `PromptedOutput(...)` instead.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above; this is a confusing recommendation

)
if len(resolved.tool_names) == 1:
tool_choice = {'type': 'tool', 'name': resolved.tool_names[0]}
else:
warnings.warn(
"Anthropic only supports forcing a single tool. Falling back to 'any' for multiple specific tools."
)
tool_choice = {'type': 'any'}

else:
assert_never(resolved.mode)

if 'parallel_tool_calls' in model_settings:
tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls']

return tool_choice
return tools, tool_choice

async def _map_message( # noqa: C901
self,
Expand Down Expand Up @@ -887,9 +941,10 @@ async def _map_message( # noqa: C901
system_prompt_parts.insert(0, instructions)
system_prompt = '\n\n'.join(system_prompt_parts)

ttl: Literal['5m', '1h']
# Add cache_control to the last message content if anthropic_cache_messages is enabled
if anthropic_messages and (cache_messages := model_settings.get('anthropic_cache_messages')):
ttl: Literal['5m', '1h'] = '5m' if cache_messages is True else cache_messages
ttl = '5m' if cache_messages is True else cache_messages
m = anthropic_messages[-1]
content = m['content']
if isinstance(content, str):
Expand All @@ -909,7 +964,7 @@ async def _map_message( # noqa: C901
# If anthropic_cache_instructions is enabled, return system prompt as a list with cache_control
if system_prompt and (cache_instructions := model_settings.get('anthropic_cache_instructions')):
# If True, use '5m'; otherwise use the specified ttl value
ttl: Literal['5m', '1h'] = '5m' if cache_instructions is True else cache_instructions
ttl = '5m' if cache_instructions is True else cache_instructions
system_prompt_blocks = [
BetaTextBlockParam(
type='text',
Expand Down
65 changes: 55 additions & 10 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools
import typing
import warnings
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
Expand Down Expand Up @@ -41,7 +42,13 @@
)
from pydantic_ai._run_context import RunContext
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UserError
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item
from pydantic_ai.models import (
Model,
ModelRequestParameters,
StreamedResponse,
_resolve_tool_choice, # pyright: ignore[reportPrivateUsage]
download_item,
)
from pydantic_ai.providers import Provider, infer_provider
from pydantic_ai.providers.bedrock import BEDROCK_GEO_PREFIXES, BedrockModelProfile
from pydantic_ai.settings import ModelSettings
Expand Down Expand Up @@ -254,9 +261,6 @@ def system(self) -> str:
"""The model provider."""
return self._provider.name

def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]

@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef:
tool_spec: ToolSpecificationTypeDef = {'name': f.name, 'inputSchema': {'json': f.parameters_json_schema}}
Expand Down Expand Up @@ -422,7 +426,7 @@ async def _messages_create(
'inferenceConfig': inference_config,
}

tool_config = self._map_tool_config(model_request_parameters)
tool_config = self._map_tool_config(model_request_parameters, model_settings)
if tool_config:
params['toolConfig'] = tool_config

Expand Down Expand Up @@ -478,17 +482,58 @@ def _map_inference_config(

return inference_config

def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None:
tools = self._get_tools(model_request_parameters)
if not tools:
def _map_tool_config(
self,
model_request_parameters: ModelRequestParameters,
model_settings: BedrockModelSettings | None,
) -> ToolConfigurationTypeDef | None:
resolved = _resolve_tool_choice(model_settings, model_request_parameters)
function_tools = model_request_parameters.function_tools
output_tools = model_request_parameters.output_tools

if resolved is None:
tool_defs_to_send = [*function_tools, *output_tools]
else:
tool_defs_to_send = resolved.filter_tools(function_tools, output_tools)

if not tool_defs_to_send:
return None

tools = [self._map_tool_definition(t) for t in tool_defs_to_send]
tool_choice: ToolChoiceTypeDef
if not model_request_parameters.allow_text_output:

if resolved is None:
# Default behavior: infer from allow_text_output
if not model_request_parameters.allow_text_output:
tool_choice = {'any': {}}
else:
tool_choice = {'auto': {}}

elif resolved.mode == 'auto':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above; branches are duplicated and could be merged

if not model_request_parameters.allow_text_output:
tool_choice = {'any': {}}
else:
tool_choice = {'auto': {}}

elif resolved.mode == 'required':
tool_choice = {'any': {}}
else:

elif resolved.mode == 'none':
# We've already filtered to only output tools, use 'auto' to let model choose
tool_choice = {'auto': {}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not model_request_parameters.allow_text_output:, shouldn't we send tool_choice = {'any': {}} as above, so we force an output tool to be used? 'auto' would allow text output which we don't want


elif resolved.mode == 'specific':
if not resolved.tool_names: # pragma: no cover
raise RuntimeError('Internal error: resolved.tool_names is empty for specific tool choice.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an assert as it should not be possible to get into this situation, or better: if we just use the list[str] as the direct value we pass around, we don't have this situation with 2 vars that could be in disagreement

if len(resolved.tool_names) == 1:
tool_choice = {'tool': {'name': resolved.tool_names[0]}}
else:
warnings.warn("Bedrock only supports forcing a single tool. Falling back to 'any'.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't we implement tool filtering exactly so we don't have to do this anymore? 😄

tool_choice = {'any': {}}

else:
assert_never(resolved.mode)

tool_config: ToolConfigurationTypeDef = {'tools': tools}
if tool_choice and BedrockModelProfile.from_profile(self.profile).bedrock_supports_tool_choice:
tool_config['toolChoice'] = tool_choice
Expand Down
Loading