-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add tool_choice setting
#3611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add tool_choice setting
#3611
Changes from all commits
8d52d65
3128b4a
6d942f5
0585347
96681ac
e71dc86
5c387fd
4dcfbe4
363c718
31bb4e1
338a073
07fcb6b
51cada5
6597e0b
914748f
57ff6bf
8a24d41
1a46a7b
5e4cfb7
70dc917
f924378
80ece45
24962c0
dc97d4e
88884f5
b55bac9
50c9db1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -545,6 +545,90 @@ def prompted_output_instructions(self) -> str | None: | |
| __repr__ = _utils.dataclasses_no_defaults_repr | ||
|
|
||
|
|
||
| @dataclass | ||
| class _ResolvedToolChoice: | ||
| """Provider-agnostic resolved tool choice. | ||
|
|
||
| This is the result of validating and resolving the user's `tool_choice` setting. | ||
| Providers should map this to their API-specific format. | ||
| """ | ||
|
|
||
| mode: Literal['none', 'auto', 'required', 'specific'] | ||
| """The resolved tool choice mode.""" | ||
|
|
||
| tool_names: list[str] = field(default_factory=list) | ||
| """For 'specific' mode, the list of tool names to force. Empty for other modes.""" | ||
|
|
||
| def filter_tools( | ||
| self, | ||
| function_tools: list[ToolDefinition], | ||
| output_tools: list[ToolDefinition], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. People may also want to use |
||
| ) -> list[ToolDefinition]: | ||
| """Filter tools based on the resolved mode. | ||
|
|
||
| - 'none': only output_tools | ||
| - 'required': only function_tools | ||
| - 'specific': specified function_tools + output_tools | ||
| - 'auto': all tools | ||
| """ | ||
| if self.mode == 'none': | ||
| return list(output_tools) | ||
| elif self.mode == 'required': | ||
| return list(function_tools) | ||
| elif self.mode == 'specific': | ||
| allowed = set(self.tool_names) | ||
| return [t for t in function_tools if t.name in allowed] + list(output_tools) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we should automatically include all output tools in this case, because the user is saying that they require one of the named tools to be called, so the model shouldn't be able to skip that and go straight for output. That is, unless the user intentionally included an output tool name, which I don't see a strong reason not to support, so let's. |
||
| else: # 'auto' | ||
| return [*function_tools, *output_tools] | ||
|
|
||
|
|
||
| def _resolve_tool_choice( # pyright: ignore[reportUnusedFunction] | ||
| model_settings: ModelSettings | None, | ||
| model_request_parameters: ModelRequestParameters, | ||
| ) -> _ResolvedToolChoice | None: | ||
| """Resolve and validate tool_choice from model settings. | ||
|
|
||
| This centralizes the common logic for handling tool_choice across all providers: | ||
| - Validates tool names in list[str] against available function_tools | ||
| - Returns a provider-agnostic _ResolvedToolChoice for the provider to map to their API format | ||
|
|
||
| Args: | ||
| model_settings: The model settings containing tool_choice. | ||
| model_request_parameters: The request parameters containing tool definitions. | ||
|
|
||
| Returns: | ||
| _ResolvedToolChoice if an explicit tool_choice was provided and validated, | ||
| None if tool_choice was not set (provider should use default behavior based on allow_text_output). | ||
|
|
||
| Raises: | ||
| UserError: If tool names in list[str] are invalid. | ||
| """ | ||
| user_tool_choice = (model_settings or {}).get('tool_choice') | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The method only uses this field of |
||
|
|
||
| if user_tool_choice is None: | ||
| return None | ||
|
|
||
| if user_tool_choice == 'none': | ||
| return _ResolvedToolChoice(mode='none') | ||
|
|
||
| if user_tool_choice in ('auto', 'required'): | ||
| return _ResolvedToolChoice(mode=user_tool_choice) | ||
|
|
||
| if isinstance(user_tool_choice, list): | ||
| if not user_tool_choice: | ||
| return _ResolvedToolChoice(mode='none') | ||
| function_tool_names = {t.name for t in model_request_parameters.function_tools} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above; let's support output tool names here as well |
||
| invalid_names = set(user_tool_choice) - function_tool_names | ||
| if invalid_names: | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| raise UserError( | ||
| f'Invalid tool names in `tool_choice`: {invalid_names}. ' | ||
| f'Available function tools: {function_tool_names or "none"}' | ||
| ) | ||
| return _ResolvedToolChoice(mode='specific', tool_names=list(user_tool_choice)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned above, I don't think we really need the |
||
|
|
||
| return None # pragma: no cover | ||
|
|
||
|
|
||
| class Model(ABC): | ||
| """Abstract class for a model.""" | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| from __future__ import annotations as _annotations | ||
|
|
||
| import io | ||
| import warnings | ||
| from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator | ||
| from contextlib import asynccontextmanager | ||
| from dataclasses import dataclass, field, replace | ||
|
|
@@ -42,7 +43,15 @@ | |
| from ..providers.anthropic import AsyncAnthropicClient | ||
| from ..settings import ModelSettings, merge_model_settings | ||
| from ..tools import ToolDefinition | ||
| from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent | ||
| from . import ( | ||
| Model, | ||
| ModelRequestParameters, | ||
| StreamedResponse, | ||
| _resolve_tool_choice, # pyright: ignore[reportPrivateUsage] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I suppose this should be public since people may have custom model classes that need to use it |
||
| check_allow_model_requests, | ||
| download_item, | ||
| get_user_agent, | ||
| ) | ||
|
|
||
| _FINISH_REASON_MAP: dict[BetaStopReason, FinishReason] = { | ||
| 'end_turn': 'stop', | ||
|
|
@@ -386,11 +395,9 @@ async def _messages_create( | |
| This is the last step before sending the request to the API. | ||
| Most preprocessing has happened in `prepare_request()`. | ||
| """ | ||
| tools = self._get_tools(model_request_parameters, model_settings) | ||
| tools, tool_choice = self._infer_tool_choice(model_settings, model_request_parameters) | ||
| tools, mcp_servers, builtin_tool_betas = self._add_builtin_tools(tools, model_request_parameters) | ||
|
|
||
| tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters) | ||
|
|
||
| system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings) | ||
| self._limit_cache_points(system_prompt, anthropic_messages, tools) | ||
| output_format = self._native_output_format(model_request_parameters) | ||
|
|
@@ -474,11 +481,9 @@ async def _messages_count_tokens( | |
| raise UserError('AsyncAnthropicBedrock client does not support `count_tokens` api.') | ||
|
|
||
| # standalone function to make it easier to override | ||
| tools = self._get_tools(model_request_parameters, model_settings) | ||
| tools, tool_choice = self._infer_tool_choice(model_settings, model_request_parameters) | ||
| tools, mcp_servers, builtin_tool_betas = self._add_builtin_tools(tools, model_request_parameters) | ||
|
|
||
| tool_choice = self._infer_tool_choice(tools, model_settings, model_request_parameters) | ||
|
|
||
| system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters, model_settings) | ||
| self._limit_cache_points(system_prompt, anthropic_messages, tools) | ||
| output_format = self._native_output_format(model_request_parameters) | ||
|
|
@@ -584,22 +589,6 @@ async def _process_streamed_response( | |
| _provider_url=self._provider.base_url, | ||
| ) | ||
|
|
||
| def _get_tools( | ||
| self, model_request_parameters: ModelRequestParameters, model_settings: AnthropicModelSettings | ||
| ) -> list[BetaToolUnionParam]: | ||
| tools: list[BetaToolUnionParam] = [ | ||
| self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values() | ||
| ] | ||
|
|
||
| # Add cache_control to the last tool if enabled | ||
| if tools and (cache_tool_defs := model_settings.get('anthropic_cache_tool_definitions')): | ||
| # If True, use '5m'; otherwise use the specified ttl value | ||
| ttl: Literal['5m', '1h'] = '5m' if cache_tool_defs is True else cache_tool_defs | ||
| last_tool = tools[-1] | ||
| last_tool['cache_control'] = self._build_cache_control(ttl) | ||
|
|
||
| return tools | ||
|
|
||
| def _add_builtin_tools( | ||
| self, tools: list[BetaToolUnionParam], model_request_parameters: ModelRequestParameters | ||
| ) -> tuple[list[BetaToolUnionParam], list[BetaRequestMCPServerURLDefinitionParam], set[str]]: | ||
|
|
@@ -663,26 +652,91 @@ def _add_builtin_tools( | |
| ) | ||
| return tools, mcp_servers, beta_features | ||
|
|
||
| def _infer_tool_choice( | ||
| def _infer_tool_choice( # noqa: C901 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that this method also returns the tools, it needs a new name for sure |
||
| self, | ||
| tools: list[BetaToolUnionParam], | ||
| model_settings: AnthropicModelSettings, | ||
| model_request_parameters: ModelRequestParameters, | ||
| ) -> BetaToolChoiceParam | None: | ||
| if not tools: | ||
| return None | ||
| ) -> tuple[list[BetaToolUnionParam], BetaToolChoiceParam | None]: | ||
| """Determine which tools to send and the API tool_choice value. | ||
|
|
||
| Returns: | ||
| A tuple of (filtered_tools, tool_choice). | ||
| """ | ||
| thinking_enabled = model_settings.get('anthropic_thinking') is not None | ||
| function_tools = model_request_parameters.function_tools | ||
| output_tools = model_request_parameters.output_tools | ||
|
|
||
| resolved = _resolve_tool_choice(model_settings, model_request_parameters) | ||
|
|
||
| if resolved is None: | ||
| tool_defs_to_send = [*function_tools, *output_tools] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the same thing that |
||
| else: | ||
| tool_choice: BetaToolChoiceParam | ||
| tool_defs_to_send = resolved.filter_tools(function_tools, output_tools) | ||
|
|
||
| # Map ToolDefinitions to Anthropic format | ||
| tools: list[BetaToolUnionParam] = [self._map_tool_definition(t) for t in tool_defs_to_send] | ||
|
|
||
| # Add cache_control to the last tool if enabled | ||
| if tools and (cache_tool_defs := model_settings.get('anthropic_cache_tool_definitions')): | ||
| ttl: Literal['5m', '1h'] = '5m' if cache_tool_defs is True else cache_tool_defs | ||
| last_tool = tools[-1] | ||
| last_tool['cache_control'] = self._build_cache_control(ttl) | ||
|
|
||
| if not tools: | ||
| return tools, None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can move this immediately after |
||
|
|
||
| tool_choice: BetaToolChoiceParam | ||
|
|
||
| if resolved is None: | ||
| if not model_request_parameters.allow_text_output: | ||
| tool_choice = {'type': 'any'} | ||
| else: | ||
| tool_choice = {'type': 'auto'} | ||
|
|
||
| if 'parallel_tool_calls' in model_settings: | ||
| tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls'] | ||
| elif resolved.mode == 'auto': | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This branch and the one above are identical :) |
||
| if not model_request_parameters.allow_text_output: | ||
| tool_choice = {'type': 'any'} | ||
| else: | ||
| tool_choice = {'type': 'auto'} | ||
|
|
||
| elif resolved.mode == 'required': | ||
| if thinking_enabled: | ||
| raise UserError( | ||
| "tool_choice='required' is not supported with Anthropic thinking mode. " | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All other user errors like this take the format of |
||
| 'Use `output_type=NativeOutput(...)` or `PromptedOutput(...)` instead.' | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is confusing, how would changing the output type help when using Also, let's double check if this thinking + |
||
| ) | ||
| tool_choice = {'type': 'any'} | ||
|
|
||
| elif resolved.mode == 'none': | ||
| if len(output_tools) == 1: | ||
| tool_choice = {'type': 'tool', 'name': output_tools[0].name} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now, when there's a specific tool that needs to be called, we're sending this arg AND filtering the tools, but it'd be much better not to do the filtering because it breaks the cache. So we should only filter if we absolutely have to: if there are multiple specific tools to be called, and the API doesn't have a "require one of multiple tools" feature. |
||
| else: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if there are 0 output tools? Then I think we should send |
||
| warnings.warn( | ||
| "Anthropic only supports forcing a single tool. Falling back to 'auto' for multiple output tools." | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we limit the tool defs we send + |
||
| ) | ||
| tool_choice = {'type': 'auto'} | ||
|
|
||
| elif resolved.mode == 'specific': | ||
| if thinking_enabled: | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| raise UserError( | ||
| 'Forcing specific tools is not supported with Anthropic thinking mode. ' | ||
| 'Use `output_type=NativeOutput(...)` or `PromptedOutput(...)` instead.' | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above; this is a confusing recommendation |
||
| ) | ||
| if len(resolved.tool_names) == 1: | ||
| tool_choice = {'type': 'tool', 'name': resolved.tool_names[0]} | ||
| else: | ||
| warnings.warn( | ||
| "Anthropic only supports forcing a single tool. Falling back to 'any' for multiple specific tools." | ||
| ) | ||
| tool_choice = {'type': 'any'} | ||
|
|
||
| else: | ||
| assert_never(resolved.mode) | ||
|
|
||
| if 'parallel_tool_calls' in model_settings: | ||
| tool_choice['disable_parallel_tool_use'] = not model_settings['parallel_tool_calls'] | ||
|
|
||
| return tool_choice | ||
| return tools, tool_choice | ||
|
|
||
| async def _map_message( # noqa: C901 | ||
| self, | ||
|
|
@@ -887,9 +941,10 @@ async def _map_message( # noqa: C901 | |
| system_prompt_parts.insert(0, instructions) | ||
| system_prompt = '\n\n'.join(system_prompt_parts) | ||
|
|
||
| ttl: Literal['5m', '1h'] | ||
| # Add cache_control to the last message content if anthropic_cache_messages is enabled | ||
| if anthropic_messages and (cache_messages := model_settings.get('anthropic_cache_messages')): | ||
| ttl: Literal['5m', '1h'] = '5m' if cache_messages is True else cache_messages | ||
| ttl = '5m' if cache_messages is True else cache_messages | ||
| m = anthropic_messages[-1] | ||
| content = m['content'] | ||
| if isinstance(content, str): | ||
|
|
@@ -909,7 +964,7 @@ async def _map_message( # noqa: C901 | |
| # If anthropic_cache_instructions is enabled, return system prompt as a list with cache_control | ||
| if system_prompt and (cache_instructions := model_settings.get('anthropic_cache_instructions')): | ||
| # If True, use '5m'; otherwise use the specified ttl value | ||
| ttl: Literal['5m', '1h'] = '5m' if cache_instructions is True else cache_instructions | ||
| ttl = '5m' if cache_instructions is True else cache_instructions | ||
| system_prompt_blocks = [ | ||
| BetaTextBlockParam( | ||
| type='text', | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import functools | ||
| import typing | ||
| import warnings | ||
| from collections.abc import AsyncIterator, Iterable, Iterator, Mapping | ||
| from contextlib import asynccontextmanager | ||
| from dataclasses import dataclass, field | ||
|
|
@@ -41,7 +42,13 @@ | |
| ) | ||
| from pydantic_ai._run_context import RunContext | ||
| from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UserError | ||
| from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item | ||
| from pydantic_ai.models import ( | ||
| Model, | ||
| ModelRequestParameters, | ||
| StreamedResponse, | ||
| _resolve_tool_choice, # pyright: ignore[reportPrivateUsage] | ||
| download_item, | ||
| ) | ||
| from pydantic_ai.providers import Provider, infer_provider | ||
| from pydantic_ai.providers.bedrock import BEDROCK_GEO_PREFIXES, BedrockModelProfile | ||
| from pydantic_ai.settings import ModelSettings | ||
|
|
@@ -254,9 +261,6 @@ def system(self) -> str: | |
| """The model provider.""" | ||
| return self._provider.name | ||
|
|
||
| def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]: | ||
| return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] | ||
|
|
||
| @staticmethod | ||
| def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef: | ||
| tool_spec: ToolSpecificationTypeDef = {'name': f.name, 'inputSchema': {'json': f.parameters_json_schema}} | ||
|
|
@@ -422,7 +426,7 @@ async def _messages_create( | |
| 'inferenceConfig': inference_config, | ||
| } | ||
|
|
||
| tool_config = self._map_tool_config(model_request_parameters) | ||
| tool_config = self._map_tool_config(model_request_parameters, model_settings) | ||
| if tool_config: | ||
| params['toolConfig'] = tool_config | ||
|
|
||
|
|
@@ -478,17 +482,58 @@ def _map_inference_config( | |
|
|
||
| return inference_config | ||
|
|
||
| def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None: | ||
| tools = self._get_tools(model_request_parameters) | ||
| if not tools: | ||
| def _map_tool_config( | ||
| self, | ||
| model_request_parameters: ModelRequestParameters, | ||
| model_settings: BedrockModelSettings | None, | ||
| ) -> ToolConfigurationTypeDef | None: | ||
| resolved = _resolve_tool_choice(model_settings, model_request_parameters) | ||
| function_tools = model_request_parameters.function_tools | ||
| output_tools = model_request_parameters.output_tools | ||
|
|
||
| if resolved is None: | ||
| tool_defs_to_send = [*function_tools, *output_tools] | ||
| else: | ||
| tool_defs_to_send = resolved.filter_tools(function_tools, output_tools) | ||
|
|
||
| if not tool_defs_to_send: | ||
| return None | ||
|
|
||
| tools = [self._map_tool_definition(t) for t in tool_defs_to_send] | ||
| tool_choice: ToolChoiceTypeDef | ||
| if not model_request_parameters.allow_text_output: | ||
|
|
||
| if resolved is None: | ||
| # Default behavior: infer from allow_text_output | ||
| if not model_request_parameters.allow_text_output: | ||
| tool_choice = {'any': {}} | ||
| else: | ||
| tool_choice = {'auto': {}} | ||
|
|
||
| elif resolved.mode == 'auto': | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above; branches are duplicated and could be merged |
||
| if not model_request_parameters.allow_text_output: | ||
| tool_choice = {'any': {}} | ||
| else: | ||
| tool_choice = {'auto': {}} | ||
|
|
||
| elif resolved.mode == 'required': | ||
| tool_choice = {'any': {}} | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
|
|
||
| elif resolved.mode == 'none': | ||
| # We've already filtered to only output tools, use 'auto' to let model choose | ||
| tool_choice = {'auto': {}} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||
|
|
||
| elif resolved.mode == 'specific': | ||
| if not resolved.tool_names: # pragma: no cover | ||
| raise RuntimeError('Internal error: resolved.tool_names is empty for specific tool choice.') | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be an |
||
| if len(resolved.tool_names) == 1: | ||
| tool_choice = {'tool': {'name': resolved.tool_names[0]}} | ||
| else: | ||
| warnings.warn("Bedrock only supports forcing a single tool. Falling back to 'any'.") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didn't we implement tool filtering exactly so we don't have to do this anymore? 😄 |
||
| tool_choice = {'any': {}} | ||
|
|
||
| else: | ||
| assert_never(resolved.mode) | ||
|
|
||
| tool_config: ToolConfigurationTypeDef = {'tools': tools} | ||
| if tool_choice and BedrockModelProfile.from_profile(self.profile).bedrock_supports_tool_choice: | ||
| tool_config['toolChoice'] = tool_choice | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of this class, we could pass around the
Literal['none', 'required', 'auto'] | list[str]value that the user provides, and havefilter_toolsbe a helper method right? I'm not sure a separate type is warranted here