Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/models/openrouter.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,49 @@ model = OpenRouterModel('openai/gpt-5')
agent = Agent(model, model_settings=settings)
...
```

## Image Generation

You can use OpenRouter models that support image generation with `BinaryImage` output type:

```python {test="skip"}
from pydantic_ai import Agent, BinaryImage

agent = Agent(
model='openrouter:google/gemini-2.5-flash-image-preview',
output_type=str | BinaryImage,
)

result = agent.run_sync('A cat')
assert isinstance(result.output, BinaryImage)
```

You can further customize image generation using the `ImageGenerationTool` built-in tool:

```python
from pydantic_ai import ImageGenerationTool

builtin_tools=[ImageGenerationTool(aspect_ratio='3:2')]
```

> Available aspect ratios: `'1:1'`, `'2:3'`, `'3:2'`, `'3:4'`, `'4:3'`, `'4:5'`, `'5:4'`, `'9:16'`, `'16:9'`, `'21:9'`.

Image generation also works with streaming:

```python {test="skip"}
from pydantic_ai import Agent, BinaryImage, ImageGenerationTool

agent = Agent(
model='openrouter:google/gemini-2.5-flash-image-preview',
output_type=str | BinaryImage,
builtin_tools=[ImageGenerationTool(aspect_ratio='3:2')],
)

response = agent.run_stream_sync('A dog')
for output in response.stream_output():
if isinstance(output, str):
print(output)
elif isinstance(output, BinaryImage):
# Handle the generated image
print(f'Generated image: {output.media_type}')
```
69 changes: 40 additions & 29 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,31 +631,9 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons
raise UnexpectedModelBehavior(f'Invalid response from {self.system} chat completions endpoint: {e}') from e

choice = response.choices[0]
items: list[ModelResponsePart] = []

if thinking_parts := self._process_thinking(choice.message):
items.extend(thinking_parts)

if choice.message.content:
items.extend(
(replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part)
for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags)
)
if choice.message.tool_calls is not None:
for c in choice.message.tool_calls:
if isinstance(c, ChatCompletionMessageFunctionToolCall):
part = ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)
elif isinstance(c, ChatCompletionMessageCustomToolCall): # pragma: no cover
# NOTE: Custom tool calls are not supported.
# See <https://github.com/pydantic/pydantic-ai/issues/2513> for more details.
raise RuntimeError('Custom tool calls are not supported')
else:
assert_never(c)
part.tool_call_id = _guard_tool_call_id(part)
items.append(part)

return ModelResponse(
parts=items,
parts=list(self._process_parts(choice.message)),
usage=self._map_usage(response),
model_name=response.model,
timestamp=timestamp,
Expand All @@ -666,14 +644,13 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons
finish_reason=self._map_finish_reason(choice.finish_reason),
)

def _process_thinking(self, message: chat.ChatCompletionMessage) -> list[ThinkingPart] | None:
def _process_thinking(self, message: chat.ChatCompletionMessage) -> Iterable[ThinkingPart]:
"""Hook that maps reasoning tokens to thinking parts.

This method may be overridden by subclasses of `OpenAIChatModel` to apply custom mappings.
"""
profile = OpenAIModelProfile.from_profile(self.profile)
custom_field = profile.openai_chat_thinking_field
items: list[ThinkingPart] = []

# Prefer the configured custom reasoning field, if present in profile.
# Fall back to built-in fields if no custom field result was found.
Expand All @@ -689,10 +666,44 @@ def _process_thinking(self, message: chat.ChatCompletionMessage) -> list[Thinkin
continue
reasoning: str | None = getattr(message, field_name, None)
if reasoning: # pragma: no branch
items.append(ThinkingPart(id=field_name, content=reasoning, provider_name=self.system))
return items
yield ThinkingPart(id=field_name, content=reasoning, provider_name=self.system)
break

def _process_content(self, message: chat.ChatCompletionMessage) -> Iterable[TextPart | ThinkingPart]:
"""Hook that maps the message content to thinking or text parts.

This method may be overridden by subclasses of `OpenAIChatModel` to apply custom mappings.
"""
if message.content:
for part in split_content_into_text_and_thinking(message.content, self.profile.thinking_tags):
yield replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part

return items or None
def _process_tool_calls(self, message: chat.ChatCompletionMessage) -> Iterable[ToolCallPart]:
"""Hook that maps tool calls to tool call parts.

This method may be overridden by subclasses of `OpenAIChatModel` to apply custom mappings.
"""
if message.tool_calls is not None:
for c in message.tool_calls:
if isinstance(c, ChatCompletionMessageFunctionToolCall):
part = ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)
elif isinstance(c, ChatCompletionMessageCustomToolCall): # pragma: no cover
# NOTE: Custom tool calls are not supported.
# See <https://github.com/pydantic/pydantic-ai/issues/2513> for more details.
raise RuntimeError('Custom tool calls are not supported')
else:
assert_never(c)
part.tool_call_id = _guard_tool_call_id(part)
yield part

def _process_parts(self, message: chat.ChatCompletionMessage) -> Iterable[ModelResponsePart]:
"""Hook that defines the mappings to transform message contents to response parts.

This method may be overridden by subclasses of `OpenAIChatModel` to apply custom mappings.
"""
return itertools.chain(
self._process_thinking(message), self._process_content(message), self._process_tool_calls(message)
)

async def _process_streamed_response(
self, response: AsyncStream[ChatCompletionChunk], model_request_parameters: ModelRequestParameters
Expand Down Expand Up @@ -781,7 +792,7 @@ def map_assistant_message(self, message: ModelResponse) -> chat.ChatCompletionAs
self._map_response_tool_call_part(item)
elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover
self._map_response_builtin_part(item)
elif isinstance(item, FilePart): # pragma: no cover
elif isinstance(item, FilePart):
self._map_response_file_part(item)
else:
assert_never(item)
Expand Down
105 changes: 100 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/openrouter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from __future__ import annotations as _annotations

import base64
import itertools
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Annotated, Any, Literal, TypeAlias, cast

from pydantic import BaseModel, Discriminator
from typing_extensions import TypedDict, assert_never, override

from ..builtin_tools import ImageGenerationTool
from ..exceptions import ModelHTTPError
from ..messages import (
BinaryImage,
FilePart,
FinishReason,
ModelResponsePart,
ModelResponseStreamEvent,
ThinkingPart,
)
Expand Down Expand Up @@ -381,6 +387,14 @@ class _OpenRouterChatCompletionMessageFunctionToolCall(chat.ChatCompletionMessag
]


class _OpenRouterImageUrl(BaseModel):
url: str


class _OpenRouterImageGeneration(BaseModel):
image_url: _OpenRouterImageUrl


class _OpenRouterCompletionMessage(chat.ChatCompletionMessage):
"""Wrapped chat completion message with OpenRouter specific attributes."""

Expand All @@ -393,6 +407,9 @@ class _OpenRouterCompletionMessage(chat.ChatCompletionMessage):
tool_calls: list[_OpenRouterChatCompletionMessageToolCallUnion] | None = None # type: ignore[reportIncompatibleVariableOverride]
"""The tool calls generated by the model, such as function calls."""

images: list[_OpenRouterImageGeneration] | None = None
"""The images generated by the model, if any."""

annotations: list[_OpenRouterAnnotation] | None = None # type: ignore[reportIncompatibleVariableOverride]
"""Annotations associated with the message, supporting both url_citation and file types."""

Expand Down Expand Up @@ -486,11 +503,14 @@ def _map_openrouter_provider_details(
return provider_details


def _openrouter_settings_to_openai_settings(model_settings: OpenRouterModelSettings) -> OpenAIChatModelSettings:
def _openrouter_settings_to_openai_settings(
model_settings: OpenRouterModelSettings, model_request_parameters: ModelRequestParameters
) -> OpenAIChatModelSettings:
"""Transforms a 'OpenRouterModelSettings' object into an 'OpenAIChatModelSettings' object.

Args:
model_settings: The 'OpenRouterModelSettings' object to transform.
model_request_parameters: The 'ModelRequestParameters' object to use for the transformation.

Returns:
An 'OpenAIChatModelSettings' object with equivalent settings.
Expand All @@ -510,6 +530,18 @@ def _openrouter_settings_to_openai_settings(model_settings: OpenRouterModelSetti
if usage := model_settings.pop('openrouter_usage', None):
extra_body['usage'] = usage

for builtin_tool in model_request_parameters.builtin_tools:
if isinstance(builtin_tool, ImageGenerationTool): # pragma: lax no cover
extra_body['modalities'] = ['text', 'image']

image_config: dict[str, str] = {}
if aspect_ratio := builtin_tool.aspect_ratio:
image_config['aspect_ratio'] = aspect_ratio
extra_body['image_config'] = image_config

if isinstance(model_request_parameters.output_object, BinaryImage): # pragma: lax no cover
extra_body['modalities'] = ['text', 'image']

model_settings['extra_body'] = extra_body

return OpenAIChatModelSettings(**model_settings) # type: ignore[reportCallIssue]
Expand Down Expand Up @@ -543,9 +575,16 @@ def prepare_request(
model_request_parameters: ModelRequestParameters,
) -> tuple[ModelSettings | None, ModelRequestParameters]:
merged_settings, customized_parameters = super().prepare_request(model_settings, model_request_parameters)
new_settings = _openrouter_settings_to_openai_settings(cast(OpenRouterModelSettings, merged_settings or {}))
new_settings = _openrouter_settings_to_openai_settings(
cast(OpenRouterModelSettings, merged_settings or {}), model_request_parameters
)
return new_settings, customized_parameters

@override
def _get_web_search_options(self, model_request_parameters: ModelRequestParameters):
"""This method is nullified because OpenRouter handles web search through a different parameter."""
return None

@override
def _validate_completion(self, response: chat.ChatCompletion) -> _OpenRouterChatCompletion:
response = _OpenRouterChatCompletion.model_validate(response.model_dump())
Expand All @@ -556,13 +595,27 @@ def _validate_completion(self, response: chat.ChatCompletion) -> _OpenRouterChat
return response

@override
def _process_thinking(self, message: chat.ChatCompletionMessage) -> list[ThinkingPart] | None:
def _process_thinking(self, message: chat.ChatCompletionMessage) -> Iterable[ThinkingPart]:
assert isinstance(message, _OpenRouterCompletionMessage)

if reasoning_details := message.reasoning_details:
return [_from_reasoning_detail(detail) for detail in reasoning_details]
for detail in reasoning_details:
yield _from_reasoning_detail(detail)
else:
return super()._process_thinking(message)
yield from super()._process_thinking(message)

def _process_image(self, message: chat.ChatCompletionMessage) -> Iterable[FilePart]:
assert isinstance(message, _OpenRouterCompletionMessage)

if images := message.images:
for image in images:
yield FilePart(
content=BinaryImage.from_data_uri(image.image_url.url),
)

@override
def _process_parts(self, message: chat.ChatCompletionMessage) -> Iterable[ModelResponsePart]:
return itertools.chain(super()._process_parts(message), self._process_image(message))

@override
def _process_provider_details(self, response: chat.ChatCompletion) -> dict[str, Any]:
Expand All @@ -575,11 +628,20 @@ def _process_provider_details(self, response: chat.ChatCompletion) -> dict[str,
@dataclass
class _MapModelResponseContext(OpenAIChatModel._MapModelResponseContext): # type: ignore[reportPrivateUsage]
reasoning_details: list[dict[str, Any]] = field(default_factory=list)
file_inputs: list[dict[str, dict[str, str]]] = field(default_factory=list)

def _into_message_param(self) -> chat.ChatCompletionAssistantMessageParam:
message_param = super()._into_message_param()
if self.reasoning_details:
message_param['reasoning_details'] = self.reasoning_details # type: ignore[reportGeneralTypeIssues]
if self.file_inputs:
content = message_param.get('content')
if isinstance(content, str): # pragma: lax no cover
message_param['content'] = [{'type': 'text', 'text': content}] + self.file_inputs # type: ignore[reportGeneralTypeIssues]
elif isinstance(content, list): # pragma: lax no cover
message_param['content'] = content + self.file_inputs # type: ignore[reportGeneralTypeIssues]
else:
message_param['content'] = self.file_inputs # type: ignore[reportGeneralTypeIssues]
return message_param

@override
Expand All @@ -591,6 +653,17 @@ def _map_response_thinking_part(self, item: ThinkingPart) -> None:
else: # pragma: lax no cover
super()._map_response_thinking_part(item)

@override
def _map_response_file_part(self, item: FilePart) -> None:
if item.content.media_type in (
'image/png',
'image/jpeg',
'image/webp',
'image/gif',
): # pragma: lax no cover
encoding = base64.b64encode(item.content.data).decode('utf-8')
self.file_inputs.append({'image_url': {'url': encoding}})

@property
@override
def _streamed_response_cls(self):
Expand All @@ -612,6 +685,9 @@ class _OpenRouterChoiceDelta(chat_completion_chunk.ChoiceDelta):
reasoning_details: list[_OpenRouterReasoningDetail] | None = None
"""The reasoning details associated with the message, if any."""

images: list[_OpenRouterImageGeneration] | None = None
"""The images generated by the model, if any."""

annotations: list[_OpenRouterAnnotation] | None = None
"""Annotations associated with the message, supporting both url_citation and file types."""

Expand Down Expand Up @@ -682,6 +758,25 @@ def _map_thinking_delta(self, choice: chat_completion_chunk.Choice) -> Iterable[
else:
return super()._map_thinking_delta(choice)

def _map_file_delta(self, choice: chat_completion_chunk.Choice) -> Iterable[ModelResponseStreamEvent]:
assert isinstance(choice, _OpenRouterChunkChoice)

if images := choice.delta.images:
for image in images:
yield self._parts_manager.handle_part(
vendor_part_id=None,
part=FilePart(
content=BinaryImage.from_data_uri(image.image_url.url),
),
)

@override
def _map_part_delta(self, choice: chat_completion_chunk.Choice) -> Iterable[ModelResponseStreamEvent]:
return itertools.chain(
super()._map_part_delta(choice),
self._map_file_delta(choice),
)

@override
def _map_provider_details(self, chunk: chat.ChatCompletionChunk) -> dict[str, Any] | None:
assert isinstance(chunk, _OpenRouterChatCompletionChunk)
Expand Down
Loading
Loading