From 186160f54860b83242317ffc26f20bc163a4a040 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Wed, 1 Oct 2025 15:52:40 +0000 Subject: [PATCH 1/2] feat(models): add native OpenRouter model class --- docs/api/models/openrouter.md | 7 + docs/models/openai.md | 5 +- docs/models/openrouter.md | 94 +++ docs/models/overview.md | 2 +- mkdocs.yml | 2 + .../pydantic_ai/models/_openai_compat.py | 489 +++++++++++++++ pydantic_ai_slim/pydantic_ai/models/openai.py | 517 +++------------- .../pydantic_ai/models/openrouter.py | 242 ++++++++ tests/models/test_openrouter.py | 575 ++++++++++++++++++ tests/test_examples.py | 1 + 10 files changed, 1487 insertions(+), 447 deletions(-) create mode 100644 docs/api/models/openrouter.md create mode 100644 docs/models/openrouter.md create mode 100644 pydantic_ai_slim/pydantic_ai/models/_openai_compat.py create mode 100644 pydantic_ai_slim/pydantic_ai/models/openrouter.py create mode 100644 tests/models/test_openrouter.py diff --git a/docs/api/models/openrouter.md b/docs/api/models/openrouter.md new file mode 100644 index 0000000000..1e502dbb89 --- /dev/null +++ b/docs/api/models/openrouter.md @@ -0,0 +1,7 @@ +# `pydantic_ai.models.openrouter` + +## Setup + +For details on how to set up authentication with this model, see [model configuration for OpenRouter](../models/openrouter.md). + +::: pydantic_ai.models.openrouter diff --git a/docs/models/openai.md b/docs/models/openai.md index 12b7fd659b..108b550f89 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -352,9 +352,10 @@ agent = Agent(model) ### OpenRouter -To use [OpenRouter](https://openrouter.ai), first create an API key at [openrouter.ai/keys](https://openrouter.ai/keys). +[OpenRouter](https://openrouter.ai) now has dedicated support in PydanticAI with the [`OpenRouterModel`][pydantic_ai.models.openrouter.OpenRouterModel]. +For detailed documentation and examples, see the [OpenRouter documentation](openrouter.md). -Once you have the API key, you can use it with the [`OpenRouterProvider`][pydantic_ai.providers.openrouter.OpenRouterProvider]: +You can also still use OpenRouter through the OpenAI-compatible interface: ```python from pydantic_ai import Agent diff --git a/docs/models/openrouter.md b/docs/models/openrouter.md new file mode 100644 index 0000000000..9e5efe122e --- /dev/null +++ b/docs/models/openrouter.md @@ -0,0 +1,94 @@ +# OpenRouter + +## Install + +To use `OpenRouterModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `openrouter` optional group: + +```bash +pip/uv-add "pydantic-ai-slim[openrouter]" +``` + +## Configuration + +To use [OpenRouter](https://openrouter.ai/) through their API, go to [openrouter.ai/keys](https://openrouter.ai/keys) and follow your nose until you find the place to generate an API key. + +`OpenRouterModelName` contains a list of available OpenRouter models. + +## Environment variable + +Once you have the API key, you can set it as an environment variable: + +```bash +export OPENROUTER_API_KEY='your-api-key' +``` + +You can then use `OpenRouterModel` with the default provider: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.openrouter import OpenRouterModel + +model = OpenRouterModel('google/gemini-2.5-flash-lite') +agent = Agent(model) +... +``` + +Or initialise the model with an explicit provider: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.openrouter import OpenRouterModel +from pydantic_ai.providers.openrouter import OpenRouterProvider + +provider = OpenRouterProvider(api_key='your-api-key') +model = OpenRouterModel('google/gemini-2.5-flash-lite', provider=provider) +agent = Agent(model) +... +``` + +## Custom HTTP Client + +You can customize the HTTP client by using the `OpenRouterProvider`: + +```python +from httpx import AsyncClient + +from pydantic_ai import Agent +from pydantic_ai.models.openrouter import OpenRouterModel +from pydantic_ai.providers.openrouter import OpenRouterProvider + +custom_http_client = AsyncClient(timeout=30) +provider = OpenRouterProvider( + api_key='your-api-key', + http_client=custom_http_client, +) +model = OpenRouterModel('google/gemini-2.5-flash-lite', provider=provider) +agent = Agent(model) +... +``` + +## Structured Output + +You can use OpenRouter models with structured output by providing a Pydantic model as the `output_type`: + +```python {noqa="I001"} +from pydantic import BaseModel + +from pydantic_ai import Agent +from pydantic_ai.models.openrouter import OpenRouterModel + +class OlympicsLocation(BaseModel): + city: str + country: str + +model = OpenRouterModel('google/gemini-2.5-flash-lite') +agent = Agent(model, output_type=OlympicsLocation) + +result = agent.run_sync('Where were the olympics held in 2012?') +print(f'City: {result.output.city}') +#> City: London +print(f'Country: {result.output.country}') +#> Country: United Kingdom +``` + +The model will validate and parse the response into your specified Pydantic model, allowing type-safe access to structured data fields via `result.output.field_name`. diff --git a/docs/models/overview.md b/docs/models/overview.md index 36137db6e5..6bb78fb26a 100644 --- a/docs/models/overview.md +++ b/docs/models/overview.md @@ -10,6 +10,7 @@ Pydantic AI is model-agnostic and has built-in support for multiple model provid * [Cohere](cohere.md) * [Bedrock](bedrock.md) * [Hugging Face](huggingface.md) +* [OpenRouter](openrouter.md) ## OpenAI-compatible Providers @@ -18,7 +19,6 @@ In addition, many providers are compatible with the OpenAI API, and can be used - [DeepSeek](openai.md#deepseek) - [Grok (xAI)](openai.md#grok-xai) - [Ollama](openai.md#ollama) -- [OpenRouter](openai.md#openrouter) - [Vercel AI Gateway](openai.md#vercel-ai-gateway) - [Perplexity](openai.md#perplexity) - [Fireworks AI](openai.md#fireworks-ai) diff --git a/mkdocs.yml b/mkdocs.yml index 58c60a717d..e635e246ee 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -34,6 +34,7 @@ nav: - models/groq.md - models/mistral.md - models/huggingface.md + - models/openrouter.md - Tools & Toolsets: - tools.md - tools-advanced.md @@ -123,6 +124,7 @@ nav: - api/models/huggingface.md - api/models/instrumented.md - api/models/mistral.md + - api/models/openrouter.md - api/models/test.md - api/models/function.md - api/models/fallback.md diff --git a/pydantic_ai_slim/pydantic_ai/models/_openai_compat.py b/pydantic_ai_slim/pydantic_ai/models/_openai_compat.py new file mode 100644 index 0000000000..95cf95250c --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/models/_openai_compat.py @@ -0,0 +1,489 @@ +"""Shared OpenAI compatibility helpers (in-progress). + +This module is a working scaffold. Implementations will be ported in small, +covered steps from `_openai_compat_ref.py` to preserve coverage. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterable, AsyncIterator, Callable, Mapping +from dataclasses import dataclass, field, replace +from datetime import datetime +from typing import Any, Literal, overload + +from pydantic import ValidationError +from typing_extensions import assert_never + +from pydantic_ai.messages import ( + FinishReason, + ModelMessage, + ModelRequest, + ModelResponse, + ModelResponsePart, + ModelResponseStreamEvent, + PartStartEvent, + TextPart, + ThinkingPart, + ToolCallPart, +) + +from .. import UnexpectedModelBehavior, _utils, usage +from .._output import OutputObjectDefinition +from .._thinking_part import split_content_into_text_and_thinking +from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime +from ..exceptions import UserError +from ..profiles import ModelProfile +from ..profiles.openai import OpenAIModelProfile +from ..settings import ModelSettings +from ..tools import ToolDefinition +from . import ModelRequestParameters, StreamedResponse, get_user_agent + +try: + from openai import NOT_GIVEN, APIStatusError, AsyncStream + from openai.types import chat + from openai.types.chat import ( + ChatCompletionChunk, + ChatCompletionMessageCustomToolCall, + ChatCompletionMessageFunctionToolCall, + ) + from openai.types.chat.completion_create_params import ResponseFormat, WebSearchOptions +except ImportError as _import_error: + raise ImportError( + 'Please install `openai` to use the OpenAI model, ' + 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' + ) from _import_error + +__all__ = ( + 'OpenAICompatStreamedResponse', + 'completions_create', + 'map_messages', + 'map_tool_definition', + 'map_usage', + 'process_response', + 'process_streamed_response', +) + + +def _map_tool_call(t: ToolCallPart) -> Any: + """Map a ToolCallPart to OpenAI ChatCompletionMessageFunctionToolCallParam.""" + return { + 'id': _guard_tool_call_id(t=t), + 'type': 'function', + 'function': {'name': t.tool_name, 'arguments': t.args_as_json_str()}, + } + + +def map_tool_definition(model_profile: ModelProfile, f: ToolDefinition) -> Any: + """Map a ToolDefinition to OpenAI ChatCompletionToolParam.""" + tool_param: dict[str, Any] = { + 'type': 'function', + 'function': { + 'name': f.name, + 'description': f.description or '', + 'parameters': f.parameters_json_schema, + }, + } + if f.strict and OpenAIModelProfile.from_profile(model_profile).openai_supports_strict_tool_definition: + tool_param['function']['strict'] = f.strict + return tool_param + + +async def map_messages(model: Any, messages: list[ModelMessage]) -> list[Any]: + """Async mapping of internal ModelMessage list to OpenAI chat messages.""" + openai_messages: list[Any] = [] + for message in messages: + if isinstance(message, ModelRequest): + async for item in model._map_user_message(message): + openai_messages.append(item) + elif isinstance(message, ModelResponse): + texts: list[str] = [] + tool_calls: list[Any] = [] + for item in message.parts: + if isinstance(item, TextPart): + texts.append(item.content) + elif isinstance(item, ToolCallPart): + tool_calls.append(_map_tool_call(item)) + message_param: dict[str, Any] = {'role': 'assistant'} + if texts: + message_param['content'] = '\n\n'.join(texts) + else: + message_param['content'] = None + if tool_calls: + message_param['tool_calls'] = tool_calls + openai_messages.append(message_param) + else: + assert_never(message) + + return openai_messages + + +def get_tools(model_profile: ModelProfile, tool_defs: dict[str, ToolDefinition]) -> list[Any]: + """Get OpenAI tools from tool definitions.""" + return [map_tool_definition(model_profile, r) for r in tool_defs.values()] + + +def _map_json_schema(model_profile: ModelProfile, o: OutputObjectDefinition) -> ResponseFormat: + """Map an OutputObjectDefinition to OpenAI ResponseFormatJSONSchema.""" + response_format_param: ResponseFormat = { + 'type': 'json_schema', + 'json_schema': {'name': o.name or 'output', 'schema': o.json_schema}, + } + if o.description: + response_format_param['json_schema']['description'] = o.description + profile = OpenAIModelProfile.from_profile(model_profile) + if profile.openai_supports_strict_tool_definition: # pragma: no branch + response_format_param['json_schema']['strict'] = bool(o.strict) + return response_format_param + + +def _get_web_search_options(model_profile: ModelProfile, builtin_tools: list[Any]) -> WebSearchOptions | None: + """Extract WebSearchOptions from builtin_tools if WebSearchTool is present.""" + for tool in builtin_tools: + if tool.__class__.__name__ == 'WebSearchTool': + if not OpenAIModelProfile.from_profile(model_profile).openai_chat_supports_web_search: + raise UserError( + f'WebSearchTool is not supported with `OpenAIChatModel` and model {getattr(model_profile, "model_name", None) or ""!r}. ' + f'Please use `OpenAIResponsesModel` instead.' + ) + if tool.user_location: + from openai.types.chat.completion_create_params import ( + WebSearchOptionsUserLocation, + WebSearchOptionsUserLocationApproximate, + ) + + return WebSearchOptions( + search_context_size=tool.search_context_size, + user_location=WebSearchOptionsUserLocation( + type='approximate', + approximate=WebSearchOptionsUserLocationApproximate(**tool.user_location), + ), + ) + return WebSearchOptions(search_context_size=tool.search_context_size) + else: + raise UserError( + f'`{tool.__class__.__name__}` is not supported by `OpenAIChatModel`. If it should be, please file an issue.' + ) + return None + + +@overload +async def completions_create( + model: Any, + messages: list[ModelMessage], + stream: Literal[True], + model_settings: ModelSettings, + model_request_parameters: ModelRequestParameters, +) -> AsyncStream[ChatCompletionChunk]: ... + + +@overload +async def completions_create( + model: Any, + messages: list[ModelMessage], + stream: Literal[False], + model_settings: ModelSettings, + model_request_parameters: ModelRequestParameters, +) -> chat.ChatCompletion: ... + + +async def completions_create( + model: Any, + messages: list[ModelMessage], + stream: bool, + model_settings: ModelSettings, + model_request_parameters: ModelRequestParameters, +) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: + """Create a chat completion using OpenAI SDK with compat helpers. + + Handles tool mapping, response-format mapping, unsupported-setting pruning, + and SDK invocation with error translation. + """ + tools = get_tools(model.profile, model_request_parameters.tool_defs) + web_search_options = _get_web_search_options(model.profile, model_request_parameters.builtin_tools) + + if not tools: + tool_choice: Literal['none', 'required', 'auto'] | None = None + elif ( + not model_request_parameters.allow_text_output + and OpenAIModelProfile.from_profile(model.profile).openai_supports_tool_choice_required + ): + tool_choice = 'required' + else: + tool_choice = 'auto' + + openai_messages = await map_messages(model, messages) + + response_format: ResponseFormat | None = None + if model_request_parameters.output_mode == 'native': + output_object = model_request_parameters.output_object + assert output_object is not None + response_format = _map_json_schema(model.profile, output_object) + elif ( + model_request_parameters.output_mode == 'prompted' and model.profile.supports_json_object_output + ): # pragma: no branch + response_format = {'type': 'json_object'} + + unsupported_model_settings = OpenAIModelProfile.from_profile(model.profile).openai_unsupported_model_settings + for setting in unsupported_model_settings: + model_settings.pop(setting, None) + + try: + extra_headers = model_settings.get('extra_headers', {}) + extra_headers.setdefault('User-Agent', get_user_agent()) + return await model.client.chat.completions.create( + model=model._model_name, + messages=openai_messages, + parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN), + tools=tools or NOT_GIVEN, + tool_choice=tool_choice or NOT_GIVEN, + stream=stream, + stream_options={'include_usage': True} if stream else NOT_GIVEN, + stop=model_settings.get('stop_sequences', NOT_GIVEN), + max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN), + timeout=model_settings.get('timeout', NOT_GIVEN), + response_format=response_format or NOT_GIVEN, + seed=model_settings.get('seed', NOT_GIVEN), + reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN), + user=model_settings.get('openai_user', NOT_GIVEN), + web_search_options=web_search_options or NOT_GIVEN, + service_tier=model_settings.get('openai_service_tier', NOT_GIVEN), + prediction=model_settings.get('openai_prediction', NOT_GIVEN), + temperature=model_settings.get('temperature', NOT_GIVEN), + top_p=model_settings.get('top_p', NOT_GIVEN), + presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), + frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), + logit_bias=model_settings.get('logit_bias', NOT_GIVEN), + logprobs=model_settings.get('openai_logprobs', NOT_GIVEN), + top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN), + extra_headers=extra_headers, + extra_body=model_settings.get('extra_body'), + ) + except APIStatusError as e: + if (status_code := e.status_code) >= 400: + from .. import ModelHTTPError + + raise ModelHTTPError(status_code=status_code, model_name=model.model_name, body=e.body) from e + raise # pragma: lax no cover + + +def process_response( + model: Any, + response: chat.ChatCompletion | str, + *, + map_usage_fn: Callable[[chat.ChatCompletion], usage.RequestUsage], + finish_reason_map: Mapping[str, FinishReason], +) -> ModelResponse: + """Process a non-streamed chat completion response into a ModelResponse.""" + if not isinstance(response, chat.ChatCompletion): + raise UnexpectedModelBehavior('Invalid response from OpenAI chat completions endpoint, expected JSON data') + + if response.created: + timestamp = number_to_datetime(response.created) + else: + timestamp = _now_utc() + response.created = int(timestamp.timestamp()) + + # Workaround for local Ollama which sometimes returns a `None` finish reason. + if response.choices and (choice := response.choices[0]) and choice.finish_reason is None: # pyright: ignore[reportUnnecessaryComparison] + choice.finish_reason = 'stop' + + try: + response = chat.ChatCompletion.model_validate(response.model_dump()) + except ValidationError as e: # pragma: no cover + raise UnexpectedModelBehavior(f'Invalid response from OpenAI chat completions endpoint: {e}') from e + + choice = response.choices[0] + items: list[ModelResponsePart] = [] + + # OpenRouter uses 'reasoning', OpenAI previously used 'reasoning_content' (removed Feb 2025) + reasoning_content = getattr(choice.message, 'reasoning', None) or getattr(choice.message, 'reasoning_content', None) + if reasoning_content: + items.append(ThinkingPart(id='reasoning_content', content=reasoning_content, provider_name=model.system)) + + vendor_details: dict[str, Any] = {} + + if choice.logprobs is not None and choice.logprobs.content: + vendor_details['logprobs'] = [ + { + 'token': lp.token, + 'bytes': lp.bytes, + 'logprob': lp.logprob, + 'top_logprobs': [ + {'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs + ], + } + for lp in choice.logprobs.content + ] + + if choice.message.content is not None: + items.extend( + (replace(part, id='content', provider_name=model.system) if isinstance(part, ThinkingPart) else part) + for part in split_content_into_text_and_thinking(choice.message.content, model.profile.thinking_tags) + ) + + if choice.message.tool_calls: + for tool_call in choice.message.tool_calls: + if isinstance(tool_call, ChatCompletionMessageFunctionToolCall): + part = ToolCallPart(tool_call.function.name, tool_call.function.arguments, tool_call_id=tool_call.id) + elif isinstance(tool_call, ChatCompletionMessageCustomToolCall): # pragma: no cover + raise RuntimeError('Custom tool calls are not supported') + else: + assert_never(tool_call) + part.tool_call_id = _guard_tool_call_id(part) + items.append(part) + + raw_finish_reason = choice.finish_reason + vendor_details['finish_reason'] = raw_finish_reason + finish_reason = finish_reason_map.get(raw_finish_reason) + + return ModelResponse( + parts=items, + usage=map_usage_fn(response), + model_name=response.model, + timestamp=timestamp, + provider_details=vendor_details or None, + provider_response_id=response.id, + provider_name=model.system, + finish_reason=finish_reason, + ) + + +async def process_streamed_response( + model: Any, + response: AsyncStream[ChatCompletionChunk], + model_request_parameters: ModelRequestParameters, + *, + map_usage_fn: Callable[[ChatCompletionChunk], usage.RequestUsage], + finish_reason_map: Mapping[str, FinishReason], +) -> OpenAICompatStreamedResponse: + """Wrap a streamed chat completion response with compat handling.""" + peekable_response = _utils.PeekableAsyncStream(response) + first_chunk = await peekable_response.peek() + if isinstance(first_chunk, _utils.Unset): # pragma: no cover + raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') + + model_name = first_chunk.model or model.model_name + + return OpenAICompatStreamedResponse( + model_request_parameters=model_request_parameters, + _model_name=model_name, + _model_profile=model.profile, + _response=peekable_response, + _timestamp=number_to_datetime(first_chunk.created), + _provider_name=model.system, + _map_usage_fn=map_usage_fn, + _finish_reason_map=finish_reason_map, + ) + + +@dataclass +class OpenAICompatStreamedResponse(StreamedResponse): + """Streaming response wrapper for OpenAI chat completions.""" + + model_request_parameters: ModelRequestParameters + _model_name: str + _model_profile: ModelProfile + _response: AsyncIterable[ChatCompletionChunk] + _timestamp: datetime + _provider_name: str + _map_usage_fn: Callable[[ChatCompletionChunk], usage.RequestUsage] = field(repr=False) + _finish_reason_map: Mapping[str, FinishReason] = field(repr=False) + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async for chunk in self._response: + self._usage += self._map_usage_fn(chunk) + + if chunk.id: # pragma: no branch + self.provider_response_id = chunk.id + + if chunk.model: + self._model_name = chunk.model + + try: + choice = chunk.choices[0] + except IndexError: + continue + + if choice.delta is None: # pyright: ignore[reportUnnecessaryComparison] + continue + + if raw_finish_reason := choice.finish_reason: + self.provider_details = {'finish_reason': raw_finish_reason} + self.finish_reason = self._finish_reason_map.get(raw_finish_reason) + + content = choice.delta.content + if content is not None: + maybe_event = self._parts_manager.handle_text_delta( + vendor_part_id='content', + content=content, + thinking_tags=self._model_profile.thinking_tags, + ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, + ) + if maybe_event is not None: + if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart): + maybe_event.part.id = 'content' + maybe_event.part.provider_name = self.provider_name + yield maybe_event + + # OpenRouter uses 'reasoning', OpenAI previously used 'reasoning_content' (removed Feb 2025) + reasoning_content = getattr(choice.delta, 'reasoning', None) or getattr( + choice.delta, 'reasoning_content', None + ) + if reasoning_content: + yield self._parts_manager.handle_thinking_delta( + vendor_part_id='reasoning_content', + id='reasoning_content', + content=reasoning_content, + provider_name=self.provider_name, + ) + + for dtc in choice.delta.tool_calls or []: + maybe_event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=dtc.index, + tool_name=dtc.function and dtc.function.name, + args=dtc.function and dtc.function.arguments, + tool_call_id=dtc.id, + ) + if maybe_event is not None: + yield maybe_event + + @property + def model_name(self) -> str: + return self._model_name + + @property + def provider_name(self) -> str: + return self._provider_name + + @property + def timestamp(self) -> datetime: + return self._timestamp + + +def map_usage( + response: chat.ChatCompletion | ChatCompletionChunk, +) -> usage.RequestUsage: + response_usage = response.usage + if response_usage is None: + return usage.RequestUsage() + else: + details = { + key: value + for key, value in response_usage.model_dump( + exclude_none=True, + exclude={'prompt_tokens', 'completion_tokens', 'total_tokens'}, + ).items() + if isinstance(value, int) + } + result = usage.RequestUsage( + input_tokens=response_usage.prompt_tokens, + output_tokens=response_usage.completion_tokens, + details=details, + ) + if response_usage.completion_tokens_details is not None: + details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True)) + result.output_audio_tokens = response_usage.completion_tokens_details.audio_tokens or 0 + if response_usage.prompt_tokens_details is not None: + result.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens or 0 + result.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens or 0 + return result diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 843716fcef..efdb8daf14 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -6,17 +6,16 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field, replace from datetime import datetime +from functools import partialmethod from typing import Any, Literal, cast, overload -from pydantic import ValidationError from pydantic_core import to_json from typing_extensions import assert_never, deprecated from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition from .._run_context import RunContext -from .._thinking_part import split_content_into_text_and_thinking -from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime +from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime from ..builtin_tools import CodeExecutionTool, WebSearchTool from ..exceptions import UserError from ..messages import ( @@ -32,7 +31,6 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, - PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, @@ -42,15 +40,23 @@ UserPromptPart, VideoUrl, ) -from ..profiles import ModelProfile, ModelProfileSpec +from ..profiles import ModelProfileSpec from ..profiles.openai import OpenAIModelProfile, OpenAISystemPromptRole from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent +from ._openai_compat import ( + OpenAICompatStreamedResponse, + completions_create as _compat_completions_create, + map_tool_definition, + map_usage, + process_response, + process_streamed_response, +) try: - from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven + from openai import APIStatusError, AsyncOpenAI, AsyncStream, NotGiven from openai.types import AllModels, chat, responses from openai.types.chat import ( ChatCompletionChunk, @@ -62,21 +68,15 @@ from openai.types.chat.chat_completion_content_part_image_param import ImageURL from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio from openai.types.chat.chat_completion_content_part_param import File, FileFile - from openai.types.chat.chat_completion_message_custom_tool_call import ChatCompletionMessageCustomToolCall - from openai.types.chat.chat_completion_message_function_tool_call import ChatCompletionMessageFunctionToolCall - from openai.types.chat.chat_completion_message_function_tool_call_param import ( - ChatCompletionMessageFunctionToolCallParam, - ) from openai.types.chat.chat_completion_prediction_content_param import ChatCompletionPredictionContentParam - from openai.types.chat.completion_create_params import ( - WebSearchOptions, - WebSearchOptionsUserLocation, - WebSearchOptionsUserLocationApproximate, + from openai.types.responses import ( + ComputerToolParam, + FileSearchToolParam, + ResponseStatus, + WebSearchToolParam, ) - from openai.types.responses import ComputerToolParam, FileSearchToolParam, WebSearchToolParam from openai.types.responses.response_input_param import FunctionCallOutput, Message from openai.types.responses.response_reasoning_item_param import Summary - from openai.types.responses.response_status import ResponseStatus from openai.types.shared import ReasoningEffort from openai.types.shared_params import Reasoning except ImportError as _import_error: @@ -85,6 +85,7 @@ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' ) from _import_error + __all__ = ( 'OpenAIModel', 'OpenAIChatModel', @@ -439,286 +440,28 @@ async def _completions_create( model_settings: OpenAIChatModelSettings, model_request_parameters: ModelRequestParameters, ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: - tools = self._get_tools(model_request_parameters) - web_search_options = self._get_web_search_options(model_request_parameters) - - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif ( - not model_request_parameters.allow_text_output - and OpenAIModelProfile.from_profile(self.profile).openai_supports_tool_choice_required - ): - tool_choice = 'required' - else: - tool_choice = 'auto' - - openai_messages = await self._map_messages(messages) - - response_format: chat.completion_create_params.ResponseFormat | None = None - if model_request_parameters.output_mode == 'native': - output_object = model_request_parameters.output_object - assert output_object is not None - response_format = self._map_json_schema(output_object) - elif ( - model_request_parameters.output_mode == 'prompted' and self.profile.supports_json_object_output - ): # pragma: no branch - response_format = {'type': 'json_object'} - - unsupported_model_settings = OpenAIModelProfile.from_profile(self.profile).openai_unsupported_model_settings - for setting in unsupported_model_settings: - model_settings.pop(setting, None) - - try: - extra_headers = model_settings.get('extra_headers', {}) - extra_headers.setdefault('User-Agent', get_user_agent()) - return await self.client.chat.completions.create( - model=self._model_name, - messages=openai_messages, - parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN), - tools=tools or NOT_GIVEN, - tool_choice=tool_choice or NOT_GIVEN, - stream=stream, - stream_options={'include_usage': True} if stream else NOT_GIVEN, - stop=model_settings.get('stop_sequences', NOT_GIVEN), - max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN), - timeout=model_settings.get('timeout', NOT_GIVEN), - response_format=response_format or NOT_GIVEN, - seed=model_settings.get('seed', NOT_GIVEN), - reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN), - user=model_settings.get('openai_user', NOT_GIVEN), - web_search_options=web_search_options or NOT_GIVEN, - service_tier=model_settings.get('openai_service_tier', NOT_GIVEN), - prediction=model_settings.get('openai_prediction', NOT_GIVEN), - temperature=model_settings.get('temperature', NOT_GIVEN), - top_p=model_settings.get('top_p', NOT_GIVEN), - presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), - frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), - logit_bias=model_settings.get('logit_bias', NOT_GIVEN), - logprobs=model_settings.get('openai_logprobs', NOT_GIVEN), - top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN), - extra_headers=extra_headers, - extra_body=model_settings.get('extra_body'), - ) - except APIStatusError as e: - if (status_code := e.status_code) >= 400: - raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e - raise # pragma: lax no cover - - def _process_response(self, response: chat.ChatCompletion | str) -> ModelResponse: - """Process a non-streamed response, and prepare a message to return.""" - # Although the OpenAI SDK claims to return a Pydantic model (`ChatCompletion`) from the chat completions function: - # * it hasn't actually performed validation (presumably they're creating the model with `model_construct` or something?!) - # * if the endpoint returns plain text, the return type is a string - # Thus we validate it fully here. - if not isinstance(response, chat.ChatCompletion): - raise UnexpectedModelBehavior('Invalid response from OpenAI chat completions endpoint, expected JSON data') - - if response.created: - timestamp = number_to_datetime(response.created) - else: - timestamp = _now_utc() - response.created = int(timestamp.timestamp()) - - # Workaround for local Ollama which sometimes returns a `None` finish reason. - if response.choices and (choice := response.choices[0]) and choice.finish_reason is None: # pyright: ignore[reportUnnecessaryComparison] - choice.finish_reason = 'stop' - - try: - response = chat.ChatCompletion.model_validate(response.model_dump()) - except ValidationError as e: - raise UnexpectedModelBehavior(f'Invalid response from OpenAI chat completions endpoint: {e}') from e - - choice = response.choices[0] - items: list[ModelResponsePart] = [] - - # The `reasoning_content` field is only present in DeepSeek models. - # https://api-docs.deepseek.com/guides/reasoning_model - if reasoning_content := getattr(choice.message, 'reasoning_content', None): - items.append(ThinkingPart(id='reasoning_content', content=reasoning_content, provider_name=self.system)) - - # The `reasoning` field is only present in gpt-oss via Ollama and OpenRouter. - # - https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot#chat-completions-api - # - https://openrouter.ai/docs/use-cases/reasoning-tokens#basic-usage-with-reasoning-tokens - if reasoning := getattr(choice.message, 'reasoning', None): - items.append(ThinkingPart(id='reasoning', content=reasoning, provider_name=self.system)) - - # NOTE: We don't currently handle OpenRouter `reasoning_details`: - # - https://openrouter.ai/docs/use-cases/reasoning-tokens#preserving-reasoning-blocks - # If you need this, please file an issue. - - vendor_details: dict[str, Any] = {} - - # Add logprobs to vendor_details if available - if choice.logprobs is not None and choice.logprobs.content: - # Convert logprobs to a serializable format - vendor_details['logprobs'] = [ - { - 'token': lp.token, - 'bytes': lp.bytes, - 'logprob': lp.logprob, - 'top_logprobs': [ - {'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs - ], - } - for lp in choice.logprobs.content - ] - - if choice.message.content is not None: - items.extend( - (replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part) - for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags) - ) - if choice.message.tool_calls is not None: - for c in choice.message.tool_calls: - if isinstance(c, ChatCompletionMessageFunctionToolCall): - part = ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id) - elif isinstance(c, ChatCompletionMessageCustomToolCall): # pragma: no cover - # NOTE: Custom tool calls are not supported. - # See for more details. - raise RuntimeError('Custom tool calls are not supported') - else: - assert_never(c) - part.tool_call_id = _guard_tool_call_id(part) - items.append(part) - - raw_finish_reason = choice.finish_reason - vendor_details['finish_reason'] = raw_finish_reason - finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason) - - return ModelResponse( - parts=items, - usage=_map_usage(response), - model_name=response.model, - timestamp=timestamp, - provider_details=vendor_details or None, - provider_response_id=response.id, - provider_name=self._provider.name, - finish_reason=finish_reason, + return await _compat_completions_create( + self, + messages, + stream, + model_settings, + model_request_parameters, ) - async def _process_streamed_response( - self, response: AsyncStream[ChatCompletionChunk], model_request_parameters: ModelRequestParameters - ) -> OpenAIStreamedResponse: - """Process a streamed response, and prepare a streaming response to return.""" - peekable_response = _utils.PeekableAsyncStream(response) - first_chunk = await peekable_response.peek() - if isinstance(first_chunk, _utils.Unset): - raise UnexpectedModelBehavior( # pragma: no cover - 'Streamed response ended without content or tool calls' - ) - - # When using Azure OpenAI and a content filter is enabled, the first chunk will contain a `''` model name, - # so we set it from a later chunk in `OpenAIChatStreamedResponse`. - model_name = first_chunk.model or self._model_name - - return OpenAIStreamedResponse( - model_request_parameters=model_request_parameters, - _model_name=model_name, - _model_profile=self.profile, - _response=peekable_response, - _timestamp=number_to_datetime(first_chunk.created), - _provider_name=self._provider.name, - ) - - def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]: - return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] - - def _get_web_search_options(self, model_request_parameters: ModelRequestParameters) -> WebSearchOptions | None: - for tool in model_request_parameters.builtin_tools: - if isinstance(tool, WebSearchTool): # pragma: no branch - if not OpenAIModelProfile.from_profile(self.profile).openai_chat_supports_web_search: - raise UserError( - f'WebSearchTool is not supported with `OpenAIChatModel` and model {self.model_name!r}. ' - f'Please use `OpenAIResponsesModel` instead.' - ) - - if tool.user_location: - return WebSearchOptions( - search_context_size=tool.search_context_size, - user_location=WebSearchOptionsUserLocation( - type='approximate', - approximate=WebSearchOptionsUserLocationApproximate(**tool.user_location), - ), - ) - return WebSearchOptions(search_context_size=tool.search_context_size) - else: - raise UserError( - f'`{tool.__class__.__name__}` is not supported by `OpenAIChatModel`. If it should be, please file an issue.' - ) - - async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]: - """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`.""" - openai_messages: list[chat.ChatCompletionMessageParam] = [] - for message in messages: - if isinstance(message, ModelRequest): - async for item in self._map_user_message(message): - openai_messages.append(item) - elif isinstance(message, ModelResponse): - texts: list[str] = [] - tool_calls: list[ChatCompletionMessageFunctionToolCallParam] = [] - for item in message.parts: - if isinstance(item, TextPart): - texts.append(item.content) - elif isinstance(item, ThinkingPart): - # NOTE: DeepSeek `reasoning_content` field should NOT be sent back per https://api-docs.deepseek.com/guides/reasoning_model, - # but we currently just send it in `` tags anyway as we don't want DeepSeek-specific checks here. - # If you need this changed, please file an issue. - start_tag, end_tag = self.profile.thinking_tags - texts.append('\n'.join([start_tag, item.content, end_tag])) - elif isinstance(item, ToolCallPart): - tool_calls.append(self._map_tool_call(item)) - # OpenAI doesn't return built-in tool calls - elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover - pass - else: - assert_never(item) - message_param = chat.ChatCompletionAssistantMessageParam(role='assistant') - if texts: - # Note: model responses from this model should only have one text item, so the following - # shouldn't merge multiple texts into one unless you switch models between runs: - message_param['content'] = '\n\n'.join(texts) - else: - message_param['content'] = None - if tool_calls: - message_param['tool_calls'] = tool_calls - openai_messages.append(message_param) - else: - assert_never(message) - if instructions := self._get_instructions(messages): - openai_messages.insert(0, chat.ChatCompletionSystemMessageParam(content=instructions, role='system')) - return openai_messages - - @staticmethod - def _map_tool_call(t: ToolCallPart) -> ChatCompletionMessageFunctionToolCallParam: - return ChatCompletionMessageFunctionToolCallParam( - id=_guard_tool_call_id(t=t), - type='function', - function={'name': t.tool_name, 'arguments': t.args_as_json_str()}, - ) + _process_response = partialmethod( + process_response, + map_usage_fn=map_usage, + finish_reason_map=_CHAT_FINISH_REASON_MAP, + ) - def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat: - response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage] - 'type': 'json_schema', - 'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema}, - } - if o.description: - response_format_param['json_schema']['description'] = o.description - if OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition: # pragma: no branch - response_format_param['json_schema']['strict'] = o.strict - return response_format_param + _process_streamed_response = partialmethod( + process_streamed_response, + map_usage_fn=map_usage, + finish_reason_map=_CHAT_FINISH_REASON_MAP, + ) def _map_tool_definition(self, f: ToolDefinition) -> chat.ChatCompletionToolParam: - tool_param: chat.ChatCompletionToolParam = { - 'type': 'function', - 'function': { - 'name': f.name, - 'description': f.description or '', - 'parameters': f.parameters_json_schema, - }, - } - if f.strict and OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition: - tool_param['function']['strict'] = f.strict - return tool_param + return map_tool_definition(self.profile, f) async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]: for part in message.parts: @@ -1115,7 +858,7 @@ async def _responses_create( # Apparently they're only checking input messages for "JSON", not instructions. assert isinstance(instructions, str) openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions)) - instructions = NOT_GIVEN + instructions = None if verbosity := model_settings.get('openai_text_verbosity'): text = text or {} @@ -1141,21 +884,21 @@ async def _responses_create( input=openai_messages, model=self._model_name, instructions=instructions, - parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN), - tools=tools or NOT_GIVEN, - tool_choice=tool_choice or NOT_GIVEN, - max_output_tokens=model_settings.get('max_tokens', NOT_GIVEN), + parallel_tool_calls=model_settings.get('parallel_tool_calls', NotGiven()), + tools=tools or NotGiven(), + tool_choice=tool_choice or NotGiven(), + max_output_tokens=model_settings.get('max_tokens', NotGiven()), stream=stream, - temperature=model_settings.get('temperature', NOT_GIVEN), - top_p=model_settings.get('top_p', NOT_GIVEN), - truncation=model_settings.get('openai_truncation', NOT_GIVEN), - timeout=model_settings.get('timeout', NOT_GIVEN), - service_tier=model_settings.get('openai_service_tier', NOT_GIVEN), + temperature=model_settings.get('temperature', NotGiven()), + top_p=model_settings.get('top_p', NotGiven()), + truncation=model_settings.get('openai_truncation', NotGiven()), + timeout=model_settings.get('timeout', NotGiven()), + service_tier=model_settings.get('openai_service_tier', NotGiven()), previous_response_id=previous_response_id, reasoning=reasoning, - user=model_settings.get('openai_user', NOT_GIVEN), - text=text or NOT_GIVEN, - include=include or NOT_GIVEN, + user=model_settings.get('openai_user', NotGiven()), + text=text or NotGiven(), + include=include or NotGiven(), extra_headers=extra_headers, extra_body=model_settings.get('extra_body'), ) @@ -1180,7 +923,7 @@ def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reason reasoning_summary = reasoning_generate_summary if reasoning_effort is None and reasoning_summary is None: - return NOT_GIVEN + return NotGiven() return Reasoning(effort=reasoning_effort, summary=reasoning_summary) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]: @@ -1412,7 +1155,7 @@ async def _map_messages( # noqa: C901 assert_never(item) else: assert_never(message) - instructions = self._get_instructions(messages) or NOT_GIVEN + instructions = self._get_instructions(messages) or NotGiven() return instructions, openai_messages def _map_json_schema(self, o: OutputObjectDefinition) -> responses.ResponseFormatTextJSONSchemaConfigParam: @@ -1504,100 +1247,7 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa return responses.EasyInputMessageParam(role='user', content=content) -@dataclass -class OpenAIStreamedResponse(StreamedResponse): - """Implementation of `StreamedResponse` for OpenAI models.""" - - _model_name: OpenAIModelName - _model_profile: ModelProfile - _response: AsyncIterable[ChatCompletionChunk] - _timestamp: datetime - _provider_name: str - - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: - async for chunk in self._response: - print(chunk) - self._usage += _map_usage(chunk) - - if chunk.id: # pragma: no branch - self.provider_response_id = chunk.id - - if chunk.model: - self._model_name = chunk.model - - try: - choice = chunk.choices[0] - except IndexError: - continue - - # When using Azure OpenAI and an async content filter is enabled, the openai SDK can return None deltas. - if choice.delta is None: # pyright: ignore[reportUnnecessaryComparison] - continue - - if raw_finish_reason := choice.finish_reason: - self.provider_details = {'finish_reason': raw_finish_reason} - self.finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason) - - # Handle the text part of the response - content = choice.delta.content - if content is not None: - maybe_event = self._parts_manager.handle_text_delta( - vendor_part_id='content', - content=content, - thinking_tags=self._model_profile.thinking_tags, - ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, - ) - if maybe_event is not None: # pragma: no branch - if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart): - maybe_event.part.id = 'content' - maybe_event.part.provider_name = self.provider_name - yield maybe_event - - # The `reasoning_content` field is only present in DeepSeek models. - # https://api-docs.deepseek.com/guides/reasoning_model - if reasoning_content := getattr(choice.delta, 'reasoning_content', None): - yield self._parts_manager.handle_thinking_delta( - vendor_part_id='reasoning_content', - id='reasoning_content', - content=reasoning_content, - provider_name=self.provider_name, - ) - - # The `reasoning` field is only present in gpt-oss via Ollama and OpenRouter. - # - https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot#chat-completions-api - # - https://openrouter.ai/docs/use-cases/reasoning-tokens#basic-usage-with-reasoning-tokens - if reasoning := getattr(choice.delta, 'reasoning', None): # pragma: no cover - yield self._parts_manager.handle_thinking_delta( - vendor_part_id='reasoning', - id='reasoning', - content=reasoning, - provider_name=self.provider_name, - ) - - for dtc in choice.delta.tool_calls or []: - maybe_event = self._parts_manager.handle_tool_call_delta( - vendor_part_id=dtc.index, - tool_name=dtc.function and dtc.function.name, - args=dtc.function and dtc.function.arguments, - tool_call_id=dtc.id, - ) - if maybe_event is not None: - yield maybe_event - - @property - def model_name(self) -> OpenAIModelName: - """Get the model name of the response.""" - return self._model_name - - @property - def provider_name(self) -> str: - """Get the provider name.""" - return self._provider_name - - @property - def timestamp(self) -> datetime: - """Get the timestamp of the response.""" - return self._timestamp +OpenAIStreamedResponse = OpenAICompatStreamedResponse @dataclass @@ -1815,55 +1465,34 @@ def timestamp(self) -> datetime: return self._timestamp -def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.Response) -> usage.RequestUsage: +def _map_usage(response: responses.Response) -> usage.RequestUsage: + """Map usage from OpenAI Responses API response.""" response_usage = response.usage if response_usage is None: return usage.RequestUsage() - elif isinstance(response_usage, responses.ResponseUsage): - details: dict[str, int] = { - key: value - for key, value in response_usage.model_dump( - exclude={'input_tokens', 'output_tokens', 'total_tokens'} - ).items() - if isinstance(value, int) - } - # Handle vLLM compatibility - some providers don't include token details - if getattr(response_usage, 'input_tokens_details', None) is not None: - cache_read_tokens = response_usage.input_tokens_details.cached_tokens - else: - cache_read_tokens = 0 - if getattr(response_usage, 'output_tokens_details', None) is not None: - details['reasoning_tokens'] = response_usage.output_tokens_details.reasoning_tokens - else: - details['reasoning_tokens'] = 0 + details: dict[str, int] = { + key: value + for key, value in response_usage.model_dump(exclude={'input_tokens', 'output_tokens', 'total_tokens'}).items() + if isinstance(value, int) + } + # Handle vLLM compatibility - some providers don't include token details + if getattr(response_usage, 'input_tokens_details', None) is not None: + cache_read_tokens = response_usage.input_tokens_details.cached_tokens + else: + cache_read_tokens = 0 - return usage.RequestUsage( - input_tokens=response_usage.input_tokens, - output_tokens=response_usage.output_tokens, - cache_read_tokens=cache_read_tokens, - details=details, - ) + if getattr(response_usage, 'output_tokens_details', None) is not None: + details['reasoning_tokens'] = response_usage.output_tokens_details.reasoning_tokens else: - details = { - key: value - for key, value in response_usage.model_dump( - exclude_none=True, exclude={'prompt_tokens', 'completion_tokens', 'total_tokens'} - ).items() - if isinstance(value, int) - } - u = usage.RequestUsage( - input_tokens=response_usage.prompt_tokens, - output_tokens=response_usage.completion_tokens, - details=details, - ) - if response_usage.completion_tokens_details is not None: - details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True)) - u.output_audio_tokens = response_usage.completion_tokens_details.audio_tokens or 0 - if response_usage.prompt_tokens_details is not None: - u.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens or 0 - u.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens or 0 - return u + details['reasoning_tokens'] = 0 + + return usage.RequestUsage( + input_tokens=response_usage.input_tokens, + output_tokens=response_usage.output_tokens, + cache_read_tokens=cache_read_tokens, + details=details, + ) def _combine_tool_call_ids(call_id: str, id: str | None) -> str: diff --git a/pydantic_ai_slim/pydantic_ai/models/openrouter.py b/pydantic_ai_slim/pydantic_ai/models/openrouter.py new file mode 100644 index 0000000000..8e2595b047 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/models/openrouter.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from collections.abc import AsyncIterable, AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from typing import Any, Literal, cast, overload + +from typing_extensions import assert_never + +from pydantic_ai.messages import ( + FinishReason, + ModelMessage, + ModelRequest, + ModelResponse, + RetryPromptPart, + SystemPromptPart, + ToolReturnPart, + UserPromptPart, +) +from pydantic_ai.models import ( + Model, + ModelRequestParameters, + StreamedResponse, + check_allow_model_requests, +) +from pydantic_ai.models._openai_compat import ( + completions_create, + map_usage, + process_response, + process_streamed_response, +) +from pydantic_ai.profiles import ModelProfileSpec +from pydantic_ai.providers import Provider +from pydantic_ai.providers.openrouter import OpenRouterProvider +from pydantic_ai.settings import ModelSettings + +try: + from openai import AsyncOpenAI, AsyncStream + from openai.types import chat + from openai.types.chat import ChatCompletionChunk +except ImportError as _import_error: + raise ImportError( + 'Please install `openai` to use the OpenRouter model, you can use the `openai` optional group - `pip install "pydantic-ai-slim[openai]"' + ) from _import_error + +OpenRouterModelName = str + + +__all__ = ['OpenRouterModel'] + + +_OPENROUTER_CHAT_FINISH_REASON_MAP: dict[str, FinishReason] = { + 'stop': 'stop', + 'length': 'length', + 'tool_calls': 'tool_call', + 'content_filter': 'content_filter', + 'function_call': 'tool_call', + 'error': 'error', +} + + +@dataclass(init=False) +class OpenRouterModel(Model): + """Model integration for OpenRouter's OpenAI-compatible chat completions API.""" + + client: AsyncOpenAI = field(repr=False) + _model_name: OpenRouterModelName = field(repr=False) + _system: str = field(default='openrouter', repr=False) + + def __init__( + self, + model_name: OpenRouterModelName, + *, + provider: Literal['openrouter'] | Provider[AsyncOpenAI] = 'openrouter', + profile: ModelProfileSpec | None = None, + settings: ModelSettings | None = None, + ): + """Initialize an OpenRouter model. + + Args: + model_name: The name of the OpenRouter model to use (e.g., 'openai/gpt-4o', 'google/gemini-2.5-flash-lite'). + provider: The provider to use for authentication and API access. Can be either the string + 'openrouter' or an instance of `Provider[AsyncOpenAI]`. If not provided, a new provider will be + created using environment variables. + profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. + settings: Default model settings for this model instance. + """ + self._model_name = model_name + + if isinstance(provider, str): + provider = OpenRouterProvider() + self._provider = provider + self.client = provider.client + + super().__init__(settings=settings, profile=profile or provider.model_profile(model_name)) + + @property + def model_name(self) -> str: + return self._model_name + + @property + def system(self) -> str: + return self._system + + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + check_allow_model_requests() + response = await self._completions_create(messages, False, model_settings, model_request_parameters) + model_response = self._process_response(response) + return model_response + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + run_context: Any | None = None, + ) -> AsyncIterator[StreamedResponse]: + check_allow_model_requests() + response = await self._completions_create(messages, True, model_settings, model_request_parameters) + async with response: + yield await self._process_streamed_response(response, model_request_parameters) + + @overload + async def _completions_create( + self, + messages: list[ModelMessage], + stream: Literal[True], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> AsyncStream[ChatCompletionChunk]: ... + + @overload + async def _completions_create( + self, + messages: list[ModelMessage], + stream: Literal[False], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> chat.ChatCompletion: ... + + async def _completions_create( + self, + messages: list[ModelMessage], + stream: bool, + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: + settings_to_use: ModelSettings = model_settings or {} + reasoning_param = self._build_reasoning_param(settings_to_use) + if reasoning_param: + settings_dict = dict(settings_to_use) + extra_body_raw = settings_dict.get('extra_body') + extra_body: dict[str, Any] = ( + dict(cast(dict[str, Any], extra_body_raw)) if isinstance(extra_body_raw, dict) else {} + ) + extra_body['reasoning'] = reasoning_param + settings_dict['extra_body'] = extra_body + settings_to_use = cast(ModelSettings, settings_dict) + + return await completions_create( + self, + messages, + stream, + settings_to_use, + model_request_parameters, + ) + + def _build_reasoning_param(self, model_settings: ModelSettings) -> dict[str, Any] | None: + reasoning_config: dict[str, Any] = {} + + if 'openrouter_reasoning_effort' in model_settings: + reasoning_config['effort'] = model_settings['openrouter_reasoning_effort'] + elif 'openrouter_reasoning_max_tokens' in model_settings: + reasoning_config['max_tokens'] = model_settings['openrouter_reasoning_max_tokens'] + elif 'openrouter_reasoning_enabled' in model_settings: + reasoning_config['enabled'] = model_settings['openrouter_reasoning_enabled'] + + if 'openrouter_reasoning_exclude' in model_settings: + reasoning_config['exclude'] = model_settings['openrouter_reasoning_exclude'] + + return reasoning_config if reasoning_config else None + + def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: + return process_response( + self, + response, + map_usage_fn=map_usage, + finish_reason_map=_OPENROUTER_CHAT_FINISH_REASON_MAP, + ) + + async def _process_streamed_response( + self, response: AsyncStream[ChatCompletionChunk], model_request_parameters: ModelRequestParameters + ) -> StreamedResponse: + return await process_streamed_response( + self, + response, + model_request_parameters, + map_usage_fn=map_usage, + finish_reason_map=_OPENROUTER_CHAT_FINISH_REASON_MAP, + ) + + @staticmethod + def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam: + if isinstance(part.content, str): + return chat.ChatCompletionUserMessageParam(role='user', content=part.content) + else: + content_parts: list[str] = [] + for item in part.content: + if isinstance(item, str): + content_parts.append(item) + return chat.ChatCompletionUserMessageParam(role='user', content=' '.join(content_parts)) + + @classmethod + async def _map_user_message(cls, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]: + for part in message.parts: + if isinstance(part, SystemPromptPart): + yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content) + elif isinstance(part, UserPromptPart): + yield cls._map_user_prompt(part) + elif isinstance(part, ToolReturnPart): + yield chat.ChatCompletionToolMessageParam( + role='tool', + tool_call_id=part.tool_call_id, + content=part.model_response_str(), + ) + elif isinstance(part, RetryPromptPart): + if part.tool_name is None: + yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response()) + else: + yield chat.ChatCompletionToolMessageParam( + role='tool', + tool_call_id=part.tool_call_id, + content=part.model_response(), + ) + else: + assert_never(part) diff --git a/tests/models/test_openrouter.py b/tests/models/test_openrouter.py new file mode 100644 index 0000000000..698f3719e6 --- /dev/null +++ b/tests/models/test_openrouter.py @@ -0,0 +1,575 @@ +import json +import os +from typing import Any, Literal, cast +from unittest.mock import patch + +import pydantic_core +import pytest +from inline_snapshot import snapshot + +from pydantic_ai import Agent +from pydantic_ai.messages import ( + ImageUrl, + ModelMessage, + ModelRequest, + RetryPromptPart, + SystemPromptPart, + TextPart, + ThinkingPart, + UserPromptPart, +) +from pydantic_ai.models import ModelRequestParameters +from pydantic_ai.settings import ModelSettings +from pydantic_ai.tools import ToolDefinition + +from ..conftest import try_import +from .mock_openai import ( + MockOpenAI, + completion_message, + get_mock_chat_completion_kwargs, +) + +with try_import() as imports_successful: + from openai.types.chat import ( + ChatCompletionChunk, + ChatCompletionMessage, + ) + from openai.types.chat.chat_completion_chunk import ( + Choice as ChunkChoice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, + ) + from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function as ChatCompletionMessageFunctionToolCall, + ) + from openai.types.completion_usage import CompletionUsage + + from pydantic_ai.models.openrouter import OpenRouterModel + from pydantic_ai.providers.openrouter import OpenRouterProvider + + def create_openrouter_model(model_name: str, mock_client: Any) -> OpenRouterModel: + """Helper to create OpenRouterModel with mock client using provider pattern.""" + provider = OpenRouterProvider(openai_client=mock_client) + return OpenRouterModel(model_name, provider=provider) + + def text_chunk( + text: str, + finish_reason: Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] | None = None, + ) -> ChatCompletionChunk: + """Create a streaming chunk with text content.""" + return ChatCompletionChunk( + id='test-chunk', + choices=[ + ChunkChoice( + delta=ChoiceDelta(content=text, role='assistant'), + finish_reason=finish_reason, + index=0, + ) + ], + created=1234567890, + model='google/gemini-2.5-flash-lite', + object='chat.completion.chunk', + ) + + def chunk(choices: list[ChunkChoice]) -> ChatCompletionChunk: + """Create a custom streaming chunk.""" + return ChatCompletionChunk( + id='test-chunk', + choices=choices, + created=1234567890, + model='google/gemini-2.5-flash-lite', + object='chat.completion.chunk', + ) + + +pytestmark = [ + pytest.mark.skipif(not imports_successful(), reason='openai not installed'), + pytest.mark.anyio, +] + + +def test_openrouter_model_init(): + c = completion_message(ChatCompletionMessage(content='test', role='assistant')) + mock_client = MockOpenAI.create_mock(c) + from pydantic_ai.providers.openrouter import OpenRouterProvider + + provider = OpenRouterProvider(openai_client=mock_client) + model = OpenRouterModel('google/gemini-2.5-flash-lite', provider=provider) + assert model.model_name == 'google/gemini-2.5-flash-lite' + assert model.system == 'openrouter' + + +def test_openrouter_model_init_with_string_provider(): + with patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test-api-key'}, clear=False): + model = OpenRouterModel('google/gemini-2.5-flash-lite', provider='openrouter') + assert model.model_name == 'google/gemini-2.5-flash-lite' + assert model.system == 'openrouter' + assert model.client is not None + + +async def test_openrouter_basic_request(allow_model_requests: None): + c = completion_message(ChatCompletionMessage(content='Hello from OpenRouter!', role='assistant')) + mock_client = MockOpenAI.create_mock(c) + model = create_openrouter_model('google/gemini-2.5-flash-lite', mock_client) + + messages: list[ModelMessage] = [ModelRequest([UserPromptPart(content='Hello')])] + model_settings = None + model_request_parameters = ModelRequestParameters( + function_tools=[], + output_tools=[], + allow_text_output=True, + ) + + response = await model.request(messages, model_settings, model_request_parameters) + + assert isinstance(response.parts[0], TextPart) + assert response.parts[0].content == 'Hello from OpenRouter!' + + +async def test_openrouter_no_reasoning_extra_body(allow_model_requests: None): + c = completion_message(ChatCompletionMessage(content='No reasoning', role='assistant')) + mock_client = MockOpenAI.create_mock(c) + model = create_openrouter_model('google/gemini-2.5-flash-lite', mock_client) + + messages: list[ModelMessage] = [ModelRequest([UserPromptPart(content='hi')])] + params = ModelRequestParameters(function_tools=[], output_tools=[], allow_text_output=True) + + response = await model.request(messages, None, params) + assert isinstance(response.parts[0], TextPart) + + kwargs = cast(MockOpenAI, mock_client).chat_completion_kwargs[0] + extra_body = cast(dict[str, Any] | None, kwargs.get('extra_body')) + assert not extra_body or 'reasoning' not in extra_body + + +async def test_openrouter_thinking_part_response(): + message = ChatCompletionMessage(content='Final answer after thinking', role='assistant') + setattr(cast(Any, message), 'reasoning', 'Let me think about this step by step...') + + c = completion_message(message) + mock_client = MockOpenAI.create_mock(c) + model = create_openrouter_model('anthropic/claude-3.7-sonnet', mock_client) + + processed_response = model._process_response(c) # type: ignore[reportPrivateUsage] + + assert len(processed_response.parts) == 2 + assert isinstance(processed_response.parts[0], ThinkingPart) + assert processed_response.parts[0].content == 'Let me think about this step by step...' + assert isinstance(processed_response.parts[1], TextPart) + assert processed_response.parts[1].content == 'Final answer after thinking' + + +def test_openrouter_reasoning_param_building(): + c = completion_message(ChatCompletionMessage(content='Test', role='assistant')) + mock_client = MockOpenAI.create_mock(c) + model = create_openrouter_model('anthropic/claude-3.7-sonnet', mock_client) + + settings = cast(ModelSettings, {'openrouter_reasoning_effort': 'high'}) + reasoning_param = model._build_reasoning_param(settings) # type: ignore[reportPrivateUsage] + assert reasoning_param == {'effort': 'high'} + + settings = cast(ModelSettings, {'openrouter_reasoning_max_tokens': 2000}) + reasoning_param = model._build_reasoning_param(settings) # type: ignore[reportPrivateUsage] + assert reasoning_param == {'max_tokens': 2000} + + settings = cast(ModelSettings, {'openrouter_reasoning_enabled': True}) + reasoning_param = model._build_reasoning_param(settings) # type: ignore[reportPrivateUsage] + assert reasoning_param == {'enabled': True} + + settings = cast(ModelSettings, {'openrouter_reasoning_effort': 'medium', 'openrouter_reasoning_exclude': True}) + reasoning_param = model._build_reasoning_param(settings) # type: ignore[reportPrivateUsage] + assert reasoning_param == {'effort': 'medium', 'exclude': True} + + settings = cast(ModelSettings, {}) + reasoning_param = model._build_reasoning_param(settings) # type: ignore[reportPrivateUsage] + assert reasoning_param is None + + +async def test_openrouter_stream_text(allow_model_requests: None): + """Test basic text streaming.""" + stream = [text_chunk('Hello '), text_chunk('from '), text_chunk('OpenRouter!'), chunk([])] + mock_client = MockOpenAI.create_mock_stream(stream) + model = create_openrouter_model('google/gemini-2.5-flash-lite', mock_client) + agent = Agent(model) + + async with agent.run_stream('test prompt') as result: + assert not result.is_complete + chunks = [c async for c in result.stream_text(debounce_by=None)] + assert chunks == snapshot(['Hello ', 'Hello from ', 'Hello from OpenRouter!']) + assert result.is_complete + + +async def test_openrouter_stream_with_finish_reason(allow_model_requests: None): + """Test streaming with finish_reason.""" + stream = [ + text_chunk('Response '), + text_chunk('complete', finish_reason='stop'), + ] + mock_client = MockOpenAI.create_mock_stream(stream) + model = create_openrouter_model('anthropic/claude-3.7-sonnet', mock_client) + agent = Agent(model) + + async with agent.run_stream('test') as result: + chunks = [c async for c in result.stream_text(debounce_by=None)] + assert chunks == snapshot(['Response ', 'Response complete']) + assert result.is_complete + + +async def test_openrouter_tool_call(allow_model_requests: None): + """Test single tool call.""" + responses = [ + completion_message( + ChatCompletionMessage( + content=None, + role='assistant', + tool_calls=[ + ChatCompletionMessageToolCall( + id='call_1', + function=ChatCompletionMessageFunctionToolCall( + arguments='{"location": "San Francisco"}', + name='get_weather', + ), + type='function', + ) + ], + ), + ), + completion_message(ChatCompletionMessage(content='The weather is sunny!', role='assistant')), + ] + mock_client = MockOpenAI.create_mock(responses) + model = create_openrouter_model('openai/gpt-4o', mock_client) + agent = Agent(model) + + @agent.tool_plain + async def get_weather(location: str) -> str: + return f'Weather data for {location}' + + result = await agent.run('What is the weather?') + assert result.output == 'The weather is sunny!' + + +async def test_openrouter_multiple_tool_calls(allow_model_requests: None): + """Test multiple sequential tool calls.""" + responses = [ + completion_message( + ChatCompletionMessage( + content=None, + role='assistant', + tool_calls=[ + ChatCompletionMessageToolCall( + id='call_1', + function=ChatCompletionMessageFunctionToolCall( + arguments='{"city": "London"}', + name='get_location', + ), + type='function', + ) + ], + ), + ), + completion_message( + ChatCompletionMessage( + content=None, + role='assistant', + tool_calls=[ + ChatCompletionMessageToolCall( + id='call_2', + function=ChatCompletionMessageFunctionToolCall( + arguments='{"city": "Paris"}', + name='get_location', + ), + type='function', + ) + ], + ), + ), + completion_message(ChatCompletionMessage(content='Both locations found!', role='assistant')), + ] + mock_client = MockOpenAI.create_mock(responses) + model = create_openrouter_model('anthropic/claude-3.5-sonnet', mock_client) + agent = Agent(model) + + @agent.tool_plain + async def get_location(city: str) -> str: + return json.dumps({'city': city, 'lat': 0, 'lng': 0}) + + result = await agent.run('Get locations') + assert result.output == 'Both locations found!' + + +async def test_openrouter_stream_tool_call(allow_model_requests: None): + """Test streaming with tool calls.""" + stream = [ + chunk( + [ + ChunkChoice( + delta=ChoiceDelta( + role='assistant', + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id='call_1', + function=ChoiceDeltaToolCallFunction(name='calculator', arguments=''), + type='function', + ) + ], + ), + finish_reason=None, + index=0, + ) + ] + ), + chunk( + [ + ChunkChoice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='{"num": 5}')) + ] + ), + finish_reason=None, + index=0, + ) + ] + ), + chunk([ChunkChoice(delta=ChoiceDelta(), finish_reason='tool_calls', index=0)]), + ] + mock_client = MockOpenAI.create_mock_stream(stream) + model = create_openrouter_model('google/gemini-2.5-flash-lite', mock_client) + + messages: list[ModelMessage] = [ModelRequest([UserPromptPart(content='Calculate 5')])] + tool_def = ToolDefinition( + name='calculator', + description='Do math', + parameters_json_schema={'type': 'object', 'properties': {'num': {'type': 'number'}}}, + outer_typed_dict_key=None, + ) + params = ModelRequestParameters( + function_tools=[tool_def], + output_tools=[], + allow_text_output=True, + ) + + async with model.request_stream(messages, None, params) as response: + events = [event async for event in response] + assert len(events) > 0 + + +async def test_openrouter_structured_response(allow_model_requests: None): + """Test structured/native output.""" + response_content = '{"name": "John", "age": 30}' + mock_response = completion_message(ChatCompletionMessage(content=response_content, role='assistant')) + mock_client = MockOpenAI.create_mock(mock_response) + model = create_openrouter_model('openai/gpt-4o', mock_client) + + messages: list[ModelMessage] = [ModelRequest([UserPromptPart(content='Get user info')])] + params = ModelRequestParameters( + function_tools=[], + output_tools=[], + allow_text_output=True, + ) + + response = await model.request(messages, None, params) + assert isinstance(response.parts[0], TextPart) + assert 'John' in response.parts[0].content + + +async def test_openrouter_usage_tracking(allow_model_requests: None): + """Test usage metrics are tracked correctly.""" + mock_response = completion_message( + ChatCompletionMessage(content='Response', role='assistant'), + usage=CompletionUsage( + completion_tokens=10, + prompt_tokens=20, + total_tokens=30, + ), + ) + mock_client = MockOpenAI.create_mock(mock_response) + model = create_openrouter_model('google/gemini-2.5-flash-lite', mock_client) + agent = Agent(model) + + result = await agent.run('test') + usage = result.usage() + assert usage.input_tokens == 20 + assert usage.output_tokens == 10 + assert usage.total_tokens == 30 + + +async def test_openrouter_with_reasoning_settings(allow_model_requests: None): + """Test OpenRouter-specific reasoning settings.""" + mock_response = completion_message(ChatCompletionMessage(content='Answer', role='assistant')) + mock_client = MockOpenAI.create_mock(mock_response) + model = create_openrouter_model('openai/o1-preview', mock_client) + + messages: list[ModelMessage] = [ModelRequest([UserPromptPart(content='Think about this')])] + settings = cast(ModelSettings, {'openrouter_reasoning_effort': 'high'}) + params = ModelRequestParameters(function_tools=[], output_tools=[], allow_text_output=True) + + response = await model.request(messages, settings, params) + assert isinstance(response.parts[0], TextPart) + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + extra_body = kwargs.get('extra_body', {}) + assert 'reasoning' in extra_body + assert extra_body['reasoning'] == {'effort': 'high'} + + +async def test_openrouter_model_custom_base_url(allow_model_requests: None): + """Test OpenRouterModel with provider.""" + # Test with provider using default base URL + from pydantic_ai.providers.openrouter import OpenRouterProvider + + provider = OpenRouterProvider(api_key='test-key') + model = OpenRouterModel('openai/gpt-4o', provider=provider) + assert model.model_name == 'openai/gpt-4o' + assert model.system == 'openrouter' + assert str(model.client.base_url) == 'https://openrouter.ai/api/v1/' + + +async def test_openrouter_model_list_content(allow_model_requests: None): + """Test OpenRouterModel with list content in UserPromptPart.""" + c = completion_message(ChatCompletionMessage(content='Response', role='assistant')) + mock_client = MockOpenAI.create_mock(c) + model = create_openrouter_model('google/gemini-2.5-flash-lite', mock_client) + + # Create a UserPromptPart with list content (not just a string) + messages: list[ModelMessage] = [ModelRequest([UserPromptPart(content=['Hello', 'world', '!'])])] + model_settings = None + model_request_parameters = ModelRequestParameters( + function_tools=[], + output_tools=[], + allow_text_output=True, + ) + + response = await model.request(messages, model_settings, model_request_parameters) + assert isinstance(response.parts[0], TextPart) + assert response.parts[0].content == 'Response' + + # Verify the list content was properly joined + kwargs = cast(MockOpenAI, mock_client).chat_completion_kwargs[0] + assert kwargs['messages'][0]['content'] == 'Hello world !' + + +async def test_openrouter_system_prompt_in_user_message(allow_model_requests: None): + """Test OpenRouterModel with SystemPromptPart in user message.""" + c = completion_message(ChatCompletionMessage(content='Response with system prompt', role='assistant')) + mock_client = MockOpenAI.create_mock(c) + model = create_openrouter_model('google/gemini-2.5-flash-lite', mock_client) + + messages: list[ModelMessage] = [ + ModelRequest([SystemPromptPart(content='You are a helpful assistant.'), UserPromptPart(content='Hello')]) + ] + model_settings = None + model_request_parameters = ModelRequestParameters( + function_tools=[], + output_tools=[], + allow_text_output=True, + ) + + response = await model.request(messages, model_settings, model_request_parameters) + assert isinstance(response.parts[0], TextPart) + assert response.parts[0].content == 'Response with system prompt' + + kwargs = cast(MockOpenAI, mock_client).chat_completion_kwargs[0] + messages_sent = kwargs['messages'] + assert len(messages_sent) == 2 + assert messages_sent[0]['role'] == 'system' + assert messages_sent[0]['content'] == 'You are a helpful assistant.' + assert messages_sent[1]['role'] == 'user' + assert messages_sent[1]['content'] == 'Hello' + + +async def test_openrouter_retry_prompt_scenarios(allow_model_requests: None): + """Test RetryPromptPart handling for different retry scenarios.""" + mock_client = MockOpenAI.create_mock([]) + model = create_openrouter_model('openai/gpt-4o', mock_client) + + retry_part_no_tool = RetryPromptPart( + content='Invalid input, please try again', + tool_name=None, + ) + + request_no_tool = ModelRequest( + parts=[retry_part_no_tool], + ) + + messages_no_tool: list[dict[str, Any]] = [] + async for msg in model._map_user_message(request_no_tool): # type: ignore[reportPrivateUsage] + messages_no_tool.append(msg) # type: ignore[reportUnknownMemberType] + + assert len(messages_no_tool) == 1 + assert isinstance(messages_no_tool[0], dict) + assert messages_no_tool[0]['role'] == 'user' + assert 'Invalid input, please try again' in messages_no_tool[0]['content'] + assert 'Fix the errors and try again.' in messages_no_tool[0]['content'] + + retry_part_with_tool = RetryPromptPart( + content='Tool execution failed', tool_name='get_weather', tool_call_id='call_12345' + ) + + request_with_tool = ModelRequest( + parts=[retry_part_with_tool], + ) + + messages_with_tool: list[dict[str, Any]] = [] + async for msg in model._map_user_message(request_with_tool): # type: ignore[reportPrivateUsage] + messages_with_tool.append(msg) # type: ignore[reportUnknownMemberType] + + assert len(messages_with_tool) == 1 + assert isinstance(messages_with_tool[0], dict) + assert messages_with_tool[0]['role'] == 'tool' + assert messages_with_tool[0]['tool_call_id'] == 'call_12345' + assert 'Tool execution failed' in messages_with_tool[0]['content'] + assert 'Fix the errors and try again.' in messages_with_tool[0]['content'] + + validation_errors = [ + pydantic_core.ErrorDetails( + type='string_type', + loc=('field_name',), + msg='Input should be a valid string', + input=123, + ) + ] + + retry_part_validation = RetryPromptPart( + content=validation_errors, tool_name='validate_input', tool_call_id='call_67890' + ) + + request_validation = ModelRequest( + parts=[retry_part_validation], + ) + + messages_validation: list[dict[str, Any]] = [] + async for msg in model._map_user_message(request_validation): # type: ignore[reportPrivateUsage] + messages_validation.append(msg) # type: ignore[reportUnknownMemberType] + + assert len(messages_validation) == 1 + assert isinstance(messages_validation[0], dict) + assert messages_validation[0]['role'] == 'tool' + assert messages_validation[0]['tool_call_id'] == 'call_67890' + content: str = messages_validation[0]['content'] + assert '1 validation errors' in content + assert 'Input should be a valid string' in content + assert 'Fix the errors and try again.' in content + + +async def test_openrouter_user_prompt_mixed_content(allow_model_requests: None): + """Test UserPromptPart with mixed string and non-string content.""" + + mock_client = MockOpenAI.create_mock([]) + model = create_openrouter_model('openai/gpt-4o', mock_client) + + user_prompt_mixed = UserPromptPart( + content=[ + 'Hello, here is an image: ', + ImageUrl(url='https://example.com/image.jpg'), + ' and some more text.', + ] + ) + + result = model._map_user_prompt(user_prompt_mixed) # type: ignore[reportPrivateUsage] + + assert result['role'] == 'user' + assert result['content'] == 'Hello, here is an image: and some more text.' diff --git a/tests/test_examples.py b/tests/test_examples.py index df390519d7..ad6565acaa 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -157,6 +157,7 @@ def print(self, *args: Any, **kwargs: Any) -> None: env.set('CO_API_KEY', 'testing') env.set('MISTRAL_API_KEY', 'testing') env.set('ANTHROPIC_API_KEY', 'testing') + env.set('OPENROUTER_API_KEY', 'testing') env.set('HF_TOKEN', 'hf_testing') env.set('AWS_ACCESS_KEY_ID', 'testing') env.set('AWS_SECRET_ACCESS_KEY', 'testing') From fc5be63840a4bc5dd264c06a07e8099c61df60bf Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Wed, 1 Oct 2025 21:08:55 +0000 Subject: [PATCH 2/2] fix: resolve type errors in OpenRouter error response handling --- .../pydantic_ai/models/_openai_compat.py | 16 +++- tests/models/test_openrouter.py | 75 +++++++++++++++++-- 2 files changed, 82 insertions(+), 9 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/_openai_compat.py b/pydantic_ai_slim/pydantic_ai/models/_openai_compat.py index 07cdaf3c33..97fb6e0f89 100644 --- a/pydantic_ai_slim/pydantic_ai/models/_openai_compat.py +++ b/pydantic_ai_slim/pydantic_ai/models/_openai_compat.py @@ -9,7 +9,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Callable, Mapping from dataclasses import dataclass, field, replace from datetime import datetime -from typing import Any, Literal, overload +from typing import Any, Literal, cast, overload from pydantic import ValidationError from typing_extensions import assert_never @@ -27,7 +27,7 @@ ToolCallPart, ) -from .. import UnexpectedModelBehavior, _utils, usage +from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._output import OutputObjectDefinition from .._thinking_part import split_content_into_text_and_thinking from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime @@ -277,6 +277,18 @@ def process_response( if not isinstance(response, chat.ChatCompletion): raise UnexpectedModelBehavior('Invalid response from OpenAI chat completions endpoint, expected JSON data') + if hasattr(response, 'error'): + error_attr = getattr(response, 'error', None) + if error_attr and isinstance(error_attr, dict): + error_dict = cast(dict[str, Any], error_attr) + error_code = error_dict.get('code') + status_code = error_code if isinstance(error_code, int) else 500 + raise ModelHTTPError( + status_code=status_code, + model_name=getattr(model, 'model_name', 'unknown'), + body={'error': error_dict}, + ) + if response.created: timestamp = number_to_datetime(response.created) else: diff --git a/tests/models/test_openrouter.py b/tests/models/test_openrouter.py index 698f3719e6..26bd4bdbd8 100644 --- a/tests/models/test_openrouter.py +++ b/tests/models/test_openrouter.py @@ -1,13 +1,13 @@ import json import os from typing import Any, Literal, cast -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pydantic_core import pytest from inline_snapshot import snapshot -from pydantic_ai import Agent +from pydantic_ai import Agent, ModelHTTPError from pydantic_ai.messages import ( ImageUrl, ModelMessage, @@ -31,6 +31,7 @@ with try_import() as imports_successful: from openai.types.chat import ( + ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ) @@ -93,8 +94,6 @@ def chunk(choices: list[ChunkChoice]) -> ChatCompletionChunk: def test_openrouter_model_init(): c = completion_message(ChatCompletionMessage(content='test', role='assistant')) mock_client = MockOpenAI.create_mock(c) - from pydantic_ai.providers.openrouter import OpenRouterProvider - provider = OpenRouterProvider(openai_client=mock_client) model = OpenRouterModel('google/gemini-2.5-flash-lite', provider=provider) assert model.model_name == 'google/gemini-2.5-flash-lite' @@ -418,9 +417,6 @@ async def test_openrouter_with_reasoning_settings(allow_model_requests: None): async def test_openrouter_model_custom_base_url(allow_model_requests: None): """Test OpenRouterModel with provider.""" - # Test with provider using default base URL - from pydantic_ai.providers.openrouter import OpenRouterProvider - provider = OpenRouterProvider(api_key='test-key') model = OpenRouterModel('openai/gpt-4o', provider=provider) assert model.model_name == 'openai/gpt-4o' @@ -573,3 +569,68 @@ async def test_openrouter_user_prompt_mixed_content(allow_model_requests: None): assert result['role'] == 'user' assert result['content'] == 'Hello, here is an image: and some more text.' + + +async def test_openrouter_error_response_with_error_key(allow_model_requests: None): + """Test that OpenRouter error responses with 'error' key are properly handled. + + Regression test for issue #2323 where OpenRouter returns HTTP 200 with an error + object in the body (e.g., from upstream provider failures like Chutes). + """ + with patch('pydantic_ai.models.openrouter.OpenRouterProvider') as mock_provider_class: + mock_provider = MagicMock() + mock_provider.client = AsyncMock() + mock_provider.model_profile = MagicMock(return_value=MagicMock()) + mock_provider_class.return_value = mock_provider + + model = OpenRouterModel('deepseek/deepseek-chat-v3-0324:free') + + error_response = ChatCompletion.model_construct( + id='error-response', + choices=[], + created=1234567890, + model='deepseek/deepseek-chat-v3-0324:free', + object='chat.completion', + ) + setattr(error_response, 'error', {'message': 'Internal Server Error', 'code': 500}) + + with patch('pydantic_ai.models.openrouter.completions_create', new_callable=AsyncMock) as mock_create: + mock_create.return_value = error_response + + messages: list[ModelMessage] = [ModelRequest(parts=[UserPromptPart(content='test')])] + model_params = cast(ModelRequestParameters, {}) + + with pytest.raises(ModelHTTPError, match='status_code: 500.*Internal Server Error'): + await model.request(messages, None, model_params) + + +async def test_openrouter_error_response_with_none_error(allow_model_requests: None): + """Test that responses with error=None are handled gracefully.""" + c = completion_message(ChatCompletionMessage(content='Success', role='assistant')) + setattr(c, 'error', None) + + mock_client = MockOpenAI.create_mock(c) + model = create_openrouter_model('openai/gpt-4o', mock_client) + + messages: list[ModelMessage] = [ModelRequest([UserPromptPart(content='test')])] + model_request_parameters = ModelRequestParameters(function_tools=[], output_tools=[], allow_text_output=True) + + response = await model.request(messages, None, model_request_parameters) + assert isinstance(response.parts[0], TextPart) + assert response.parts[0].content == 'Success' + + +async def test_openrouter_error_response_with_non_dict_error(allow_model_requests: None): + """Test that responses with non-dict error values are handled gracefully.""" + c = completion_message(ChatCompletionMessage(content='Success', role='assistant')) + setattr(c, 'error', 'some error string') + + mock_client = MockOpenAI.create_mock(c) + model = create_openrouter_model('openai/gpt-4o', mock_client) + + messages: list[ModelMessage] = [ModelRequest([UserPromptPart(content='test')])] + model_request_parameters = ModelRequestParameters(function_tools=[], output_tools=[], allow_text_output=True) + + response = await model.request(messages, None, model_request_parameters) + assert isinstance(response.parts[0], TextPart) + assert response.parts[0].content == 'Success'