diff --git a/docs/models/google.md b/docs/models/google.md index 2cc35d9c0..30816715a 100644 --- a/docs/models/google.md +++ b/docs/models/google.md @@ -104,14 +104,14 @@ You can supply a custom `GoogleProvider` instance using the `provider` argument This is useful if you're using a custom-compatible endpoint with the Google Generative Language API. ```python -from google import genai +from google.genai import Client from google.genai.types import HttpOptions from pydantic_ai import Agent from pydantic_ai.models.google import GoogleModel from pydantic_ai.providers.google import GoogleProvider -client = genai.Client( +client = Client( api_key='gemini-custom-api-key', http_options=HttpOptions(base_url='gemini-custom-base-url'), ) diff --git a/docs/tools.md b/docs/tools.md index 4b40e7881..eddffb703 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -770,7 +770,7 @@ from pydantic_ai.ext.langchain import LangChainToolset toolkit = SlackToolkit() -toolset = LangChainToolset(toolkit.get_tools()) +toolset = LangChainToolset(toolkit.get_tools(), id='slack') agent = Agent('openai:gpt-4o', toolsets=[toolset]) # ... @@ -823,6 +823,7 @@ toolset = ACIToolset( 'OPEN_WEATHER_MAP__FORECAST', ], linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'), + id='open_weather_map', ) agent = Agent('openai:gpt-4o', toolsets=[toolset]) diff --git a/docs/toolsets.md b/docs/toolsets.md index 5caac22c0..add7009f9 100644 --- a/docs/toolsets.md +++ b/docs/toolsets.md @@ -84,7 +84,10 @@ def temperature_fahrenheit(city: str) -> float: return 69.8 -weather_toolset = FunctionToolset(tools=[temperature_celsius, temperature_fahrenheit]) +weather_toolset = FunctionToolset( + tools=[temperature_celsius, temperature_fahrenheit], + id='weather', # (1)! +) @weather_toolset.tool @@ -95,10 +98,10 @@ def conditions(ctx: RunContext, city: str) -> str: return "It's raining" -datetime_toolset = FunctionToolset() +datetime_toolset = FunctionToolset(id='datetime') datetime_toolset.add_function(lambda: datetime.now(), name='now') -test_model = TestModel() # (1)! +test_model = TestModel() # (2)! agent = Agent(test_model) result = agent.run_sync('What tools are available?', toolsets=[weather_toolset]) @@ -110,7 +113,8 @@ print([t.name for t in test_model.last_model_request_parameters.function_tools]) #> ['now'] ``` -1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. +1. `FunctionToolset` supports an optional `id` argument that can help to identify the toolset in error messages. A toolset also needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow. +2. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. _(This example is complete, it can be run "as is")_ @@ -609,7 +613,7 @@ from pydantic_ai.ext.langchain import LangChainToolset toolkit = SlackToolkit() -toolset = LangChainToolset(toolkit.get_tools()) +toolset = LangChainToolset(toolkit.get_tools(), id='slack') agent = Agent('openai:gpt-4o', toolsets=[toolset]) # ... @@ -634,6 +638,7 @@ toolset = ACIToolset( 'OPEN_WEATHER_MAP__FORECAST', ], linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'), + id='open_weather_map', ) agent = Agent('openai:gpt-4o', toolsets=[toolset]) diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 8a916b5cc..101d3f2e9 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -27,7 +27,7 @@ VideoUrl, ) -from .agent import Agent, AgentDepsT, OutputDataT +from .agent import AbstractAgent, AgentDepsT, OutputDataT # AgentWorker output type needs to be invariant for use in both parameter and return positions WorkerOutputT = TypeVar('WorkerOutputT') @@ -59,7 +59,9 @@ @asynccontextmanager -async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT, OutputDataT]) -> AsyncIterator[None]: +async def worker_lifespan( + app: FastA2A, worker: Worker, agent: AbstractAgent[AgentDepsT, OutputDataT] +) -> AsyncIterator[None]: """Custom lifespan that runs the worker during application startup. This ensures the worker is started and ready to process tasks as soon as the application starts. @@ -70,7 +72,7 @@ async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT, def agent_to_a2a( - agent: Agent[AgentDepsT, OutputDataT], + agent: AbstractAgent[AgentDepsT, OutputDataT], *, storage: Storage | None = None, broker: Broker | None = None, @@ -116,7 +118,7 @@ def agent_to_a2a( class AgentWorker(Worker[list[ModelMessage]], Generic[WorkerOutputT, AgentDepsT]): """A worker that uses an agent to execute tasks.""" - agent: Agent[AgentDepsT, WorkerOutputT] + agent: AbstractAgent[AgentDepsT, WorkerOutputT] async def run_task(self, params: TaskSendParams) -> None: task = await self.storage.load_task(params['id']) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 312a8a2fc..16df8f9dc 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -303,10 +303,18 @@ async def stream( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], ) -> AsyncIterator[result.AgentStream[DepsT, T]]: - async with self._stream(ctx) as streamed_response: + assert not self._did_stream, 'stream() should only be called once per node' + + model_settings, model_request_parameters, message_history, run_context = await self._prepare_request(ctx) + async with ctx.deps.model.request_stream( + message_history, model_settings, model_request_parameters, run_context + ) as streamed_response: + self._did_stream = True + ctx.state.usage.requests += 1 agent_stream = result.AgentStream[DepsT, T]( streamed_response, ctx.deps.output_schema, + model_request_parameters, ctx.deps.output_validators, build_run_context(ctx), ctx.deps.usage_limits, @@ -318,28 +326,6 @@ async def stream( async for _ in agent_stream: pass - @asynccontextmanager - async def _stream( - self, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], - ) -> AsyncIterator[models.StreamedResponse]: - assert not self._did_stream, 'stream() should only be called once per node' - - model_settings, model_request_parameters = await self._prepare_request(ctx) - model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) - message_history = await _process_message_history( - ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) - ) - async with ctx.deps.model.request_stream( - message_history, model_settings, model_request_parameters - ) as streamed_response: - self._did_stream = True - ctx.state.usage.requests += 1 - yield streamed_response - # In case the user didn't manually consume the full stream, ensure it is fully consumed here, - # otherwise usage won't be properly counted: - async for _ in streamed_response: - pass model_response = streamed_response.get() self._finish_handling(ctx, model_response) @@ -351,11 +337,7 @@ async def _make_request( if self._result is not None: return self._result # pragma: no cover - model_settings, model_request_parameters = await self._prepare_request(ctx) - model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) - message_history = await _process_message_history( - ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) - ) + model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx) model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.incr(_usage.Usage()) @@ -363,7 +345,7 @@ async def _make_request( async def _prepare_request( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> tuple[ModelSettings | None, models.ModelRequestParameters]: + ) -> tuple[ModelSettings | None, models.ModelRequestParameters, list[_messages.ModelMessage], RunContext[DepsT]]: ctx.state.message_history.append(self.request) # Check usage @@ -373,9 +355,18 @@ async def _prepare_request( # Increment run_step ctx.state.run_step += 1 + run_context = build_run_context(ctx) + model_settings = merge_model_settings(ctx.deps.model_settings, None) + model_request_parameters = await _prepare_request_parameters(ctx) - return model_settings, model_request_parameters + model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) + + message_history = await _process_message_history( + ctx.state.message_history, ctx.deps.history_processors, run_context + ) + + return model_settings, model_request_parameters, message_history, run_context def _finish_handling( self, diff --git a/pydantic_ai_slim/pydantic_ai/_cli.py b/pydantic_ai_slim/pydantic_ai/_cli.py index ae4f6ff6f..3f596405a 100644 --- a/pydantic_ai_slim/pydantic_ai/_cli.py +++ b/pydantic_ai_slim/pydantic_ai/_cli.py @@ -16,7 +16,7 @@ from . import __version__ from ._run_context import AgentDepsT -from .agent import Agent +from .agent import AbstractAgent, Agent from .exceptions import UserError from .messages import ModelMessage from .models import KnownModelName, infer_model @@ -220,7 +220,7 @@ def cli( # noqa: C901 async def run_chat( stream: bool, - agent: Agent[AgentDepsT, OutputDataT], + agent: AbstractAgent[AgentDepsT, OutputDataT], console: Console, code_theme: str, prog_name: str, @@ -263,7 +263,7 @@ async def run_chat( async def ask_agent( - agent: Agent[AgentDepsT, OutputDataT], + agent: AbstractAgent[AgentDepsT, OutputDataT], prompt: str, stream: bool, console: Console, diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 67849b0c7..71e7251fd 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -977,6 +977,10 @@ def __init__( self.max_retries = max_retries self.output_validators = output_validators or [] + @property + def id(self) -> str | None: + return 'output' + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: return { tool_def.name: ToolsetTool( diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index 2dbe6faf3..e4fcb4d81 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -72,7 +72,7 @@ from pydantic import BaseModel, ValidationError from ._agent_graph import CallToolsNode, ModelRequestNode -from .agent import Agent, AgentRun, RunOutputDataT +from .agent import AbstractAgent, AgentRun, RunOutputDataT from .messages import ( AgentStreamEvent, FunctionToolResultEvent, @@ -115,7 +115,7 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette): def __init__( self, - agent: Agent[AgentDepsT, OutputDataT], + agent: AbstractAgent[AgentDepsT, OutputDataT], *, # Agent.iter parameters. output_type: OutputSpec[OutputDataT] | None = None, @@ -223,7 +223,7 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]): agent: The Pydantic AI `Agent` to adapt. """ - agent: Agent[AgentDepsT, OutputDataT] = field(repr=False) + agent: AbstractAgent[AgentDepsT, OutputDataT] = field(repr=False) async def run( self, @@ -273,7 +273,8 @@ async def run( parameters_json_schema=tool.parameters, ) for tool in run_input.tools - ] + ], + id='ag_ui_frontend', ) toolsets = [*toolsets, toolset] if toolsets else [toolset] diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 9f6348c6c..c80a328ed 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -4,17 +4,18 @@ import inspect import json import warnings +from abc import ABC, abstractmethod from asyncio import Lock -from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator, Mapping, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar from copy import deepcopy from types import FrameType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, overload from opentelemetry.trace import NoOpTracer, use_span from pydantic.json_schema import GenerateJsonSchema -from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated +from typing_extensions import Literal, Never, Self, TypeAlias, TypeIs, TypeVar, deprecated from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop @@ -96,370 +97,310 @@ RunOutputDataT = TypeVar('RunOutputDataT') """Type variable for the result data of a run where `output_type` was customized on the run call.""" +EventStreamHandler: TypeAlias = Callable[ + [ + RunContext[AgentDepsT], + AsyncIterable[_messages.AgentStreamEvent | _messages.HandleResponseEvent], + ], + Awaitable[None], +] -@final -@dataclasses.dataclass(init=False) -class Agent(Generic[AgentDepsT, OutputDataT]): - """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. - Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT] - and the output type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT]. +class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC): + @property + @abstractmethod + def model(self) -> models.Model | models.KnownModelName | str | None: + raise NotImplementedError - By default, if neither generic parameter is customised, agents have type `Agent[None, str]`. + @property + @abstractmethod + def name(self) -> str | None: + raise NotImplementedError - Minimal usage example: + @name.setter + @abstractmethod + def name(self, value: str | None) -> None: + raise NotImplementedError - ```python - from pydantic_ai import Agent + @property + @abstractmethod + def output_type(self) -> OutputSpec[OutputDataT]: + raise NotImplementedError - agent = Agent('openai:gpt-4o') - result = agent.run_sync('What is the capital of France?') - print(result.output) - #> Paris - ``` - """ + @property + @abstractmethod + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + raise NotImplementedError - model: models.Model | models.KnownModelName | str | None - """The default model configured for this agent. + @overload + async def run( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AgentRunResult[OutputDataT]: ... - We allow `str` here since the actual list of allowed models changes frequently. - """ + @overload + async def run( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AgentRunResult[RunOutputDataT]: ... - name: str | None - """The name of the agent, used for logging. + @overload + @deprecated('`result_type` is deprecated, use `output_type` instead.') + async def run( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + result_type: type[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AgentRunResult[RunOutputDataT]: ... - If `None`, we try to infer the agent name from the call frame when the agent is first run. - """ - end_strategy: EndStrategy - """Strategy for handling tool calls when a final result is found.""" + async def run( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AgentRunResult[Any]: + """Run the agent with a user prompt in async mode. - model_settings: ModelSettings | None - """Optional model request settings to use for this agents's runs, by default. + This method builds an internal agent graph (using system prompts, tools and result schemas) and then + runs the graph to completion. The result of the run is returned. - Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will - be merged with this value, with the runtime argument taking priority. - """ + Example: + ```python + from pydantic_ai import Agent - output_type: OutputSpec[OutputDataT] - """ - The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`. - """ + agent = Agent('openai:gpt-4o') - instrument: InstrumentationSettings | bool | None - """Options to automatically instrument with OpenTelemetry.""" + async def main(): + agent_run = await agent.run('What is the capital of France?') + print(agent_run.output) + #> Paris + ``` - _instrument_default: ClassVar[InstrumentationSettings | bool] = False + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. - _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) - _deprecated_result_tool_name: str | None = dataclasses.field(repr=False) - _deprecated_result_tool_description: str | None = dataclasses.field(repr=False) - _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False) - _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False) - _instructions: str | None = dataclasses.field(repr=False) - _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) - _system_prompts: tuple[str, ...] = dataclasses.field(repr=False) - _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) - _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field( - repr=False - ) - _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False) - _output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False) - _user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False) - _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) - _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) - _max_result_retries: int = dataclasses.field(repr=False) + Returns: + The result of the run. + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) - _enter_lock: Lock = dataclasses.field(repr=False) - _entered_count: int = dataclasses.field(repr=False) - _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) + if 'result_type' in _deprecated_kwargs: # pragma: no cover + if output_type is not str: + raise TypeError('`result_type` and `output_type` cannot be set at the same time.') + warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) + output_type = _deprecated_kwargs.pop('result_type') + + _utils.validate_empty_kwargs(_deprecated_kwargs) + + async with self.iter( + user_prompt=user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + toolsets=toolsets, + ) as agent_run: + async for node in agent_run: + if self.event_stream_handler is not None and ( + self.is_model_request_node(node) or self.is_call_tools_node(node) + ): + async with node.stream(agent_run.ctx) as stream: + await self.event_stream_handler(_agent_graph.build_run_context(agent_run.ctx), stream) + + assert agent_run.result is not None, 'The graph run did not finish properly' + return agent_run.result @overload - def __init__( + def run_sync( self, - model: models.Model | models.KnownModelName | str | None = None, + user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[OutputDataT] = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, model_settings: ModelSettings | None = None, - retries: int = 1, - output_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - ) -> None: ... + ) -> AgentRunResult[OutputDataT]: ... @overload - @deprecated( - '`result_type`, `result_tool_name` & `result_tool_description` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' - ) - def __init__( + def run_sync( self, - model: models.Model | models.KnownModelName | str | None = None, + user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - result_type: type[OutputDataT] = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, model_settings: ModelSettings | None = None, - retries: int = 1, - result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, - result_tool_description: str | None = None, - result_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - ) -> None: ... + ) -> AgentRunResult[RunOutputDataT]: ... @overload - @deprecated('`mcp_servers` is deprecated, use `toolsets` instead.') - def __init__( + @deprecated('`result_type` is deprecated, use `output_type` instead.') + def run_sync( self, - model: models.Model | models.KnownModelName | str | None = None, + user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - result_type: type[OutputDataT] = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, + result_type: type[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, model_settings: ModelSettings | None = None, - retries: int = 1, - result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, - result_tool_description: str | None = None, - result_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[MCPServer] = (), - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - ) -> None: ... + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AgentRunResult[RunOutputDataT]: ... - def __init__( + def run_sync( self, - model: models.Model | models.KnownModelName | str | None = None, + user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - # TODO change this back to `output_type: _output.OutputType[OutputDataT] = str,` when we remove the overloads - output_type: Any = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, model_settings: ModelSettings | None = None, - retries: int = 1, - output_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - **_deprecated_kwargs: Any, - ): - """Create an agent. - - Args: - model: The default model to use for this agent, if not provide, - you must provide the model when calling it. We allow `str` here since the actual list of allowed models changes frequently. - output_type: The type of the output data, used to validate the data returned by the model, - defaults to `str`. - instructions: Instructions to use for this agent, you can also register instructions via a function with - [`instructions`][pydantic_ai.Agent.instructions]. - system_prompt: Static system prompts to use for this agent, you can also register system - prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt]. - deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully - parameterize the agent, and therefore get the best out of static type checking. - If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright - or add a type hint `: Agent[None, ]`. - name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame - when the agent is first run. - model_settings: Optional model request settings to use for this agent's runs, by default. - retries: The default number of retries to allow before raising an error. - output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. - tools: Tools to register with the agent, you can also register tools via the decorators - [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. - prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools. - This is useful if you want to customize the definition of multiple tools or you want to register - a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] - prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step. - This is useful if you want to customize the definition of multiple output tools or you want to register - a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] - toolsets: Toolsets to register with the agent, including MCP servers. - defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, - it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, - which checks for the necessary environment variables. Set this to `false` - to defer the evaluation until the first run. Useful if you want to - [override the model][pydantic_ai.Agent.override] for testing. - end_strategy: Strategy for handling tool calls that are requested alongside a final result. - See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information. - instrument: Set to True to automatically instrument with OpenTelemetry, - which will use Logfire if it's configured. - Set to an instance of [`InstrumentationSettings`][pydantic_ai.agent.InstrumentationSettings] to customize. - If this isn't set, then the last value set by - [`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all] - will be used, which defaults to False. - See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. - history_processors: Optional list of callables to process the message history before sending it to the model. - Each processor takes a list of messages and returns a modified list of messages. - Processors can be sync or async and are applied in sequence. - """ - if model is None or defer_model_check: - self.model = model - else: - self.model = models.infer_model(model) - - self.end_strategy = end_strategy - self.name = name - self.model_settings = model_settings + **_deprecated_kwargs: Never, + ) -> AgentRunResult[Any]: + """Synchronously run the agent with a user prompt. - if 'result_type' in _deprecated_kwargs: - if output_type is not str: # pragma: no cover - raise TypeError('`result_type` and `output_type` cannot be set at the same time.') - warnings.warn('`result_type` is deprecated, use `output_type` instead', DeprecationWarning, stacklevel=2) - output_type = _deprecated_kwargs.pop('result_type') + This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. + You therefore can't use this method inside async code or if there's an active event loop. - self.output_type = output_type + Example: + ```python + from pydantic_ai import Agent - self.instrument = instrument + agent = Agent('openai:gpt-4o') - self._deps_type = deps_type + result_sync = agent.run_sync('What is the capital of Italy?') + print(result_sync.output) + #> Rome + ``` - self._deprecated_result_tool_name = _deprecated_kwargs.pop('result_tool_name', None) - if self._deprecated_result_tool_name is not None: - warnings.warn( - '`result_tool_name` is deprecated, use `output_type` with `ToolOutput` instead', - DeprecationWarning, - stacklevel=2, - ) + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. - self._deprecated_result_tool_description = _deprecated_kwargs.pop('result_tool_description', None) - if self._deprecated_result_tool_description is not None: - warnings.warn( - '`result_tool_description` is deprecated, use `output_type` with `ToolOutput` instead', - DeprecationWarning, - stacklevel=2, - ) - result_retries = _deprecated_kwargs.pop('result_retries', None) - if result_retries is not None: - if output_retries is not None: # pragma: no cover - raise TypeError('`output_retries` and `result_retries` cannot be set at the same time.') - warnings.warn( - '`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning, stacklevel=2 - ) - output_retries = result_retries + Returns: + The result of the run. + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) - if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None): - if toolsets is not None: # pragma: no cover - raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.') - warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning) - toolsets = mcp_servers + if 'result_type' in _deprecated_kwargs: # pragma: no cover + if output_type is not str: + raise TypeError('`result_type` and `output_type` cannot be set at the same time.') + warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) + output_type = _deprecated_kwargs.pop('result_type') _utils.validate_empty_kwargs(_deprecated_kwargs) - default_output_mode = ( - self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None - ) - - self._output_schema = _output.OutputSchema[OutputDataT].build( - output_type, - default_mode=default_output_mode, - name=self._deprecated_result_tool_name, - description=self._deprecated_result_tool_description, - ) - self._output_validators = [] - - self._instructions = '' - self._instructions_functions = [] - if isinstance(instructions, (str, Callable)): - instructions = [instructions] - for instruction in instructions or []: - if isinstance(instruction, str): - self._instructions += instruction + '\n' - else: - self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction)) - self._instructions = self._instructions.strip() or None - - self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) - self._system_prompt_functions = [] - self._system_prompt_dynamic_functions = {} - - self._max_result_retries = output_retries if output_retries is not None else retries - self._prepare_tools = prepare_tools - self._prepare_output_tools = prepare_output_tools - - self._output_toolset = self._output_schema.toolset - if self._output_toolset: - self._output_toolset.max_retries = self._max_result_retries - - self._function_toolset = FunctionToolset(tools, max_retries=retries) - self._user_toolsets = toolsets or () - - self.history_processors = history_processors or [] - - self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) - self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) - self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar( - '_override_toolsets', default=None + return get_event_loop().run_until_complete( + self.run( + user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=False, + toolsets=toolsets, + ) ) - self._enter_lock = _utils.get_async_lock() - self._entered_count = 0 - self._exit_stack = None - - @staticmethod - def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: - """Set the instrumentation options for all agents where `instrument` is not set.""" - Agent._instrument_default = instrument - @overload - async def run( + def run_stream( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: None = None, message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, + model: models.Model | models.KnownModelName | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[OutputDataT]: ... + ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload - async def run( + def run_stream( self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, + user_prompt: str | Sequence[_messages.UserContent], *, output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, @@ -470,11 +411,11 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[RunOutputDataT]: ... + ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @overload @deprecated('`result_type` is deprecated, use `output_type` instead.') - async def run( + def run_stream( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, @@ -487,9 +428,10 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[RunOutputDataT]: ... + ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... - async def run( + @asynccontextmanager + async def run_stream( # noqa C901 self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, @@ -503,11 +445,8 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, - ) -> AgentRunResult[Any]: - """Run the agent with a user prompt in async mode. - - This method builds an internal agent graph (using system prompts, tools and result schemas) and then - runs the graph to completion. The result of the run is returned. + ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: + """Run the agent with a user prompt in async mode, returning a streamed response. Example: ```python @@ -516,9 +455,9 @@ async def run( agent = Agent('openai:gpt-4o') async def main(): - agent_run = await agent.run('What is the capital of France?') - print(agent_run.output) - #> Paris + async with agent.run_stream('What is the capital of the UK?') as response: + print(await response.get_output()) + #> London ``` Args: @@ -537,8 +476,12 @@ async def main(): Returns: The result of the run. """ + # TODO: We need to deprecate this now that we have the `iter` method. + # Before that, though, we should add an event for when we reach the final result of the stream. if infer_name and self.name is None: - self._infer_name(inspect.currentframe()) + # f_back because `asynccontextmanager` adds one frame + if frame := inspect.currentframe(): # pragma: no branch + self._infer_name(frame.f_back) if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: @@ -548,8 +491,9 @@ async def main(): _utils.validate_empty_kwargs(_deprecated_kwargs) + yielded = False async with self.iter( - user_prompt=user_prompt, + user_prompt, output_type=output_type, message_history=message_history, model=model, @@ -557,13 +501,71 @@ async def main(): model_settings=model_settings, usage_limits=usage_limits, usage=usage, + infer_name=False, toolsets=toolsets, ) as agent_run: - async for _ in agent_run: - pass + first_node = agent_run.next_node # start with the first node + assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node + node = first_node + while True: + if self.is_model_request_node(node): + graph_ctx = agent_run.ctx + async with node.stream(graph_ctx) as stream: - assert agent_run.result is not None, 'The graph run did not finish properly' - return agent_run.result + async def stream_to_final(s: AgentStream) -> FinalResult[AgentStream] | None: + async for event in stream: + if isinstance(event, _messages.FinalResultEvent): + return FinalResult(s, event.tool_name, event.tool_call_id) + return None + + final_result = await stream_to_final(stream) + if final_result is not None: + if yielded: + raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover + yielded = True + + messages = graph_ctx.state.message_history.copy() + + async def on_complete() -> None: + """Called when the stream has completed. + + The model response will have been added to messages by now + by `StreamedRunResult._marked_completed`. + """ + last_message = messages[-1] + assert isinstance(last_message, _messages.ModelResponse) + tool_calls = [ + part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) + ] + + parts: list[_messages.ModelRequestPart] = [] + async for _event in _agent_graph.process_function_tools( + graph_ctx.deps.tool_manager, + tool_calls, + final_result, + graph_ctx, + parts, + ): + pass + if parts: + messages.append(_messages.ModelRequest(parts)) + + yield StreamedRunResult( + messages, + graph_ctx.deps.new_message_index, + stream, + on_complete, + ) + break + next_node = await agent_run.next(node) + if not isinstance(next_node, _agent_graph.AgentNode): + raise exceptions.AgentRunError( # pragma: no cover + 'Should have produced a StreamedRunResult before getting here' + ) + node = cast(_agent_graph.AgentNode[Any, Any], next_node) + + if not yielded: + raise exceptions.AgentRunError('Agent run finished without producing a final result') # pragma: no cover @overload def iter( @@ -617,6 +619,7 @@ def iter( ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... @asynccontextmanager + @abstractmethod async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, @@ -709,315 +712,912 @@ async def main(): Returns: The result of the run. """ - if infer_name and self.name is None: - self._infer_name(inspect.currentframe()) - model_used = self._get_model(model) - del model + raise NotImplementedError + yield - if 'result_type' in _deprecated_kwargs: # pragma: no cover - if output_type is not str: - raise TypeError('`result_type` and `output_type` cannot be set at the same time.') - warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) - output_type = _deprecated_kwargs.pop('result_type') + @contextmanager + @abstractmethod + def override( + self, + *, + deps: AgentDepsT | _utils.Unset = _utils.UNSET, + model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + ) -> Iterator[None]: + """Context manager to temporarily override agent dependencies, model, or toolsets. - _utils.validate_empty_kwargs(_deprecated_kwargs) + This is particularly useful when testing. + You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). - deps = self._get_deps(deps) - new_message_index = len(message_history) if message_history else 0 - output_schema = self._prepare_output_schema(output_type, model_used.profile) + Args: + deps: The dependencies to use instead of the dependencies passed to the agent run. + model: The model to use instead of the model passed to the agent run. + toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. + tools: The tools to use instead of the tools registered with the agent. + """ + raise NotImplementedError + yield - output_type_ = output_type or self.output_type + def _infer_name(self, function_frame: FrameType | None) -> None: + """Infer the agent name from the call frame. - # We consider it a user error if a user tries to restrict the result type while having an output validator that - # may change the result type from the restricted type to something else. Therefore, we consider the following - # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. - output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) + Usage should be `self._infer_name(inspect.currentframe())`. + """ + assert self.name is None, 'Name already set' + if function_frame is not None: # pragma: no branch + if parent_frame := function_frame.f_back: # pragma: no branch + for name, item in parent_frame.f_locals.items(): + if item is self: + self.name = name + return + if parent_frame.f_locals != parent_frame.f_globals: # pragma: no branch + # if we couldn't find the agent in locals and globals are a different dict, try globals + for name, item in parent_frame.f_globals.items(): + if item is self: + self.name = name + return - output_toolset = self._output_toolset - if output_schema != self._output_schema or output_validators: - output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset) - if output_toolset: - output_toolset.max_retries = self._max_result_retries - output_toolset.output_validators = output_validators + @staticmethod + def is_model_request_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeIs[_agent_graph.ModelRequestNode[T, S]]: + """Check if the node is a `ModelRequestNode`, narrowing the type if it is. - # Build the graph - graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( - _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) - ) + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, _agent_graph.ModelRequestNode) - # Build the initial state - usage = usage or _usage.Usage() - state = _agent_graph.GraphAgentState( - message_history=message_history[:] if message_history else [], - usage=usage, - retries=0, - run_step=0, - ) + @staticmethod + def is_call_tools_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeIs[_agent_graph.CallToolsNode[T, S]]: + """Check if the node is a `CallToolsNode`, narrowing the type if it is. - if isinstance(model_used, InstrumentedModel): - instrumentation_settings = model_used.instrumentation_settings - tracer = model_used.instrumentation_settings.tracer - else: - instrumentation_settings = None - tracer = NoOpTracer() + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, _agent_graph.CallToolsNode) - run_context = RunContext[AgentDepsT]( - deps=deps, - model=model_used, - usage=usage, - prompt=user_prompt, - messages=state.message_history, - tracer=tracer, - trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content, - run_step=state.run_step, - ) + @staticmethod + def is_user_prompt_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeIs[_agent_graph.UserPromptNode[T, S]]: + """Check if the node is a `UserPromptNode`, narrowing the type if it is. - toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) - # This will raise errors for any name conflicts - run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context) + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, _agent_graph.UserPromptNode) - # Merge model settings in order of precedence: run > agent > model - merged_settings = merge_model_settings(model_used.settings, self.model_settings) - model_settings = merge_model_settings(merged_settings, model_settings) - usage_limits = usage_limits or _usage.UsageLimits() - agent_name = self.name or 'agent' - run_span = tracer.start_span( - 'agent run', - attributes={ - 'model_name': model_used.model_name if model_used else 'no-model', - 'agent_name': agent_name, - 'logfire.msg': f'{agent_name} run', - }, - ) + @staticmethod + def is_end_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeIs[End[result.FinalResult[S]]]: + """Check if the node is a `End`, narrowing the type if it is. - async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: - parts = [ - self._instructions, - *[await func.run(run_context) for func in self._instructions_functions], - ] + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, End) - model_profile = model_used.profile - if isinstance(output_schema, _output.PromptedOutputSchema): - instructions = output_schema.instructions(model_profile.prompted_output_template) - parts.append(instructions) + @abstractmethod + async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]: + raise NotImplementedError - parts = [p for p in parts if p] - if not parts: - return None - return '\n\n'.join(parts).strip() + @abstractmethod + async def __aexit__(self, *args: Any) -> bool | None: + raise NotImplementedError - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( - user_deps=deps, - prompt=user_prompt, - new_message_index=new_message_index, - model=model_used, - model_settings=model_settings, - usage_limits=usage_limits, - max_result_retries=self._max_result_retries, - end_strategy=self.end_strategy, - output_schema=output_schema, - output_validators=output_validators, - history_processors=self.history_processors, - tool_manager=run_toolset, - tracer=tracer, - get_instructions=get_instructions, - instrumentation_settings=instrumentation_settings, + def to_ag_ui( + self, + *, + # Agent.iter parameters + output_type: OutputSpec[OutputDataT] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + # Starlette + debug: bool = False, + routes: Sequence[BaseRoute] | None = None, + middleware: Sequence[Middleware] | None = None, + exception_handlers: Mapping[Any, ExceptionHandler] | None = None, + on_startup: Sequence[Callable[[], Any]] | None = None, + on_shutdown: Sequence[Callable[[], Any]] | None = None, + lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None, + ) -> AGUIApp[AgentDepsT, OutputDataT]: + """Convert the agent to an AG-UI application. + + This allows you to use the agent with a compatible AG-UI frontend. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + app = agent.to_ag_ui() + ``` + + The `app` is an ASGI application that can be used with any ASGI server. + + To run the application, you can use the following command: + + ```bash + uvicorn app:app --host 0.0.0.0 --port 8000 + ``` + + See [AG-UI docs](../ag-ui.md) for more information. + + Args: + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has + no output validators since output validators would expect an argument that matches the agent's + output type. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. + + debug: Boolean indicating if debug tracebacks should be returned on errors. + routes: A list of routes to serve incoming HTTP and WebSocket requests. + middleware: A list of middleware to run for every request. A starlette application will always + automatically include two middleware classes. `ServerErrorMiddleware` is added as the very + outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. + `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled + exception cases occurring in the routing or endpoints. + exception_handlers: A mapping of either integer status codes, or exception class types onto + callables which handle the exceptions. Exception handler callables should be of the form + `handler(request, exc) -> response` and may be either standard functions, or async functions. + on_startup: A list of callables to run on application startup. Startup handler callables do not + take any arguments, and may be either standard functions, or async functions. + on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do + not take any arguments, and may be either standard functions, or async functions. + lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks. + This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or + the other, not both. + + Returns: + An ASGI application for running Pydantic AI agents with AG-UI protocol support. + """ + from .ag_ui import AGUIApp + + return AGUIApp( + agent=self, + # Agent.iter parameters + output_type=output_type, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + # Starlette + debug=debug, + routes=routes, + middleware=middleware, + exception_handlers=exception_handlers, + on_startup=on_startup, + on_shutdown=on_shutdown, + lifespan=lifespan, + ) + + def to_a2a( + self, + *, + storage: Storage | None = None, + broker: Broker | None = None, + # Agent card + name: str | None = None, + url: str = 'http://localhost:8000', + version: str = '1.0.0', + description: str | None = None, + provider: AgentProvider | None = None, + skills: list[Skill] | None = None, + # Starlette + debug: bool = False, + routes: Sequence[Route] | None = None, + middleware: Sequence[Middleware] | None = None, + exception_handlers: dict[Any, ExceptionHandler] | None = None, + lifespan: Lifespan[FastA2A] | None = None, + ) -> FastA2A: + """Convert the agent to a FastA2A application. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + app = agent.to_a2a() + ``` + + The `app` is an ASGI application that can be used with any ASGI server. + + To run the application, you can use the following command: + + ```bash + uvicorn app:app --host 0.0.0.0 --port 8000 + ``` + """ + from ._a2a import agent_to_a2a + + return agent_to_a2a( + self, + storage=storage, + broker=broker, + name=name, + url=url, + version=version, + description=description, + provider=provider, + skills=skills, + debug=debug, + routes=routes, + middleware=middleware, + exception_handlers=exception_handlers, + lifespan=lifespan, + ) + + async def to_cli(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: + """Run the agent in a CLI chat interface. + + Args: + deps: The dependencies to pass to the agent. + prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. + + Example: + ```python {title="agent_to_cli.py" test="skip"} + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') + + async def main(): + await agent.to_cli() + ``` + """ + from rich.console import Console + + from pydantic_ai._cli import run_chat + + await run_chat(stream=True, agent=self, deps=deps, console=Console(), code_theme='monokai', prog_name=prog_name) + + def to_cli_sync(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: + """Run the agent in a CLI chat interface with the non-async interface. + + Args: + deps: The dependencies to pass to the agent. + prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. + + ```python {title="agent_to_cli_sync.py" test="skip"} + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') + agent.to_cli_sync() + agent.to_cli_sync(prog_name='assistant') + ``` + """ + return get_event_loop().run_until_complete(self.to_cli(deps=deps, prog_name=prog_name)) + + +class WrapperAgent(AbstractAgent[AgentDepsT, OutputDataT]): + def __init__(self, wrapped: AbstractAgent[AgentDepsT, OutputDataT]): + self.wrapped = wrapped + + @property + def model(self) -> models.Model | models.KnownModelName | str | None: + return self.wrapped.model + + @property + def name(self) -> str | None: + return self.wrapped.name + + @name.setter + def name(self, value: str | None) -> None: + self.wrapped.name = value + + @property + def output_type(self) -> OutputSpec[OutputDataT]: + return self.wrapped.output_type + + @property + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + return self.wrapped.event_stream_handler + + async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]: + return await self.wrapped.__aenter__() + + async def __aexit__(self, *args: Any) -> bool | None: + return await self.wrapped.__aexit__(*args) + + def __getattr__(self, name: str) -> Any: + return getattr(self._agent, name) + + @overload + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... + + @overload + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... + + @overload + @deprecated('`result_type` is deprecated, use `output_type` instead.') + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + result_type: type[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... + + @asynccontextmanager + async def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: + """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. + + This method builds an internal agent graph (using system prompts, tools and output schemas) and then returns an + `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are + executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the + stream of events coming from the execution of tools. + + The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics, + and the final result of the run once it has completed. + + For more details, see the documentation of `AgentRun`. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + nodes = [] + async with agent.iter('What is the capital of France?') as agent_run: + async for node in agent_run: + nodes.append(node) + print(nodes) + ''' + [ + UserPromptNode( + user_prompt='What is the capital of France?', + instructions=None, + instructions_functions=[], + system_prompts=(), + system_prompt_functions=[], + system_prompt_dynamic_functions={}, + ), + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + ) + ] + ) + ), + CallToolsNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris')], + usage=Usage( + requests=1, request_tokens=56, response_tokens=1, total_tokens=57 + ), + model_name='gpt-4o', + timestamp=datetime.datetime(...), + ) + ), + End(data=FinalResult(output='Paris')), + ] + ''' + print(agent_run.result.output) + #> Paris + ``` + + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + + Returns: + The result of the run. + """ + async with self.wrapped.iter( + user_prompt=user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + ) as result: + yield result + + @contextmanager + def override( + self, + *, + deps: AgentDepsT | _utils.Unset = _utils.UNSET, + model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + ) -> Iterator[None]: + """Context manager to temporarily override agent dependencies, model, or toolsets. + + This is particularly useful when testing. + You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). + + Args: + deps: The dependencies to use instead of the dependencies passed to the agent run. + model: The model to use instead of the model passed to the agent run. + toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. + tools: The tools to use instead of the tools registered with the agent. + """ + with self.wrapped.override(deps=deps, model=model, toolsets=toolsets, tools=tools): + yield + + +@dataclasses.dataclass(init=False) +class Agent(AbstractAgent[AgentDepsT, OutputDataT]): + """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. + + Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT] + and the output type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT]. + + By default, if neither generic parameter is customised, agents have type `Agent[None, str]`. + + Minimal usage example: + + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + result = agent.run_sync('What is the capital of France?') + print(result.output) + #> Paris + ``` + """ + + _model: models.Model | models.KnownModelName | str | None + """The default model configured for this agent. + + We allow `str` here since the actual list of allowed models changes frequently. + """ + + _name: str | None + """The name of the agent, used for logging. + + If `None`, we try to infer the agent name from the call frame when the agent is first run. + """ + end_strategy: EndStrategy + """Strategy for handling tool calls when a final result is found.""" + + model_settings: ModelSettings | None + """Optional model request settings to use for this agents's runs, by default. + + Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will + be merged with this value, with the runtime argument taking priority. + """ + + _output_type: OutputSpec[OutputDataT] + """ + The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`. + """ + + instrument: InstrumentationSettings | bool | None + """Options to automatically instrument with OpenTelemetry.""" + + _instrument_default: ClassVar[InstrumentationSettings | bool] = False + + _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) + _deprecated_result_tool_name: str | None = dataclasses.field(repr=False) + _deprecated_result_tool_description: str | None = dataclasses.field(repr=False) + _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False) + _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False) + _instructions: str | None = dataclasses.field(repr=False) + _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) + _system_prompts: tuple[str, ...] = dataclasses.field(repr=False) + _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) + _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field( + repr=False + ) + _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False) + _output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False) + _user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False) + _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) + _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) + _max_result_retries: int = dataclasses.field(repr=False) + + _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) + + _enter_lock: Lock = dataclasses.field(repr=False) + _entered_count: int = dataclasses.field(repr=False) + _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) + + @overload + def __init__( + self, + model: models.Model | models.KnownModelName | str | None = None, + *, + output_type: OutputSpec[OutputDataT] = str, + instructions: str + | _system_prompt.SystemPromptFunc[AgentDepsT] + | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] + | None = None, + system_prompt: str | Sequence[str] = (), + deps_type: type[AgentDepsT] = NoneType, + name: str | None = None, + model_settings: ModelSettings | None = None, + retries: int = 1, + output_retries: int | None = None, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), + prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + defer_model_check: bool = False, + end_strategy: EndStrategy = 'early', + instrument: InstrumentationSettings | bool | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ) -> None: ... + + @overload + @deprecated( + '`result_type`, `result_tool_name` & `result_tool_description` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' + ) + def __init__( + self, + model: models.Model | models.KnownModelName | str | None = None, + *, + result_type: type[OutputDataT] = str, + instructions: str + | _system_prompt.SystemPromptFunc[AgentDepsT] + | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] + | None = None, + system_prompt: str | Sequence[str] = (), + deps_type: type[AgentDepsT] = NoneType, + name: str | None = None, + model_settings: ModelSettings | None = None, + retries: int = 1, + result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, + result_tool_description: str | None = None, + result_retries: int | None = None, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), + prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + defer_model_check: bool = False, + end_strategy: EndStrategy = 'early', + instrument: InstrumentationSettings | bool | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + ) -> None: ... + + @overload + @deprecated('`mcp_servers` is deprecated, use `toolsets` instead.') + def __init__( + self, + model: models.Model | models.KnownModelName | str | None = None, + *, + result_type: type[OutputDataT] = str, + instructions: str + | _system_prompt.SystemPromptFunc[AgentDepsT] + | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] + | None = None, + system_prompt: str | Sequence[str] = (), + deps_type: type[AgentDepsT] = NoneType, + name: str | None = None, + model_settings: ModelSettings | None = None, + retries: int = 1, + result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, + result_tool_description: str | None = None, + result_retries: int | None = None, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), + prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + mcp_servers: Sequence[MCPServer] = (), + defer_model_check: bool = False, + end_strategy: EndStrategy = 'early', + instrument: InstrumentationSettings | bool | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ) -> None: ... + + def __init__( + self, + model: models.Model | models.KnownModelName | str | None = None, + *, + # TODO change this back to `output_type: _output.OutputType[OutputDataT] = str,` when we remove the overloads + output_type: Any = str, + instructions: str + | _system_prompt.SystemPromptFunc[AgentDepsT] + | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] + | None = None, + system_prompt: str | Sequence[str] = (), + deps_type: type[AgentDepsT] = NoneType, + name: str | None = None, + model_settings: ModelSettings | None = None, + retries: int = 1, + output_retries: int | None = None, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), + prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + defer_model_check: bool = False, + end_strategy: EndStrategy = 'early', + instrument: InstrumentationSettings | bool | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + **_deprecated_kwargs: Any, + ): + """Create an agent. + + Args: + model: The default model to use for this agent, if not provide, + you must provide the model when calling it. We allow `str` here since the actual list of allowed models changes frequently. + output_type: The type of the output data, used to validate the data returned by the model, + defaults to `str`. + instructions: Instructions to use for this agent, you can also register instructions via a function with + [`instructions`][pydantic_ai.Agent.instructions]. + system_prompt: Static system prompts to use for this agent, you can also register system + prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt]. + deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully + parameterize the agent, and therefore get the best out of static type checking. + If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright + or add a type hint `: Agent[None, ]`. + name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame + when the agent is first run. + model_settings: Optional model request settings to use for this agent's runs, by default. + retries: The default number of retries to allow before raising an error. + output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. + tools: Tools to register with the agent, you can also register tools via the decorators + [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. + prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools. + This is useful if you want to customize the definition of multiple tools or you want to register + a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] + prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step. + This is useful if you want to customize the definition of multiple output tools or you want to register + a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] + toolsets: Toolsets to register with the agent, including MCP servers. + defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, + it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, + which checks for the necessary environment variables. Set this to `false` + to defer the evaluation until the first run. Useful if you want to + [override the model][pydantic_ai.Agent.override] for testing. + end_strategy: Strategy for handling tool calls that are requested alongside a final result. + See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information. + instrument: Set to True to automatically instrument with OpenTelemetry, + which will use Logfire if it's configured. + Set to an instance of [`InstrumentationSettings`][pydantic_ai.agent.InstrumentationSettings] to customize. + If this isn't set, then the last value set by + [`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all] + will be used, which defaults to False. + See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. + history_processors: Optional list of callables to process the message history before sending it to the model. + Each processor takes a list of messages and returns a modified list of messages. + Processors can be sync or async and are applied in sequence. + event_stream_handler: TODO: Optional handler for events from the agent stream. + """ + if model is None or defer_model_check: + self._model = model + else: + self._model = models.infer_model(model) + + self._name = name + self.end_strategy = end_strategy + self.model_settings = model_settings + + if 'result_type' in _deprecated_kwargs: + if output_type is not str: # pragma: no cover + raise TypeError('`result_type` and `output_type` cannot be set at the same time.') + warnings.warn('`result_type` is deprecated, use `output_type` instead', DeprecationWarning, stacklevel=2) + output_type = _deprecated_kwargs.pop('result_type') + + self._output_type = output_type + + self.instrument = instrument + + self._deps_type = deps_type + + self._deprecated_result_tool_name = _deprecated_kwargs.pop('result_tool_name', None) + if self._deprecated_result_tool_name is not None: + warnings.warn( + '`result_tool_name` is deprecated, use `output_type` with `ToolOutput` instead', + DeprecationWarning, + stacklevel=2, + ) + + self._deprecated_result_tool_description = _deprecated_kwargs.pop('result_tool_description', None) + if self._deprecated_result_tool_description is not None: + warnings.warn( + '`result_tool_description` is deprecated, use `output_type` with `ToolOutput` instead', + DeprecationWarning, + stacklevel=2, + ) + result_retries = _deprecated_kwargs.pop('result_retries', None) + if result_retries is not None: + if output_retries is not None: # pragma: no cover + raise TypeError('`output_retries` and `result_retries` cannot be set at the same time.') + warnings.warn( + '`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning, stacklevel=2 + ) + output_retries = result_retries + + if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None): + if toolsets is not None: # pragma: no cover + raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.') + warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning) + toolsets = mcp_servers + + _utils.validate_empty_kwargs(_deprecated_kwargs) + + default_output_mode = ( + self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None ) - start_node = _agent_graph.UserPromptNode[AgentDepsT]( - user_prompt=user_prompt, - instructions=self._instructions, - instructions_functions=self._instructions_functions, - system_prompts=self._system_prompts, - system_prompt_functions=self._system_prompt_functions, - system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, + + self._output_schema = _output.OutputSchema[OutputDataT].build( + output_type, + default_mode=default_output_mode, + name=self._deprecated_result_tool_name, + description=self._deprecated_result_tool_description, ) + self._output_validators = [] - try: - async with graph.iter( - start_node, - state=state, - deps=graph_deps, - span=use_span(run_span) if run_span.is_recording() else None, - infer_name=False, - ) as graph_run: - agent_run = AgentRun(graph_run) - yield agent_run - if (final_result := agent_run.result) is not None and run_span.is_recording(): - if instrumentation_settings and instrumentation_settings.include_content: - run_span.set_attribute( - 'final_result', - ( - final_result.output - if isinstance(final_result.output, str) - else json.dumps(InstrumentedModel.serialize_any(final_result.output)) - ), - ) - finally: - try: - if instrumentation_settings and run_span.is_recording(): - run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings)) - finally: - run_span.end() + self._instructions = '' + self._instructions_functions = [] + if isinstance(instructions, (str, Callable)): + instructions = [instructions] + for instruction in instructions or []: + if isinstance(instruction, str): + self._instructions += instruction + '\n' + else: + self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction)) + self._instructions = self._instructions.strip() or None - def _run_span_end_attributes( - self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings - ): - return { - **usage.opentelemetry_attributes(), - 'all_messages_events': json.dumps( - [InstrumentedModel.event_to_dict(e) for e in settings.messages_to_otel_events(state.message_history)] - ), - 'logfire.json_schema': json.dumps( - { - 'type': 'object', - 'properties': { - 'all_messages_events': {'type': 'array'}, - 'final_result': {'type': 'object'}, - }, - } - ), - } + self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) + self._system_prompt_functions = [] + self._system_prompt_dynamic_functions = {} - @overload - def run_sync( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[OutputDataT]: ... + self._max_result_retries = output_retries if output_retries is not None else retries + self._prepare_tools = prepare_tools + self._prepare_output_tools = prepare_output_tools - @overload - def run_sync( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: OutputSpec[RunOutputDataT] | None = None, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[RunOutputDataT]: ... + self._output_toolset = self._output_schema.toolset + if self._output_toolset: + self._output_toolset.max_retries = self._max_result_retries - @overload - @deprecated('`result_type` is deprecated, use `output_type` instead.') - def run_sync( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - result_type: type[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[RunOutputDataT]: ... + self._function_toolset = FunctionToolset(tools, max_retries=retries, id='agent') + self._user_toolsets = toolsets or () - def run_sync( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: OutputSpec[RunOutputDataT] | None = None, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - **_deprecated_kwargs: Never, - ) -> AgentRunResult[Any]: - """Synchronously run the agent with a user prompt. + self.history_processors = history_processors or [] - This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. - You therefore can't use this method inside async code or if there's an active event loop. + self._event_stream_handler = event_stream_handler - Example: - ```python - from pydantic_ai import Agent + self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) + self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) + self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar( + '_override_toolsets', default=None + ) + self._override_tools: ContextVar[ + _utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]] + ] = ContextVar('_override_tools', default=None) - agent = Agent('openai:gpt-4o') + self._enter_lock = _utils.get_async_lock() + self._entered_count = 0 + self._exit_stack = None - result_sync = agent.run_sync('What is the capital of Italy?') - print(result_sync.output) - #> Rome - ``` + @staticmethod + def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: + """Set the instrumentation options for all agents where `instrument` is not set.""" + Agent._instrument_default = instrument - Args: - user_prompt: User input to start/continue the conversation. - output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no - output validators since output validators would expect an argument that matches the agent's output type. - message_history: History of the conversation so far. - model: Optional model to use for this run, required if `model` was not set when creating the agent. - deps: Optional dependencies to use for this run. - model_settings: Optional settings to use for this model's request. - usage_limits: Optional limits on model request count or token usage. - usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional additional toolsets for this run. + @property + def model(self) -> models.Model | models.KnownModelName | str | None: + return self._model - Returns: - The result of the run. - """ - if infer_name and self.name is None: - self._infer_name(inspect.currentframe()) + @model.setter + def model(self, value: models.Model | models.KnownModelName | str | None) -> None: + self._model = value - if 'result_type' in _deprecated_kwargs: # pragma: no cover - if output_type is not str: - raise TypeError('`result_type` and `output_type` cannot be set at the same time.') - warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) - output_type = _deprecated_kwargs.pop('result_type') + @property + def name(self) -> str | None: + return self._name - _utils.validate_empty_kwargs(_deprecated_kwargs) + @name.setter + def name(self, value: str | None) -> None: + self._name = value - return get_event_loop().run_until_complete( - self.run( - user_prompt, - output_type=output_type, - message_history=message_history, - model=model, - deps=deps, - model_settings=model_settings, - usage_limits=usage_limits, - usage=usage, - infer_name=False, - toolsets=toolsets, - ) - ) + @property + def output_type(self) -> OutputSpec[OutputDataT]: + return self._output_type + + @property + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + return self._event_stream_handler + + def __repr__(self) -> str: # pragma: no cover + return f'{type(self).__name__}(model={self.model!r}, name={self.name!r}, end_strategy={self.end_strategy!r}, model_settings={self.model_settings!r}, output_type={self.output_type!r}, instrument={self.instrument!r})' @overload - def run_stream( + def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, + output_type: None = None, message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | None = None, + model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @overload - def run_stream( + def iter( self, - user_prompt: str | Sequence[_messages.UserContent], + user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, @@ -1028,11 +1628,12 @@ def run_stream( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @overload @deprecated('`result_type` is deprecated, use `output_type` instead.') - def run_stream( + def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, @@ -1045,10 +1646,10 @@ def run_stream( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... @asynccontextmanager - async def run_stream( # noqa C901 + async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, @@ -1062,8 +1663,18 @@ async def run_stream( # noqa C901 infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, - ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: - """Run the agent with a user prompt in async mode, returning a streamed response. + ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: + """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. + + This method builds an internal agent graph (using system prompts, tools and output schemas) and then returns an + `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are + executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the + stream of events coming from the execution of tools. + + The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics, + and the final result of the run once it has completed. + + For more details, see the documentation of `AgentRun`. Example: ```python @@ -1072,9 +1683,46 @@ async def run_stream( # noqa C901 agent = Agent('openai:gpt-4o') async def main(): - async with agent.run_stream('What is the capital of the UK?') as response: - print(await response.get_output()) - #> London + nodes = [] + async with agent.iter('What is the capital of France?') as agent_run: + async for node in agent_run: + nodes.append(node) + print(nodes) + ''' + [ + UserPromptNode( + user_prompt='What is the capital of France?', + instructions=None, + instructions_functions=[], + system_prompts=(), + system_prompt_functions=[], + system_prompt_dynamic_functions={}, + ), + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + ) + ] + ) + ), + CallToolsNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris')], + usage=Usage( + requests=1, request_tokens=56, response_tokens=1, total_tokens=57 + ), + model_name='gpt-4o', + timestamp=datetime.datetime(...), + ) + ), + End(data=FinalResult(output='Paris')), + ] + ''' + print(agent_run.result.output) + #> Paris ``` Args: @@ -1093,12 +1741,10 @@ async def main(): Returns: The result of the run. """ - # TODO: We need to deprecate this now that we have the `iter` method. - # Before that, though, we should add an event for when we reach the final result of the stream. if infer_name and self.name is None: - # f_back because `asynccontextmanager` adds one frame - if frame := inspect.currentframe(): # pragma: no branch - self._infer_name(frame.f_back) + self._infer_name(inspect.currentframe()) + model_used = self._get_model(model) + del model if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: @@ -1108,81 +1754,161 @@ async def main(): _utils.validate_empty_kwargs(_deprecated_kwargs) - yielded = False - async with self.iter( - user_prompt, - output_type=output_type, - message_history=message_history, - model=model, + deps = self._get_deps(deps) + new_message_index = len(message_history) if message_history else 0 + output_schema = self._prepare_output_schema(output_type, model_used.profile) + + output_type_ = output_type or self.output_type + + # We consider it a user error if a user tries to restrict the result type while having an output validator that + # may change the result type from the restricted type to something else. Therefore, we consider the following + # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. + output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) + + output_toolset = self._output_toolset + if output_schema != self._output_schema or output_validators: + output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset) + if output_toolset: + output_toolset.max_retries = self._max_result_retries + output_toolset.output_validators = output_validators + + # Build the graph + graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( + _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) + ) + + # Build the initial state + usage = usage or _usage.Usage() + state = _agent_graph.GraphAgentState( + message_history=message_history[:] if message_history else [], + usage=usage, + retries=0, + run_step=0, + ) + + if isinstance(model_used, InstrumentedModel): + instrumentation_settings = model_used.instrumentation_settings + tracer = model_used.instrumentation_settings.tracer + else: + instrumentation_settings = None + tracer = NoOpTracer() + + run_context = RunContext[AgentDepsT]( deps=deps, - model_settings=model_settings, - usage_limits=usage_limits, + model=model_used, usage=usage, - infer_name=False, - toolsets=toolsets, - ) as agent_run: - first_node = agent_run.next_node # start with the first node - assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node - node = first_node - while True: - if self.is_model_request_node(node): - graph_ctx = agent_run.ctx - async with node.stream(graph_ctx) as stream: + prompt=user_prompt, + messages=state.message_history, + tracer=tracer, + trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content, + run_step=state.run_step, + ) - async def stream_to_final(s: AgentStream) -> FinalResult[AgentStream] | None: - async for event in stream: - if isinstance(event, _messages.FinalResultEvent): - return FinalResult(s, event.tool_name, event.tool_call_id) - return None + toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) + # This will raise errors for any name conflicts + tool_manager = await ToolManager[AgentDepsT].build(toolset, run_context) - final_result = await stream_to_final(stream) - if final_result is not None: - if yielded: - raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover - yielded = True + # Merge model settings in order of precedence: run > agent > model + merged_settings = merge_model_settings(model_used.settings, self.model_settings) + model_settings = merge_model_settings(merged_settings, model_settings) + usage_limits = usage_limits or _usage.UsageLimits() + agent_name = self.name or 'agent' + run_span = tracer.start_span( + 'agent run', + attributes={ + 'model_name': model_used.model_name if model_used else 'no-model', + 'agent_name': agent_name, + 'logfire.msg': f'{agent_name} run', + }, + ) - messages = graph_ctx.state.message_history.copy() + async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: + parts = [ + self._instructions, + *[await func.run(run_context) for func in self._instructions_functions], + ] - async def on_complete() -> None: - """Called when the stream has completed. + model_profile = model_used.profile + if isinstance(output_schema, _output.PromptedOutputSchema): + instructions = output_schema.instructions(model_profile.prompted_output_template) + parts.append(instructions) - The model response will have been added to messages by now - by `StreamedRunResult._marked_completed`. - """ - last_message = messages[-1] - assert isinstance(last_message, _messages.ModelResponse) - tool_calls = [ - part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) - ] + parts = [p for p in parts if p] + if not parts: + return None + return '\n\n'.join(parts).strip() - parts: list[_messages.ModelRequestPart] = [] - async for _event in _agent_graph.process_function_tools( - graph_ctx.deps.tool_manager, - tool_calls, - final_result, - graph_ctx, - parts, - ): - pass - if parts: - messages.append(_messages.ModelRequest(parts)) + graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( + user_deps=deps, + prompt=user_prompt, + new_message_index=new_message_index, + model=model_used, + model_settings=model_settings, + usage_limits=usage_limits, + max_result_retries=self._max_result_retries, + end_strategy=self.end_strategy, + output_schema=output_schema, + output_validators=output_validators, + history_processors=self.history_processors, + tool_manager=tool_manager, + tracer=tracer, + get_instructions=get_instructions, + instrumentation_settings=instrumentation_settings, + ) + start_node = _agent_graph.UserPromptNode[AgentDepsT]( + user_prompt=user_prompt, + instructions=self._instructions, + instructions_functions=self._instructions_functions, + system_prompts=self._system_prompts, + system_prompt_functions=self._system_prompt_functions, + system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, + ) - yield StreamedRunResult( - messages, - graph_ctx.deps.new_message_index, - stream, - on_complete, - ) - break - next_node = await agent_run.next(node) - if not isinstance(next_node, _agent_graph.AgentNode): - raise exceptions.AgentRunError( # pragma: no cover - 'Should have produced a StreamedRunResult before getting here' - ) - node = cast(_agent_graph.AgentNode[Any, Any], next_node) + try: + async with graph.iter( + start_node, + state=state, + deps=graph_deps, + span=use_span(run_span) if run_span.is_recording() else None, + infer_name=False, + ) as graph_run: + agent_run = AgentRun(graph_run) + yield agent_run + if (final_result := agent_run.result) is not None and run_span.is_recording(): + if instrumentation_settings and instrumentation_settings.include_content: + run_span.set_attribute( + 'final_result', + ( + final_result.output + if isinstance(final_result.output, str) + else json.dumps(InstrumentedModel.serialize_any(final_result.output)) + ), + ) + finally: + try: + if instrumentation_settings and run_span.is_recording(): + run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings)) + finally: + run_span.end() - if not yielded: - raise exceptions.AgentRunError('Agent run finished without producing a final result') # pragma: no cover + def _run_span_end_attributes( + self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings + ): + return { + **usage.opentelemetry_attributes(), + 'all_messages_events': json.dumps( + [InstrumentedModel.event_to_dict(e) for e in settings.messages_to_otel_events(state.message_history)] + ), + 'logfire.json_schema': json.dumps( + { + 'type': 'object', + 'properties': { + 'all_messages_events': {'type': 'array'}, + 'final_result': {'type': 'object'}, + }, + } + ), + } @contextmanager def override( @@ -1191,6 +1917,7 @@ def override( deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent dependencies, model, or toolsets. @@ -1201,6 +1928,7 @@ def override( deps: The dependencies to use instead of the dependencies passed to the agent run. model: The model to use instead of the model passed to the agent run. toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. + tools: The tools to use instead of the tools registered with the agent. """ if _utils.is_set(deps): deps_token = self._override_deps.set(_utils.Some(deps)) @@ -1215,7 +1943,12 @@ def override( if _utils.is_set(toolsets): toolsets_token = self._override_toolsets.set(_utils.Some(toolsets)) else: - toolsets_token = None + toolsets_token = None + + if _utils.is_set(tools): + tools_token = self._override_tools.set(_utils.Some(tools)) + else: + tools_token = None try: yield @@ -1226,6 +1959,8 @@ def override( self._override_model.reset(model_token) if toolsets_token is not None: self._override_toolsets.reset(toolsets_token) + if tools_token is not None: + self._override_tools.reset(tools_token) @overload def instructions( @@ -1678,6 +2413,13 @@ def _get_toolset( output_toolset: The output toolset to use instead of the one built at agent construction time. additional_toolsets: Additional toolsets to add. """ + if some_tools := self._override_tools.get(): + function_toolset = FunctionToolset( + some_tools.value, max_retries=self._function_toolset.max_retries, id='agent' + ) + else: + function_toolset = self._function_toolset + if some_user_toolsets := self._override_toolsets.get(): user_toolsets = some_user_toolsets.value elif additional_toolsets is not None: @@ -1685,7 +2427,7 @@ def _get_toolset( else: user_toolsets = self._user_toolsets - all_toolsets = [self._function_toolset, *user_toolsets] + all_toolsets = [function_toolset, *user_toolsets] if self._prepare_tools: all_toolsets = [PreparedToolset(CombinedToolset(all_toolsets), self._prepare_tools)] @@ -1698,24 +2440,13 @@ def _get_toolset( return CombinedToolset(all_toolsets) - def _infer_name(self, function_frame: FrameType | None) -> None: - """Infer the agent name from the call frame. + @property + def toolset(self) -> AbstractToolset[AgentDepsT]: + """The complete toolset that will be available to the model during an agent run. - Usage should be `self._infer_name(inspect.currentframe())`. + This will include function tools registered directly to the agent, output tools, and user-provided toolsets including MCP servers. """ - assert self.name is None, 'Name already set' - if function_frame is not None: # pragma: no branch - if parent_frame := function_frame.f_back: # pragma: no branch - for name, item in parent_frame.f_locals.items(): - if item is self: - self.name = name - return - if parent_frame.f_locals != parent_frame.f_globals: # pragma: no branch - # if we couldn't find the agent in locals and globals are a different dict, try globals - for name, item in parent_frame.f_globals.items(): - if item is self: - self.name = name - return + return self._get_toolset() @property @deprecated( @@ -1743,47 +2474,7 @@ def _prepare_output_schema( return schema # pyright: ignore[reportReturnType] - @staticmethod - def is_model_request_node( - node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeIs[_agent_graph.ModelRequestNode[T, S]]: - """Check if the node is a `ModelRequestNode`, narrowing the type if it is. - - This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. - """ - return isinstance(node, _agent_graph.ModelRequestNode) - - @staticmethod - def is_call_tools_node( - node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeIs[_agent_graph.CallToolsNode[T, S]]: - """Check if the node is a `CallToolsNode`, narrowing the type if it is. - - This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. - """ - return isinstance(node, _agent_graph.CallToolsNode) - - @staticmethod - def is_user_prompt_node( - node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeIs[_agent_graph.UserPromptNode[T, S]]: - """Check if the node is a `UserPromptNode`, narrowing the type if it is. - - This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. - """ - return isinstance(node, _agent_graph.UserPromptNode) - - @staticmethod - def is_end_node( - node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeIs[End[result.FinalResult[S]]]: - """Check if the node is a `End`, narrowing the type if it is. - - This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. - """ - return isinstance(node, End) - - async def __aenter__(self) -> Self: + async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]: """Enter the agent context. This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used. @@ -1846,201 +2537,6 @@ async def run_mcp_servers( async with self: yield - def to_ag_ui( - self, - *, - # Agent.iter parameters - output_type: OutputSpec[OutputDataT] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: UsageLimits | None = None, - usage: Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - # Starlette - debug: bool = False, - routes: Sequence[BaseRoute] | None = None, - middleware: Sequence[Middleware] | None = None, - exception_handlers: Mapping[Any, ExceptionHandler] | None = None, - on_startup: Sequence[Callable[[], Any]] | None = None, - on_shutdown: Sequence[Callable[[], Any]] | None = None, - lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None, - ) -> AGUIApp[AgentDepsT, OutputDataT]: - """Convert the agent to an AG-UI application. - - This allows you to use the agent with a compatible AG-UI frontend. - - Example: - ```python - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o') - app = agent.to_ag_ui() - ``` - - The `app` is an ASGI application that can be used with any ASGI server. - - To run the application, you can use the following command: - - ```bash - uvicorn app:app --host 0.0.0.0 --port 8000 - ``` - - See [AG-UI docs](../ag-ui.md) for more information. - - Args: - output_type: Custom output type to use for this run, `output_type` may only be used if the agent has - no output validators since output validators would expect an argument that matches the agent's - output type. - model: Optional model to use for this run, required if `model` was not set when creating the agent. - deps: Optional dependencies to use for this run. - model_settings: Optional settings to use for this model's request. - usage_limits: Optional limits on model request count or token usage. - usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. - - debug: Boolean indicating if debug tracebacks should be returned on errors. - routes: A list of routes to serve incoming HTTP and WebSocket requests. - middleware: A list of middleware to run for every request. A starlette application will always - automatically include two middleware classes. `ServerErrorMiddleware` is added as the very - outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. - `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled - exception cases occurring in the routing or endpoints. - exception_handlers: A mapping of either integer status codes, or exception class types onto - callables which handle the exceptions. Exception handler callables should be of the form - `handler(request, exc) -> response` and may be either standard functions, or async functions. - on_startup: A list of callables to run on application startup. Startup handler callables do not - take any arguments, and may be either standard functions, or async functions. - on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do - not take any arguments, and may be either standard functions, or async functions. - lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks. - This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or - the other, not both. - - Returns: - An ASGI application for running Pydantic AI agents with AG-UI protocol support. - """ - from .ag_ui import AGUIApp - - return AGUIApp( - agent=self, - # Agent.iter parameters - output_type=output_type, - model=model, - deps=deps, - model_settings=model_settings, - usage_limits=usage_limits, - usage=usage, - infer_name=infer_name, - toolsets=toolsets, - # Starlette - debug=debug, - routes=routes, - middleware=middleware, - exception_handlers=exception_handlers, - on_startup=on_startup, - on_shutdown=on_shutdown, - lifespan=lifespan, - ) - - def to_a2a( - self, - *, - storage: Storage | None = None, - broker: Broker | None = None, - # Agent card - name: str | None = None, - url: str = 'http://localhost:8000', - version: str = '1.0.0', - description: str | None = None, - provider: AgentProvider | None = None, - skills: list[Skill] | None = None, - # Starlette - debug: bool = False, - routes: Sequence[Route] | None = None, - middleware: Sequence[Middleware] | None = None, - exception_handlers: dict[Any, ExceptionHandler] | None = None, - lifespan: Lifespan[FastA2A] | None = None, - ) -> FastA2A: - """Convert the agent to a FastA2A application. - - Example: - ```python - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o') - app = agent.to_a2a() - ``` - - The `app` is an ASGI application that can be used with any ASGI server. - - To run the application, you can use the following command: - - ```bash - uvicorn app:app --host 0.0.0.0 --port 8000 - ``` - """ - from ._a2a import agent_to_a2a - - return agent_to_a2a( - self, - storage=storage, - broker=broker, - name=name, - url=url, - version=version, - description=description, - provider=provider, - skills=skills, - debug=debug, - routes=routes, - middleware=middleware, - exception_handlers=exception_handlers, - lifespan=lifespan, - ) - - async def to_cli(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: - """Run the agent in a CLI chat interface. - - Args: - deps: The dependencies to pass to the agent. - prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. - - Example: - ```python {title="agent_to_cli.py" test="skip"} - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') - - async def main(): - await agent.to_cli() - ``` - """ - from rich.console import Console - - from pydantic_ai._cli import run_chat - - await run_chat(stream=True, agent=self, deps=deps, console=Console(), code_theme='monokai', prog_name=prog_name) - - def to_cli_sync(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: - """Run the agent in a CLI chat interface with the non-async interface. - - Args: - deps: The dependencies to pass to the agent. - prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. - - ```python {title="agent_to_cli_sync.py" test="skip"} - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') - agent.to_cli_sync() - agent.to_cli_sync(prog_name='assistant') - ``` - """ - return get_event_loop().run_until_complete(self.to_cli(deps=deps, prog_name=prog_name)) - @dataclasses.dataclass(repr=False) class AgentRun(Generic[AgentDepsT, OutputDataT]): diff --git a/pydantic_ai_slim/pydantic_ai/ext/aci.py b/pydantic_ai_slim/pydantic_ai/ext/aci.py index 6cd43402a..ef686d134 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/aci.py +++ b/pydantic_ai_slim/pydantic_ai/ext/aci.py @@ -71,5 +71,7 @@ def implementation(*args: Any, **kwargs: Any) -> str: class ACIToolset(FunctionToolset): """A toolset that wraps ACI.dev tools.""" - def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str): - super().__init__([tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions]) + def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str, id: str | None = None): + super().__init__( + [tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions], id=id + ) diff --git a/pydantic_ai_slim/pydantic_ai/ext/langchain.py b/pydantic_ai_slim/pydantic_ai/ext/langchain.py index 3fb407938..3782c0b9d 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/langchain.py +++ b/pydantic_ai_slim/pydantic_ai/ext/langchain.py @@ -65,5 +65,5 @@ def proxy(*args: Any, **kwargs: Any) -> str: class LangChainToolset(FunctionToolset): """A toolset that wraps LangChain tools.""" - def __init__(self, tools: list[LangChainTool]): - super().__init__([tool_from_langchain(tool) for tool in tools]) + def __init__(self, tools: list[LangChainTool], id: str | None = None): + super().__init__([tool_from_langchain(tool) for tool in tools], id=id) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py new file mode 100644 index 000000000..cd7ba738c --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import warnings +from collections.abc import Sequence +from dataclasses import replace +from typing import Any, Callable + +from temporalio.client import ClientConfig, Plugin as ClientPlugin +from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter +from temporalio.converter import DefaultPayloadConverter +from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner + +from ._agent import TemporalAgent +from ._logfire import LogfirePlugin +from ._run_context import TemporalRunContext, TemporalRunContextWithDeps + +__all__ = [ + 'TemporalRunContext', + 'TemporalRunContextWithDeps', + 'PydanticAIPlugin', + 'LogfirePlugin', + 'AgentPlugin', + 'TemporalAgent', +] + + +class PydanticAIPlugin(ClientPlugin, WorkerPlugin): + """Temporal client and worker plugin for Pydantic AI.""" + + def configure_client(self, config: ClientConfig) -> ClientConfig: + if (data_converter := config.get('data_converter')) and data_converter.payload_converter_class not in ( + DefaultPayloadConverter, + PydanticPayloadConverter, + ): + warnings.warn( + 'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.' + ) + + config['data_converter'] = pydantic_data_converter + return super().configure_client(config) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType] + if isinstance(runner, SandboxedWorkflowRunner): + config['workflow_runner'] = replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + 'pydantic_ai', + 'logfire', + # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize + 'attrs', + # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize + 'numpy', + 'pandas', + ), + ) + return super().configure_worker(config) + + +class AgentPlugin(WorkerPlugin): + """Temporal worker plugin for a specific Pydantic AI agent.""" + + def __init__(self, agent: TemporalAgent[Any, Any]): + self.agent = agent + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType] + config['activities'] = [*activities, *self.agent.temporal_activities] + return super().configure_worker(config) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py new file mode 100644 index 000000000..7ef9717a1 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -0,0 +1,330 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator, Iterator, Sequence +from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager +from typing import Any, Callable, Literal, overload + +from temporalio import workflow +from temporalio.workflow import ActivityConfig +from typing_extensions import Never, deprecated + +from pydantic_ai import ( + _utils, + messages as _messages, + models, + usage as _usage, +) +from pydantic_ai._run_context import AgentDepsT +from pydantic_ai.agent import AbstractAgent, Agent, AgentRun, RunOutputDataT, WrapperAgent +from pydantic_ai.exceptions import UserError +from pydantic_ai.ext.temporal._run_context import TemporalRunContext +from pydantic_ai.models import Model +from pydantic_ai.output import OutputDataT, OutputSpec +from pydantic_ai.settings import ModelSettings +from pydantic_ai.tools import ( + Tool, + ToolFuncEither, +) +from pydantic_ai.toolsets import AbstractToolset + +from ._model import TemporalModel +from ._toolset import TemporalWrapperToolset, temporalize_toolset + + +class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]): + def __init__( + self, + wrapped: AbstractAgent[AgentDepsT, OutputDataT], + activity_config: ActivityConfig = {}, + toolset_activity_config: dict[str, ActivityConfig] = {}, + tool_activity_config: dict[str, dict[str, ActivityConfig | Literal[False]]] = {}, + run_context_type: type[TemporalRunContext] = TemporalRunContext, + temporalize_toolset_func: Callable[ + [ + AbstractToolset[Any], + ActivityConfig, + dict[str, ActivityConfig | Literal[False]], + type[TemporalRunContext], + ], + AbstractToolset[Any], + ] = temporalize_toolset, + ): + """Wrap an agent to make it compatible with Temporal. + + Args: + wrapped: The agent to wrap. + activity_config: The Temporal activity config to use. + toolset_activity_config: The Temporal activity config to use for specific toolsets identified by ID. + tool_activity_config: The Temporal activity config to use for specific tools identified by toolset ID and tool name. + run_context_type: The type of run context to use to serialize and deserialize the run context. + temporalize_toolset_func: The function to use to prepare the toolsets for Temporal. + """ + super().__init__(wrapped) + + # TODO: Make this work with any AbstractAgent + assert isinstance(wrapped, Agent) + agent = wrapped + + activities: list[Callable[..., Any]] = [] + if not isinstance(agent.model, Model): + raise UserError( + 'Model cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + ) + + temporal_model = TemporalModel(agent.model, activity_config, agent.event_stream_handler, run_context_type) + activities.extend(temporal_model.temporal_activities) + + def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]: + id = toolset.id + if not id: + raise UserError( + "Toolsets that implement their own tool calling need to have an ID in order to be used with Temporal. The ID will be used to identify the toolset's activities within the workflow." + ) + toolset = temporalize_toolset_func( + toolset, + activity_config | toolset_activity_config.get(id, {}), + tool_activity_config.get(id, {}), + run_context_type, + ) + if isinstance(toolset, TemporalWrapperToolset): + activities.extend(toolset.temporal_activities) + return toolset + + # TODO: Use public methods so others can replicate this + temporal_toolsets = [ + temporalize_toolset(toolset) + for toolset in [agent._function_toolset, *agent._user_toolsets] # pyright: ignore[reportPrivateUsage] + ] + + self._model = temporal_model + self._toolsets = temporal_toolsets + self._temporal_activities = activities + + @property + def model(self) -> Model: + return self._model + + @property + def temporal_activities(self) -> list[Callable[..., Any]]: + return self._temporal_activities + + @overload + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... + + @overload + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... + + @overload + @deprecated('`result_type` is deprecated, use `output_type` instead.') + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + result_type: type[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... + + @asynccontextmanager + async def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: + """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. + + This method builds an internal agent graph (using system prompts, tools and output schemas) and then returns an + `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are + executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the + stream of events coming from the execution of tools. + + The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics, + and the final result of the run once it has completed. + + For more details, see the documentation of `AgentRun`. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + nodes = [] + async with agent.iter('What is the capital of France?') as agent_run: + async for node in agent_run: + nodes.append(node) + print(nodes) + ''' + [ + UserPromptNode( + user_prompt='What is the capital of France?', + instructions=None, + instructions_functions=[], + system_prompts=(), + system_prompt_functions=[], + system_prompt_dynamic_functions={}, + ), + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + ) + ] + ) + ), + CallToolsNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris')], + usage=Usage( + requests=1, request_tokens=56, response_tokens=1, total_tokens=57 + ), + model_name='gpt-4o', + timestamp=datetime.datetime(...), + ) + ), + End(data=FinalResult(output='Paris')), + ] + ''' + print(agent_run.result.output) + #> Paris + ``` + + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + + Returns: + The result of the run. + """ + if not workflow.in_workflow(): + async with super().iter( + user_prompt=user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + ) as result: + yield result + + if model is not None: + raise UserError( + 'Model cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + ) + if toolsets is not None: + raise UserError( + 'Toolsets cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + ) + + # We reset tools here as the temporalized function toolset is already in self._toolsets. + with super().override(model=self._model, toolsets=self._toolsets, tools=[]): + async with super().iter( + user_prompt=user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + ) as result: + yield result + + @contextmanager + def override( + self, + *, + deps: AgentDepsT | _utils.Unset = _utils.UNSET, + model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + ) -> Iterator[None]: + """Context manager to temporarily override agent dependencies, model, or toolsets. + + This is particularly useful when testing. + You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). + + Args: + deps: The dependencies to use instead of the dependencies passed to the agent run. + model: The model to use instead of the model passed to the agent run. + toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. + tools: The tools to use instead of the tools registered with the agent. + """ + if workflow.in_workflow(): + if _utils.is_set(model): + raise UserError( + 'Model cannot be contextually overridden when using Temporal, it must be set at agent creation time.' + ) + if _utils.is_set(toolsets): + raise UserError( + 'Toolsets cannot be contextually overridden when using Temporal, they must be set at agent creation time.' + ) + if _utils.is_set(tools): + raise UserError( + 'Tools cannot be contextually overridden when using Temporal, they must be set at agent creation time.' + ) + + with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools): + yield diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py new file mode 100644 index 000000000..e46477541 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Literal + +from pydantic import ConfigDict, with_config +from temporalio import activity, workflow +from temporalio.workflow import ActivityConfig + +from pydantic_ai._run_context import RunContext +from pydantic_ai.exceptions import UserError +from pydantic_ai.toolsets import FunctionToolset, ToolsetTool +from pydantic_ai.toolsets.function import _FunctionToolsetTool # pyright: ignore[reportPrivateUsage] + +from ._run_context import TemporalRunContext +from ._toolset import TemporalWrapperToolset + + +@dataclass +@with_config(ConfigDict(arbitrary_types_allowed=True)) +class _CallToolParams: + name: str + tool_args: dict[str, Any] + serialized_run_context: Any + + +class TemporalFunctionToolset(TemporalWrapperToolset): + def __init__( + self, + toolset: FunctionToolset, + activity_config: ActivityConfig = {}, + tool_activity_config: dict[str, ActivityConfig | Literal[False]] = {}, + run_context_type: type[TemporalRunContext] = TemporalRunContext, + ): + super().__init__(toolset) + self.activity_config = activity_config + self.tool_activity_config = tool_activity_config + self.run_context_type = run_context_type + + id = toolset.id + assert id is not None + + @activity.defn(name=f'function_toolset__{id}__call_tool') + async def call_tool_activity(params: _CallToolParams) -> Any: + name = params.name + ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context) + try: + tool = (await toolset.get_tools(ctx))[name] + except KeyError as e: + raise UserError( + f'Tool {name!r} not found in toolset {toolset.id!r}. ' + 'Removing or renaming tools during an agent run is not supported with Temporal.' + ) from e + + return await self.wrapped.call_tool(name, params.tool_args, ctx, tool) + + self.call_tool_activity = call_tool_activity + + @property + def wrapped_function_toolset(self) -> FunctionToolset: + assert isinstance(self.wrapped, FunctionToolset) + return self.wrapped + + @property + def temporal_activities(self) -> list[Callable[..., Any]]: + return [self.call_tool_activity] + + async def call_tool(self, name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: + if not workflow.in_workflow(): + return await super().call_tool(name, tool_args, ctx, tool) + + tool_activity_config = self.tool_activity_config.get(name, {}) + if tool_activity_config is False: + assert isinstance(tool, _FunctionToolsetTool) + if not tool.is_async: + raise UserError( + 'Disabling running a non-async tool in a Temporal activity is not possible. Make the tool function async instead.' + ) + return await super().call_tool(name, tool_args, ctx, tool) + + tool_activity_config = self.activity_config | tool_activity_config + serialized_run_context = self.run_context_type.serialize_run_context(ctx) + return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=self.call_tool_activity, + arg=_CallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), + **tool_activity_config, + ) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py new file mode 100644 index 000000000..bb307b990 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Callable + +from opentelemetry.trace import get_tracer +from temporalio.client import ClientConfig, Plugin as ClientPlugin +from temporalio.contrib.opentelemetry import TracingInterceptor +from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig +from temporalio.service import ConnectConfig, ServiceClient + + +def _default_setup_logfire(): + import logfire + + logfire.configure(console=False) + logfire.instrument_pydantic_ai() + + +class LogfirePlugin(ClientPlugin): + """Temporal client plugin for Logfire.""" + + def __init__(self, setup_logfire: Callable[[], None] = _default_setup_logfire): + self.setup_logfire = setup_logfire + + def configure_client(self, config: ClientConfig) -> ClientConfig: + interceptors = config.get('interceptors', []) + config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporal'))] + return super().configure_client(config) + + async def connect_service_client(self, config: ConnectConfig) -> ServiceClient: + self.setup_logfire() + + config.runtime = Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) + return await super().connect_service_client(config) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py new file mode 100644 index 000000000..b89f66ba3 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Literal + +from pydantic import ConfigDict, with_config +from temporalio import activity, workflow +from temporalio.workflow import ActivityConfig + +from pydantic_ai._run_context import RunContext +from pydantic_ai.exceptions import UserError +from pydantic_ai.mcp import MCPServer, ToolResult +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.abstract import ToolsetTool + +from ._run_context import TemporalRunContext +from ._toolset import TemporalWrapperToolset + + +@dataclass +@with_config(ConfigDict(arbitrary_types_allowed=True)) +class _GetToolsParams: + serialized_run_context: Any + + +@dataclass +@with_config(ConfigDict(arbitrary_types_allowed=True)) +class _CallToolParams: + name: str + tool_args: dict[str, Any] + serialized_run_context: Any + tool_def: ToolDefinition + + +class TemporalMCPServer(TemporalWrapperToolset): + def __init__( + self, + server: MCPServer, + activity_config: ActivityConfig = {}, + tool_activity_config: dict[str, ActivityConfig | Literal[False]] = {}, + run_context_type: type[TemporalRunContext] = TemporalRunContext, + ): + super().__init__(server) + self.activity_config = activity_config + self.tool_activity_config = tool_activity_config + self.run_context_type = run_context_type + + id = server.id + assert id is not None + + @activity.defn(name=f'mcp_server__{id}__get_tools') + async def get_tools_activity(params: _GetToolsParams) -> dict[str, ToolDefinition]: + run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context) + tools = await self.wrapped.get_tools(run_context) + # ToolsetTool is not serializable as its holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time), + # so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity. + return {name: tool.tool_def for name, tool in tools.items()} + + self.get_tools_activity = get_tools_activity + + @activity.defn(name=f'mcp_server__{id}__call_tool') + async def call_tool_activity(params: _CallToolParams) -> ToolResult: + run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context) + return await self.wrapped.call_tool( + params.name, + params.tool_args, + run_context, + self.wrapped_server.tool_for_tool_def(params.tool_def), + ) + + self.call_tool_activity = call_tool_activity + + @property + def wrapped_server(self) -> MCPServer: + assert isinstance(self.wrapped, MCPServer) + return self.wrapped + + @property + def temporal_activities(self) -> list[Callable[..., Any]]: + return [self.get_tools_activity, self.call_tool_activity] + + async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: + if not workflow.in_workflow(): + return await super().get_tools(ctx) + + serialized_run_context = self.run_context_type.serialize_run_context(ctx) + tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=self.get_tools_activity, + arg=_GetToolsParams(serialized_run_context=serialized_run_context), + **self.activity_config, + ) + return {name: self.wrapped_server.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()} + + async def call_tool( + self, + name: str, + tool_args: dict[str, Any], + ctx: RunContext[Any], + tool: ToolsetTool[Any], + ) -> ToolResult: + if not workflow.in_workflow(): + return await super().call_tool(name, tool_args, ctx, tool) + + tool_activity_config = self.tool_activity_config.get(name, {}) + if tool_activity_config is False: + raise UserError('Disabling running an MCP tool in a Temporal activity is not possible.') + + tool_activity_config = self.activity_config | tool_activity_config + serialized_run_context = self.run_context_type.serialize_run_context(ctx) + return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=self.call_tool_activity, + arg=_CallToolParams( + name=name, + tool_args=tool_args, + serialized_run_context=serialized_run_context, + tool_def=tool.tool_def, + ), + **tool_activity_config, + ) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py new file mode 100644 index 000000000..97f0a7041 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Callable + +from pydantic import ConfigDict, with_config +from temporalio import activity, workflow +from temporalio.workflow import ActivityConfig + +from pydantic_ai._run_context import RunContext +from pydantic_ai.agent import EventStreamHandler +from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import ( + ModelMessage, + ModelResponse, + ModelResponseStreamEvent, +) +from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse +from pydantic_ai.models.wrapper import WrapperModel +from pydantic_ai.settings import ModelSettings +from pydantic_ai.usage import Usage + +from ._run_context import TemporalRunContext + + +@dataclass +@with_config(ConfigDict(arbitrary_types_allowed=True)) +class _RequestParams: + messages: list[ModelMessage] + model_settings: ModelSettings | None + model_request_parameters: ModelRequestParameters + serialized_run_context: Any + + +class _TemporalStreamedResponse(StreamedResponse): + def __init__(self, response: ModelResponse): + self.response = response + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + return + # noinspection PyUnreachableCode + yield + + def get(self) -> ModelResponse: + """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far.""" + return self.response + + def usage(self) -> Usage: + """Get the usage of the response so far. This will not be the final usage until the stream is exhausted.""" + return self.response.usage + + @property + def model_name(self) -> str: + """Get the model name of the response.""" + return self.response.model_name or '' + + @property + def timestamp(self) -> datetime: + """Get the timestamp of the response.""" + return self.response.timestamp + + +class TemporalModel(WrapperModel): + def __init__( + self, + model: Model, + activity_config: ActivityConfig = {}, + event_stream_handler: EventStreamHandler[Any] | None = None, + run_context_type: type[TemporalRunContext] = TemporalRunContext, + ): + super().__init__(model) + self.activity_config = activity_config + self.event_stream_handler = event_stream_handler + self.run_context_type = run_context_type + + id = '_'.join([model.system, model.model_name]) + + @activity.defn(name=f'model__{id}__request') + async def request_activity(params: _RequestParams) -> ModelResponse: + return await self.wrapped.request(params.messages, params.model_settings, params.model_request_parameters) + + self.request_activity = request_activity + + @activity.defn(name=f'model__{id}__request_stream') + async def request_stream_activity(params: _RequestParams) -> ModelResponse: + run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context) + async with self.wrapped.request_stream( + params.messages, params.model_settings, params.model_request_parameters, run_context + ) as streamed_response: + # Keep in sync with `AgentStream.__aiter__` + async def aiter(): + # `AgentStream.__aiter__`, which this is based on, calls `_get_usage_checking_stream_response` here, + # but we don't have access to the `_usage_limits`. + + # TODO: Create new stream wrapper that does this + async for event in streamed_response: + yield event + if ( + final_result_event := params.model_request_parameters.get_final_result_event(event) + ) is not None: + yield final_result_event + break + + # If we broke out of the above loop, we need to yield the rest of the events + # If we didn't, this will just be a no-op + async for event in streamed_response: + yield event + + assert event_stream_handler is not None + await event_stream_handler(run_context, aiter()) + + async for _ in streamed_response: + pass + return streamed_response.get() + + self.request_stream_activity = request_stream_activity + + @property + def temporal_activities(self) -> list[Callable[..., Any]]: + return [self.request_activity, self.request_stream_activity] + + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + if not workflow.in_workflow(): + return await super().request(messages, model_settings, model_request_parameters) + + return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=self.request_activity, + arg=_RequestParams( + messages=messages, + model_settings=model_settings, + model_request_parameters=model_request_parameters, + serialized_run_context=None, + ), + **self.activity_config, + ) + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, + ) -> AsyncIterator[StreamedResponse]: + if self.event_stream_handler is None: + raise UserError('Streaming with Temporal requires `Agent` to have an `event_stream_handler`') + if run_context is None: + raise UserError('Streaming with Temporal requires `request_stream` to be called with a `run_context`') + + serialized_run_context = self.run_context_type.serialize_run_context(run_context) + response = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=self.request_stream_activity, + arg=_RequestParams( + messages=messages, + model_settings=model_settings, + model_request_parameters=model_request_parameters, + serialized_run_context=serialized_run_context, + ), + **self.activity_config, + ) + yield _TemporalStreamedResponse(response) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py new file mode 100644 index 000000000..cec240b3d --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import Any + +from pydantic_ai._run_context import RunContext +from pydantic_ai.exceptions import UserError + + +class TemporalRunContext(RunContext[Any]): + def __init__(self, **kwargs: Any): + self.__dict__ = kwargs + setattr( + self, + '__dataclass_fields__', + {name: field for name, field in RunContext.__dataclass_fields__.items() if name in kwargs}, + ) + + def __getattribute__(self, name: str) -> Any: + try: + return super().__getattribute__(name) + except AttributeError as e: + if name in RunContext.__dataclass_fields__: + raise AttributeError( + f'{self.__class__.__name__!r} object has no attribute {name!r}. ' + 'To make the attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to `TemporalAgent`.' + ) + else: + raise e + + @classmethod + def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]: + return { + 'retries': ctx.retries, + 'tool_call_id': ctx.tool_call_id, + 'tool_name': ctx.tool_name, + 'retry': ctx.retry, + 'run_step': ctx.run_step, + } + + @classmethod + def deserialize_run_context(cls, ctx: dict[str, Any]) -> RunContext[Any]: + return cls(**ctx) + + +class TemporalRunContextWithDeps(TemporalRunContext): + @classmethod + def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]: + if not isinstance(ctx.deps, dict): + raise UserError( + 'The `deps` object must be a JSON-serializable dictionary in order to be used with Temporal. ' + 'To use a different type, pass a `TemporalRunContext` subclass to `TemporalAgent` with custom `serialize_run_context` and `deserialize_run_context` class methods.' + ) + return {**super().serialize_run_context(ctx), 'deps': ctx.deps} # pyright: ignore[reportUnknownMemberType] diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py new file mode 100644 index 000000000..c568ef793 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Callable, Literal + +from temporalio.workflow import ActivityConfig + +from pydantic_ai.ext.temporal._run_context import TemporalRunContext +from pydantic_ai.mcp import MCPServer +from pydantic_ai.toolsets.abstract import AbstractToolset +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.toolsets.wrapper import WrapperToolset + + +class TemporalWrapperToolset(WrapperToolset[Any], ABC): + @property + @abstractmethod + def temporal_activities(self) -> list[Callable[..., Any]]: + raise NotImplementedError + + +def temporalize_toolset( + toolset: AbstractToolset[Any], + activity_config: ActivityConfig = {}, + tool_activity_config: dict[str, ActivityConfig | Literal[False]] = {}, + run_context_type: type[TemporalRunContext] = TemporalRunContext, +) -> AbstractToolset[Any]: + """Temporalize a toolset. + + Args: + toolset: The toolset to temporalize. + activity_config: The Temporal activity config to use. + tool_activity_config: The Temporal activity config to use for specific tools identified by tool name. + run_context_type: The type of run context to use to serialize and deserialize the run context. + """ + if isinstance(toolset, FunctionToolset): + from ._function_toolset import TemporalFunctionToolset + + return TemporalFunctionToolset(toolset, activity_config, tool_activity_config, run_context_type) + elif isinstance(toolset, MCPServer): + from ._mcp_server import TemporalMCPServer + + return TemporalMCPServer(toolset, activity_config, tool_activity_config, run_context_type) + else: + return toolset diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index c84f4b10b..b3426ed36 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -64,10 +64,12 @@ class MCPServer(AbstractToolset[Any], ABC): read_timeout: float = 5 * 60 process_tool_call: ProcessToolCallback | None = None allow_sampling: bool = True - max_retries: int = 1 sampling_model: models.Model | None = None + max_retries: int = 1 # } end of "abstract fields" + _id: str | None = field(init=False, default=None) + _enter_lock: Lock = field(compare=False) _running_count: int _exit_stack: AsyncExitStack | None @@ -76,7 +78,31 @@ class MCPServer(AbstractToolset[Any], ABC): _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] _write_stream: MemoryObjectSendStream[SessionMessage] - def __post_init__(self): + def __init__( + self, + tool_prefix: str | None = None, + log_level: mcp_types.LoggingLevel | None = None, + log_handler: LoggingFnT | None = None, + timeout: float = 5, + read_timeout: float = 5 * 60, + process_tool_call: ProcessToolCallback | None = None, + allow_sampling: bool = True, + sampling_model: models.Model | None = None, + max_retries: int = 1, + id: str | None = None, + ): + self.tool_prefix = tool_prefix + self.log_level = log_level + self.log_handler = log_handler + self.timeout = timeout + self.read_timeout = read_timeout + self.process_tool_call = process_tool_call + self.allow_sampling = allow_sampling + self.sampling_model = sampling_model + self.max_retries = max_retries + + self._id = id or tool_prefix + self._enter_lock = Lock() self._running_count = 0 self._exit_stack = None @@ -96,12 +122,16 @@ async def client_streams( yield @property - def name(self) -> str: + def id(self) -> str | None: + return self._id + + @property + def label(self) -> str: return repr(self) @property def tool_name_conflict_hint(self) -> str: - return 'Consider setting `tool_prefix` to avoid name conflicts.' + return 'Set the `tool_prefix` attribute to avoid name conflicts.' async def list_tools(self) -> list[mcp_types.Tool]: """Retrieve tools that are currently active on the server. @@ -177,20 +207,25 @@ async def call_tool( async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: return { - name: ToolsetTool( - toolset=self, - tool_def=ToolDefinition( + name: self.tool_for_tool_def( + ToolDefinition( name=name, description=mcp_tool.description, parameters_json_schema=mcp_tool.inputSchema, ), - max_retries=self.max_retries, - args_validator=TOOL_SCHEMA_VALIDATOR, ) for mcp_tool in await self.list_tools() if (name := f'{self.tool_prefix}_{mcp_tool.name}' if self.tool_prefix else mcp_tool.name) } + def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]: + return ToolsetTool( + toolset=self, + tool_def=tool_def, + max_retries=self.max_retries, + args_validator=TOOL_SCHEMA_VALIDATOR, + ) + async def __aenter__(self) -> Self: """Enter the MCP server context. @@ -309,7 +344,7 @@ def _get_content( assert_never(resource) -@dataclass +@dataclass(init=False) class MCPServerStdio(MCPServer): """Runs an MCP server in a subprocess and communicates with it over stdin/stdout. @@ -387,17 +422,78 @@ async def main(): timeout: float = 5 """The timeout in seconds to wait for the client to initialize.""" + read_timeout: float = 5 * 60 + """Maximum time in seconds to wait for new messages before timing out. + + This timeout applies to the long-lived connection after it's established. + If no new messages are received within this time, the connection will be considered stale + and may be closed. Defaults to 5 minutes (300 seconds). + """ + process_tool_call: ProcessToolCallback | None = None """Hook to customize tool calling and optionally pass extra metadata.""" allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" + sampling_model: models.Model | None = None + """The model to use for sampling.""" + max_retries: int = 1 """The maximum number of times to retry a tool call.""" - sampling_model: models.Model | None = None - """The model to use for sampling.""" + def __init__( + self, + command: str, + args: Sequence[str], + env: dict[str, str] | None = None, + cwd: str | Path | None = None, + id: str | None = None, + tool_prefix: str | None = None, + log_level: mcp_types.LoggingLevel | None = None, + log_handler: LoggingFnT | None = None, + timeout: float = 5, + read_timeout: float = 5 * 60, + process_tool_call: ProcessToolCallback | None = None, + allow_sampling: bool = True, + sampling_model: models.Model | None = None, + max_retries: int = 1, + ): + """Build a new MCP server. + + Args: + command: The command to run. + args: The arguments to pass to the command. + env: The environment variables to set in the subprocess. + cwd: The working directory to use when spawning the process. + id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. + tool_prefix: A prefix to add to all tools that are registered with the server. + log_level: The log level to set when connecting to the server, if any. + log_handler: A handler for logging messages from the server. + timeout: The timeout in seconds to wait for the client to initialize. + read_timeout: Maximum time in seconds to wait for new messages before timing out. + process_tool_call: Hook to customize tool calling and optionally pass extra metadata. + allow_sampling: Whether to allow MCP sampling through this client. + sampling_model: The model to use for sampling. + max_retries: The maximum number of times to retry a tool call. + """ + self.command = command + self.args = args + self.env = env + self.cwd = cwd + + super().__init__( + tool_prefix, + log_level, + log_handler, + timeout, + read_timeout, + process_tool_call, + allow_sampling, + sampling_model, + max_retries, + id, + ) @asynccontextmanager async def client_streams( @@ -413,7 +509,10 @@ async def client_streams( yield read_stream, write_stream def __repr__(self) -> str: - return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})' + if self.id: + return f'{self.__class__.__name__} {self.id!r}' + else: + return f'{self.__class__.__name__}(command={self.command!r}, args={self.args!r})' @dataclass(init=False) @@ -453,14 +552,6 @@ class _MCPServerHTTP(MCPServer): ``` """ - read_timeout: float = 5 * 60 - """Maximum time in seconds to wait for new messages before timing out. - - This timeout applies to the long-lived connection after it's established. - If no new messages are received within this time, the connection will be considered stale - and may be closed. Defaults to 5 minutes (300 seconds). - """ - # last fields are re-defined from the parent class so they appear as fields tool_prefix: str | None = None """A prefix to add to all tools that are registered with the server. @@ -488,46 +579,71 @@ class _MCPServerHTTP(MCPServer): If the connection cannot be established within this time, the operation will fail. """ + read_timeout: float = 5 * 60 + """Maximum time in seconds to wait for new messages before timing out. + + This timeout applies to the long-lived connection after it's established. + If no new messages are received within this time, the connection will be considered stale + and may be closed. Defaults to 5 minutes (300 seconds). + """ + process_tool_call: ProcessToolCallback | None = None """Hook to customize tool calling and optionally pass extra metadata.""" allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" - max_retries: int = 1 - """The maximum number of times to retry a tool call.""" - sampling_model: models.Model | None = None """The model to use for sampling.""" + max_retries: int = 1 + """The maximum number of times to retry a tool call.""" + def __init__( self, *, url: str, headers: dict[str, str] | None = None, http_client: httpx.AsyncClient | None = None, - read_timeout: float | None = None, + id: str | None = None, tool_prefix: str | None = None, log_level: mcp_types.LoggingLevel | None = None, log_handler: LoggingFnT | None = None, timeout: float = 5, + read_timeout: float | None = None, process_tool_call: ProcessToolCallback | None = None, allow_sampling: bool = True, - max_retries: int = 1, sampling_model: models.Model | None = None, - **kwargs: Any, + max_retries: int = 1, + **_deprecated_kwargs: Any, ): - # Handle deprecated sse_read_timeout parameter - if 'sse_read_timeout' in kwargs: + """Build a new MCP server. + + Args: + url: The URL of the endpoint on the MCP server. + headers: Optional HTTP headers to be sent with each request to the endpoint. + http_client: An `httpx.AsyncClient` to use with the endpoint. + id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. + tool_prefix: A prefix to add to all tools that are registered with the server. + log_level: The log level to set when connecting to the server, if any. + log_handler: A handler for logging messages from the server. + timeout: The timeout in seconds to wait for the client to initialize. + read_timeout: Maximum time in seconds to wait for new messages before timing out. + process_tool_call: Hook to customize tool calling and optionally pass extra metadata. + allow_sampling: Whether to allow MCP sampling through this client. + sampling_model: The model to use for sampling. + max_retries: The maximum number of times to retry a tool call. + """ + if 'sse_read_timeout' in _deprecated_kwargs: if read_timeout is not None: raise TypeError("'read_timeout' and 'sse_read_timeout' cannot be set at the same time.") warnings.warn( "'sse_read_timeout' is deprecated, use 'read_timeout' instead.", DeprecationWarning, stacklevel=2 ) - read_timeout = kwargs.pop('sse_read_timeout') + read_timeout = _deprecated_kwargs.pop('sse_read_timeout') - _utils.validate_empty_kwargs(kwargs) + _utils.validate_empty_kwargs(_deprecated_kwargs) if read_timeout is None: read_timeout = 5 * 60 @@ -535,15 +651,19 @@ def __init__( self.url = url self.headers = headers self.http_client = http_client - self.tool_prefix = tool_prefix - self.log_level = log_level - self.log_handler = log_handler - self.timeout = timeout - self.process_tool_call = process_tool_call - self.allow_sampling = allow_sampling - self.max_retries = max_retries - self.sampling_model = sampling_model - self.read_timeout = read_timeout + + super().__init__( + tool_prefix, + log_level, + log_handler, + timeout, + read_timeout, + process_tool_call, + allow_sampling, + sampling_model, + max_retries, + id, + ) @property @abstractmethod @@ -606,7 +726,10 @@ def httpx_client_factory( yield read_stream, write_stream def __repr__(self) -> str: # pragma: no cover - return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})' + if self.id: + return f'{self.__class__.__name__} {self.id!r}' + else: + return f'{self.__class__.__name__}(url={self.url!r})' @dataclass(init=False) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 6cdcbfbd6..fea76fd75 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -13,19 +13,30 @@ from dataclasses import dataclass, field, replace from datetime import datetime from functools import cache, cached_property -from typing import Generic, TypeVar, overload +from typing import Any, Generic, TypeVar, overload import httpx from typing_extensions import Literal, TypeAliasType, TypedDict -from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec - from .. import _utils from .._output import OutputObjectDefinition from .._parts_manager import ModelResponsePartsManager +from .._run_context import RunContext from ..exceptions import UserError -from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl +from ..messages import ( + FileUrl, + FinalResultEvent, + ModelMessage, + ModelRequest, + ModelResponse, + ModelResponseStreamEvent, + PartStartEvent, + TextPart, + ToolCallPart, + VideoUrl, +) from ..output import OutputMode +from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec from ..profiles._json_schema import JsonSchemaTransformer from ..settings import ModelSettings from ..tools import ToolDefinition @@ -343,6 +354,22 @@ class ModelRequestParameters: output_tools: list[ToolDefinition] = field(default_factory=list) allow_text_output: bool = True + @cached_property + def tool_defs(self) -> dict[str, ToolDefinition]: + return {tool_def.name: tool_def for tool_def in [*self.function_tools, *self.output_tools]} + + def get_final_result_event(self, e: ModelResponseStreamEvent) -> FinalResultEvent | None: + """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" + if isinstance(e, PartStartEvent): + new_part = e.part + if isinstance(new_part, TextPart) and self.allow_text_output: # pragma: no branch + return FinalResultEvent(tool_name=None, tool_call_id=None) + elif isinstance(new_part, ToolCallPart) and (tool_def := self.tool_defs.get(new_part.tool_name)): + if tool_def.kind == 'output': + return FinalResultEvent(tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id) + elif tool_def.kind == 'deferred': + return FinalResultEvent(tool_name=None, tool_call_id=None) + __repr__ = _utils.dataclasses_no_defaults_repr @@ -388,6 +415,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: """Make a request to the model and return a streaming response.""" # This method is not required, but you need to implement it if you want to support streamed responses diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index a62741568..5ea1f5347 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -11,6 +11,7 @@ from typing_extensions import assert_never from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._run_context import RunContext from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( BinaryContent, @@ -171,6 +172,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._messages_create( @@ -297,10 +299,7 @@ async def _process_streamed_response(self, response: AsyncStream[BetaRawMessageS ) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolParam]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901 """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index f16f9d111..00eeed4ae 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -15,6 +15,7 @@ from typing_extensions import ParamSpec, assert_never from pydantic_ai import _utils, usage +from pydantic_ai._run_context import RunContext from pydantic_ai.messages import ( AudioUrl, BinaryContent, @@ -225,10 +226,7 @@ def __init__( super().__init__(settings=settings, profile=profile or provider.model_profile) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] @staticmethod def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef: @@ -264,6 +262,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, True, settings, model_request_parameters) diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 4243ef492..652b6740f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -238,10 +238,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]: return cohere_messages def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] @staticmethod def _map_tool_call(t: ToolCallPart) -> ToolCallV2: diff --git a/pydantic_ai_slim/pydantic_ai/models/fallback.py b/pydantic_ai_slim/pydantic_ai/models/fallback.py index 4455defce..498d8e1bd 100644 --- a/pydantic_ai_slim/pydantic_ai/models/fallback.py +++ b/pydantic_ai_slim/pydantic_ai/models/fallback.py @@ -3,10 +3,11 @@ from collections.abc import AsyncIterator from contextlib import AsyncExitStack, asynccontextmanager, suppress from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable from opentelemetry.trace import get_current_span +from pydantic_ai._run_context import RunContext from pydantic_ai.models.instrumented import InstrumentedModel from ..exceptions import FallbackExceptionGroup, ModelHTTPError @@ -83,6 +84,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: """Try each model in sequence until one succeeds.""" exceptions: list[Exception] = [] @@ -92,7 +94,7 @@ async def request_stream( async with AsyncExitStack() as stack: try: response = await stack.enter_async_context( - model.request_stream(messages, model_settings, customized_model_request_parameters) + model.request_stream(messages, model_settings, customized_model_request_parameters, run_context) ) except Exception as exc: if self._fallback_on(exc): diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 4efc6d710..5db489b7d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -7,13 +7,12 @@ from dataclasses import dataclass, field from datetime import datetime from itertools import chain -from typing import Callable, Union +from typing import Any, Callable, Union from typing_extensions import TypeAlias, assert_never, overload -from pydantic_ai.profiles import ModelProfileSpec - from .. import _utils, usage +from .._run_context import RunContext from .._utils import PeekableAsyncStream from ..messages import ( BinaryContent, @@ -30,6 +29,7 @@ UserContent, UserPromptPart, ) +from ..profiles import ModelProfileSpec from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse @@ -145,6 +145,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: agent_info = AgentInfo( model_request_parameters.function_tools, diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 7c371a943..94ed3255f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -13,10 +13,9 @@ from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse from typing_extensions import NotRequired, TypedDict, assert_never -from pydantic_ai.providers import Provider, infer_provider - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._output import OutputObjectDefinition +from .._run_context import RunContext from ..exceptions import UserError from ..messages import ( BinaryContent, @@ -36,6 +35,7 @@ VideoUrl, ) from ..profiles import ModelProfileSpec +from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -164,6 +164,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() async with self._make_request( @@ -182,9 +183,7 @@ def system(self) -> str: return self._system def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None: - tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [_function_from_abstract_tool(t) for t in model_request_parameters.output_tools] + tools = [_function_from_abstract_tool(t) for t in model_request_parameters.tool_defs.values()] return _GeminiTools(function_declarations=tools) if tools else None def _get_tool_config( diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index ad5da243c..d6dd82b4a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -12,6 +12,7 @@ from .. import UnexpectedModelBehavior, _utils, usage from .._output import OutputObjectDefinition +from .._run_context import RunContext from ..exceptions import UserError from ..messages import ( BinaryContent, @@ -44,7 +45,7 @@ ) try: - from google import genai + from google.genai import Client from google.genai.types import ( ContentDict, ContentUnionDict, @@ -130,10 +131,10 @@ class GoogleModel(Model): Apart from `__init__`, all methods are private or match those of the base class. """ - client: genai.Client = field(repr=False) + client: Client = field(repr=False) _model_name: GoogleModelName = field(repr=False) - _provider: Provider[genai.Client] = field(repr=False) + _provider: Provider[Client] = field(repr=False) _url: str | None = field(repr=False) _system: str = field(default='google', repr=False) @@ -141,7 +142,7 @@ def __init__( self, model_name: GoogleModelName, *, - provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla', + provider: Literal['google-gla', 'google-vertex'] | Provider[Client] = 'google-gla', profile: ModelProfileSpec | None = None, settings: ModelSettings | None = None, ): @@ -187,6 +188,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() model_settings = cast(GoogleModelSettings, model_settings or {}) @@ -204,16 +206,10 @@ def system(self) -> str: return self._system def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None: - tools: list[ToolDict] = [ + return [ ToolDict(function_declarations=[_function_declaration_from_tool(t)]) - for t in model_request_parameters.function_tools + for t in model_request_parameters.tool_defs.values() ] - if model_request_parameters.output_tools: - tools += [ - ToolDict(function_declarations=[_function_declaration_from_tool(t)]) - for t in model_request_parameters.output_tools - ] - return tools or None def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index ffca84b44..22dc69d56 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -5,13 +5,13 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime -from typing import Literal, Union, cast, overload +from typing import Any, Literal, Union, cast, overload from typing_extensions import assert_never -from pydantic_ai._thinking_part import split_content_into_text_and_thinking - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._run_context import RunContext +from .._thinking_part import split_content_into_text_and_thinking from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime from ..messages import ( BinaryContent, @@ -166,6 +166,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._completions_create( @@ -285,10 +286,7 @@ async def _process_streamed_response(self, response: AsyncStream[chat.ChatComple ) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 4b3c2ff40..ba20f552e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -5,14 +5,13 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Literal, Union, cast, overload +from typing import Any, Literal, Union, cast, overload from typing_extensions import assert_never -from pydantic_ai._thinking_part import split_content_into_text_and_thinking -from pydantic_ai.providers import Provider, infer_provider - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._run_context import RunContext +from .._thinking_part import split_content_into_text_and_thinking from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc from ..messages import ( AudioUrl, @@ -33,6 +32,7 @@ UserPromptPart, VideoUrl, ) +from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests @@ -146,6 +146,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._completions_create( @@ -272,10 +273,7 @@ async def _process_streamed_response(self, response: AsyncIterable[ChatCompletio ) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] async def _map_messages( self, messages: list[ModelMessage] diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index 233020f6f..49a97bd49 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -18,6 +18,7 @@ from opentelemetry.util.types import AttributeValue from pydantic import TypeAdapter +from .._run_context import RunContext from ..messages import ( ModelMessage, ModelRequest, @@ -222,12 +223,13 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: with self._instrument(messages, model_settings, model_request_parameters) as finish: response_stream: StreamedResponse | None = None try: async with super().request_stream( - messages, model_settings, model_request_parameters + messages, model_settings, model_request_parameters, run_context ) as response_stream: yield response_stream finally: diff --git a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py index ebfaac92d..a4f649786 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py +++ b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py @@ -3,9 +3,10 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from .. import _mcp, exceptions, usage +from .._run_context import RunContext from ..messages import ModelMessage, ModelResponse from ..settings import ModelSettings from . import Model, ModelRequestParameters, StreamedResponse @@ -76,6 +77,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: raise NotImplementedError('MCP Sampling does not support streaming') yield diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 0104e2055..c5e6211b2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -11,9 +11,9 @@ from httpx import Timeout from typing_extensions import assert_never -from pydantic_ai._thinking_part import split_content_into_text_and_thinking - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils +from .._run_context import RunContext +from .._thinking_part import split_content_into_text_and_thinking from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime from ..messages import ( BinaryContent, @@ -173,6 +173,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: """Make a streaming request to the model from Pydantic AI call.""" check_allow_model_requests() @@ -233,11 +234,7 @@ async def _stream_completions_create( response: MistralEventStreamAsync[MistralCompletionEvent] | None mistral_messages = self._map_messages(messages) - if ( - model_request_parameters.output_tools - and model_request_parameters.function_tools - or model_request_parameters.function_tools - ): + if model_request_parameters.function_tools: # Function Calling response = await self.client.chat.stream_async( model=str(self._model_name), @@ -305,16 +302,13 @@ def _map_function_and_output_tools_definition( Returns None if both function_tools and output_tools are empty. """ - all_tools: list[ToolDefinition] = ( - model_request_parameters.function_tools + model_request_parameters.output_tools - ) tools = [ MistralTool( function=MistralFunction( name=r.name, parameters=r.parameters_json_schema, description=r.description or '' ) ) - for r in all_tools + for r in model_request_parameters.tool_defs.values() ] return tools if tools else None diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 1881862cd..6124bb964 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -11,12 +11,10 @@ from pydantic import ValidationError from typing_extensions import assert_never -from pydantic_ai._thinking_part import split_content_into_text_and_thinking -from pydantic_ai.profiles.openai import OpenAIModelProfile -from pydantic_ai.providers import Provider, infer_provider - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition +from .._run_context import RunContext +from .._thinking_part import split_content_into_text_and_thinking from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime from ..messages import ( AudioUrl, @@ -38,6 +36,8 @@ VideoUrl, ) from ..profiles import ModelProfileSpec +from ..profiles.openai import OpenAIModelProfile +from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -254,6 +254,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._completions_create( @@ -438,10 +439,7 @@ async def _process_streamed_response(self, response: AsyncStream[ChatCompletionC ) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`.""" @@ -678,6 +676,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._responses_create( @@ -834,10 +833,7 @@ def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reason return Reasoning(effort=reasoning_effort, summary=reasoning_summary) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] def _map_tool_definition(self, f: ToolDefinition) -> responses.FunctionToolParam: return { diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index a80d551ff..d1b015078 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -12,6 +12,7 @@ from typing_extensions import assert_never from .. import _utils +from .._run_context import RunContext from ..messages import ( ModelMessage, ModelRequest, @@ -118,6 +119,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: self.last_model_request_parameters = model_request_parameters diff --git a/pydantic_ai_slim/pydantic_ai/models/wrapper.py b/pydantic_ai_slim/pydantic_ai/models/wrapper.py index cc91f9c72..9818ad603 100644 --- a/pydantic_ai_slim/pydantic_ai/models/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/models/wrapper.py @@ -6,6 +6,7 @@ from functools import cached_property from typing import Any +from .._run_context import RunContext from ..messages import ModelMessage, ModelResponse from ..profiles import ModelProfile from ..settings import ModelSettings @@ -35,8 +36,11 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: - async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream: + async with self.wrapped.request_stream( + messages, model_settings, model_request_parameters, run_context + ) as response_stream: yield response_stream def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters: diff --git a/pydantic_ai_slim/pydantic_ai/providers/google.py b/pydantic_ai_slim/pydantic_ai/providers/google.py index fc876fcff..70eaa6d86 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/google.py +++ b/pydantic_ai_slim/pydantic_ai/providers/google.py @@ -10,8 +10,8 @@ from pydantic_ai.providers import Provider try: - from google import genai from google.auth.credentials import Credentials + from google.genai import Client except ImportError as _import_error: raise ImportError( 'Please install the `google-genai` package to use the Google provider, ' @@ -19,7 +19,7 @@ ) from _import_error -class GoogleProvider(Provider[genai.Client]): +class GoogleProvider(Provider[Client]): """Provider for Google.""" @property @@ -31,7 +31,7 @@ def base_url(self) -> str: return str(self._client._api_client._http_options.base_url) # type: ignore[reportPrivateUsage] @property - def client(self) -> genai.Client: + def client(self) -> Client: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: @@ -50,7 +50,7 @@ def __init__( ) -> None: ... @overload - def __init__(self, *, client: genai.Client) -> None: ... + def __init__(self, *, client: Client) -> None: ... @overload def __init__(self, *, vertexai: bool = False) -> None: ... @@ -62,7 +62,7 @@ def __init__( credentials: Credentials | None = None, project: str | None = None, location: VertexAILocation | Literal['global'] | None = None, - client: genai.Client | None = None, + client: Client | None = None, vertexai: bool | None = None, ) -> None: """Create a new Google provider. @@ -95,13 +95,13 @@ def __init__( 'Set the `GOOGLE_API_KEY` environment variable or pass it via `GoogleProvider(api_key=...)`' 'to use the Google Generative Language API.' ) - self._client = genai.Client( + self._client = Client( vertexai=vertexai, api_key=api_key, http_options={'headers': {'User-Agent': get_user_agent()}}, ) else: - self._client = genai.Client( + self._client = Client( vertexai=vertexai, project=project or os.environ.get('GOOGLE_CLOUD_PROJECT'), # From https://github.com/pydantic/pydantic-ai/pull/2031/files#r2169682149: diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index e640302b2..5261f75c8 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -46,6 +46,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _raw_stream_response: models.StreamedResponse _output_schema: OutputSchema[OutputDataT] + _model_request_parameters: models.ModelRequestParameters _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] _usage_limits: UsageLimits | None @@ -230,32 +231,12 @@ def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: return self._agent_stream_iterator async def aiter(): - output_schema = self._output_schema - - def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None: - """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" - if isinstance(e, _messages.PartStartEvent): - new_part = e.part - if isinstance(new_part, _messages.TextPart) and isinstance( - output_schema, TextOutputSchema - ): # pragma: no branch - return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) - elif isinstance(new_part, _messages.ToolCallPart) and ( - tool_def := self._tool_manager.get_tool_def(new_part.tool_name) - ): - if tool_def.kind == 'output': - return _messages.FinalResultEvent( - tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id - ) - elif tool_def.kind == 'deferred': - return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) - usage_checking_stream = _get_usage_checking_stream_response( self._raw_stream_response, self._usage_limits, self.usage ) async for event in usage_checking_stream: yield event - if (final_result_event := _get_final_result_event(event)) is not None: + if (final_result_event := self._model_request_parameters.get_final_result_event(event)) is not None: self._final_result_event = final_result_event yield final_result_event break diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py index 455336418..7f44c2be6 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py @@ -70,9 +70,23 @@ class AbstractToolset(ABC, Generic[AgentDepsT]): """ @property - def name(self) -> str: + @abstractmethod + def id(self) -> str | None: + """An ID for the toolset that is unique among all toolsets registered with the same agent. + + If you're implementing a concrete implementation that users can instantiate more than once, you should let them optionally pass a custom ID to the constructor and return that here. + + A toolset needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow. + """ + raise NotImplementedError() + + @property + def label(self) -> str: """The name of the toolset for use in error messages.""" - return self.__class__.__name__.replace('Toolset', ' toolset') + label = self.__class__.__name__ + if self.id: + label += f' {self.id!r}' + return label @property def tool_name_conflict_hint(self) -> str: @@ -116,6 +130,11 @@ def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: """Run a visitor function on all concrete toolsets that are not wrappers (i.e. they implement their own tool listing and calling).""" visitor(self) + def visit_and_replace( + self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]] + ) -> AbstractToolset[AgentDepsT]: + return visitor(self) + def filtered( self, filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool] ) -> FilteredToolset[AgentDepsT]: diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index 4b1511fae..43e2d0557 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -40,6 +40,14 @@ def __post_init__(self): self._entered_count = 0 self._exit_stack = None + @property + def id(self) -> str | None: + return None + + @property + def label(self) -> str: + return f'{self.__class__.__name__}({", ".join(toolset.label for toolset in self.toolsets)})' + async def __aenter__(self) -> Self: async with self._enter_lock: if self._entered_count == 0: @@ -64,7 +72,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[ for name, tool in tools.items(): if existing_tools := all_tools.get(name): raise UserError( - f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}' + f'{toolset.label} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.label}: {name!r}. {toolset.tool_name_conflict_hint}' ) all_tools[name] = _CombinedToolsetTool( @@ -86,3 +94,8 @@ async def call_tool( def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: for toolset in self.toolsets: toolset.apply(visitor) + + def visit_and_replace( + self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]] + ) -> AbstractToolset[AgentDepsT]: + return CombinedToolset(toolsets=[visitor(toolset) for toolset in self.toolsets]) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py index 3ad2e976b..a67c3b0ad 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass, replace +from dataclasses import dataclass, field, replace from typing import Any from pydantic_core import SchemaValidator, core_schema @@ -12,7 +12,7 @@ TOOL_SCHEMA_VALIDATOR = SchemaValidator(schema=core_schema.any_schema()) -@dataclass +@dataclass(init=False) class DeferredToolset(AbstractToolset[AgentDepsT]): """A toolset that holds deferred tools whose results will be produced outside of the Pydantic AI agent run in which they were called. @@ -20,6 +20,15 @@ class DeferredToolset(AbstractToolset[AgentDepsT]): """ tool_defs: list[ToolDefinition] + _id: str | None = field(init=False, default=None) + + def __init__(self, tool_defs: list[ToolDefinition], id: str | None = None): + self._id = id + self.tool_defs = tool_defs + + @property + def id(self) -> str | None: + return self._id async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: return { diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py index 63f44a1f0..c206b22c5 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/function.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -24,6 +24,7 @@ class _FunctionToolsetTool(ToolsetTool[AgentDepsT]): """A tool definition for a function toolset tool that keeps track of the function to call.""" call_func: Callable[[dict[str, Any], RunContext[AgentDepsT]], Awaitable[Any]] + is_async: bool @dataclass(init=False) @@ -35,14 +36,22 @@ class FunctionToolset(AbstractToolset[AgentDepsT]): max_retries: int = field(default=1) tools: dict[str, Tool[Any]] = field(default_factory=dict) + _id: str | None = field(init=False, default=None) - def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1): + def __init__( + self, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], + max_retries: int = 1, + id: str | None = None, + ): """Build a new function toolset. Args: tools: The tools to add to the toolset. max_retries: The maximum number of retries for each tool during a run. + id: An optional unique ID for the toolset. A toolset needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow. """ + self._id = id self.max_retries = max_retries self.tools = {} for tool in tools: @@ -51,6 +60,10 @@ def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, else: self.add_function(tool) + @property + def id(self) -> str | None: + return self._id + @overload def tool(self, func: ToolFuncEither[AgentDepsT, ToolParams], /) -> ToolFuncEither[AgentDepsT, ToolParams]: ... @@ -228,6 +241,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[ max_retries=tool.max_retries if tool.max_retries is not None else self.max_retries, args_validator=tool.function_schema.validator, call_func=tool.function_schema.call, + is_async=tool.function_schema.is_async, ) return tools diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py index be70ed4f0..a430b0dab 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py @@ -17,6 +17,10 @@ class PrefixedToolset(WrapperToolset[AgentDepsT]): prefix: str + @property + def tool_name_conflict_hint(self) -> str: + return 'Change the `prefix` attribute to avoid name conflicts.' + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: return { new_name: replace( diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py index 8440f1c46..fdbde6baa 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import Any, Callable from typing_extensions import Self @@ -18,6 +18,14 @@ class WrapperToolset(AbstractToolset[AgentDepsT]): wrapped: AbstractToolset[AgentDepsT] + @property + def id(self) -> str | None: + return None + + @property + def label(self) -> str: + return f'{self.__class__.__name__}({self.wrapped.label})' + async def __aenter__(self) -> Self: await self.wrapped.__aenter__() return self @@ -35,3 +43,8 @@ async def call_tool( def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: self.wrapped.apply(visitor) + + def visit_and_replace( + self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]] + ) -> AbstractToolset[AgentDepsT]: + return replace(self, wrapped=visitor(self.wrapped)) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 0e0801232..a0fed0b2f 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -84,6 +84,8 @@ evals = ["pydantic-evals=={{ version }}"] a2a = ["fasta2a>=0.4.1"] # AG-UI ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"] +# Temporal +temporal = ["temporalio>=1.15.0"] [dependency-groups] dev = [ diff --git a/pyproject.toml b/pyproject.toml index 841f186ef..fdd67f183 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,ag-ui]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,ag-ui,temporal]=={{ version }}", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] diff --git a/temporal.py b/temporal.py new file mode 100644 index 000000000..acb6d476c --- /dev/null +++ b/temporal.py @@ -0,0 +1,118 @@ +import asyncio +import random +from collections.abc import AsyncIterable +from datetime import timedelta + +import logfire +from temporalio import workflow +from temporalio.client import Client +from temporalio.worker import Worker +from temporalio.workflow import ActivityConfig +from typing_extensions import TypedDict + +from pydantic_ai import Agent, RunContext +from pydantic_ai.ext.temporal import ( + AgentPlugin, + LogfirePlugin, + PydanticAIPlugin, + TemporalAgent, + TemporalRunContextWithDeps, +) +from pydantic_ai.mcp import MCPServerStdio +from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent +from pydantic_ai.toolsets.function import FunctionToolset + + +class Deps(TypedDict): + country: str + + +async def event_stream_handler(ctx: RunContext[Deps], stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent]): + logfire.info(f'{ctx.run_step=}') + async for event in stream: + logfire.info(f'{event=}') + + +async def get_country(ctx: RunContext[Deps]) -> str: + return ctx.deps['country'] + + +def get_weather(city: str) -> str: + return 'sunny' + + +agent = Agent( + 'openai:gpt-4o', + deps_type=Deps, + toolsets=[ + FunctionToolset[Deps](tools=[get_weather], id='toolset'), + MCPServerStdio('python', ['-m', 'tests.mcp_server'], timeout=20, id='mcp'), + ], + tools=[get_country], + event_stream_handler=event_stream_handler, +) + +# This needs to be called in the same scope where the `agent` is bound to the workflow, +# as it modifies the `agent` object in place to swap out methods that use IO for ones that use Temporal activities. +temporal_agent = TemporalAgent( + agent, + activity_config=ActivityConfig(start_to_close_timeout=timedelta(seconds=60)), + toolset_activity_config={ + 'country': ActivityConfig(start_to_close_timeout=timedelta(seconds=120)), + }, + tool_activity_config={ + 'toolset': { + 'get_country': False, + 'get_weather': ActivityConfig(start_to_close_timeout=timedelta(seconds=180)), + }, + }, + run_context_type=TemporalRunContextWithDeps, +) + +with workflow.unsafe.imports_passed_through(): + import pandas # noqa: F401 + + +@workflow.defn +class MyAgentWorkflow: + @workflow.run + async def run(self, prompt: str, deps: Deps) -> str: + result = await temporal_agent.run(prompt, deps=deps) + return result.output + + +TASK_QUEUE = 'pydantic-ai-agent-task-queue' + + +def setup_logfire(): + logfire.configure(console=False) + logfire.instrument_pydantic_ai() + logfire.instrument_httpx(capture_all=True) + + +async def main(): + client = await Client.connect( + 'localhost:7233', + plugins=[PydanticAIPlugin(), LogfirePlugin(setup_logfire)], + ) + + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[MyAgentWorkflow], + plugins=[AgentPlugin(temporal_agent)], + ): + output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] + MyAgentWorkflow.run, + args=[ + 'what is the capital of the country? what is the weather there? what is the product name?', + Deps(country='Mexico'), + ], + id=f'my-agent-workflow-id-{random.random()}', + task_queue=TASK_QUEUE, + ) + print(output) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/tests/conftest.py b/tests/conftest.py index 0214100c9..d4d4a93ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -422,7 +422,7 @@ async def vertex_provider(): pytest.skip('Requires properly configured local google vertex config to pass') try: - from google import genai + from google.genai import Client from pydantic_ai.providers.google import GoogleProvider except ImportError: # pragma: lax no cover @@ -430,7 +430,7 @@ async def vertex_provider(): project = os.getenv('GOOGLE_PROJECT', 'pydantic-ai') location = os.getenv('GOOGLE_LOCATION', 'us-central1') - client = genai.Client(vertexai=True, project=project, location=location) + client = Client(vertexai=True, project=project, location=location) try: yield GoogleProvider(client=client) diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index b952bf716..831f6f339 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -11,6 +11,7 @@ from opentelemetry._events import NoOpEventLoggerProvider from opentelemetry.trace import NoOpTracerProvider +from pydantic_ai._run_context import RunContext from pydantic_ai.messages import ( AudioUrl, BinaryContent, @@ -89,6 +90,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext | None = None, ) -> AsyncIterator[StreamedResponse]: yield MyResponseStream() diff --git a/tests/test_examples.py b/tests/test_examples.py index 20a7ece7b..6e480cb0a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -269,6 +269,10 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: class MockMCPServer(AbstractToolset[Any]): + @property + def id(self) -> str | None: + return None + async def __aenter__(self) -> MockMCPServer: return self diff --git a/tests/test_mcp.py b/tests/test_mcp.py index de77b3587..f39510d27 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -247,7 +247,7 @@ def get_none() -> None: # pragma: no cover with pytest.raises( UserError, match=re.escape( - "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'], tool_prefix=None) defines a tool whose name conflicts with existing tool from Function toolset: 'get_none'. Consider setting `tool_prefix` to avoid name conflicts." + "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server']) defines a tool whose name conflicts with existing tool from FunctionToolset 'agent': 'get_none'. Set the `tool_prefix` attribute to avoid name conflicts." ), ): await agent.run('Get me a conflict') diff --git a/tests/test_tools.py b/tests/test_tools.py index e6a21a891..051c2310c 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,4 +1,5 @@ import json +import re from dataclasses import dataclass, replace from typing import Annotated, Any, Callable, Literal, Union @@ -606,7 +607,9 @@ def test_tool_return_conflict(): # this raises an error with pytest.raises( UserError, - match="Function toolset defines a tool whose name conflicts with existing tool from Output toolset: 'ctx_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.", + match=re.escape( + "FunctionToolset 'agent' defines a tool whose name conflicts with existing tool from OutputToolset 'output': 'ctx_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts." + ), ): Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')).run_sync( '', deps=0 @@ -616,7 +619,9 @@ def test_tool_return_conflict(): def test_tool_name_conflict_hint(): with pytest.raises( UserError, - match="Prefixed toolset defines a tool whose name conflicts with existing tool from Function toolset: 'foo_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.", + match=re.escape( + "PrefixedToolset(FunctionToolset 'tool') defines a tool whose name conflicts with existing tool from FunctionToolset 'agent': 'foo_tool'. Change the `prefix` attribute to avoid name conflicts." + ), ): def tool(x: int) -> int: @@ -625,7 +630,7 @@ def tool(x: int) -> int: def foo_tool(x: str) -> str: return x + 'foo' # pragma: no cover - function_toolset = FunctionToolset([tool]) + function_toolset = FunctionToolset([tool], id='tool') prefixed_toolset = PrefixedToolset(function_toolset, 'foo') Agent('test', tools=[foo_tool], toolsets=[prefixed_toolset]).run_sync('') diff --git a/uv.lock b/uv.lock index e4d40ca65..f33e27a01 100644 --- a/uv.lock +++ b/uv.lock @@ -2286,6 +2286,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695, upload-time = "2023-02-04T12:11:25.002Z" }, ] +[[package]] +name = "nexus-rpc" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ef/66/540687556bd28cf1ec370cc6881456203dfddb9dab047b8979c6865b5984/nexus_rpc-1.1.0.tar.gz", hash = "sha256:d65ad6a2f54f14e53ebe39ee30555eaeb894102437125733fb13034a04a44553", size = 77383, upload-time = "2025-07-07T19:03:58.368Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/2f/9e9d0dcaa4c6ffa22b7aa31069a8a264c753ff8027b36af602cce038c92f/nexus_rpc-1.1.0-py3-none-any.whl", hash = "sha256:d1b007af2aba186a27e736f8eaae39c03aed05b488084ff6c3d1785c9ba2ad38", size = 27743, upload-time = "2025-07-07T19:03:57.556Z" }, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -3021,7 +3033,7 @@ wheels = [ name = "pydantic-ai" source = { editable = "." } dependencies = [ - { name = "pydantic-ai-slim", extra = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "temporal", "vertexai"] }, ] [package.optional-dependencies] @@ -3060,7 +3072,7 @@ requires-dist = [ { name = "fasta2a", marker = "extra == 'a2a'", specifier = ">=0.4.1" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "temporal", "vertexai"], editable = "pydantic_ai_slim" }, ] provides-extras = ["a2a", "examples", "logfire"] @@ -3184,6 +3196,9 @@ openai = [ tavily = [ { name = "tavily-python" }, ] +temporal = [ + { name = "temporalio" }, +] vertexai = [ { name = "google-auth" }, { name = "requests" }, @@ -3240,9 +3255,10 @@ requires-dist = [ { name = "rich", marker = "extra == 'cli'", specifier = ">=13" }, { name = "starlette", marker = "extra == 'ag-ui'", specifier = ">=0.45.3" }, { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, + { name = "temporalio", marker = "extra == 'temporal'", specifier = ">=1.15.0" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] -provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] +provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "temporal", "vertexai"] [package.metadata.requires-dev] dev = [ @@ -4154,6 +4170,26 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a5/cd/71088461d7720128c78802289b3b36298f42745e5f8c334b0ffc157b881e/tavily_python-0.5.1-py3-none-any.whl", hash = "sha256:169601f703c55cf338758dcacfa7102473b479a9271d65a3af6fc3668990f757", size = 43767, upload-time = "2025-02-07T00:22:04.99Z" }, ] +[[package]] +name = "temporalio" +version = "1.15.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nexus-rpc" }, + { name = "protobuf" }, + { name = "python-dateutil", marker = "python_full_version < '3.11'" }, + { name = "types-protobuf" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0b/af/1a3619fc62333d0acbdf90cfc5ada97e68e8c0f79610363b2dbb30871d83/temporalio-1.15.0.tar.gz", hash = "sha256:a4bc6ca01717880112caab75d041713aacc8263dc66e41f5019caef68b344fa0", size = 1684485, upload-time = "2025-07-29T03:44:09.071Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/2d/0153f2bc459e0cb59d41d4dd71da46bf9a98ca98bc37237576c258d6696b/temporalio-1.15.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:74bc5cc0e6bdc161a43015538b0821b8713f5faa716c4209971c274b528e0d47", size = 12703607, upload-time = "2025-07-29T03:43:30.083Z" }, + { url = "https://files.pythonhosted.org/packages/e4/39/1b867ec698c8987aef3b7a7024b5c0c732841112fa88d021303d0fc69bea/temporalio-1.15.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:ee8001304dae5723d79797516cfeebe04b966fdbdf348e658fce3b43afdda3cd", size = 12232853, upload-time = "2025-07-29T03:43:38.909Z" }, + { url = "https://files.pythonhosted.org/packages/5e/3e/647d9a7c8b2f638f639717404c0bcbdd7d54fddd7844fdb802e3f40dc55f/temporalio-1.15.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8febd1ac36720817e69c2176aa4aca14a97fe0b83f0d2449c0c730b8f0174d02", size = 12636700, upload-time = "2025-07-29T03:43:49.066Z" }, + { url = "https://files.pythonhosted.org/packages/9a/13/7aa9ec694fec9fba39efdbf61d892bccf7d2b1aa3d9bd359544534c1d309/temporalio-1.15.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:202d81a42cafaed9ccc7ccbea0898838e3b8bf92fee65394f8790f37eafbaa63", size = 12860186, upload-time = "2025-07-29T03:43:57.644Z" }, + { url = "https://files.pythonhosted.org/packages/9f/2b/ba962401324892236148046dbffd805d4443d6df7a7dc33cc7964b566bf9/temporalio-1.15.0-cp39-abi3-win_amd64.whl", hash = "sha256:aae5b18d7c9960238af0f3ebf6b7e5959e05f452106fc0d21a8278d78724f780", size = 12932800, upload-time = "2025-07-29T03:44:06.271Z" }, +] + [[package]] name = "tenacity" version = "8.5.0" @@ -4375,6 +4411,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/63/2463d89481e811f007b0e1cd0a91e52e141b47f9de724d20db7b861dcfec/types_certifi-2021.10.8.3-py3-none-any.whl", hash = "sha256:b2d1e325e69f71f7c78e5943d410e650b4707bb0ef32e4ddf3da37f54176e88a", size = 2136, upload-time = "2022-06-09T15:19:03.127Z" }, ] +[[package]] +name = "types-protobuf" +version = "6.30.2.20250516" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ac/6c/5cf088aaa3927d1cc39910f60f220f5ff573ab1a6485b2836e8b26beb58c/types_protobuf-6.30.2.20250516.tar.gz", hash = "sha256:aecd1881770a9bb225ede66872ef7f0da4505edd0b193108edd9892e48d49a41", size = 62254, upload-time = "2025-05-16T03:06:50.794Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/66/06a9c161f5dd5deb4f5c016ba29106a8f1903eb9a1ba77d407dd6588fecb/types_protobuf-6.30.2.20250516-py3-none-any.whl", hash = "sha256:8c226d05b5e8b2623111765fa32d6e648bbc24832b4c2fddf0fa340ba5d5b722", size = 76480, upload-time = "2025-05-16T03:06:49.444Z" }, +] + [[package]] name = "types-requests" version = "2.31.0.6"