From 6263510c6986126751952f41c7baacb70b04eefd Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 21 Jul 2025 21:44:24 +0000 Subject: [PATCH 01/41] Add optional `id` field to toolsets --- docs/tools.md | 3 +- docs/toolsets.md | 15 +- pydantic_ai_slim/pydantic_ai/_output.py | 4 + pydantic_ai_slim/pydantic_ai/ag_ui.py | 3 +- pydantic_ai_slim/pydantic_ai/agent.py | 2 +- pydantic_ai_slim/pydantic_ai/ext/aci.py | 6 +- pydantic_ai_slim/pydantic_ai/ext/langchain.py | 4 +- pydantic_ai_slim/pydantic_ai/mcp.py | 154 ++++++++++++++++-- .../pydantic_ai/toolsets/abstract.py | 18 +- .../pydantic_ai/toolsets/combined.py | 10 +- .../pydantic_ai/toolsets/deferred.py | 13 +- .../pydantic_ai/toolsets/function.py | 14 +- .../pydantic_ai/toolsets/prefixed.py | 4 + .../pydantic_ai/toolsets/wrapper.py | 8 + tests/test_examples.py | 4 + tests/test_mcp.py | 2 +- tests/test_tools.py | 11 +- 17 files changed, 243 insertions(+), 32 deletions(-) diff --git a/docs/tools.md b/docs/tools.md index 4b40e7881..eddffb703 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -770,7 +770,7 @@ from pydantic_ai.ext.langchain import LangChainToolset toolkit = SlackToolkit() -toolset = LangChainToolset(toolkit.get_tools()) +toolset = LangChainToolset(toolkit.get_tools(), id='slack') agent = Agent('openai:gpt-4o', toolsets=[toolset]) # ... @@ -823,6 +823,7 @@ toolset = ACIToolset( 'OPEN_WEATHER_MAP__FORECAST', ], linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'), + id='open_weather_map', ) agent = Agent('openai:gpt-4o', toolsets=[toolset]) diff --git a/docs/toolsets.md b/docs/toolsets.md index 5caac22c0..add7009f9 100644 --- a/docs/toolsets.md +++ b/docs/toolsets.md @@ -84,7 +84,10 @@ def temperature_fahrenheit(city: str) -> float: return 69.8 -weather_toolset = FunctionToolset(tools=[temperature_celsius, temperature_fahrenheit]) +weather_toolset = FunctionToolset( + tools=[temperature_celsius, temperature_fahrenheit], + id='weather', # (1)! +) @weather_toolset.tool @@ -95,10 +98,10 @@ def conditions(ctx: RunContext, city: str) -> str: return "It's raining" -datetime_toolset = FunctionToolset() +datetime_toolset = FunctionToolset(id='datetime') datetime_toolset.add_function(lambda: datetime.now(), name='now') -test_model = TestModel() # (1)! +test_model = TestModel() # (2)! agent = Agent(test_model) result = agent.run_sync('What tools are available?', toolsets=[weather_toolset]) @@ -110,7 +113,8 @@ print([t.name for t in test_model.last_model_request_parameters.function_tools]) #> ['now'] ``` -1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. +1. `FunctionToolset` supports an optional `id` argument that can help to identify the toolset in error messages. A toolset also needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow. +2. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. _(This example is complete, it can be run "as is")_ @@ -609,7 +613,7 @@ from pydantic_ai.ext.langchain import LangChainToolset toolkit = SlackToolkit() -toolset = LangChainToolset(toolkit.get_tools()) +toolset = LangChainToolset(toolkit.get_tools(), id='slack') agent = Agent('openai:gpt-4o', toolsets=[toolset]) # ... @@ -634,6 +638,7 @@ toolset = ACIToolset( 'OPEN_WEATHER_MAP__FORECAST', ], linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'), + id='open_weather_map', ) agent = Agent('openai:gpt-4o', toolsets=[toolset]) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 67849b0c7..71e7251fd 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -977,6 +977,10 @@ def __init__( self.max_retries = max_retries self.output_validators = output_validators or [] + @property + def id(self) -> str | None: + return 'output' + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: return { tool_def.name: ToolsetTool( diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index a43d8bda4..13201abb0 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -273,7 +273,8 @@ async def run( parameters_json_schema=tool.parameters, ) for tool in run_input.tools - ] + ], + id='ag_ui_frontend', ) toolsets = [*toolsets, toolset] if toolsets else [toolset] diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 9f6348c6c..c3deb779c 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -420,7 +420,7 @@ def __init__( if self._output_toolset: self._output_toolset.max_retries = self._max_result_retries - self._function_toolset = FunctionToolset(tools, max_retries=retries) + self._function_toolset = FunctionToolset(tools, max_retries=retries, id='agent') self._user_toolsets = toolsets or () self.history_processors = history_processors or [] diff --git a/pydantic_ai_slim/pydantic_ai/ext/aci.py b/pydantic_ai_slim/pydantic_ai/ext/aci.py index 6cd43402a..ef686d134 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/aci.py +++ b/pydantic_ai_slim/pydantic_ai/ext/aci.py @@ -71,5 +71,7 @@ def implementation(*args: Any, **kwargs: Any) -> str: class ACIToolset(FunctionToolset): """A toolset that wraps ACI.dev tools.""" - def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str): - super().__init__([tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions]) + def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str, id: str | None = None): + super().__init__( + [tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions], id=id + ) diff --git a/pydantic_ai_slim/pydantic_ai/ext/langchain.py b/pydantic_ai_slim/pydantic_ai/ext/langchain.py index 3fb407938..3782c0b9d 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/langchain.py +++ b/pydantic_ai_slim/pydantic_ai/ext/langchain.py @@ -65,5 +65,5 @@ def proxy(*args: Any, **kwargs: Any) -> str: class LangChainToolset(FunctionToolset): """A toolset that wraps LangChain tools.""" - def __init__(self, tools: list[LangChainTool]): - super().__init__([tool_from_langchain(tool) for tool in tools]) + def __init__(self, tools: list[LangChainTool], id: str | None = None): + super().__init__([tool_from_langchain(tool) for tool in tools], id=id) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 2ca7950b3..efdf2eb40 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -61,10 +61,12 @@ class MCPServer(AbstractToolset[Any], ABC): timeout: float = 5 process_tool_call: ProcessToolCallback | None = None allow_sampling: bool = True - max_retries: int = 1 sampling_model: models.Model | None = None + max_retries: int = 1 # } end of "abstract fields" + _id: str | None = field(init=False, default=None) + _enter_lock: Lock = field(compare=False) _running_count: int _exit_stack: AsyncExitStack | None @@ -73,7 +75,29 @@ class MCPServer(AbstractToolset[Any], ABC): _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] _write_stream: MemoryObjectSendStream[SessionMessage] - def __post_init__(self): + def __init__( + self, + tool_prefix: str | None = None, + log_level: mcp_types.LoggingLevel | None = None, + log_handler: LoggingFnT | None = None, + timeout: float = 5, + process_tool_call: ProcessToolCallback | None = None, + allow_sampling: bool = True, + sampling_model: models.Model | None = None, + max_retries: int = 1, + id: str | None = None, + ): + self.tool_prefix = tool_prefix + self.log_level = log_level + self.log_handler = log_handler + self.timeout = timeout + self.process_tool_call = process_tool_call + self.allow_sampling = allow_sampling + self.sampling_model = sampling_model + self.max_retries = max_retries + + self._id = id or tool_prefix + self._enter_lock = Lock() self._running_count = 0 self._exit_stack = None @@ -93,7 +117,11 @@ async def client_streams( yield @property - def name(self) -> str: + def id(self) -> str | None: + return self._id + + @property + def label(self) -> str: return repr(self) @property @@ -294,7 +322,7 @@ def _map_tool_result_part( assert_never(part) -@dataclass +@dataclass(init=False) class MCPServerStdio(MCPServer): """Runs an MCP server in a subprocess and communicates with it over stdin/stdout. @@ -378,11 +406,61 @@ async def main(): allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" + sampling_model: models.Model | None = None + """The model to use for sampling.""" + max_retries: int = 1 """The maximum number of times to retry a tool call.""" - sampling_model: models.Model | None = None - """The model to use for sampling.""" + def __init__( + self, + command: str, + args: Sequence[str], + env: dict[str, str] | None = None, + cwd: str | Path | None = None, + id: str | None = None, + tool_prefix: str | None = None, + log_level: mcp_types.LoggingLevel | None = None, + log_handler: LoggingFnT | None = None, + timeout: float = 5, + process_tool_call: ProcessToolCallback | None = None, + allow_sampling: bool = True, + sampling_model: models.Model | None = None, + max_retries: int = 1, + ): + """Build a new MCP server. + + Args: + command: The command to run. + args: The arguments to pass to the command. + env: The environment variables to set in the subprocess. + cwd: The working directory to use when spawning the process. + id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. + tool_prefix: A prefix to add to all tools that are registered with the server. + log_level: The log level to set when connecting to the server, if any. + log_handler: A handler for logging messages from the server. + timeout: The timeout in seconds to wait for the client to initialize. + process_tool_call: Hook to customize tool calling and optionally pass extra metadata. + allow_sampling: Whether to allow MCP sampling through this client. + sampling_model: The model to use for sampling. + max_retries: The maximum number of times to retry a tool call. + """ + self.command = command + self.args = args + self.env = env + self.cwd = cwd + + super().__init__( + tool_prefix, + log_level, + log_handler, + timeout, + process_tool_call, + allow_sampling, + sampling_model, + max_retries, + id, + ) @asynccontextmanager async def client_streams( @@ -398,7 +476,10 @@ async def client_streams( yield read_stream, write_stream def __repr__(self) -> str: - return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})' + if self.id: + return f'{self.__class__.__name__} {self.id!r}' + else: + return f'{self.__class__.__name__}(command={self.command!r}, args={self.args!r})' @dataclass @@ -479,11 +560,61 @@ class _MCPServerHTTP(MCPServer): allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" + sampling_model: models.Model | None = None + """The model to use for sampling.""" + max_retries: int = 1 """The maximum number of times to retry a tool call.""" - sampling_model: models.Model | None = None - """The model to use for sampling.""" + def __init__( + self, + url: str, + headers: dict[str, Any] | None = None, + http_client: httpx.AsyncClient | None = None, + sse_read_timeout: float = 5 * 60, + id: str | None = None, + tool_prefix: str | None = None, + log_level: mcp_types.LoggingLevel | None = None, + log_handler: LoggingFnT | None = None, + timeout: float = 5, + process_tool_call: ProcessToolCallback | None = None, + allow_sampling: bool = True, + sampling_model: models.Model | None = None, + max_retries: int = 1, + ): + """Build a new MCP server. + + Args: + url: The URL of the endpoint on the MCP server. + headers: Optional HTTP headers to be sent with each request to the endpoint. + http_client: An `httpx.AsyncClient` to use with the endpoint. + sse_read_timeout: Maximum time in seconds to wait for new SSE messages before timing out. + id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. + tool_prefix: A prefix to add to all tools that are registered with the server. + log_level: The log level to set when connecting to the server, if any. + log_handler: A handler for logging messages from the server. + timeout: The timeout in seconds to wait for the client to initialize. + process_tool_call: Hook to customize tool calling and optionally pass extra metadata. + allow_sampling: Whether to allow MCP sampling through this client. + sampling_model: The model to use for sampling. + max_retries: The maximum number of times to retry a tool call. + """ + self.url = url + self.headers = headers + self.http_client = http_client + self.sse_read_timeout = sse_read_timeout + + super().__init__( + tool_prefix, + log_level, + log_handler, + timeout, + process_tool_call, + allow_sampling, + sampling_model, + max_retries, + id, + ) @property @abstractmethod @@ -546,7 +677,10 @@ def httpx_client_factory( yield read_stream, write_stream def __repr__(self) -> str: # pragma: no cover - return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})' + if self.id: + return f'{self.__class__.__name__} {self.id!r}' + else: + return f'{self.__class__.__name__}(url={self.url!r})' @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py index 455336418..d73119e58 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py @@ -70,9 +70,23 @@ class AbstractToolset(ABC, Generic[AgentDepsT]): """ @property - def name(self) -> str: + @abstractmethod + def id(self) -> str | None: + """An ID for the toolset that is unique among all toolsets registered with the same agent. + + If you're implementing a concrete implementation that users can instantiate more than once, you should let them optionally pass a custom ID to the constructor and return that here. + + A toolset needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow. + """ + raise NotImplementedError() + + @property + def label(self) -> str: """The name of the toolset for use in error messages.""" - return self.__class__.__name__.replace('Toolset', ' toolset') + label = self.__class__.__name__ + if self.id: + label += f' {self.id!r}' + return label @property def tool_name_conflict_hint(self) -> str: diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index 4b1511fae..750c54b8e 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -40,6 +40,14 @@ def __post_init__(self): self._entered_count = 0 self._exit_stack = None + @property + def id(self) -> str | None: + return None + + @property + def label(self) -> str: + return f'{self.__class__.__name__}({", ".join(toolset.label for toolset in self.toolsets)})' + async def __aenter__(self) -> Self: async with self._enter_lock: if self._entered_count == 0: @@ -64,7 +72,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[ for name, tool in tools.items(): if existing_tools := all_tools.get(name): raise UserError( - f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}' + f'{toolset.label} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.label}: {name!r}. {toolset.tool_name_conflict_hint}' ) all_tools[name] = _CombinedToolsetTool( diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py index 3ad2e976b..a67c3b0ad 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass, replace +from dataclasses import dataclass, field, replace from typing import Any from pydantic_core import SchemaValidator, core_schema @@ -12,7 +12,7 @@ TOOL_SCHEMA_VALIDATOR = SchemaValidator(schema=core_schema.any_schema()) -@dataclass +@dataclass(init=False) class DeferredToolset(AbstractToolset[AgentDepsT]): """A toolset that holds deferred tools whose results will be produced outside of the Pydantic AI agent run in which they were called. @@ -20,6 +20,15 @@ class DeferredToolset(AbstractToolset[AgentDepsT]): """ tool_defs: list[ToolDefinition] + _id: str | None = field(init=False, default=None) + + def __init__(self, tool_defs: list[ToolDefinition], id: str | None = None): + self._id = id + self.tool_defs = tool_defs + + @property + def id(self) -> str | None: + return self._id async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: return { diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py index 63f44a1f0..81c667a9e 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/function.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -35,14 +35,22 @@ class FunctionToolset(AbstractToolset[AgentDepsT]): max_retries: int = field(default=1) tools: dict[str, Tool[Any]] = field(default_factory=dict) + _id: str | None = field(init=False, default=None) - def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1): + def __init__( + self, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], + max_retries: int = 1, + id: str | None = None, + ): """Build a new function toolset. Args: tools: The tools to add to the toolset. max_retries: The maximum number of retries for each tool during a run. + id: An optional unique ID for the toolset. A toolset needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow. """ + self._id = id self.max_retries = max_retries self.tools = {} for tool in tools: @@ -51,6 +59,10 @@ def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, else: self.add_function(tool) + @property + def id(self) -> str | None: + return self._id + @overload def tool(self, func: ToolFuncEither[AgentDepsT, ToolParams], /) -> ToolFuncEither[AgentDepsT, ToolParams]: ... diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py index be70ed4f0..a430b0dab 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py @@ -17,6 +17,10 @@ class PrefixedToolset(WrapperToolset[AgentDepsT]): prefix: str + @property + def tool_name_conflict_hint(self) -> str: + return 'Change the `prefix` attribute to avoid name conflicts.' + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: return { new_name: replace( diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py index 8440f1c46..6d5a409a5 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -18,6 +18,14 @@ class WrapperToolset(AbstractToolset[AgentDepsT]): wrapped: AbstractToolset[AgentDepsT] + @property + def id(self) -> str | None: + return None + + @property + def label(self) -> str: + return f'{self.__class__.__name__}({self.wrapped.label})' + async def __aenter__(self) -> Self: await self.wrapped.__aenter__() return self diff --git a/tests/test_examples.py b/tests/test_examples.py index 4b6bc27bc..380c56d66 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -263,6 +263,10 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: class MockMCPServer(AbstractToolset[Any]): + @property + def id(self) -> str | None: + return None + async def __aenter__(self) -> MockMCPServer: return self diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 94528b40a..b5ad7e1ae 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -237,7 +237,7 @@ def get_none() -> None: # pragma: no cover with pytest.raises( UserError, match=re.escape( - "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'], tool_prefix=None) defines a tool whose name conflicts with existing tool from Function toolset: 'get_none'. Consider setting `tool_prefix` to avoid name conflicts." + "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server']) defines a tool whose name conflicts with existing tool from FunctionToolset 'agent': 'get_none'. Set the `tool_prefix` attribute to avoid name conflicts." ), ): await agent.run('Get me a conflict') diff --git a/tests/test_tools.py b/tests/test_tools.py index c72cd1e08..e5dac6d3f 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,4 +1,5 @@ import json +import re from dataclasses import dataclass, replace from typing import Annotated, Any, Callable, Literal, Union @@ -586,7 +587,9 @@ def test_tool_return_conflict(): # this raises an error with pytest.raises( UserError, - match="Function toolset defines a tool whose name conflicts with existing tool from Output toolset: 'ctx_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.", + match=re.escape( + "FunctionToolset 'agent' defines a tool whose name conflicts with existing tool from OutputToolset 'output': 'ctx_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts." + ), ): Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')).run_sync( '', deps=0 @@ -596,7 +599,9 @@ def test_tool_return_conflict(): def test_tool_name_conflict_hint(): with pytest.raises( UserError, - match="Prefixed toolset defines a tool whose name conflicts with existing tool from Function toolset: 'foo_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.", + match=re.escape( + "PrefixedToolset(FunctionToolset 'tool') defines a tool whose name conflicts with existing tool from FunctionToolset 'agent': 'foo_tool'. Change the `prefix` attribute to avoid name conflicts." + ), ): def tool(x: int) -> int: @@ -605,7 +610,7 @@ def tool(x: int) -> int: def foo_tool(x: str) -> str: return x + 'foo' # pragma: no cover - function_toolset = FunctionToolset([tool]) + function_toolset = FunctionToolset([tool], id='tool') prefixed_toolset = PrefixedToolset(function_toolset, 'foo') Agent('test', tools=[foo_tool], toolsets=[prefixed_toolset]).run_sync('') From d02361a3a453dc784f61ae5680b22b006d828590 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 21 Jul 2025 22:05:36 +0000 Subject: [PATCH 02/41] WIP: temporalize_agent --- pydantic_ai_slim/pydantic_ai/agent.py | 8 ++ pydantic_ai_slim/pydantic_ai/mcp.py | 2 +- .../pydantic_ai/temporal/__init__.py | 104 +++++++++++++++ .../pydantic_ai/temporal/agent.py | 66 ++++++++++ .../pydantic_ai/temporal/function_toolset.py | 100 +++++++++++++++ .../pydantic_ai/temporal/mcp_server.py | 121 ++++++++++++++++++ .../pydantic_ai/temporal/model.py | 116 +++++++++++++++++ pydantic_ai_slim/pyproject.toml | 2 + pyproject.toml | 2 +- temporal.py | 99 ++++++++++++++ uv.lock | 38 +++++- 11 files changed, 653 insertions(+), 5 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/__init__.py create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/agent.py create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/model.py create mode 100644 temporal.py diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index c3deb779c..ad3012d60 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1698,6 +1698,14 @@ def _get_toolset( return CombinedToolset(all_toolsets) + @property + def toolset(self) -> AbstractToolset[AgentDepsT]: + """The complete toolset that will be available to the model during an agent run. + + This will include function tools registered directly to the agent, output tools, and user-provided toolsets including MCP servers. + """ + return self._get_toolset() + def _infer_name(self, function_frame: FrameType | None) -> None: """Infer the agent name from the call frame. diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index efdf2eb40..2bbcb7e75 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -126,7 +126,7 @@ def label(self) -> str: @property def tool_name_conflict_hint(self) -> str: - return 'Consider setting `tool_prefix` to avoid name conflicts.' + return 'Set the `tool_prefix` attribute to avoid name conflicts.' async def list_tools(self) -> list[mcp_types.Tool]: """Retrieve tools that are currently active on the server. diff --git a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py new file mode 100644 index 000000000..21f0ca90d --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Callable + +from temporalio.common import Priority, RetryPolicy +from temporalio.workflow import ActivityCancellationType, VersioningIntent + +from pydantic_ai._run_context import AgentDepsT, RunContext + + +class _TemporalRunContext(RunContext[AgentDepsT]): + _data: dict[str, Any] + + def __init__(self, **kwargs: Any): + self._data = kwargs + setattr( + self, + '__dataclass_fields__', + {name: field for name, field in RunContext.__dataclass_fields__.items() if name in kwargs}, + ) + + def __getattribute__(self, name: str) -> Any: + try: + return super().__getattribute__(name) + except AttributeError as e: + data = super().__getattribute__('_data') + if name in data: + return data[name] + raise e # TODO: Explain how to make a new run context attribute available + + @classmethod + def serialize_run_context(cls, ctx: RunContext[AgentDepsT]) -> dict[str, Any]: + return { + 'deps': ctx.deps, + 'retries': ctx.retries, + 'tool_call_id': ctx.tool_call_id, + 'tool_name': ctx.tool_name, + 'retry': ctx.retry, + 'run_step': ctx.run_step, + } + + @classmethod + def deserialize_run_context(cls, ctx: dict[str, Any]) -> RunContext[AgentDepsT]: + return cls(**ctx) + + +@dataclass +class TemporalSettings: + """Settings for Temporal `execute_activity` and Pydantic AI-specific Temporal activity behavior.""" + + # Temporal settings + task_queue: str | None = None + schedule_to_close_timeout: timedelta | None = None + schedule_to_start_timeout: timedelta | None = None + start_to_close_timeout: timedelta | None = None + heartbeat_timeout: timedelta | None = None + retry_policy: RetryPolicy | None = None + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL + activity_id: str | None = None + versioning_intent: VersioningIntent | None = None + summary: str | None = None + priority: Priority = Priority.default + + # Pydantic AI specific + tool_settings: dict[str, dict[str, TemporalSettings]] | None = None + + def for_tool(self, toolset_id: str, tool_id: str) -> TemporalSettings: + if self.tool_settings is None: + return self + return self.tool_settings.get(toolset_id, {}).get(tool_id, self) + + serialize_run_context: Callable[[RunContext], Any] = _TemporalRunContext.serialize_run_context + deserialize_run_context: Callable[[dict[str, Any]], RunContext] = _TemporalRunContext.deserialize_run_context + + @property + def execute_activity_kwargs(self) -> dict[str, Any]: + return { + 'task_queue': self.task_queue, + 'schedule_to_close_timeout': self.schedule_to_close_timeout, + 'schedule_to_start_timeout': self.schedule_to_start_timeout, + 'start_to_close_timeout': self.start_to_close_timeout, + 'heartbeat_timeout': self.heartbeat_timeout, + 'retry_policy': self.retry_policy, + 'cancellation_type': self.cancellation_type, + 'activity_id': self.activity_id, + 'versioning_intent': self.versioning_intent, + 'summary': self.summary, + 'priority': self.priority, + } + + +def initialize_temporal(): + """Explicitly import types without which Temporal will not be able to serialize/deserialize `ModelMessage`s.""" + from pydantic_ai.messages import ( # noqa F401 + ModelResponse, # pyright: ignore[reportUnusedImport] + ImageUrl, # pyright: ignore[reportUnusedImport] + AudioUrl, # pyright: ignore[reportUnusedImport] + DocumentUrl, # pyright: ignore[reportUnusedImport] + VideoUrl, # pyright: ignore[reportUnusedImport] + BinaryContent, # pyright: ignore[reportUnusedImport] + UserContent, # pyright: ignore[reportUnusedImport] + ) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/agent.py b/pydantic_ai_slim/pydantic_ai/temporal/agent.py new file mode 100644 index 000000000..b8a0fc538 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/agent.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import Any, Callable + +from pydantic_ai.agent import Agent +from pydantic_ai.mcp import MCPServer +from pydantic_ai.toolsets.abstract import AbstractToolset +from pydantic_ai.toolsets.function import FunctionToolset + +from ..models import Model +from . import TemporalSettings +from .function_toolset import temporalize_function_toolset +from .mcp_server import temporalize_mcp_server +from .model import temporalize_model + + +def temporalize_toolset(toolset: AbstractToolset, settings: TemporalSettings | None) -> list[Callable[..., Any]]: + """Temporalize a toolset. + + Args: + toolset: The toolset to temporalize. + settings: The temporal settings to use. + """ + if isinstance(toolset, FunctionToolset): + return temporalize_function_toolset(toolset, settings) + elif isinstance(toolset, MCPServer): + return temporalize_mcp_server(toolset, settings) + else: + return [] + + +def temporalize_agent( + agent: Agent, + settings: TemporalSettings | None = None, + temporalize_toolset_func: Callable[ + [AbstractToolset, TemporalSettings | None], list[Callable[..., Any]] + ] = temporalize_toolset, +) -> list[Callable[..., Any]]: + """Temporalize an agent. + + Args: + agent: The agent to temporalize. + settings: The temporal settings to use. + temporalize_toolset_func: The function to use to temporalize the toolsets. + """ + if existing_activities := getattr(agent, '__temporal_activities', None): + return existing_activities + + settings = settings or TemporalSettings() + + # TODO: Doesn't consider model/toolsets passed at iter time. + + activities: list[Callable[..., Any]] = [] + if isinstance(agent.model, Model): + activities.extend(temporalize_model(agent.model, settings)) + + def temporalize_toolset(toolset: AbstractToolset) -> None: + activities.extend(temporalize_toolset_func(toolset, settings)) + + agent.toolset.apply(temporalize_toolset) + + setattr(agent, '__temporal_activities', activities) + return activities + + +# TODO: untemporalize_agent diff --git a/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py b/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py new file mode 100644 index 000000000..842e86aaf --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +from pydantic import ConfigDict, with_config +from temporalio import activity, workflow + +from pydantic_ai.toolsets.function import FunctionToolset + +from .._run_context import RunContext +from ..toolsets import ToolsetTool +from . import TemporalSettings + + +@dataclass +@with_config(ConfigDict(arbitrary_types_allowed=True)) +class _CallToolParams: + name: str + tool_args: dict[str, Any] + serialized_run_context: Any + + +def temporalize_function_toolset( + toolset: FunctionToolset, + settings: TemporalSettings | None = None, +) -> list[Callable[..., Any]]: + """Temporalize a function toolset. + + Args: + toolset: The function toolset to temporalize. + settings: The temporal settings to use. + """ + if activities := getattr(toolset, '__temporal_activities', None): + return activities + + id = toolset.id + if not id: + raise ValueError( + "A function toolset needs to have an ID in order to be used in a durable execution environment like Temporal. The ID will be used to identify the toolset's activities within the workflow." + ) + + settings = settings or TemporalSettings() + + original_call_tool = toolset.call_tool + + @activity.defn(name=f'function_toolset__{id}__call_tool') + async def call_tool_activity(params: _CallToolParams) -> Any: + name = params.name + ctx = settings.for_tool(id, name).deserialize_run_context(params.serialized_run_context) + tool = (await toolset.get_tools(ctx))[name] + return await original_call_tool(name, params.tool_args, ctx, tool) + + async def call_tool(name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: + tool_settings = settings.for_tool(id, name) + serialized_run_context = tool_settings.serialize_run_context(ctx) + return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=call_tool_activity, + arg=_CallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), + **tool_settings.execute_activity_kwargs, + ) + + toolset.call_tool = call_tool + + activities = [call_tool_activity] + setattr(toolset, '__temporal_activities', activities) + return activities + + +# class TemporalFunctionToolset(FunctionToolset[AgentDepsT]): +# def __init__( +# self, +# tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], +# max_retries: int = 1, +# temporal_settings: TemporalSettings | None = None, +# serialize_run_context: Callable[[RunContext[AgentDepsT]], Any] | None = None, +# deserialize_run_context: Callable[[Any], RunContext[AgentDepsT]] | None = None, +# ): +# super().__init__(tools, max_retries) +# self.temporal_settings = temporal_settings or TemporalSettings() +# self.serialize_run_context = serialize_run_context or TemporalRunContext[AgentDepsT].serialize_run_context +# self.deserialize_run_context = deserialize_run_context or TemporalRunContext[AgentDepsT].deserialize_run_context + +# @activity.defn(name='function_toolset_call_tool') +# async def call_tool_activity(params: FunctionCallToolParams) -> Any: +# ctx = self.deserialize_run_context(params.serialized_run_context) +# tool = (await self.get_tools(ctx))[params.name] +# return await FunctionToolset[AgentDepsT].call_tool(self, params.name, params.tool_args, ctx, tool) + +# self.call_tool_activity = call_tool_activity + +# async def call_tool( +# self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] +# ) -> Any: +# serialized_run_context = self.serialize_run_context(ctx) +# return await workflow.execute_activity( +# activity=self.call_tool_activity, +# arg=FunctionCallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), +# **self.temporal_settings.__dict__, +# ) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py b/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py new file mode 100644 index 000000000..8d3be7624 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +from mcp import types as mcp_types +from pydantic import ConfigDict, with_config +from temporalio import activity, workflow + +from pydantic_ai.mcp import MCPServer, ToolResult + +from . import TemporalSettings + + +@dataclass +@with_config(ConfigDict(arbitrary_types_allowed=True)) +class _CallToolParams: + name: str + tool_args: dict[str, Any] + metadata: dict[str, Any] | None = None + + +def temporalize_mcp_server( + server: MCPServer, + settings: TemporalSettings | None = None, +) -> list[Callable[..., Any]]: + """Temporalize an MCP server. + + Args: + server: The MCP server to temporalize. + settings: The temporal settings to use. + """ + if activities := getattr(server, '__temporal_activities', None): + return activities + + id = server.id + if not id: + raise ValueError( + "An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal. The ID will be used to identify the server's activities within the workflow." + ) + + settings = settings or TemporalSettings() + + original_list_tools = server.list_tools + original_direct_call_tool = server.direct_call_tool + + @activity.defn(name=f'mcp_server__{id}__list_tools') + async def list_tools_activity() -> list[mcp_types.Tool]: + return await original_list_tools() + + @activity.defn(name=f'mcp_server__{id}__call_tool') + async def call_tool_activity(params: _CallToolParams) -> ToolResult: + return await original_direct_call_tool(params.name, params.tool_args, params.metadata) + + async def list_tools() -> list[mcp_types.Tool]: + return await workflow.execute_activity( # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType] + activity=list_tools_activity, + **settings.execute_activity_kwargs, + ) + + async def direct_call_tool( + name: str, + args: dict[str, Any], + metadata: dict[str, Any] | None = None, + ) -> ToolResult: + return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=call_tool_activity, + arg=_CallToolParams(name=name, tool_args=args, metadata=metadata), + **settings.for_tool(id, name).execute_activity_kwargs, + ) + + server.list_tools = list_tools + server.direct_call_tool = direct_call_tool + + activities = [list_tools_activity, call_tool_activity] + setattr(server, '__temporal_activities', activities) + return activities + + +# class TemporalMCPServer(WrapperToolset[Any]): +# temporal_settings: TemporalSettings + +# @property +# def wrapped_server(self) -> MCPServer: +# assert isinstance(self.wrapped, MCPServer) +# return self.wrapped + +# def __init__(self, wrapped: MCPServer, temporal_settings: TemporalSettings | None = None): +# assert isinstance(self.wrapped, MCPServer) +# super().__init__(wrapped) +# self.temporal_settings = temporal_settings or TemporalSettings() + +# @activity.defn(name='mcp_server_list_tools') +# async def list_tools_activity() -> list[mcp_types.Tool]: +# return await self.wrapped_server.list_tools() + +# self.list_tools_activity = list_tools_activity + +# @activity.defn(name='mcp_server_call_tool') +# async def call_tool_activity(params: MCPCallToolParams) -> ToolResult: +# return await self.wrapped_server.direct_call_tool(params.name, params.tool_args, params.metadata) + +# self.call_tool_activity = call_tool_activity + +# async def list_tools(self) -> list[mcp_types.Tool]: +# return await workflow.execute_activity( +# activity=self.list_tools_activity, +# **self.temporal_settings.__dict__, +# ) + +# async def direct_call_tool( +# self, +# name: str, +# args: dict[str, Any], +# metadata: dict[str, Any] | None = None, +# ) -> ToolResult: +# return await workflow.execute_activity( +# activity=self.call_tool_activity, +# arg=MCPCallToolParams(name=name, tool_args=args, metadata=metadata), +# **self.temporal_settings.__dict__, +# ) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/model.py b/pydantic_ai_slim/pydantic_ai/temporal/model.py new file mode 100644 index 000000000..7f8a25e71 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/model.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any, Callable + +from pydantic import ConfigDict, with_config +from temporalio import activity, workflow + +from ..messages import ( + ModelMessage, + ModelResponse, +) +from ..models import Model, ModelRequestParameters, StreamedResponse +from ..settings import ModelSettings +from . import TemporalSettings + + +@dataclass +@with_config(ConfigDict(arbitrary_types_allowed=True)) +class _RequestParams: + messages: list[ModelMessage] + model_settings: ModelSettings | None + model_request_parameters: ModelRequestParameters + + +def temporalize_model(model: Model, settings: TemporalSettings | None = None) -> list[Callable[..., Any]]: + """Temporalize a model. + + Args: + model: The model to temporalize. + settings: The temporal settings to use. + """ + if activities := getattr(model, '__temporal_activities', None): + return activities + + settings = settings or TemporalSettings() + + original_request = model.request + + @activity.defn(name='model_request') + async def request_activity(params: _RequestParams) -> ModelResponse: + return await original_request(params.messages, params.model_settings, params.model_request_parameters) + + async def request( + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=request_activity, + arg=_RequestParams( + messages=messages, model_settings=model_settings, model_request_parameters=model_request_parameters + ), + **settings.execute_activity_kwargs, + ) + + @asynccontextmanager + async def request_stream( + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> AsyncIterator[StreamedResponse]: + raise NotImplementedError('Cannot stream with temporal yet') + yield + + model.request = request + model.request_stream = request_stream + + activities = [request_activity] + setattr(model, '__temporal_activities', activities) + return activities + + +# @dataclass +# class TemporalModel(WrapperModel): +# temporal_settings: TemporalSettings + +# def __init__( +# self, +# wrapped: Model | KnownModelName, +# temporal_settings: TemporalSettings | None = None, +# ) -> None: +# super().__init__(wrapped) +# self.temporal_settings = temporal_settings or TemporalSettings() + +# @activity.defn +# async def request_activity(params: ModelRequestParams) -> ModelResponse: +# return await self.wrapped.request(params.messages, params.model_settings, params.model_request_parameters) + +# self.request_activity = request_activity + +# async def request( +# self, +# messages: list[ModelMessage], +# model_settings: ModelSettings | None, +# model_request_parameters: ModelRequestParameters, +# ) -> ModelResponse: +# return await workflow.execute_activity( +# activity=self.request_activity, +# arg=ModelRequestParams( +# messages=messages, model_settings=model_settings, model_request_parameters=model_request_parameters +# ), +# **self.temporal_settings.__dict__, +# ) + +# @asynccontextmanager +# async def request_stream( +# self, +# messages: list[ModelMessage], +# model_settings: ModelSettings | None, +# model_request_parameters: ModelRequestParameters, +# ) -> AsyncIterator[StreamedResponse]: +# raise NotImplementedError('Cannot stream with temporal yet') +# yield diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 3dfc4a766..10d36d6b7 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -84,6 +84,8 @@ evals = ["pydantic-evals=={{ version }}"] a2a = ["fasta2a>=0.4.1"] # AG-UI ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"] +# Temporal +temporal = ["temporalio>=1.13.0"] [dependency-groups] dev = [ diff --git a/pyproject.toml b/pyproject.toml index 841f186ef..fdd67f183 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,ag-ui]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,ag-ui,temporal]=={{ version }}", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] diff --git a/temporal.py b/temporal.py new file mode 100644 index 000000000..340019b8c --- /dev/null +++ b/temporal.py @@ -0,0 +1,99 @@ +# /// script +# dependencies = [ +# "temporalio", +# "logfire", +# ] +# /// +import asyncio +import random +from datetime import timedelta + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.opentelemetry import TracingInterceptor +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig +from temporalio.worker import Worker + +with workflow.unsafe.imports_passed_through(): + from pydantic_ai import Agent + from pydantic_ai.mcp import MCPServerStdio + from pydantic_ai.models.openai import OpenAIModel + from pydantic_ai.temporal import ( + TemporalSettings, + initialize_temporal, + ) + from pydantic_ai.temporal.agent import temporalize_agent + from pydantic_ai.toolsets.function import FunctionToolset + + initialize_temporal() + + def get_uv_index(location: str) -> int: + return 3 + + toolset = FunctionToolset(tools=[get_uv_index], id='uv_index') + mcp_server = MCPServerStdio( + 'python', + ['-m', 'tests.mcp_server'], + timeout=20, + id='test', + ) + + model = OpenAIModel('gpt-4o') + my_agent = Agent(model=model, instructions='be helpful', toolsets=[toolset, mcp_server]) + + temporal_settings = TemporalSettings( + start_to_close_timeout=timedelta(seconds=60), + tool_settings={ + 'uv_index': { + 'get_uv_index': TemporalSettings(start_to_close_timeout=timedelta(seconds=110)), + }, + }, + ) + activities = temporalize_agent(my_agent, temporal_settings) + + +def init_runtime_with_telemetry() -> Runtime: + # import logfire + + # logfire.configure(send_to_logfire=True, service_version='0.0.1', console=False) + # logfire.instrument_pydantic_ai() + # logfire.instrument_httpx(capture_all=True) + + # Setup SDK metrics to OTel endpoint + return Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) + + +# Basic workflow that logs and invokes an activity +@workflow.defn +class MyAgentWorkflow: + @workflow.run + async def run(self, prompt: str) -> str: + return (await my_agent.run(prompt)).output + + +async def main(): + client = await Client.connect( + 'localhost:7233', + interceptors=[TracingInterceptor()], + data_converter=pydantic_data_converter, + runtime=init_runtime_with_telemetry(), + ) + + async with Worker( + client, + task_queue='my-agent-task-queue', + workflows=[MyAgentWorkflow], + activities=activities, + ): + output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] + MyAgentWorkflow.run, + 'what is 2 plus the UV Index in Mexico City? and what is the product name?', + id=f'my-agent-workflow-id-{random.random()}', + task_queue='my-agent-task-queue', + ) + print(output) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/uv.lock b/uv.lock index a47bd1445..a44ee9b56 100644 --- a/uv.lock +++ b/uv.lock @@ -3000,7 +3000,7 @@ wheels = [ name = "pydantic-ai" source = { editable = "." } dependencies = [ - { name = "pydantic-ai-slim", extra = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "temporal", "vertexai"] }, ] [package.optional-dependencies] @@ -3039,7 +3039,7 @@ requires-dist = [ { name = "fasta2a", marker = "extra == 'a2a'", specifier = ">=0.4.1" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "temporal", "vertexai"], editable = "pydantic_ai_slim" }, ] provides-extras = ["a2a", "examples", "logfire"] @@ -3163,6 +3163,9 @@ openai = [ tavily = [ { name = "tavily-python" }, ] +temporal = [ + { name = "temporalio" }, +] vertexai = [ { name = "google-auth" }, { name = "requests" }, @@ -3219,9 +3222,10 @@ requires-dist = [ { name = "rich", marker = "extra == 'cli'", specifier = ">=13" }, { name = "starlette", marker = "extra == 'ag-ui'", specifier = ">=0.45.3" }, { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, + { name = "temporalio", marker = "extra == 'temporal'", specifier = ">=1.13.0" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] -provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] +provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "temporal", "vertexai"] [package.metadata.requires-dev] dev = [ @@ -3931,6 +3935,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a5/cd/71088461d7720128c78802289b3b36298f42745e5f8c334b0ffc157b881e/tavily_python-0.5.1-py3-none-any.whl", hash = "sha256:169601f703c55cf338758dcacfa7102473b479a9271d65a3af6fc3668990f757", size = 43767, upload-time = "2025-02-07T00:22:04.99Z" }, ] +[[package]] +name = "temporalio" +version = "1.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, + { name = "python-dateutil", marker = "python_full_version < '3.11'" }, + { name = "types-protobuf" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3e/a3/a76477b523937f47a21941188c16b3c6b1eef6baadc7c8efeea497d909de/temporalio-1.13.0.tar.gz", hash = "sha256:5a979eee5433da6ab5d8a2bcde25a1e7d454e91920acb0bf7ca93d415750828b", size = 1558745, upload-time = "2025-06-20T19:57:26.944Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/f4/a5a74284c671bd50ce7353ad1dad7dab1a795f891458454049e95bc5378f/temporalio-1.13.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:7ee14cab581352e77171d1e4ce01a899231abfe75c5f7233e3e260f361a344cc", size = 12086961, upload-time = "2025-06-20T19:57:15.25Z" }, + { url = "https://files.pythonhosted.org/packages/1f/b7/5dc6e34f4e9a3da8b75cb3fe0d32edca1d9201d598c38d022501d38650a9/temporalio-1.13.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:575a0c57dbb089298b4775f3aca86ebaf8d58d5ba155e7fc5509877c25e6bb44", size = 11745239, upload-time = "2025-06-20T19:57:17.934Z" }, + { url = "https://files.pythonhosted.org/packages/04/30/4b9b15af87c181fd9364b61971faa0faa07d199320d7ff1712b5d51b5bbb/temporalio-1.13.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf099a27f22c0dbc22f3d86dba76d59be5da812ff044ba3fa183e3e14bd5e9a", size = 12119197, upload-time = "2025-06-20T19:57:20.509Z" }, + { url = "https://files.pythonhosted.org/packages/46/9f/a5b627d773974c654b6cd22ed3937e7e2471023af244ea417f0e917e617b/temporalio-1.13.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7e20c711f41c66877b9d54ab33c79a14ccaac9ed498a174274f6129110f4d84", size = 12413459, upload-time = "2025-06-20T19:57:22.816Z" }, + { url = "https://files.pythonhosted.org/packages/a3/73/efb6957212eb8c8dfff26c7c2c6ddf745aa5990a3b722cff17c8feaa66fc/temporalio-1.13.0-cp39-abi3-win_amd64.whl", hash = "sha256:9286cb84c1e078b2bcc6e8c6bd0be878d8ed395be991ac0d7cff555e3a82ac0b", size = 12440644, upload-time = "2025-06-20T19:57:25.175Z" }, +] + [[package]] name = "tenacity" version = "8.5.0" @@ -4121,6 +4144,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/63/2463d89481e811f007b0e1cd0a91e52e141b47f9de724d20db7b861dcfec/types_certifi-2021.10.8.3-py3-none-any.whl", hash = "sha256:b2d1e325e69f71f7c78e5943d410e650b4707bb0ef32e4ddf3da37f54176e88a", size = 2136, upload-time = "2022-06-09T15:19:03.127Z" }, ] +[[package]] +name = "types-protobuf" +version = "6.30.2.20250516" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ac/6c/5cf088aaa3927d1cc39910f60f220f5ff573ab1a6485b2836e8b26beb58c/types_protobuf-6.30.2.20250516.tar.gz", hash = "sha256:aecd1881770a9bb225ede66872ef7f0da4505edd0b193108edd9892e48d49a41", size = 62254, upload-time = "2025-05-16T03:06:50.794Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/66/06a9c161f5dd5deb4f5c016ba29106a8f1903eb9a1ba77d407dd6588fecb/types_protobuf-6.30.2.20250516-py3-none-any.whl", hash = "sha256:8c226d05b5e8b2623111765fa32d6e648bbc24832b4c2fddf0fa340ba5d5b722", size = 76480, upload-time = "2025-05-16T03:06:49.444Z" }, +] + [[package]] name = "types-requests" version = "2.31.0.6" From f65770936648fc365165e96411b3b51156e04431 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 22 Jul 2025 23:29:23 +0000 Subject: [PATCH 03/41] Add Agent event_stream_handler --- pydantic_ai_slim/pydantic_ai/agent.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index ad3012d60..a888204f2 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -5,12 +5,12 @@ import json import warnings from asyncio import Lock -from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator, Mapping, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar from copy import deepcopy from types import FrameType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeAlias, cast, final, overload from opentelemetry.trace import NoOpTracer, use_span from pydantic.json_schema import GenerateJsonSchema @@ -96,6 +96,14 @@ RunOutputDataT = TypeVar('RunOutputDataT') """Type variable for the result data of a run where `output_type` was customized on the run call.""" +EventStreamHandler: TypeAlias = Callable[ + [ + RunContext[AgentDepsT], + AsyncIterable[_messages.AgentStreamEvent | _messages.HandleResponseEvent], + ], + Awaitable[None], +] + @final @dataclasses.dataclass(init=False) @@ -168,6 +176,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) + _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) _enter_lock: Lock = dataclasses.field(repr=False) _entered_count: int = dataclasses.field(repr=False) @@ -197,6 +206,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> None: ... @overload @@ -257,6 +267,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> None: ... def __init__( @@ -283,6 +294,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, **_deprecated_kwargs: Any, ): """Create an agent. @@ -331,6 +343,7 @@ def __init__( history_processors: Optional list of callables to process the message history before sending it to the model. Each processor takes a list of messages and returns a modified list of messages. Processors can be sync or async and are applied in sequence. + event_stream_handler: TODO: Optional handler for events from the agent stream. """ if model is None or defer_model_check: self.model = model @@ -425,6 +438,8 @@ def __init__( self.history_processors = history_processors or [] + self._event_stream_handler = event_stream_handler + self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar( @@ -559,8 +574,12 @@ async def main(): usage=usage, toolsets=toolsets, ) as agent_run: - async for _ in agent_run: - pass + async for node in agent_run: + if self._event_stream_handler is not None and ( + self.is_model_request_node(node) or self.is_call_tools_node(node) + ): + async with node.stream(agent_run.ctx) as stream: + await self._event_stream_handler(_agent_graph.build_run_context(agent_run.ctx), stream) assert agent_run.result is not None, 'The graph run did not finish properly' return agent_run.result From 2f0489481e492e90d5d097469b32b9314145333e Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 22 Jul 2025 23:33:00 +0000 Subject: [PATCH 04/41] Pass run_context to Model.request_stream for Temporal --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 27 ++++++++++--------- .../pydantic_ai/models/__init__.py | 7 ++--- .../pydantic_ai/models/anthropic.py | 2 ++ .../pydantic_ai/models/bedrock.py | 2 ++ .../pydantic_ai/models/fallback.py | 6 +++-- .../pydantic_ai/models/function.py | 7 ++--- pydantic_ai_slim/pydantic_ai/models/gemini.py | 5 ++-- pydantic_ai_slim/pydantic_ai/models/google.py | 2 ++ pydantic_ai_slim/pydantic_ai/models/groq.py | 7 ++--- .../pydantic_ai/models/huggingface.py | 9 ++++--- .../pydantic_ai/models/instrumented.py | 4 ++- .../pydantic_ai/models/mcp_sampling.py | 4 ++- .../pydantic_ai/models/mistral.py | 5 ++-- pydantic_ai_slim/pydantic_ai/models/openai.py | 10 ++++--- pydantic_ai_slim/pydantic_ai/models/test.py | 2 ++ .../pydantic_ai/models/wrapper.py | 6 ++++- tests/models/test_instrumented.py | 2 ++ 17 files changed, 68 insertions(+), 39 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 312a8a2fc..0dea48dab 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -325,13 +325,9 @@ async def _stream( ) -> AsyncIterator[models.StreamedResponse]: assert not self._did_stream, 'stream() should only be called once per node' - model_settings, model_request_parameters = await self._prepare_request(ctx) - model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) - message_history = await _process_message_history( - ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) - ) + model_settings, model_request_parameters, message_history, run_context = await self._prepare_request(ctx) async with ctx.deps.model.request_stream( - message_history, model_settings, model_request_parameters + message_history, model_settings, model_request_parameters, run_context ) as streamed_response: self._did_stream = True ctx.state.usage.requests += 1 @@ -351,11 +347,7 @@ async def _make_request( if self._result is not None: return self._result # pragma: no cover - model_settings, model_request_parameters = await self._prepare_request(ctx) - model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) - message_history = await _process_message_history( - ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) - ) + model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx) model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.incr(_usage.Usage()) @@ -363,7 +355,7 @@ async def _make_request( async def _prepare_request( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> tuple[ModelSettings | None, models.ModelRequestParameters]: + ) -> tuple[ModelSettings | None, models.ModelRequestParameters, list[_messages.ModelMessage], RunContext[DepsT]]: ctx.state.message_history.append(self.request) # Check usage @@ -373,9 +365,18 @@ async def _prepare_request( # Increment run_step ctx.state.run_step += 1 + run_context = build_run_context(ctx) + model_settings = merge_model_settings(ctx.deps.model_settings, None) + model_request_parameters = await _prepare_request_parameters(ctx) - return model_settings, model_request_parameters + model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) + + message_history = await _process_message_history( + ctx.state.message_history, ctx.deps.history_processors, run_context + ) + + return model_settings, model_request_parameters, message_history, run_context def _finish_handling( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 6193d1d41..f998d9007 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -13,19 +13,19 @@ from dataclasses import dataclass, field, replace from datetime import datetime from functools import cache, cached_property -from typing import Generic, TypeVar, overload +from typing import Any, Generic, TypeVar, overload import httpx from typing_extensions import Literal, TypeAliasType, TypedDict -from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec - from .. import _utils from .._output import OutputObjectDefinition from .._parts_manager import ModelResponsePartsManager +from .._run_context import RunContext from ..exceptions import UserError from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl from ..output import OutputMode +from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec from ..profiles._json_schema import JsonSchemaTransformer from ..settings import ModelSettings from ..tools import ToolDefinition @@ -379,6 +379,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: """Make a request to the model and return a streaming response.""" # This method is not required, but you need to implement it if you want to support streamed responses diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index a62741568..3006cfde6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -11,6 +11,7 @@ from typing_extensions import assert_never from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._run_context import RunContext from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( BinaryContent, @@ -171,6 +172,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._messages_create( diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index f16f9d111..3fde8d252 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -15,6 +15,7 @@ from typing_extensions import ParamSpec, assert_never from pydantic_ai import _utils, usage +from pydantic_ai._run_context import RunContext from pydantic_ai.messages import ( AudioUrl, BinaryContent, @@ -264,6 +265,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, True, settings, model_request_parameters) diff --git a/pydantic_ai_slim/pydantic_ai/models/fallback.py b/pydantic_ai_slim/pydantic_ai/models/fallback.py index 4455defce..498d8e1bd 100644 --- a/pydantic_ai_slim/pydantic_ai/models/fallback.py +++ b/pydantic_ai_slim/pydantic_ai/models/fallback.py @@ -3,10 +3,11 @@ from collections.abc import AsyncIterator from contextlib import AsyncExitStack, asynccontextmanager, suppress from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable from opentelemetry.trace import get_current_span +from pydantic_ai._run_context import RunContext from pydantic_ai.models.instrumented import InstrumentedModel from ..exceptions import FallbackExceptionGroup, ModelHTTPError @@ -83,6 +84,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: """Try each model in sequence until one succeeds.""" exceptions: list[Exception] = [] @@ -92,7 +94,7 @@ async def request_stream( async with AsyncExitStack() as stack: try: response = await stack.enter_async_context( - model.request_stream(messages, model_settings, customized_model_request_parameters) + model.request_stream(messages, model_settings, customized_model_request_parameters, run_context) ) except Exception as exc: if self._fallback_on(exc): diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index e8476b554..0fcf9e819 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -7,13 +7,12 @@ from dataclasses import dataclass, field from datetime import datetime from itertools import chain -from typing import Callable, Union +from typing import Any, Callable, Union from typing_extensions import TypeAlias, assert_never, overload -from pydantic_ai.profiles import ModelProfileSpec - from .. import _utils, usage +from .._run_context import RunContext from .._utils import PeekableAsyncStream from ..messages import ( AudioUrl, @@ -32,6 +31,7 @@ UserContent, UserPromptPart, ) +from ..profiles import ModelProfileSpec from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse @@ -147,6 +147,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: agent_info = AgentInfo( model_request_parameters.function_tools, diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 7c371a943..feb6cf10e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -13,10 +13,9 @@ from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse from typing_extensions import NotRequired, TypedDict, assert_never -from pydantic_ai.providers import Provider, infer_provider - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._output import OutputObjectDefinition +from .._run_context import RunContext from ..exceptions import UserError from ..messages import ( BinaryContent, @@ -36,6 +35,7 @@ VideoUrl, ) from ..profiles import ModelProfileSpec +from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -164,6 +164,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() async with self._make_request( diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 9ec1260d4..7eac84113 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -12,6 +12,7 @@ from .. import UnexpectedModelBehavior, _utils, usage from .._output import OutputObjectDefinition +from .._run_context import RunContext from ..exceptions import UserError from ..messages import ( BinaryContent, @@ -187,6 +188,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() model_settings = cast(GoogleModelSettings, model_settings or {}) diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 92376b44d..23a1fe18c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -5,13 +5,13 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime -from typing import Literal, Union, cast, overload +from typing import Any, Literal, Union, cast, overload from typing_extensions import assert_never -from pydantic_ai._thinking_part import split_content_into_text_and_thinking - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._run_context import RunContext +from .._thinking_part import split_content_into_text_and_thinking from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime from ..messages import ( BinaryContent, @@ -166,6 +166,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._completions_create( diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 41d53ca62..fa6b2b57d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -5,14 +5,13 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Literal, Union, cast, overload +from typing import Any, Literal, Union, cast, overload from typing_extensions import assert_never -from pydantic_ai._thinking_part import split_content_into_text_and_thinking -from pydantic_ai.providers import Provider, infer_provider - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._run_context import RunContext +from .._thinking_part import split_content_into_text_and_thinking from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc from ..messages import ( AudioUrl, @@ -33,6 +32,7 @@ UserPromptPart, VideoUrl, ) +from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests @@ -146,6 +146,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._completions_create( diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index 233020f6f..49a97bd49 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -18,6 +18,7 @@ from opentelemetry.util.types import AttributeValue from pydantic import TypeAdapter +from .._run_context import RunContext from ..messages import ( ModelMessage, ModelRequest, @@ -222,12 +223,13 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: with self._instrument(messages, model_settings, model_request_parameters) as finish: response_stream: StreamedResponse | None = None try: async with super().request_stream( - messages, model_settings, model_request_parameters + messages, model_settings, model_request_parameters, run_context ) as response_stream: yield response_stream finally: diff --git a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py index ebfaac92d..a4f649786 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py +++ b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py @@ -3,9 +3,10 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from .. import _mcp, exceptions, usage +from .._run_context import RunContext from ..messages import ModelMessage, ModelResponse from ..settings import ModelSettings from . import Model, ModelRequestParameters, StreamedResponse @@ -76,6 +77,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: raise NotImplementedError('MCP Sampling does not support streaming') yield diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 46b627826..b475e3363 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -11,9 +11,9 @@ from httpx import Timeout from typing_extensions import assert_never -from pydantic_ai._thinking_part import split_content_into_text_and_thinking - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils +from .._run_context import RunContext +from .._thinking_part import split_content_into_text_and_thinking from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime from ..messages import ( BinaryContent, @@ -172,6 +172,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: """Make a streaming request to the model from Pydantic AI call.""" check_allow_model_requests() diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 61d280fb2..0df5055e2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -11,12 +11,10 @@ from pydantic import ValidationError from typing_extensions import assert_never -from pydantic_ai._thinking_part import split_content_into_text_and_thinking -from pydantic_ai.profiles.openai import OpenAIModelProfile -from pydantic_ai.providers import Provider, infer_provider - from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition +from .._run_context import RunContext +from .._thinking_part import split_content_into_text_and_thinking from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime from ..messages import ( AudioUrl, @@ -38,6 +36,8 @@ VideoUrl, ) from ..profiles import ModelProfileSpec +from ..profiles.openai import OpenAIModelProfile +from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -244,6 +244,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._completions_create( @@ -659,6 +660,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() response = await self._responses_create( diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index a80d551ff..d1b015078 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -12,6 +12,7 @@ from typing_extensions import assert_never from .. import _utils +from .._run_context import RunContext from ..messages import ( ModelMessage, ModelRequest, @@ -118,6 +119,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: self.last_model_request_parameters = model_request_parameters diff --git a/pydantic_ai_slim/pydantic_ai/models/wrapper.py b/pydantic_ai_slim/pydantic_ai/models/wrapper.py index cc91f9c72..9818ad603 100644 --- a/pydantic_ai_slim/pydantic_ai/models/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/models/wrapper.py @@ -6,6 +6,7 @@ from functools import cached_property from typing import Any +from .._run_context import RunContext from ..messages import ModelMessage, ModelResponse from ..profiles import ModelProfile from ..settings import ModelSettings @@ -35,8 +36,11 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: - async with self.wrapped.request_stream(messages, model_settings, model_request_parameters) as response_stream: + async with self.wrapped.request_stream( + messages, model_settings, model_request_parameters, run_context + ) as response_stream: yield response_stream def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters: diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index b952bf716..831f6f339 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -11,6 +11,7 @@ from opentelemetry._events import NoOpEventLoggerProvider from opentelemetry.trace import NoOpTracerProvider +from pydantic_ai._run_context import RunContext from pydantic_ai.messages import ( AudioUrl, BinaryContent, @@ -89,6 +90,7 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext | None = None, ) -> AsyncIterator[StreamedResponse]: yield MyResponseStream() From 5f6cfa77c8652a4046645e04023f39a3c5a87580 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 22 Jul 2025 23:36:54 +0000 Subject: [PATCH 05/41] Streaming with Temporal --- .../pydantic_ai/temporal/__init__.py | 16 +- .../pydantic_ai/temporal/agent.py | 4 +- .../pydantic_ai/temporal/function_toolset.py | 35 +--- .../pydantic_ai/temporal/mcp_server.py | 48 +---- .../pydantic_ai/temporal/model.py | 171 +++++++++++++----- temporal.py | 46 +++-- 6 files changed, 169 insertions(+), 151 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py index 21f0ca90d..d6a060860 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py @@ -11,10 +11,8 @@ class _TemporalRunContext(RunContext[AgentDepsT]): - _data: dict[str, Any] - def __init__(self, **kwargs: Any): - self._data = kwargs + self.__dict__ = kwargs setattr( self, '__dataclass_fields__', @@ -25,10 +23,12 @@ def __getattribute__(self, name: str) -> Any: try: return super().__getattribute__(name) except AttributeError as e: - data = super().__getattribute__('_data') - if name in data: - return data[name] - raise e # TODO: Explain how to make a new run context attribute available + if name in RunContext.__dataclass_fields__: + raise AttributeError( + f'Temporalized {RunContext.__name__!r} object has no attribute {name!r}. To make the attribute available, pass a `TemporalSettings` object to `temporalize_agent` that has a custom `serialize_run_context` function that returns a dictionary that includes the attribute.' + ) + else: + raise e @classmethod def serialize_run_context(cls, ctx: RunContext[AgentDepsT]) -> dict[str, Any]: @@ -75,7 +75,7 @@ def for_tool(self, toolset_id: str, tool_id: str) -> TemporalSettings: deserialize_run_context: Callable[[dict[str, Any]], RunContext] = _TemporalRunContext.deserialize_run_context @property - def execute_activity_kwargs(self) -> dict[str, Any]: + def execute_activity_options(self) -> dict[str, Any]: return { 'task_queue': self.task_queue, 'schedule_to_close_timeout': self.schedule_to_close_timeout, diff --git a/pydantic_ai_slim/pydantic_ai/temporal/agent.py b/pydantic_ai_slim/pydantic_ai/temporal/agent.py index b8a0fc538..5f59d51fc 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/agent.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/agent.py @@ -30,7 +30,7 @@ def temporalize_toolset(toolset: AbstractToolset, settings: TemporalSettings | N def temporalize_agent( - agent: Agent, + agent: Agent[Any, Any], settings: TemporalSettings | None = None, temporalize_toolset_func: Callable[ [AbstractToolset, TemporalSettings | None], list[Callable[..., Any]] @@ -52,7 +52,7 @@ def temporalize_agent( activities: list[Callable[..., Any]] = [] if isinstance(agent.model, Model): - activities.extend(temporalize_model(agent.model, settings)) + activities.extend(temporalize_model(agent.model, settings, agent._event_stream_handler)) # pyright: ignore[reportPrivateUsage] def temporalize_toolset(toolset: AbstractToolset) -> None: activities.extend(temporalize_toolset_func(toolset, settings)) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py b/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py index 842e86aaf..8b34a8fb5 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py @@ -57,7 +57,7 @@ async def call_tool(name: str, tool_args: dict[str, Any], ctx: RunContext, tool: return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=call_tool_activity, arg=_CallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), - **tool_settings.execute_activity_kwargs, + **tool_settings.execute_activity_options, ) toolset.call_tool = call_tool @@ -65,36 +65,3 @@ async def call_tool(name: str, tool_args: dict[str, Any], ctx: RunContext, tool: activities = [call_tool_activity] setattr(toolset, '__temporal_activities', activities) return activities - - -# class TemporalFunctionToolset(FunctionToolset[AgentDepsT]): -# def __init__( -# self, -# tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], -# max_retries: int = 1, -# temporal_settings: TemporalSettings | None = None, -# serialize_run_context: Callable[[RunContext[AgentDepsT]], Any] | None = None, -# deserialize_run_context: Callable[[Any], RunContext[AgentDepsT]] | None = None, -# ): -# super().__init__(tools, max_retries) -# self.temporal_settings = temporal_settings or TemporalSettings() -# self.serialize_run_context = serialize_run_context or TemporalRunContext[AgentDepsT].serialize_run_context -# self.deserialize_run_context = deserialize_run_context or TemporalRunContext[AgentDepsT].deserialize_run_context - -# @activity.defn(name='function_toolset_call_tool') -# async def call_tool_activity(params: FunctionCallToolParams) -> Any: -# ctx = self.deserialize_run_context(params.serialized_run_context) -# tool = (await self.get_tools(ctx))[params.name] -# return await FunctionToolset[AgentDepsT].call_tool(self, params.name, params.tool_args, ctx, tool) - -# self.call_tool_activity = call_tool_activity - -# async def call_tool( -# self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] -# ) -> Any: -# serialized_run_context = self.serialize_run_context(ctx) -# return await workflow.execute_activity( -# activity=self.call_tool_activity, -# arg=FunctionCallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), -# **self.temporal_settings.__dict__, -# ) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py b/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py index 8d3be7624..6a7248468 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py @@ -55,7 +55,7 @@ async def call_tool_activity(params: _CallToolParams) -> ToolResult: async def list_tools() -> list[mcp_types.Tool]: return await workflow.execute_activity( # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType] activity=list_tools_activity, - **settings.execute_activity_kwargs, + **settings.execute_activity_options, ) async def direct_call_tool( @@ -66,7 +66,7 @@ async def direct_call_tool( return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=call_tool_activity, arg=_CallToolParams(name=name, tool_args=args, metadata=metadata), - **settings.for_tool(id, name).execute_activity_kwargs, + **settings.for_tool(id, name).execute_activity_options, ) server.list_tools = list_tools @@ -75,47 +75,3 @@ async def direct_call_tool( activities = [list_tools_activity, call_tool_activity] setattr(server, '__temporal_activities', activities) return activities - - -# class TemporalMCPServer(WrapperToolset[Any]): -# temporal_settings: TemporalSettings - -# @property -# def wrapped_server(self) -> MCPServer: -# assert isinstance(self.wrapped, MCPServer) -# return self.wrapped - -# def __init__(self, wrapped: MCPServer, temporal_settings: TemporalSettings | None = None): -# assert isinstance(self.wrapped, MCPServer) -# super().__init__(wrapped) -# self.temporal_settings = temporal_settings or TemporalSettings() - -# @activity.defn(name='mcp_server_list_tools') -# async def list_tools_activity() -> list[mcp_types.Tool]: -# return await self.wrapped_server.list_tools() - -# self.list_tools_activity = list_tools_activity - -# @activity.defn(name='mcp_server_call_tool') -# async def call_tool_activity(params: MCPCallToolParams) -> ToolResult: -# return await self.wrapped_server.direct_call_tool(params.name, params.tool_args, params.metadata) - -# self.call_tool_activity = call_tool_activity - -# async def list_tools(self) -> list[mcp_types.Tool]: -# return await workflow.execute_activity( -# activity=self.list_tools_activity, -# **self.temporal_settings.__dict__, -# ) - -# async def direct_call_tool( -# self, -# name: str, -# args: dict[str, Any], -# metadata: dict[str, Any] | None = None, -# ) -> ToolResult: -# return await workflow.execute_activity( -# activity=self.call_tool_activity, -# arg=MCPCallToolParams(name=name, tool_args=args, metadata=metadata), -# **self.temporal_settings.__dict__, -# ) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/model.py b/pydantic_ai_slim/pydantic_ai/temporal/model.py index 7f8a25e71..32b4967c0 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/model.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/model.py @@ -3,17 +3,27 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass +from datetime import datetime from typing import Any, Callable from pydantic import ConfigDict, with_config from temporalio import activity, workflow +from .._run_context import RunContext +from ..agent import EventStreamHandler +from ..exceptions import UserError from ..messages import ( + FinalResultEvent, ModelMessage, ModelResponse, + ModelResponseStreamEvent, + PartStartEvent, + TextPart, + ToolCallPart, ) from ..models import Model, ModelRequestParameters, StreamedResponse from ..settings import ModelSettings +from ..usage import Usage from . import TemporalSettings @@ -23,14 +33,48 @@ class _RequestParams: messages: list[ModelMessage] model_settings: ModelSettings | None model_request_parameters: ModelRequestParameters + serialized_run_context: Any -def temporalize_model(model: Model, settings: TemporalSettings | None = None) -> list[Callable[..., Any]]: +class _TemporalStreamedResponse(StreamedResponse): + def __init__(self, response: ModelResponse): + self.response = response + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + return + # noinspection PyUnreachableCode + yield + + def get(self) -> ModelResponse: + """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far.""" + return self.response + + def usage(self) -> Usage: + """Get the usage of the response so far. This will not be the final usage until the stream is exhausted.""" + return self.response.usage + + @property + def model_name(self) -> str: + """Get the model name of the response.""" + return self.response.model_name or '' + + @property + def timestamp(self) -> datetime: + """Get the timestamp of the response.""" + return self.response.timestamp + + +def temporalize_model( # noqa: C901 + model: Model, + settings: TemporalSettings | None = None, + event_stream_handler: EventStreamHandler | None = None, +) -> list[Callable[..., Any]]: """Temporalize a model. Args: model: The model to temporalize. settings: The temporal settings to use. + event_stream_handler: The event stream handler to use. """ if activities := getattr(model, '__temporal_activities', None): return activities @@ -38,11 +82,64 @@ def temporalize_model(model: Model, settings: TemporalSettings | None = None) -> settings = settings or TemporalSettings() original_request = model.request + original_request_stream = model.request_stream @activity.defn(name='model_request') async def request_activity(params: _RequestParams) -> ModelResponse: return await original_request(params.messages, params.model_settings, params.model_request_parameters) + @activity.defn(name='model_request_stream') + async def request_stream_activity(params: _RequestParams) -> ModelResponse: + run_context = settings.deserialize_run_context(params.serialized_run_context) + async with original_request_stream( + params.messages, params.model_settings, params.model_request_parameters, run_context + ) as streamed_response: + tool_defs = { + tool_def.name: tool_def + for tool_def in [ + *params.model_request_parameters.output_tools, + *params.model_request_parameters.function_tools, + ] + } + + async def aiter(): + def _get_final_result_event(e: ModelResponseStreamEvent) -> FinalResultEvent | None: + """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" + if isinstance(e, PartStartEvent): + new_part = e.part + if ( + isinstance(new_part, TextPart) and params.model_request_parameters.allow_text_output + ): # pragma: no branch + return FinalResultEvent(tool_name=None, tool_call_id=None) + elif isinstance(new_part, ToolCallPart) and (tool_def := tool_defs.get(new_part.tool_name)): + if tool_def.kind == 'output': + return FinalResultEvent( + tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id + ) + elif tool_def.kind == 'deferred': + return FinalResultEvent(tool_name=None, tool_call_id=None) + + # TODO: usage_checking_stream = _get_usage_checking_stream_response( + # self._raw_stream_response, self._usage_limits, self.usage + # ) + async for event in streamed_response: + yield event + if (final_result_event := _get_final_result_event(event)) is not None: + yield final_result_event + break + + # If we broke out of the above loop, we need to yield the rest of the events + # If we didn't, this will just be a no-op + async for event in streamed_response: + yield event + + assert event_stream_handler is not None + await event_stream_handler(run_context, aiter()) + + async for _ in streamed_response: + pass + return streamed_response.get() + async def request( messages: list[ModelMessage], model_settings: ModelSettings | None, @@ -51,9 +148,12 @@ async def request( return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=request_activity, arg=_RequestParams( - messages=messages, model_settings=model_settings, model_request_parameters=model_request_parameters + messages=messages, + model_settings=model_settings, + model_request_parameters=model_request_parameters, + serialized_run_context=None, ), - **settings.execute_activity_kwargs, + **settings.execute_activity_options, ) @asynccontextmanager @@ -61,56 +161,29 @@ async def request_stream( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: - raise NotImplementedError('Cannot stream with temporal yet') - yield + if event_stream_handler is None: + raise UserError('Streaming with Temporal requires `Agent` to have an `event_stream_handler`') + if run_context is None: + raise UserError('Streaming with Temporal requires `request_stream` to be called with a `run_context`') + + serialized_run_context = settings.serialize_run_context(run_context) + response = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=request_stream_activity, + arg=_RequestParams( + messages=messages, + model_settings=model_settings, + model_request_parameters=model_request_parameters, + serialized_run_context=serialized_run_context, + ), + **settings.execute_activity_options, + ) + yield _TemporalStreamedResponse(response) model.request = request model.request_stream = request_stream - activities = [request_activity] + activities = [request_activity, request_stream_activity] setattr(model, '__temporal_activities', activities) return activities - - -# @dataclass -# class TemporalModel(WrapperModel): -# temporal_settings: TemporalSettings - -# def __init__( -# self, -# wrapped: Model | KnownModelName, -# temporal_settings: TemporalSettings | None = None, -# ) -> None: -# super().__init__(wrapped) -# self.temporal_settings = temporal_settings or TemporalSettings() - -# @activity.defn -# async def request_activity(params: ModelRequestParams) -> ModelResponse: -# return await self.wrapped.request(params.messages, params.model_settings, params.model_request_parameters) - -# self.request_activity = request_activity - -# async def request( -# self, -# messages: list[ModelMessage], -# model_settings: ModelSettings | None, -# model_request_parameters: ModelRequestParameters, -# ) -> ModelResponse: -# return await workflow.execute_activity( -# activity=self.request_activity, -# arg=ModelRequestParams( -# messages=messages, model_settings=model_settings, model_request_parameters=model_request_parameters -# ), -# **self.temporal_settings.__dict__, -# ) - -# @asynccontextmanager -# async def request_stream( -# self, -# messages: list[ModelMessage], -# model_settings: ModelSettings | None, -# model_request_parameters: ModelRequestParameters, -# ) -> AsyncIterator[StreamedResponse]: -# raise NotImplementedError('Cannot stream with temporal yet') -# yield diff --git a/temporal.py b/temporal.py index 340019b8c..057b427be 100644 --- a/temporal.py +++ b/temporal.py @@ -6,6 +6,7 @@ # /// import asyncio import random +from collections.abc import AsyncIterable from datetime import timedelta from temporalio import workflow @@ -14,11 +15,13 @@ from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig from temporalio.worker import Worker +from typing_extensions import TypedDict with workflow.unsafe.imports_passed_through(): from pydantic_ai import Agent + from pydantic_ai._run_context import RunContext from pydantic_ai.mcp import MCPServerStdio - from pydantic_ai.models.openai import OpenAIModel + from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent from pydantic_ai.temporal import ( TemporalSettings, initialize_temporal, @@ -28,10 +31,13 @@ initialize_temporal() - def get_uv_index(location: str) -> int: - return 3 + class Deps(TypedDict): + country: str - toolset = FunctionToolset(tools=[get_uv_index], id='uv_index') + def get_country(ctx: RunContext[Deps]) -> str: + return ctx.deps['country'] + + toolset = FunctionToolset[Deps](tools=[get_country], id='country') mcp_server = MCPServerStdio( 'python', ['-m', 'tests.mcp_server'], @@ -39,14 +45,26 @@ def get_uv_index(location: str) -> int: id='test', ) - model = OpenAIModel('gpt-4o') - my_agent = Agent(model=model, instructions='be helpful', toolsets=[toolset, mcp_server]) + async def event_stream_handler( + ctx: RunContext[Deps], + stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], + ): + print(f'{ctx.run_step=}') + async for event in stream: + print(event) + + my_agent = Agent( + 'openai:gpt-4o', + toolsets=[toolset, mcp_server], + event_stream_handler=event_stream_handler, + deps_type=Deps, + ) temporal_settings = TemporalSettings( start_to_close_timeout=timedelta(seconds=60), - tool_settings={ - 'uv_index': { - 'get_uv_index': TemporalSettings(start_to_close_timeout=timedelta(seconds=110)), + tool_settings={ # TODO: Allow default temporal settings to be set for an entire toolset + 'country': { + 'get_country': TemporalSettings(start_to_close_timeout=timedelta(seconds=110)), }, }, ) @@ -68,8 +86,9 @@ def init_runtime_with_telemetry() -> Runtime: @workflow.defn class MyAgentWorkflow: @workflow.run - async def run(self, prompt: str) -> str: - return (await my_agent.run(prompt)).output + async def run(self, prompt: str, deps: Deps) -> str: + result = await my_agent.run(prompt, deps=deps) + return result.output async def main(): @@ -88,7 +107,10 @@ async def main(): ): output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] MyAgentWorkflow.run, - 'what is 2 plus the UV Index in Mexico City? and what is the product name?', + args=[ + 'what is the capital of the capital of the country? and what is the product name?', + Deps(country='Mexico'), + ], id=f'my-agent-workflow-id-{random.random()}', task_queue='my-agent-task-queue', ) From a1e96e6ff69c90df1471df68364e7d8bac30f7f4 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 24 Jul 2025 14:01:58 +0000 Subject: [PATCH 06/41] Fix google types issues by importing only google.genai.Client --- docs/models/google.md | 4 ++-- pydantic_ai_slim/pydantic_ai/models/google.py | 8 ++++---- pydantic_ai_slim/pydantic_ai/providers/google.py | 14 +++++++------- tests/conftest.py | 4 ++-- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/models/google.md b/docs/models/google.md index 2cc35d9c0..30816715a 100644 --- a/docs/models/google.md +++ b/docs/models/google.md @@ -104,14 +104,14 @@ You can supply a custom `GoogleProvider` instance using the `provider` argument This is useful if you're using a custom-compatible endpoint with the Google Generative Language API. ```python -from google import genai +from google.genai import Client from google.genai.types import HttpOptions from pydantic_ai import Agent from pydantic_ai.models.google import GoogleModel from pydantic_ai.providers.google import GoogleProvider -client = genai.Client( +client = Client( api_key='gemini-custom-api-key', http_options=HttpOptions(base_url='gemini-custom-base-url'), ) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 7eac84113..80ec6af85 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -45,7 +45,7 @@ ) try: - from google import genai + from google.genai import Client from google.genai.types import ( ContentDict, ContentUnionDict, @@ -131,10 +131,10 @@ class GoogleModel(Model): Apart from `__init__`, all methods are private or match those of the base class. """ - client: genai.Client = field(repr=False) + client: Client = field(repr=False) _model_name: GoogleModelName = field(repr=False) - _provider: Provider[genai.Client] = field(repr=False) + _provider: Provider[Client] = field(repr=False) _url: str | None = field(repr=False) _system: str = field(default='google', repr=False) @@ -142,7 +142,7 @@ def __init__( self, model_name: GoogleModelName, *, - provider: Literal['google-gla', 'google-vertex'] | Provider[genai.Client] = 'google-gla', + provider: Literal['google-gla', 'google-vertex'] | Provider[Client] = 'google-gla', profile: ModelProfileSpec | None = None, settings: ModelSettings | None = None, ): diff --git a/pydantic_ai_slim/pydantic_ai/providers/google.py b/pydantic_ai_slim/pydantic_ai/providers/google.py index fc876fcff..70eaa6d86 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/google.py +++ b/pydantic_ai_slim/pydantic_ai/providers/google.py @@ -10,8 +10,8 @@ from pydantic_ai.providers import Provider try: - from google import genai from google.auth.credentials import Credentials + from google.genai import Client except ImportError as _import_error: raise ImportError( 'Please install the `google-genai` package to use the Google provider, ' @@ -19,7 +19,7 @@ ) from _import_error -class GoogleProvider(Provider[genai.Client]): +class GoogleProvider(Provider[Client]): """Provider for Google.""" @property @@ -31,7 +31,7 @@ def base_url(self) -> str: return str(self._client._api_client._http_options.base_url) # type: ignore[reportPrivateUsage] @property - def client(self) -> genai.Client: + def client(self) -> Client: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: @@ -50,7 +50,7 @@ def __init__( ) -> None: ... @overload - def __init__(self, *, client: genai.Client) -> None: ... + def __init__(self, *, client: Client) -> None: ... @overload def __init__(self, *, vertexai: bool = False) -> None: ... @@ -62,7 +62,7 @@ def __init__( credentials: Credentials | None = None, project: str | None = None, location: VertexAILocation | Literal['global'] | None = None, - client: genai.Client | None = None, + client: Client | None = None, vertexai: bool | None = None, ) -> None: """Create a new Google provider. @@ -95,13 +95,13 @@ def __init__( 'Set the `GOOGLE_API_KEY` environment variable or pass it via `GoogleProvider(api_key=...)`' 'to use the Google Generative Language API.' ) - self._client = genai.Client( + self._client = Client( vertexai=vertexai, api_key=api_key, http_options={'headers': {'User-Agent': get_user_agent()}}, ) else: - self._client = genai.Client( + self._client = Client( vertexai=vertexai, project=project or os.environ.get('GOOGLE_CLOUD_PROJECT'), # From https://github.com/pydantic/pydantic-ai/pull/2031/files#r2169682149: diff --git a/tests/conftest.py b/tests/conftest.py index 3ae576c63..16c39380b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -383,7 +383,7 @@ async def vertex_provider(): pytest.skip('Requires properly configured local google vertex config to pass') try: - from google import genai + from google.genai import Client from pydantic_ai.providers.google import GoogleProvider except ImportError: # pragma: lax no cover @@ -391,7 +391,7 @@ async def vertex_provider(): project = os.getenv('GOOGLE_PROJECT', 'pydantic-ai') location = os.getenv('GOOGLE_LOCATION', 'us-central1') - client = genai.Client(vertexai=True, project=project, location=location) + client = Client(vertexai=True, project=project, location=location) try: yield GoogleProvider(client=client) From 966e7f8e148d72b633ad09ec3decfc5e95a56400 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 24 Jul 2025 14:03:20 +0000 Subject: [PATCH 07/41] Import TypeAlias from typing_extensions for Python 3.9 --- pydantic_ai_slim/pydantic_ai/agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index a888204f2..be9adef30 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -10,11 +10,11 @@ from contextvars import ContextVar from copy import deepcopy from types import FrameType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeAlias, cast, final, overload +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload from opentelemetry.trace import NoOpTracer, use_span from pydantic.json_schema import GenerateJsonSchema -from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated +from typing_extensions import Literal, Never, Self, TypeAlias, TypeIs, TypeVar, deprecated from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop From 090ec23d0e1abe89dc56b772fbd36a4b876b865c Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 24 Jul 2025 23:51:09 +0000 Subject: [PATCH 08/41] Start cleaning up temporal integration --- .../pydantic_ai/temporal/__init__.py | 155 +++++++----------- ...nction_toolset.py => _function_toolset.py} | 2 +- .../{mcp_server.py => _mcp_server.py} | 2 +- .../temporal/{model.py => _model.py} | 2 +- .../pydantic_ai/temporal/_run_context.py | 41 +++++ .../pydantic_ai/temporal/_settings.py | 57 +++++++ .../pydantic_ai/temporal/_toolset.py | 26 +++ .../pydantic_ai/temporal/agent.py | 66 -------- temporal.py | 136 ++++++++------- 9 files changed, 262 insertions(+), 225 deletions(-) rename pydantic_ai_slim/pydantic_ai/temporal/{function_toolset.py => _function_toolset.py} (98%) rename pydantic_ai_slim/pydantic_ai/temporal/{mcp_server.py => _mcp_server.py} (98%) rename pydantic_ai_slim/pydantic_ai/temporal/{model.py => _model.py} (99%) create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/_run_context.py create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/_settings.py create mode 100644 pydantic_ai_slim/pydantic_ai/temporal/_toolset.py delete mode 100644 pydantic_ai_slim/pydantic_ai/temporal/agent.py diff --git a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py index d6a060860..f01028c17 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py @@ -1,104 +1,65 @@ from __future__ import annotations -from dataclasses import dataclass -from datetime import timedelta +import contextlib from typing import Any, Callable -from temporalio.common import Priority, RetryPolicy -from temporalio.workflow import ActivityCancellationType, VersioningIntent - -from pydantic_ai._run_context import AgentDepsT, RunContext - - -class _TemporalRunContext(RunContext[AgentDepsT]): - def __init__(self, **kwargs: Any): - self.__dict__ = kwargs - setattr( - self, - '__dataclass_fields__', - {name: field for name, field in RunContext.__dataclass_fields__.items() if name in kwargs}, - ) - - def __getattribute__(self, name: str) -> Any: - try: - return super().__getattribute__(name) - except AttributeError as e: - if name in RunContext.__dataclass_fields__: - raise AttributeError( - f'Temporalized {RunContext.__name__!r} object has no attribute {name!r}. To make the attribute available, pass a `TemporalSettings` object to `temporalize_agent` that has a custom `serialize_run_context` function that returns a dictionary that includes the attribute.' - ) - else: - raise e - - @classmethod - def serialize_run_context(cls, ctx: RunContext[AgentDepsT]) -> dict[str, Any]: - return { - 'deps': ctx.deps, - 'retries': ctx.retries, - 'tool_call_id': ctx.tool_call_id, - 'tool_name': ctx.tool_name, - 'retry': ctx.retry, - 'run_step': ctx.run_step, - } - - @classmethod - def deserialize_run_context(cls, ctx: dict[str, Any]) -> RunContext[AgentDepsT]: - return cls(**ctx) - - -@dataclass -class TemporalSettings: - """Settings for Temporal `execute_activity` and Pydantic AI-specific Temporal activity behavior.""" - - # Temporal settings - task_queue: str | None = None - schedule_to_close_timeout: timedelta | None = None - schedule_to_start_timeout: timedelta | None = None - start_to_close_timeout: timedelta | None = None - heartbeat_timeout: timedelta | None = None - retry_policy: RetryPolicy | None = None - cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL - activity_id: str | None = None - versioning_intent: VersioningIntent | None = None - summary: str | None = None - priority: Priority = Priority.default - - # Pydantic AI specific - tool_settings: dict[str, dict[str, TemporalSettings]] | None = None - - def for_tool(self, toolset_id: str, tool_id: str) -> TemporalSettings: - if self.tool_settings is None: - return self - return self.tool_settings.get(toolset_id, {}).get(tool_id, self) - - serialize_run_context: Callable[[RunContext], Any] = _TemporalRunContext.serialize_run_context - deserialize_run_context: Callable[[dict[str, Any]], RunContext] = _TemporalRunContext.deserialize_run_context - - @property - def execute_activity_options(self) -> dict[str, Any]: - return { - 'task_queue': self.task_queue, - 'schedule_to_close_timeout': self.schedule_to_close_timeout, - 'schedule_to_start_timeout': self.schedule_to_start_timeout, - 'start_to_close_timeout': self.start_to_close_timeout, - 'heartbeat_timeout': self.heartbeat_timeout, - 'retry_policy': self.retry_policy, - 'cancellation_type': self.cancellation_type, - 'activity_id': self.activity_id, - 'versioning_intent': self.versioning_intent, - 'summary': self.summary, - 'priority': self.priority, - } +from temporalio import workflow + +from pydantic_ai.agent import Agent +from pydantic_ai.toolsets.abstract import AbstractToolset + +from ..models import Model +from ._model import temporalize_model +from ._run_context import TemporalRunContext +from ._settings import TemporalSettings +from ._toolset import temporalize_toolset + +__all__ = [ + 'initialize_temporal', + 'TemporalSettings', + 'TemporalRunContext', +] def initialize_temporal(): - """Explicitly import types without which Temporal will not be able to serialize/deserialize `ModelMessage`s.""" - from pydantic_ai.messages import ( # noqa F401 - ModelResponse, # pyright: ignore[reportUnusedImport] - ImageUrl, # pyright: ignore[reportUnusedImport] - AudioUrl, # pyright: ignore[reportUnusedImport] - DocumentUrl, # pyright: ignore[reportUnusedImport] - VideoUrl, # pyright: ignore[reportUnusedImport] - BinaryContent, # pyright: ignore[reportUnusedImport] - UserContent, # pyright: ignore[reportUnusedImport] - ) + """Initialize Temporal.""" + with workflow.unsafe.imports_passed_through(): + with contextlib.suppress(ModuleNotFoundError): + import pandas # pyright: ignore[reportUnusedImport] # noqa: F401 + + +def temporalize_agent( + agent: Agent[Any, Any], + settings: TemporalSettings | None = None, + temporalize_toolset_func: Callable[ + [AbstractToolset, TemporalSettings | None], list[Callable[..., Any]] + ] = temporalize_toolset, +) -> list[Callable[..., Any]]: + """Temporalize an agent. + + Args: + agent: The agent to temporalize. + settings: The temporal settings to use. + temporalize_toolset_func: The function to use to temporalize the toolsets. + """ + if existing_activities := getattr(agent, '__temporal_activities', None): + return existing_activities + + settings = settings or TemporalSettings() + + # TODO: Doesn't consider model/toolsets passed at iter time. + + activities: list[Callable[..., Any]] = [] + if isinstance(agent.model, Model): + activities.extend(temporalize_model(agent.model, settings, agent._event_stream_handler)) # pyright: ignore[reportPrivateUsage] + + def temporalize_toolset(toolset: AbstractToolset) -> None: + activities.extend(temporalize_toolset_func(toolset, settings)) + + agent.toolset.apply(temporalize_toolset) + + setattr(agent, '__temporal_activities', activities) + return activities + + +# TODO: untemporalize_agent diff --git a/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py b/pydantic_ai_slim/pydantic_ai/temporal/_function_toolset.py similarity index 98% rename from pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py rename to pydantic_ai_slim/pydantic_ai/temporal/_function_toolset.py index 8b34a8fb5..2c371a1ff 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/_function_toolset.py @@ -10,7 +10,7 @@ from .._run_context import RunContext from ..toolsets import ToolsetTool -from . import TemporalSettings +from ._settings import TemporalSettings @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py b/pydantic_ai_slim/pydantic_ai/temporal/_mcp_server.py similarity index 98% rename from pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py rename to pydantic_ai_slim/pydantic_ai/temporal/_mcp_server.py index 6a7248468..6a93e0c34 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/_mcp_server.py @@ -9,7 +9,7 @@ from pydantic_ai.mcp import MCPServer, ToolResult -from . import TemporalSettings +from ._settings import TemporalSettings @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/temporal/model.py b/pydantic_ai_slim/pydantic_ai/temporal/_model.py similarity index 99% rename from pydantic_ai_slim/pydantic_ai/temporal/model.py rename to pydantic_ai_slim/pydantic_ai/temporal/_model.py index 32b4967c0..cbadf7bd1 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/model.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/_model.py @@ -24,7 +24,7 @@ from ..models import Model, ModelRequestParameters, StreamedResponse from ..settings import ModelSettings from ..usage import Usage -from . import TemporalSettings +from ._settings import TemporalSettings @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_run_context.py b/pydantic_ai_slim/pydantic_ai/temporal/_run_context.py new file mode 100644 index 000000000..8bc7029e6 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/_run_context.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import Any + +from pydantic_ai._run_context import AgentDepsT, RunContext + + +class TemporalRunContext(RunContext[AgentDepsT]): + def __init__(self, **kwargs: Any): + self.__dict__ = kwargs + setattr( + self, + '__dataclass_fields__', + {name: field for name, field in RunContext.__dataclass_fields__.items() if name in kwargs}, + ) + + def __getattribute__(self, name: str) -> Any: + try: + return super().__getattribute__(name) + except AttributeError as e: + if name in RunContext.__dataclass_fields__: + raise AttributeError( + f'Temporalized {RunContext.__name__!r} object has no attribute {name!r}. To make the attribute available, pass a `TemporalSettings` object to `temporalize_agent` that has a custom `serialize_run_context` function that returns a dictionary that includes the attribute.' + ) + else: + raise e + + @classmethod + def serialize_run_context(cls, ctx: RunContext[AgentDepsT]) -> dict[str, Any]: + return { + 'deps': ctx.deps, + 'retries': ctx.retries, + 'tool_call_id': ctx.tool_call_id, + 'tool_name': ctx.tool_name, + 'retry': ctx.retry, + 'run_step': ctx.run_step, + } + + @classmethod + def deserialize_run_context(cls, ctx: dict[str, Any]) -> RunContext[AgentDepsT]: + return cls(**ctx) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_settings.py b/pydantic_ai_slim/pydantic_ai/temporal/_settings.py new file mode 100644 index 000000000..14c9d595e --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/_settings.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Callable + +from temporalio.common import Priority, RetryPolicy +from temporalio.workflow import ActivityCancellationType, VersioningIntent + +from pydantic_ai._run_context import RunContext + +from ._run_context import TemporalRunContext + + +@dataclass +class TemporalSettings: + """Settings for Temporal `execute_activity` and Pydantic AI-specific Temporal activity behavior.""" + + # Temporal settings + task_queue: str | None = None + schedule_to_close_timeout: timedelta | None = None + schedule_to_start_timeout: timedelta | None = None + start_to_close_timeout: timedelta | None = None + heartbeat_timeout: timedelta | None = None + retry_policy: RetryPolicy | None = None + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL + activity_id: str | None = None + versioning_intent: VersioningIntent | None = None + summary: str | None = None + priority: Priority = Priority.default + + # Pydantic AI specific + tool_settings: dict[str, dict[str, TemporalSettings]] | None = None + + def for_tool(self, toolset_id: str, tool_id: str) -> TemporalSettings: + if self.tool_settings is None: + return self + return self.tool_settings.get(toolset_id, {}).get(tool_id, self) + + serialize_run_context: Callable[[RunContext], Any] = TemporalRunContext.serialize_run_context + deserialize_run_context: Callable[[dict[str, Any]], RunContext] = TemporalRunContext.deserialize_run_context + + @property + def execute_activity_options(self) -> dict[str, Any]: + return { + 'task_queue': self.task_queue, + 'schedule_to_close_timeout': self.schedule_to_close_timeout, + 'schedule_to_start_timeout': self.schedule_to_start_timeout, + 'start_to_close_timeout': self.start_to_close_timeout, + 'heartbeat_timeout': self.heartbeat_timeout, + 'retry_policy': self.retry_policy, + 'cancellation_type': self.cancellation_type, + 'activity_id': self.activity_id, + 'versioning_intent': self.versioning_intent, + 'summary': self.summary, + 'priority': self.priority, + } diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/temporal/_toolset.py new file mode 100644 index 000000000..289d90071 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/temporal/_toolset.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import Any, Callable + +from pydantic_ai.mcp import MCPServer +from pydantic_ai.toolsets.abstract import AbstractToolset +from pydantic_ai.toolsets.function import FunctionToolset + +from ._function_toolset import temporalize_function_toolset +from ._mcp_server import temporalize_mcp_server +from ._settings import TemporalSettings + + +def temporalize_toolset(toolset: AbstractToolset, settings: TemporalSettings | None) -> list[Callable[..., Any]]: + """Temporalize a toolset. + + Args: + toolset: The toolset to temporalize. + settings: The temporal settings to use. + """ + if isinstance(toolset, FunctionToolset): + return temporalize_function_toolset(toolset, settings) + elif isinstance(toolset, MCPServer): + return temporalize_mcp_server(toolset, settings) + else: + return [] diff --git a/pydantic_ai_slim/pydantic_ai/temporal/agent.py b/pydantic_ai_slim/pydantic_ai/temporal/agent.py deleted file mode 100644 index 5f59d51fc..000000000 --- a/pydantic_ai_slim/pydantic_ai/temporal/agent.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -from typing import Any, Callable - -from pydantic_ai.agent import Agent -from pydantic_ai.mcp import MCPServer -from pydantic_ai.toolsets.abstract import AbstractToolset -from pydantic_ai.toolsets.function import FunctionToolset - -from ..models import Model -from . import TemporalSettings -from .function_toolset import temporalize_function_toolset -from .mcp_server import temporalize_mcp_server -from .model import temporalize_model - - -def temporalize_toolset(toolset: AbstractToolset, settings: TemporalSettings | None) -> list[Callable[..., Any]]: - """Temporalize a toolset. - - Args: - toolset: The toolset to temporalize. - settings: The temporal settings to use. - """ - if isinstance(toolset, FunctionToolset): - return temporalize_function_toolset(toolset, settings) - elif isinstance(toolset, MCPServer): - return temporalize_mcp_server(toolset, settings) - else: - return [] - - -def temporalize_agent( - agent: Agent[Any, Any], - settings: TemporalSettings | None = None, - temporalize_toolset_func: Callable[ - [AbstractToolset, TemporalSettings | None], list[Callable[..., Any]] - ] = temporalize_toolset, -) -> list[Callable[..., Any]]: - """Temporalize an agent. - - Args: - agent: The agent to temporalize. - settings: The temporal settings to use. - temporalize_toolset_func: The function to use to temporalize the toolsets. - """ - if existing_activities := getattr(agent, '__temporal_activities', None): - return existing_activities - - settings = settings or TemporalSettings() - - # TODO: Doesn't consider model/toolsets passed at iter time. - - activities: list[Callable[..., Any]] = [] - if isinstance(agent.model, Model): - activities.extend(temporalize_model(agent.model, settings, agent._event_stream_handler)) # pyright: ignore[reportPrivateUsage] - - def temporalize_toolset(toolset: AbstractToolset) -> None: - activities.extend(temporalize_toolset_func(toolset, settings)) - - agent.toolset.apply(temporalize_toolset) - - setattr(agent, '__temporal_activities', activities) - return activities - - -# TODO: untemporalize_agent diff --git a/temporal.py b/temporal.py index 057b427be..8b2baec60 100644 --- a/temporal.py +++ b/temporal.py @@ -9,80 +9,77 @@ from collections.abc import AsyncIterable from datetime import timedelta +import logfire +from opentelemetry import trace from temporalio import workflow from temporalio.client import Client from temporalio.contrib.opentelemetry import TracingInterceptor from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig from temporalio.worker import Worker +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner, SandboxRestrictions from typing_extensions import TypedDict -with workflow.unsafe.imports_passed_through(): - from pydantic_ai import Agent - from pydantic_ai._run_context import RunContext - from pydantic_ai.mcp import MCPServerStdio - from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent - from pydantic_ai.temporal import ( - TemporalSettings, - initialize_temporal, - ) - from pydantic_ai.temporal.agent import temporalize_agent - from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai import Agent, RunContext +from pydantic_ai.mcp import MCPServerStdio +from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent +from pydantic_ai.temporal import ( + TemporalSettings, + initialize_temporal, + temporalize_agent, +) +from pydantic_ai.toolsets import FunctionToolset - initialize_temporal() +initialize_temporal() - class Deps(TypedDict): - country: str - def get_country(ctx: RunContext[Deps]) -> str: - return ctx.deps['country'] +class Deps(TypedDict): + country: str - toolset = FunctionToolset[Deps](tools=[get_country], id='country') - mcp_server = MCPServerStdio( - 'python', - ['-m', 'tests.mcp_server'], - timeout=20, - id='test', - ) - async def event_stream_handler( - ctx: RunContext[Deps], - stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], - ): - print(f'{ctx.run_step=}') - async for event in stream: - print(event) - - my_agent = Agent( - 'openai:gpt-4o', - toolsets=[toolset, mcp_server], - event_stream_handler=event_stream_handler, - deps_type=Deps, - ) +def get_country(ctx: RunContext[Deps]) -> str: + return ctx.deps['country'] - temporal_settings = TemporalSettings( - start_to_close_timeout=timedelta(seconds=60), - tool_settings={ # TODO: Allow default temporal settings to be set for an entire toolset - 'country': { - 'get_country': TemporalSettings(start_to_close_timeout=timedelta(seconds=110)), - }, - }, - ) - activities = temporalize_agent(my_agent, temporal_settings) +toolset = FunctionToolset[Deps](tools=[get_country], id='country') +mcp_server = MCPServerStdio( + 'python', + ['-m', 'tests.mcp_server'], + timeout=20, + id='test', +) -def init_runtime_with_telemetry() -> Runtime: - # import logfire - # logfire.configure(send_to_logfire=True, service_version='0.0.1', console=False) - # logfire.instrument_pydantic_ai() - # logfire.instrument_httpx(capture_all=True) +async def event_stream_handler( + ctx: RunContext[Deps], + stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], +): + logfire.info(f'{ctx.run_step=}') + async for event in stream: + logfire.info(f'{event=}') - # Setup SDK metrics to OTel endpoint - return Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) + +my_agent = Agent( + 'openai:gpt-4o', + toolsets=[toolset, mcp_server], + event_stream_handler=event_stream_handler, + deps_type=Deps, +) + +temporal_settings = TemporalSettings( + start_to_close_timeout=timedelta(seconds=60), + tool_settings={ # TODO: Allow default temporal settings to be set for all activities in a toolset + 'country': { + 'get_country': TemporalSettings(start_to_close_timeout=timedelta(seconds=110)), + }, + }, +) +activities = temporalize_agent(my_agent, temporal_settings) + + +TASK_QUEUE = 'pydantic-ai-agent-task-queue' -# Basic workflow that logs and invokes an activity @workflow.defn class MyAgentWorkflow: @workflow.run @@ -92,18 +89,39 @@ async def run(self, prompt: str, deps: Deps) -> str: async def main(): + def init_runtime_with_telemetry() -> Runtime: + logfire.configure(console=False) + logfire.instrument_pydantic_ai() + logfire.instrument_httpx(capture_all=True) + + # Setup SDK metrics to OTel endpoint + return Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) + client = await Client.connect( 'localhost:7233', - interceptors=[TracingInterceptor()], - data_converter=pydantic_data_converter, - runtime=init_runtime_with_telemetry(), + interceptors=[ # TODO: Use ClientPlugin.configure_client for this + TracingInterceptor(trace.get_tracer('temporal')) + ], + data_converter=pydantic_data_converter, # TODO: Use ClientPlugin.configure_client for this + runtime=init_runtime_with_telemetry(), # TODO: Use ClientPlugin.connect_service_client for this ) async with Worker( client, - task_queue='my-agent-task-queue', + task_queue=TASK_QUEUE, workflows=[MyAgentWorkflow], activities=activities, + workflow_runner=SandboxedWorkflowRunner( # TODO: Use WorkerPlugin.configure_worker for this, see https://github.com/temporalio/sdk-python/blob/da6616a93e9ee5170842bb5a056e2383e18d07c6/tests/test_plugins.py#L71 + restrictions=SandboxRestrictions.default.with_passthrough_modules( + 'pydantic_ai', + 'logfire', # TODO: Only if module available? + # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize + 'attrs', + # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize + 'numpy', # TODO: Only if module available? + 'pandas', # TODO: Only if module available? + ), + ), ): output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] MyAgentWorkflow.run, @@ -112,7 +130,7 @@ async def main(): Deps(country='Mexico'), ], id=f'my-agent-workflow-id-{random.random()}', - task_queue='my-agent-task-queue', + task_queue=TASK_QUEUE, ) print(output) From 4c87691603a1b4cc444493b6464517a443a5d045 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 25 Jul 2025 17:34:55 +0000 Subject: [PATCH 09/41] with_passthrough_modules doesn't import itself --- temporal.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/temporal.py b/temporal.py index 8b2baec60..ea64f6797 100644 --- a/temporal.py +++ b/temporal.py @@ -114,12 +114,12 @@ def init_runtime_with_telemetry() -> Runtime: workflow_runner=SandboxedWorkflowRunner( # TODO: Use WorkerPlugin.configure_worker for this, see https://github.com/temporalio/sdk-python/blob/da6616a93e9ee5170842bb5a056e2383e18d07c6/tests/test_plugins.py#L71 restrictions=SandboxRestrictions.default.with_passthrough_modules( 'pydantic_ai', - 'logfire', # TODO: Only if module available? + 'logfire', # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize 'attrs', # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize - 'numpy', # TODO: Only if module available? - 'pandas', # TODO: Only if module available? + 'numpy', + 'pandas', ), ), ): From 694fa6b2f8705ab0988cdae4197cff579ef1ee9b Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 25 Jul 2025 20:56:49 +0000 Subject: [PATCH 10/41] Use Temporal plugins --- .../pydantic_ai/temporal/__init__.py | 77 ++++++++++++++++--- pyproject.toml | 1 + temporal.py | 53 +++---------- uv.lock | 23 +++--- 4 files changed, 92 insertions(+), 62 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py index f01028c17..32ea06775 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py @@ -1,9 +1,18 @@ from __future__ import annotations -import contextlib +from collections.abc import Sequence +from dataclasses import replace from typing import Any, Callable -from temporalio import workflow +import logfire # TODO: Not always available +from opentelemetry import trace # TODO: Not always available +from temporalio.client import ClientConfig, Plugin as ClientPlugin +from temporalio.contrib.opentelemetry import TracingInterceptor +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig +from temporalio.service import ConnectConfig, ServiceClient +from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner from pydantic_ai.agent import Agent from pydantic_ai.toolsets.abstract import AbstractToolset @@ -15,17 +24,66 @@ from ._toolset import temporalize_toolset __all__ = [ - 'initialize_temporal', 'TemporalSettings', 'TemporalRunContext', + 'PydanticAIPlugin', + 'LogfirePlugin', + 'AgentPlugin', ] -def initialize_temporal(): - """Initialize Temporal.""" - with workflow.unsafe.imports_passed_through(): - with contextlib.suppress(ModuleNotFoundError): - import pandas # pyright: ignore[reportUnusedImport] # noqa: F401 +class PydanticAIPlugin(ClientPlugin, WorkerPlugin): + """Temporal client and worker plugin for Pydantic AI.""" + + def configure_client(self, config: ClientConfig) -> ClientConfig: + config['data_converter'] = pydantic_data_converter + return super().configure_client(config) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType] + if isinstance(runner, SandboxedWorkflowRunner): + config['workflow_runner'] = replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + 'pydantic_ai', + 'logfire', + # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize + 'attrs', + # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize + 'numpy', + 'pandas', + ), + ) + return super().configure_worker(config) + + +class LogfirePlugin(ClientPlugin): + """Temporal client plugin for Logfire.""" + + def configure_client(self, config: ClientConfig) -> ClientConfig: + config['interceptors'] = [TracingInterceptor(trace.get_tracer('temporal'))] + return super().configure_client(config) + + async def connect_service_client(self, config: ConnectConfig) -> ServiceClient: + # TODO: Do we need this here? + logfire.configure(console=False) + logfire.instrument_pydantic_ai() + logfire.instrument_httpx(capture_all=True) + + config.runtime = Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) + return await super().connect_service_client(config) + + +class AgentPlugin(WorkerPlugin): + """Temporal worker plugin for a specific Pydantic AI agent.""" + + def __init__(self, agent: Agent[Any, Any], settings: TemporalSettings | None = None): + self.activities = temporalize_agent(agent, settings) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType] + config['activities'] = [*activities, *self.activities] + return super().configure_worker(config) def temporalize_agent( @@ -47,7 +105,8 @@ def temporalize_agent( settings = settings or TemporalSettings() - # TODO: Doesn't consider model/toolsets passed at iter time. + # TODO: Doesn't consider model/toolsets passed at iter time, raise an error if that happens. + # Similarly, passing event_stream_handler at iter time should raise an error. activities: list[Callable[..., Any]] = [] if isinstance(agent.model, Model): diff --git a/pyproject.toml b/pyproject.toml index fdd67f183..e3f2ed07f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ pydantic-ai-slim = { workspace = true } pydantic-evals = { workspace = true } pydantic-graph = { workspace = true } pydantic-ai-examples = { workspace = true } +temporalio = { git = "https://github.com/temporalio/sdk-python.git", rev = "main" } [tool.uv.workspace] members = [ diff --git a/temporal.py b/temporal.py index ea64f6797..2b8a4a319 100644 --- a/temporal.py +++ b/temporal.py @@ -1,37 +1,25 @@ -# /// script -# dependencies = [ -# "temporalio", -# "logfire", -# ] -# /// import asyncio import random from collections.abc import AsyncIterable from datetime import timedelta import logfire -from opentelemetry import trace from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.opentelemetry import TracingInterceptor -from temporalio.contrib.pydantic import pydantic_data_converter -from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig from temporalio.worker import Worker -from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner, SandboxRestrictions from typing_extensions import TypedDict from pydantic_ai import Agent, RunContext from pydantic_ai.mcp import MCPServerStdio from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent from pydantic_ai.temporal import ( + AgentPlugin, + LogfirePlugin, + PydanticAIPlugin, TemporalSettings, - initialize_temporal, - temporalize_agent, ) from pydantic_ai.toolsets import FunctionToolset -initialize_temporal() - class Deps(TypedDict): country: str @@ -59,7 +47,7 @@ async def event_stream_handler( logfire.info(f'{event=}') -my_agent = Agent( +agent = Agent( 'openai:gpt-4o', toolsets=[toolset, mcp_server], event_stream_handler=event_stream_handler, @@ -74,7 +62,6 @@ async def event_stream_handler( }, }, ) -activities = temporalize_agent(my_agent, temporal_settings) TASK_QUEUE = 'pydantic-ai-agent-task-queue' @@ -84,44 +71,26 @@ async def event_stream_handler( class MyAgentWorkflow: @workflow.run async def run(self, prompt: str, deps: Deps) -> str: - result = await my_agent.run(prompt, deps=deps) + result = await agent.run(prompt, deps=deps) return result.output -async def main(): - def init_runtime_with_telemetry() -> Runtime: - logfire.configure(console=False) - logfire.instrument_pydantic_ai() - logfire.instrument_httpx(capture_all=True) +# TODO: For some reason, when I put this (specifically the temporalize_agent call) inside `async def main()`, +# we get tons of errors. +plugin = AgentPlugin(agent, temporal_settings) - # Setup SDK metrics to OTel endpoint - return Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) +async def main(): client = await Client.connect( 'localhost:7233', - interceptors=[ # TODO: Use ClientPlugin.configure_client for this - TracingInterceptor(trace.get_tracer('temporal')) - ], - data_converter=pydantic_data_converter, # TODO: Use ClientPlugin.configure_client for this - runtime=init_runtime_with_telemetry(), # TODO: Use ClientPlugin.connect_service_client for this + plugins=[PydanticAIPlugin(), LogfirePlugin()], ) async with Worker( client, task_queue=TASK_QUEUE, workflows=[MyAgentWorkflow], - activities=activities, - workflow_runner=SandboxedWorkflowRunner( # TODO: Use WorkerPlugin.configure_worker for this, see https://github.com/temporalio/sdk-python/blob/da6616a93e9ee5170842bb5a056e2383e18d07c6/tests/test_plugins.py#L71 - restrictions=SandboxRestrictions.default.with_passthrough_modules( - 'pydantic_ai', - 'logfire', - # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize - 'attrs', - # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize - 'numpy', - 'pandas', - ), - ), + plugins=[plugin], ): output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] MyAgentWorkflow.run, diff --git a/uv.lock b/uv.lock index b33d028f9..dcc5090da 100644 --- a/uv.lock +++ b/uv.lock @@ -2286,6 +2286,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695, upload-time = "2023-02-04T12:11:25.002Z" }, ] +[[package]] +name = "nexus-rpc" +version = "1.1.0" +source = { git = "https://github.com/nexus-rpc/sdk-python.git?rev=35f574c711193a6e2560d3e6665732a5bb7ae92c#35f574c711193a6e2560d3e6665732a5bb7ae92c" } +dependencies = [ + { name = "typing-extensions" }, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -3243,7 +3251,7 @@ requires-dist = [ { name = "rich", marker = "extra == 'cli'", specifier = ">=13" }, { name = "starlette", marker = "extra == 'ag-ui'", specifier = ">=0.45.3" }, { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, - { name = "temporalio", marker = "extra == 'temporal'", specifier = ">=1.13.0" }, + { name = "temporalio", marker = "extra == 'temporal'", git = "https://github.com/temporalio/sdk-python.git?rev=main" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "temporal", "vertexai"] @@ -4160,22 +4168,15 @@ wheels = [ [[package]] name = "temporalio" -version = "1.13.0" -source = { registry = "https://pypi.org/simple" } +version = "1.14.1" +source = { git = "https://github.com/temporalio/sdk-python.git?rev=main#e767013acca543345e0408a167556bbb987eb130" } dependencies = [ + { name = "nexus-rpc" }, { name = "protobuf" }, { name = "python-dateutil", marker = "python_full_version < '3.11'" }, { name = "types-protobuf" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3e/a3/a76477b523937f47a21941188c16b3c6b1eef6baadc7c8efeea497d909de/temporalio-1.13.0.tar.gz", hash = "sha256:5a979eee5433da6ab5d8a2bcde25a1e7d454e91920acb0bf7ca93d415750828b", size = 1558745, upload-time = "2025-06-20T19:57:26.944Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f5/f4/a5a74284c671bd50ce7353ad1dad7dab1a795f891458454049e95bc5378f/temporalio-1.13.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:7ee14cab581352e77171d1e4ce01a899231abfe75c5f7233e3e260f361a344cc", size = 12086961, upload-time = "2025-06-20T19:57:15.25Z" }, - { url = "https://files.pythonhosted.org/packages/1f/b7/5dc6e34f4e9a3da8b75cb3fe0d32edca1d9201d598c38d022501d38650a9/temporalio-1.13.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:575a0c57dbb089298b4775f3aca86ebaf8d58d5ba155e7fc5509877c25e6bb44", size = 11745239, upload-time = "2025-06-20T19:57:17.934Z" }, - { url = "https://files.pythonhosted.org/packages/04/30/4b9b15af87c181fd9364b61971faa0faa07d199320d7ff1712b5d51b5bbb/temporalio-1.13.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf099a27f22c0dbc22f3d86dba76d59be5da812ff044ba3fa183e3e14bd5e9a", size = 12119197, upload-time = "2025-06-20T19:57:20.509Z" }, - { url = "https://files.pythonhosted.org/packages/46/9f/a5b627d773974c654b6cd22ed3937e7e2471023af244ea417f0e917e617b/temporalio-1.13.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7e20c711f41c66877b9d54ab33c79a14ccaac9ed498a174274f6129110f4d84", size = 12413459, upload-time = "2025-06-20T19:57:22.816Z" }, - { url = "https://files.pythonhosted.org/packages/a3/73/efb6957212eb8c8dfff26c7c2c6ddf745aa5990a3b722cff17c8feaa66fc/temporalio-1.13.0-cp39-abi3-win_amd64.whl", hash = "sha256:9286cb84c1e078b2bcc6e8c6bd0be878d8ed395be991ac0d7cff555e3a82ac0b", size = 12440644, upload-time = "2025-06-20T19:57:25.175Z" }, -] [[package]] name = "tenacity" From 5e858d310e0034b37707d3febc5cb343dc82e65c Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 25 Jul 2025 23:20:46 +0000 Subject: [PATCH 11/41] Polish --- .../pydantic_ai/ext/temporal/__init__.py | 70 ++++++++++ .../pydantic_ai/ext/temporal/_agent.py | 120 +++++++++++++++++ .../{ => ext}/temporal/_function_toolset.py | 49 +++++-- .../pydantic_ai/ext/temporal/_logfire.py | 34 +++++ .../{ => ext}/temporal/_mcp_server.py | 28 +++- .../pydantic_ai/{ => ext}/temporal/_model.py | 49 +++++-- .../pydantic_ai/ext/temporal/_run_context.py | 68 ++++++++++ .../{ => ext}/temporal/_settings.py | 22 ++-- .../pydantic_ai/ext/temporal/_toolset.py | 41 ++++++ .../pydantic_ai/temporal/__init__.py | 124 ------------------ .../pydantic_ai/temporal/_run_context.py | 41 ------ .../pydantic_ai/temporal/_toolset.py | 26 ---- temporal.py | 67 +++++----- 13 files changed, 472 insertions(+), 267 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py create mode 100644 pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py rename pydantic_ai_slim/pydantic_ai/{ => ext}/temporal/_function_toolset.py (51%) create mode 100644 pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py rename pydantic_ai_slim/pydantic_ai/{ => ext}/temporal/_mcp_server.py (72%) rename pydantic_ai_slim/pydantic_ai/{ => ext}/temporal/_model.py (82%) create mode 100644 pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py rename pydantic_ai_slim/pydantic_ai/{ => ext}/temporal/_settings.py (71%) create mode 100644 pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py delete mode 100644 pydantic_ai_slim/pydantic_ai/temporal/__init__.py delete mode 100644 pydantic_ai_slim/pydantic_ai/temporal/_run_context.py delete mode 100644 pydantic_ai_slim/pydantic_ai/temporal/_toolset.py diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py new file mode 100644 index 000000000..f1ec642a5 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import replace +from typing import Any, Callable + +from temporalio.client import ClientConfig, Plugin as ClientPlugin +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner + +from pydantic_ai.agent import Agent + +from ._agent import temporalize_agent, untemporalize_agent +from ._logfire import LogfirePlugin +from ._run_context import TemporalRunContext +from ._settings import TemporalSettings + +__all__ = [ + 'TemporalSettings', + 'TemporalRunContext', + 'PydanticAIPlugin', + 'LogfirePlugin', + 'AgentPlugin', + 'temporalize_agent', + 'untemporalize_agent', +] + + +class PydanticAIPlugin(ClientPlugin, WorkerPlugin): + """Temporal client and worker plugin for Pydantic AI.""" + + def configure_client(self, config: ClientConfig) -> ClientConfig: + config['data_converter'] = pydantic_data_converter + return super().configure_client(config) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType] + if isinstance(runner, SandboxedWorkflowRunner): + config['workflow_runner'] = replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + 'pydantic_ai', + 'logfire', + # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize + 'attrs', + # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize + 'numpy', + 'pandas', + ), + ) + return super().configure_worker(config) + + +class AgentPlugin(WorkerPlugin): + """Temporal worker plugin for a specific Pydantic AI agent.""" + + def __init__(self, agent: Agent[Any, Any]): + self.agent = agent + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + agent_activities = getattr(self.agent, '__temporal_activities', None) + if agent_activities is None: + raise ValueError( + 'The agent has not been temporalized yet, call `temporalize_agent(agent)` (or `with temporalized_agent(agent): ...`) first.' + ) + + activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType] + config['activities'] = [*activities, *agent_activities] + return super().configure_worker(config) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py new file mode 100644 index 000000000..f3717b6da --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any, Callable + +from pydantic_ai.agent import Agent +from pydantic_ai.models import Model +from pydantic_ai.toolsets.abstract import AbstractToolset + +from ._model import temporalize_model, untemporalize_model +from ._settings import TemporalSettings +from ._toolset import temporalize_toolset, untemporalize_toolset + + +def temporalize_agent( + agent: Agent[Any, Any], + settings: TemporalSettings | None = None, + toolset_settings: dict[str, TemporalSettings] = {}, + tool_settings: dict[str, dict[str, TemporalSettings]] = {}, + temporalize_toolset_func: Callable[ + [AbstractToolset, TemporalSettings | None, dict[str, TemporalSettings]], list[Callable[..., Any]] + ] = temporalize_toolset, +) -> list[Callable[..., Any]]: + """Temporalize an agent. + + Args: + agent: The agent to temporalize. + settings: The temporal settings to use. + toolset_settings: The temporal settings to use for specific toolsets identified by ID. + tool_settings: The temporal settings to use for specific tools identified by toolset ID and tool name. + temporalize_toolset_func: The function to use to temporalize the toolsets. + """ + if existing_activities := getattr(agent, '__temporal_activities', None): + return existing_activities + + settings = settings or TemporalSettings() + + activities: list[Callable[..., Any]] = [] + if isinstance(agent.model, Model): + activities.extend(temporalize_model(agent.model, settings, agent._event_stream_handler)) # pyright: ignore[reportPrivateUsage] + + def temporalize_toolset(toolset: AbstractToolset) -> None: + id = toolset.id + if not id: + raise ValueError( + "A toolset needs to have an ID in order to be used with Temporal. The ID will be used to identify the toolset's activities within the workflow." + ) + activities.extend( + temporalize_toolset_func(toolset, settings.merge(toolset_settings.get(id)), tool_settings.get(id, {})) + ) + + agent.toolset.apply(temporalize_toolset) + + original_iter = agent.iter + original_override = agent.override + setattr(agent, '__original_iter', original_iter) + setattr(agent, '__original_override', original_override) + + def iter(*args: Any, **kwargs: Any) -> Any: + if kwargs.get('model') is not None: + raise ValueError( + 'Model cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + ) + if kwargs.get('toolsets') is not None: + raise ValueError( + 'Toolsets cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + ) + if kwargs.get('event_stream_handler') is not None: + raise ValueError( + 'Event stream handler cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + ) + + return original_iter(*args, **kwargs) + + def override(*args: Any, **kwargs: Any) -> Any: + if kwargs.get('model') is not None: + raise ValueError('Model cannot be overridden when using Temporal, it must be set at agent creation time.') + if kwargs.get('toolsets') is not None: + raise ValueError( + 'Toolsets cannot be overridden when using Temporal, it must be set at agent creation time.' + ) + return original_override(*args, **kwargs) + + agent.iter = iter + agent.override = override + + setattr(agent, '__temporal_activities', activities) + return activities + + +def untemporalize_agent(agent: Agent[Any, Any]) -> None: + """Untemporalize an agent. + + Args: + agent: The agent to untemporalize. + """ + if not hasattr(agent, '__temporal_activities'): + return + + if isinstance(agent.model, Model): + untemporalize_model(agent.model) + + agent.toolset.apply(untemporalize_toolset) + + agent.iter = getattr(agent, '__original_iter') + agent.override = getattr(agent, '__original_override') + delattr(agent, '__original_iter') + delattr(agent, '__original_override') + + delattr(agent, '__temporal_activities') + + +@contextmanager +def temporalized_agent(agent: Agent[Any, Any], settings: TemporalSettings | None = None) -> Generator[None, None, None]: + temporalize_agent(agent, settings) + try: + yield + finally: + untemporalize_agent(agent) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py similarity index 51% rename from pydantic_ai_slim/pydantic_ai/temporal/_function_toolset.py rename to pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py index 2c371a1ff..054420453 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py @@ -6,10 +6,10 @@ from pydantic import ConfigDict, with_config from temporalio import activity, workflow -from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai._run_context import RunContext +from pydantic_ai.toolsets import FunctionToolset, ToolsetTool -from .._run_context import RunContext -from ..toolsets import ToolsetTool +from ._run_context import TemporalRunContext from ._settings import TemporalSettings @@ -24,40 +24,50 @@ class _CallToolParams: def temporalize_function_toolset( toolset: FunctionToolset, settings: TemporalSettings | None = None, + tool_settings: dict[str, TemporalSettings] = {}, ) -> list[Callable[..., Any]]: """Temporalize a function toolset. Args: toolset: The function toolset to temporalize. settings: The temporal settings to use. + tool_settings: The temporal settings to use for specific tools identified by tool name. """ if activities := getattr(toolset, '__temporal_activities', None): return activities id = toolset.id - if not id: - raise ValueError( - "A function toolset needs to have an ID in order to be used in a durable execution environment like Temporal. The ID will be used to identify the toolset's activities within the workflow." - ) + assert id is not None settings = settings or TemporalSettings() original_call_tool = toolset.call_tool + setattr(toolset, '__original_call_tool', original_call_tool) @activity.defn(name=f'function_toolset__{id}__call_tool') async def call_tool_activity(params: _CallToolParams) -> Any: name = params.name - ctx = settings.for_tool(id, name).deserialize_run_context(params.serialized_run_context) - tool = (await toolset.get_tools(ctx))[name] + settings_for_tool = settings.merge(tool_settings.get(name)) + ctx = TemporalRunContext.deserialize_run_context( + params.serialized_run_context, settings_for_tool.deserialize_run_context + ) + try: + tool = (await toolset.get_tools(ctx))[name] + except KeyError as e: + raise ValueError( + f'Tool {name!r} not found in toolset {toolset.id!r}. ' + 'Removing or renaming tools during an agent run is not supported with Temporal.' + ) from e + return await original_call_tool(name, params.tool_args, ctx, tool) async def call_tool(name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: - tool_settings = settings.for_tool(id, name) - serialized_run_context = tool_settings.serialize_run_context(ctx) + settings_for_tool = settings.merge(tool_settings.get(name)) + serialized_run_context = TemporalRunContext.serialize_run_context(ctx, settings_for_tool.serialize_run_context) return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=call_tool_activity, arg=_CallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), - **tool_settings.execute_activity_options, + **settings_for_tool.execute_activity_options, ) toolset.call_tool = call_tool @@ -65,3 +75,18 @@ async def call_tool(name: str, tool_args: dict[str, Any], ctx: RunContext, tool: activities = [call_tool_activity] setattr(toolset, '__temporal_activities', activities) return activities + + +def untemporalize_function_toolset(toolset: FunctionToolset) -> None: + """Untemporalize a function toolset. + + Args: + toolset: The function toolset to untemporalize. + """ + if not hasattr(toolset, '__temporal_activities'): + return + + toolset.call_tool = getattr(toolset, '__original_call_tool') + delattr(toolset, '__original_call_tool') + + delattr(toolset, '__temporal_activities') diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py new file mode 100644 index 000000000..bb307b990 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Callable + +from opentelemetry.trace import get_tracer +from temporalio.client import ClientConfig, Plugin as ClientPlugin +from temporalio.contrib.opentelemetry import TracingInterceptor +from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig +from temporalio.service import ConnectConfig, ServiceClient + + +def _default_setup_logfire(): + import logfire + + logfire.configure(console=False) + logfire.instrument_pydantic_ai() + + +class LogfirePlugin(ClientPlugin): + """Temporal client plugin for Logfire.""" + + def __init__(self, setup_logfire: Callable[[], None] = _default_setup_logfire): + self.setup_logfire = setup_logfire + + def configure_client(self, config: ClientConfig) -> ClientConfig: + interceptors = config.get('interceptors', []) + config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporal'))] + return super().configure_client(config) + + async def connect_service_client(self, config: ConnectConfig) -> ServiceClient: + self.setup_logfire() + + config.runtime = Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) + return await super().connect_service_client(config) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py similarity index 72% rename from pydantic_ai_slim/pydantic_ai/temporal/_mcp_server.py rename to pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py index 6a93e0c34..edc021a90 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/_mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py @@ -23,26 +23,27 @@ class _CallToolParams: def temporalize_mcp_server( server: MCPServer, settings: TemporalSettings | None = None, + tool_settings: dict[str, TemporalSettings] = {}, ) -> list[Callable[..., Any]]: """Temporalize an MCP server. Args: server: The MCP server to temporalize. settings: The temporal settings to use. + tool_settings: The temporal settings to use for each tool. """ if activities := getattr(server, '__temporal_activities', None): return activities id = server.id - if not id: - raise ValueError( - "An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal. The ID will be used to identify the server's activities within the workflow." - ) + assert id is not None settings = settings or TemporalSettings() original_list_tools = server.list_tools original_direct_call_tool = server.direct_call_tool + setattr(server, '__original_list_tools', original_list_tools) + setattr(server, '__original_direct_call_tool', original_direct_call_tool) @activity.defn(name=f'mcp_server__{id}__list_tools') async def list_tools_activity() -> list[mcp_types.Tool]: @@ -66,7 +67,7 @@ async def direct_call_tool( return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=call_tool_activity, arg=_CallToolParams(name=name, tool_args=args, metadata=metadata), - **settings.for_tool(id, name).execute_activity_options, + **tool_settings.get(name, settings).execute_activity_options, ) server.list_tools = list_tools @@ -75,3 +76,20 @@ async def direct_call_tool( activities = [list_tools_activity, call_tool_activity] setattr(server, '__temporal_activities', activities) return activities + + +def untemporalize_mcp_server(server: MCPServer) -> None: + """Untemporalize an MCP server. + + Args: + server: The MCP server to untemporalize. + """ + if not hasattr(server, '__temporal_activities'): + return + + server.list_tools = getattr(server, '__original_list_tools') + server.direct_call_tool = getattr(server, '__original_direct_call_tool') + delattr(server, '__original_list_tools') + delattr(server, '__original_direct_call_tool') + + delattr(server, '__temporal_activities') diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py similarity index 82% rename from pydantic_ai_slim/pydantic_ai/temporal/_model.py rename to pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py index cbadf7bd1..9f2240e58 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/_model.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py @@ -9,10 +9,10 @@ from pydantic import ConfigDict, with_config from temporalio import activity, workflow -from .._run_context import RunContext -from ..agent import EventStreamHandler -from ..exceptions import UserError -from ..messages import ( +from pydantic_ai._run_context import RunContext +from pydantic_ai.agent import EventStreamHandler +from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import ( FinalResultEvent, ModelMessage, ModelResponse, @@ -21,9 +21,11 @@ TextPart, ToolCallPart, ) -from ..models import Model, ModelRequestParameters, StreamedResponse -from ..settings import ModelSettings -from ..usage import Usage +from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse +from pydantic_ai.settings import ModelSettings +from pydantic_ai.usage import Usage + +from ._run_context import TemporalRunContext from ._settings import TemporalSettings @@ -84,13 +86,18 @@ def temporalize_model( # noqa: C901 original_request = model.request original_request_stream = model.request_stream + setattr(model, '__original_request', original_request) + setattr(model, '__original_request_stream', original_request_stream) + @activity.defn(name='model_request') async def request_activity(params: _RequestParams) -> ModelResponse: return await original_request(params.messages, params.model_settings, params.model_request_parameters) @activity.defn(name='model_request_stream') async def request_stream_activity(params: _RequestParams) -> ModelResponse: - run_context = settings.deserialize_run_context(params.serialized_run_context) + run_context = TemporalRunContext.deserialize_run_context( + params.serialized_run_context, settings.deserialize_run_context + ) async with original_request_stream( params.messages, params.model_settings, params.model_request_parameters, run_context ) as streamed_response: @@ -102,6 +109,7 @@ async def request_stream_activity(params: _RequestParams) -> ModelResponse: ] } + # Keep in sync with `AgentStream.__aiter__` async def aiter(): def _get_final_result_event(e: ModelResponseStreamEvent) -> FinalResultEvent | None: """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" @@ -119,9 +127,9 @@ def _get_final_result_event(e: ModelResponseStreamEvent) -> FinalResultEvent | N elif tool_def.kind == 'deferred': return FinalResultEvent(tool_name=None, tool_call_id=None) - # TODO: usage_checking_stream = _get_usage_checking_stream_response( - # self._raw_stream_response, self._usage_limits, self.usage - # ) + # `AgentStream.__aiter__`, which this is based on, calls `_get_usage_checking_stream_response` here, + # but we don't have access to the `_usage_limits`. + async for event in streamed_response: yield event if (final_result_event := _get_final_result_event(event)) is not None: @@ -168,7 +176,7 @@ async def request_stream( if run_context is None: raise UserError('Streaming with Temporal requires `request_stream` to be called with a `run_context`') - serialized_run_context = settings.serialize_run_context(run_context) + serialized_run_context = TemporalRunContext.serialize_run_context(run_context, settings.serialize_run_context) response = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=request_stream_activity, arg=_RequestParams( @@ -187,3 +195,20 @@ async def request_stream( activities = [request_activity, request_stream_activity] setattr(model, '__temporal_activities', activities) return activities + + +def untemporalize_model(model: Model) -> None: + """Untemporalize a model. + + Args: + model: The model to untemporalize. + """ + if not hasattr(model, '__temporal_activities'): + return + + model.request = getattr(model, '__original_request') + model.request_stream = getattr(model, '__original_request_stream') + + delattr(model, '__original_request') + delattr(model, '__original_request_stream') + delattr(model, '__temporal_activities') diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py new file mode 100644 index 000000000..c1e896508 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import Any, Callable + +from pydantic_ai._run_context import RunContext + + +class TemporalRunContext(RunContext[Any]): + def __init__(self, **kwargs: Any): + self.__dict__ = kwargs + setattr( + self, + '__dataclass_fields__', + {name: field for name, field in RunContext.__dataclass_fields__.items() if name in kwargs}, + ) + + def __getattribute__(self, name: str) -> Any: + try: + return super().__getattribute__(name) + except AttributeError as e: + if name in RunContext.__dataclass_fields__: + raise AttributeError( + f'Temporalized {RunContext.__name__!r} object has no attribute {name!r}. ' + 'To make the attribute available, pass a `TemporalSettings` object to `temporalize_agent` with a custom `serialize_run_context` function that returns a dictionary that includes the attribute.' + ) + else: + raise e + + @classmethod + def serialize_run_context( + cls, + ctx: RunContext[Any], + extra_serializer: Callable[[RunContext[Any]], dict[str, Any]] | None = None, + ) -> dict[str, Any]: + return { + 'retries': ctx.retries, + 'tool_call_id': ctx.tool_call_id, + 'tool_name': ctx.tool_name, + 'retry': ctx.retry, + 'run_step': ctx.run_step, + **(extra_serializer(ctx) if extra_serializer else {}), + } + + @classmethod + def deserialize_run_context( + cls, ctx: dict[str, Any], extra_deserializer: Callable[[dict[str, Any]], dict[str, Any]] | None = None + ) -> RunContext[Any]: + return cls( + retries=ctx['retries'], + tool_call_id=ctx['tool_call_id'], + tool_name=ctx['tool_name'], + retry=ctx['retry'], + run_step=ctx['run_step'], + **(extra_deserializer(ctx) if extra_deserializer else {}), + ) + + +def serialize_run_context_deps(ctx: RunContext[Any]) -> dict[str, Any]: + if not isinstance(ctx.deps, dict): + raise ValueError( + 'The `deps` object must be a JSON-serializable dictionary in order to be used with Temporal. ' + 'To use a different type, pass a `TemporalSettings` object to `temporalize_agent` with custom `serialize_run_context` and `deserialize_run_context` functions.' + ) + return {'deps': ctx.deps} # pyright: ignore[reportUnknownMemberType] + + +def deserialize_run_context_deps(ctx: dict[str, Any]) -> dict[str, Any]: + return {'deps': ctx['deps']} diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_settings.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_settings.py similarity index 71% rename from pydantic_ai_slim/pydantic_ai/temporal/_settings.py rename to pydantic_ai_slim/pydantic_ai/ext/temporal/_settings.py index 14c9d595e..4340f1863 100644 --- a/pydantic_ai_slim/pydantic_ai/temporal/_settings.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_settings.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, fields, replace from datetime import timedelta from typing import Any, Callable @@ -9,7 +9,7 @@ from pydantic_ai._run_context import RunContext -from ._run_context import TemporalRunContext +from ._run_context import deserialize_run_context_deps, serialize_run_context_deps @dataclass @@ -29,16 +29,8 @@ class TemporalSettings: summary: str | None = None priority: Priority = Priority.default - # Pydantic AI specific - tool_settings: dict[str, dict[str, TemporalSettings]] | None = None - - def for_tool(self, toolset_id: str, tool_id: str) -> TemporalSettings: - if self.tool_settings is None: - return self - return self.tool_settings.get(toolset_id, {}).get(tool_id, self) - - serialize_run_context: Callable[[RunContext], Any] = TemporalRunContext.serialize_run_context - deserialize_run_context: Callable[[dict[str, Any]], RunContext] = TemporalRunContext.deserialize_run_context + serialize_run_context: Callable[[RunContext], dict[str, Any]] = serialize_run_context_deps + deserialize_run_context: Callable[[dict[str, Any]], dict[str, Any]] = deserialize_run_context_deps @property def execute_activity_options(self) -> dict[str, Any]: @@ -55,3 +47,9 @@ def execute_activity_options(self) -> dict[str, Any]: 'summary': self.summary, 'priority': self.priority, } + + def merge(self, other: TemporalSettings | None) -> TemporalSettings: + """Merge non-default values from another TemporalSettings instance into this one, returning a new instance.""" + if not other: + return self + return replace(self, **{f.name: value for f in fields(other) if (value := getattr(other, f.name)) != f.default}) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py new file mode 100644 index 000000000..bf8b8281d --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import Any, Callable + +from pydantic_ai.mcp import MCPServer +from pydantic_ai.toolsets.abstract import AbstractToolset +from pydantic_ai.toolsets.function import FunctionToolset + +from ._function_toolset import temporalize_function_toolset, untemporalize_function_toolset +from ._mcp_server import temporalize_mcp_server, untemporalize_mcp_server +from ._settings import TemporalSettings + + +def temporalize_toolset( + toolset: AbstractToolset, settings: TemporalSettings | None, tool_settings: dict[str, TemporalSettings] = {} +) -> list[Callable[..., Any]]: + """Temporalize a toolset. + + Args: + toolset: The toolset to temporalize. + settings: The temporal settings to use. + tool_settings: The temporal settings to use for specific tools identified by tool name. + """ + if isinstance(toolset, FunctionToolset): + return temporalize_function_toolset(toolset, settings, tool_settings) + elif isinstance(toolset, MCPServer): + return temporalize_mcp_server(toolset, settings, tool_settings) + else: + return [] + + +def untemporalize_toolset(toolset: AbstractToolset) -> None: + """Untemporalize a toolset. + + Args: + toolset: The toolset to untemporalize. + """ + if isinstance(toolset, FunctionToolset): + untemporalize_function_toolset(toolset) + elif isinstance(toolset, MCPServer): + untemporalize_mcp_server(toolset) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/temporal/__init__.py deleted file mode 100644 index 32ea06775..000000000 --- a/pydantic_ai_slim/pydantic_ai/temporal/__init__.py +++ /dev/null @@ -1,124 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from dataclasses import replace -from typing import Any, Callable - -import logfire # TODO: Not always available -from opentelemetry import trace # TODO: Not always available -from temporalio.client import ClientConfig, Plugin as ClientPlugin -from temporalio.contrib.opentelemetry import TracingInterceptor -from temporalio.contrib.pydantic import pydantic_data_converter -from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig -from temporalio.service import ConnectConfig, ServiceClient -from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig -from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner - -from pydantic_ai.agent import Agent -from pydantic_ai.toolsets.abstract import AbstractToolset - -from ..models import Model -from ._model import temporalize_model -from ._run_context import TemporalRunContext -from ._settings import TemporalSettings -from ._toolset import temporalize_toolset - -__all__ = [ - 'TemporalSettings', - 'TemporalRunContext', - 'PydanticAIPlugin', - 'LogfirePlugin', - 'AgentPlugin', -] - - -class PydanticAIPlugin(ClientPlugin, WorkerPlugin): - """Temporal client and worker plugin for Pydantic AI.""" - - def configure_client(self, config: ClientConfig) -> ClientConfig: - config['data_converter'] = pydantic_data_converter - return super().configure_client(config) - - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType] - if isinstance(runner, SandboxedWorkflowRunner): - config['workflow_runner'] = replace( - runner, - restrictions=runner.restrictions.with_passthrough_modules( - 'pydantic_ai', - 'logfire', - # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize - 'attrs', - # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize - 'numpy', - 'pandas', - ), - ) - return super().configure_worker(config) - - -class LogfirePlugin(ClientPlugin): - """Temporal client plugin for Logfire.""" - - def configure_client(self, config: ClientConfig) -> ClientConfig: - config['interceptors'] = [TracingInterceptor(trace.get_tracer('temporal'))] - return super().configure_client(config) - - async def connect_service_client(self, config: ConnectConfig) -> ServiceClient: - # TODO: Do we need this here? - logfire.configure(console=False) - logfire.instrument_pydantic_ai() - logfire.instrument_httpx(capture_all=True) - - config.runtime = Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) - return await super().connect_service_client(config) - - -class AgentPlugin(WorkerPlugin): - """Temporal worker plugin for a specific Pydantic AI agent.""" - - def __init__(self, agent: Agent[Any, Any], settings: TemporalSettings | None = None): - self.activities = temporalize_agent(agent, settings) - - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType] - config['activities'] = [*activities, *self.activities] - return super().configure_worker(config) - - -def temporalize_agent( - agent: Agent[Any, Any], - settings: TemporalSettings | None = None, - temporalize_toolset_func: Callable[ - [AbstractToolset, TemporalSettings | None], list[Callable[..., Any]] - ] = temporalize_toolset, -) -> list[Callable[..., Any]]: - """Temporalize an agent. - - Args: - agent: The agent to temporalize. - settings: The temporal settings to use. - temporalize_toolset_func: The function to use to temporalize the toolsets. - """ - if existing_activities := getattr(agent, '__temporal_activities', None): - return existing_activities - - settings = settings or TemporalSettings() - - # TODO: Doesn't consider model/toolsets passed at iter time, raise an error if that happens. - # Similarly, passing event_stream_handler at iter time should raise an error. - - activities: list[Callable[..., Any]] = [] - if isinstance(agent.model, Model): - activities.extend(temporalize_model(agent.model, settings, agent._event_stream_handler)) # pyright: ignore[reportPrivateUsage] - - def temporalize_toolset(toolset: AbstractToolset) -> None: - activities.extend(temporalize_toolset_func(toolset, settings)) - - agent.toolset.apply(temporalize_toolset) - - setattr(agent, '__temporal_activities', activities) - return activities - - -# TODO: untemporalize_agent diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_run_context.py b/pydantic_ai_slim/pydantic_ai/temporal/_run_context.py deleted file mode 100644 index 8bc7029e6..000000000 --- a/pydantic_ai_slim/pydantic_ai/temporal/_run_context.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from typing import Any - -from pydantic_ai._run_context import AgentDepsT, RunContext - - -class TemporalRunContext(RunContext[AgentDepsT]): - def __init__(self, **kwargs: Any): - self.__dict__ = kwargs - setattr( - self, - '__dataclass_fields__', - {name: field for name, field in RunContext.__dataclass_fields__.items() if name in kwargs}, - ) - - def __getattribute__(self, name: str) -> Any: - try: - return super().__getattribute__(name) - except AttributeError as e: - if name in RunContext.__dataclass_fields__: - raise AttributeError( - f'Temporalized {RunContext.__name__!r} object has no attribute {name!r}. To make the attribute available, pass a `TemporalSettings` object to `temporalize_agent` that has a custom `serialize_run_context` function that returns a dictionary that includes the attribute.' - ) - else: - raise e - - @classmethod - def serialize_run_context(cls, ctx: RunContext[AgentDepsT]) -> dict[str, Any]: - return { - 'deps': ctx.deps, - 'retries': ctx.retries, - 'tool_call_id': ctx.tool_call_id, - 'tool_name': ctx.tool_name, - 'retry': ctx.retry, - 'run_step': ctx.run_step, - } - - @classmethod - def deserialize_run_context(cls, ctx: dict[str, Any]) -> RunContext[AgentDepsT]: - return cls(**ctx) diff --git a/pydantic_ai_slim/pydantic_ai/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/temporal/_toolset.py deleted file mode 100644 index 289d90071..000000000 --- a/pydantic_ai_slim/pydantic_ai/temporal/_toolset.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations - -from typing import Any, Callable - -from pydantic_ai.mcp import MCPServer -from pydantic_ai.toolsets.abstract import AbstractToolset -from pydantic_ai.toolsets.function import FunctionToolset - -from ._function_toolset import temporalize_function_toolset -from ._mcp_server import temporalize_mcp_server -from ._settings import TemporalSettings - - -def temporalize_toolset(toolset: AbstractToolset, settings: TemporalSettings | None) -> list[Callable[..., Any]]: - """Temporalize a toolset. - - Args: - toolset: The toolset to temporalize. - settings: The temporal settings to use. - """ - if isinstance(toolset, FunctionToolset): - return temporalize_function_toolset(toolset, settings) - elif isinstance(toolset, MCPServer): - return temporalize_mcp_server(toolset, settings) - else: - return [] diff --git a/temporal.py b/temporal.py index 2b8a4a319..677e831ba 100644 --- a/temporal.py +++ b/temporal.py @@ -10,38 +10,22 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, RunContext -from pydantic_ai.mcp import MCPServerStdio -from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent -from pydantic_ai.temporal import ( +from pydantic_ai.ext.temporal import ( AgentPlugin, LogfirePlugin, PydanticAIPlugin, TemporalSettings, + temporalize_agent, ) -from pydantic_ai.toolsets import FunctionToolset +from pydantic_ai.mcp import MCPServerStdio +from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent class Deps(TypedDict): country: str -def get_country(ctx: RunContext[Deps]) -> str: - return ctx.deps['country'] - - -toolset = FunctionToolset[Deps](tools=[get_country], id='country') -mcp_server = MCPServerStdio( - 'python', - ['-m', 'tests.mcp_server'], - timeout=20, - id='test', -) - - -async def event_stream_handler( - ctx: RunContext[Deps], - stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], -): +async def event_stream_handler(ctx: RunContext[Deps], stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent]): logfire.info(f'{ctx.run_step=}') async for event in stream: logfire.info(f'{event=}') @@ -49,24 +33,33 @@ async def event_stream_handler( agent = Agent( 'openai:gpt-4o', - toolsets=[toolset, mcp_server], - event_stream_handler=event_stream_handler, deps_type=Deps, + toolsets=[MCPServerStdio('python', ['-m', 'tests.mcp_server'], timeout=20, id='test')], + event_stream_handler=event_stream_handler, ) -temporal_settings = TemporalSettings( - start_to_close_timeout=timedelta(seconds=60), - tool_settings={ # TODO: Allow default temporal settings to be set for all activities in a toolset + +@agent.tool +def get_country(ctx: RunContext[Deps]) -> str: + return ctx.deps['country'] + + +# This needs to be called in the same scope where the `agent` is bound to the workflow, +# as it modifies the `agent` object in place to swap out methods that use IO for ones that use Temporal activities. +temporalize_agent( + agent, + settings=TemporalSettings(start_to_close_timeout=timedelta(seconds=60)), + toolset_settings={ + 'country': TemporalSettings(start_to_close_timeout=timedelta(seconds=120)), + }, + tool_settings={ 'country': { - 'get_country': TemporalSettings(start_to_close_timeout=timedelta(seconds=110)), + 'get_country': TemporalSettings(start_to_close_timeout=timedelta(seconds=180)), }, }, ) -TASK_QUEUE = 'pydantic-ai-agent-task-queue' - - @workflow.defn class MyAgentWorkflow: @workflow.run @@ -75,22 +68,26 @@ async def run(self, prompt: str, deps: Deps) -> str: return result.output -# TODO: For some reason, when I put this (specifically the temporalize_agent call) inside `async def main()`, -# we get tons of errors. -plugin = AgentPlugin(agent, temporal_settings) +TASK_QUEUE = 'pydantic-ai-agent-task-queue' + + +def setup_logfire(): + logfire.configure(console=False) + logfire.instrument_pydantic_ai() + logfire.instrument_httpx(capture_all=True) async def main(): client = await Client.connect( 'localhost:7233', - plugins=[PydanticAIPlugin(), LogfirePlugin()], + plugins=[PydanticAIPlugin(), LogfirePlugin(setup_logfire)], ) async with Worker( client, task_queue=TASK_QUEUE, workflows=[MyAgentWorkflow], - plugins=[plugin], + plugins=[AgentPlugin(agent)], ): output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] MyAgentWorkflow.run, From 2474f1a33ee7def4f63f9376307910c593c57f19 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 30 Jul 2025 22:57:22 +0000 Subject: [PATCH 12/41] Use latest temporalio version with plugins --- pydantic_ai_slim/pyproject.toml | 2 +- pyproject.toml | 1 - uv.lock | 20 ++++++++++++++++---- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 5ef86678e..a0fed0b2f 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -85,7 +85,7 @@ a2a = ["fasta2a>=0.4.1"] # AG-UI ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"] # Temporal -temporal = ["temporalio>=1.13.0"] +temporal = ["temporalio>=1.15.0"] [dependency-groups] dev = [ diff --git a/pyproject.toml b/pyproject.toml index e3f2ed07f..fdd67f183 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,6 @@ pydantic-ai-slim = { workspace = true } pydantic-evals = { workspace = true } pydantic-graph = { workspace = true } pydantic-ai-examples = { workspace = true } -temporalio = { git = "https://github.com/temporalio/sdk-python.git", rev = "main" } [tool.uv.workspace] members = [ diff --git a/uv.lock b/uv.lock index dcc5090da..f33e27a01 100644 --- a/uv.lock +++ b/uv.lock @@ -2289,10 +2289,14 @@ wheels = [ [[package]] name = "nexus-rpc" version = "1.1.0" -source = { git = "https://github.com/nexus-rpc/sdk-python.git?rev=35f574c711193a6e2560d3e6665732a5bb7ae92c#35f574c711193a6e2560d3e6665732a5bb7ae92c" } +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/ef/66/540687556bd28cf1ec370cc6881456203dfddb9dab047b8979c6865b5984/nexus_rpc-1.1.0.tar.gz", hash = "sha256:d65ad6a2f54f14e53ebe39ee30555eaeb894102437125733fb13034a04a44553", size = 77383, upload-time = "2025-07-07T19:03:58.368Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/2f/9e9d0dcaa4c6ffa22b7aa31069a8a264c753ff8027b36af602cce038c92f/nexus_rpc-1.1.0-py3-none-any.whl", hash = "sha256:d1b007af2aba186a27e736f8eaae39c03aed05b488084ff6c3d1785c9ba2ad38", size = 27743, upload-time = "2025-07-07T19:03:57.556Z" }, +] [[package]] name = "nodeenv" @@ -3251,7 +3255,7 @@ requires-dist = [ { name = "rich", marker = "extra == 'cli'", specifier = ">=13" }, { name = "starlette", marker = "extra == 'ag-ui'", specifier = ">=0.45.3" }, { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, - { name = "temporalio", marker = "extra == 'temporal'", git = "https://github.com/temporalio/sdk-python.git?rev=main" }, + { name = "temporalio", marker = "extra == 'temporal'", specifier = ">=1.15.0" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "temporal", "vertexai"] @@ -4168,8 +4172,8 @@ wheels = [ [[package]] name = "temporalio" -version = "1.14.1" -source = { git = "https://github.com/temporalio/sdk-python.git?rev=main#e767013acca543345e0408a167556bbb987eb130" } +version = "1.15.0" +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nexus-rpc" }, { name = "protobuf" }, @@ -4177,6 +4181,14 @@ dependencies = [ { name = "types-protobuf" }, { name = "typing-extensions" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/0b/af/1a3619fc62333d0acbdf90cfc5ada97e68e8c0f79610363b2dbb30871d83/temporalio-1.15.0.tar.gz", hash = "sha256:a4bc6ca01717880112caab75d041713aacc8263dc66e41f5019caef68b344fa0", size = 1684485, upload-time = "2025-07-29T03:44:09.071Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/2d/0153f2bc459e0cb59d41d4dd71da46bf9a98ca98bc37237576c258d6696b/temporalio-1.15.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:74bc5cc0e6bdc161a43015538b0821b8713f5faa716c4209971c274b528e0d47", size = 12703607, upload-time = "2025-07-29T03:43:30.083Z" }, + { url = "https://files.pythonhosted.org/packages/e4/39/1b867ec698c8987aef3b7a7024b5c0c732841112fa88d021303d0fc69bea/temporalio-1.15.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:ee8001304dae5723d79797516cfeebe04b966fdbdf348e658fce3b43afdda3cd", size = 12232853, upload-time = "2025-07-29T03:43:38.909Z" }, + { url = "https://files.pythonhosted.org/packages/5e/3e/647d9a7c8b2f638f639717404c0bcbdd7d54fddd7844fdb802e3f40dc55f/temporalio-1.15.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8febd1ac36720817e69c2176aa4aca14a97fe0b83f0d2449c0c730b8f0174d02", size = 12636700, upload-time = "2025-07-29T03:43:49.066Z" }, + { url = "https://files.pythonhosted.org/packages/9a/13/7aa9ec694fec9fba39efdbf61d892bccf7d2b1aa3d9bd359544534c1d309/temporalio-1.15.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:202d81a42cafaed9ccc7ccbea0898838e3b8bf92fee65394f8790f37eafbaa63", size = 12860186, upload-time = "2025-07-29T03:43:57.644Z" }, + { url = "https://files.pythonhosted.org/packages/9f/2b/ba962401324892236148046dbffd805d4443d6df7a7dc33cc7964b566bf9/temporalio-1.15.0-cp39-abi3-win_amd64.whl", hash = "sha256:aae5b18d7c9960238af0f3ebf6b7e5959e05f452106fc0d21a8278d78724f780", size = 12932800, upload-time = "2025-07-29T03:44:06.271Z" }, +] [[package]] name = "tenacity" From 7c62e351628bffaa12fbde0a83bdfeeaaa4bb5f9 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 31 Jul 2025 15:10:24 +0000 Subject: [PATCH 13/41] Temporalize MCPServer.get_tools and call_tool instead of list_tools and direct_call_tool --- .../pydantic_ai/ext/temporal/_mcp_server.py | 82 +++++++++++++------ pydantic_ai_slim/pydantic_ai/mcp.py | 15 ++-- temporal.py | 3 + 3 files changed, 71 insertions(+), 29 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py index edc021a90..5805639cd 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py @@ -3,21 +3,31 @@ from dataclasses import dataclass from typing import Any, Callable -from mcp import types as mcp_types from pydantic import ConfigDict, with_config from temporalio import activity, workflow +from pydantic_ai._run_context import RunContext from pydantic_ai.mcp import MCPServer, ToolResult +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.abstract import ToolsetTool +from ._run_context import TemporalRunContext from ._settings import TemporalSettings +@dataclass +@with_config(ConfigDict(arbitrary_types_allowed=True)) +class _GetToolsParams: + serialized_run_context: Any + + @dataclass @with_config(ConfigDict(arbitrary_types_allowed=True)) class _CallToolParams: name: str tool_args: dict[str, Any] - metadata: dict[str, Any] | None = None + serialized_run_context: Any + tool_def: ToolDefinition def temporalize_mcp_server( @@ -40,40 +50,64 @@ def temporalize_mcp_server( settings = settings or TemporalSettings() - original_list_tools = server.list_tools - original_direct_call_tool = server.direct_call_tool - setattr(server, '__original_list_tools', original_list_tools) - setattr(server, '__original_direct_call_tool', original_direct_call_tool) + original_get_tools = server.get_tools + original_call_tool = server.call_tool + setattr(server, '__original_get_tools', original_get_tools) + setattr(server, '__original_call_tool', original_call_tool) - @activity.defn(name=f'mcp_server__{id}__list_tools') - async def list_tools_activity() -> list[mcp_types.Tool]: - return await original_list_tools() + @activity.defn(name=f'mcp_server__{id}__get_tools') + async def get_tools_activity(params: _GetToolsParams) -> dict[str, ToolDefinition]: + run_context = TemporalRunContext.deserialize_run_context( + params.serialized_run_context, settings.deserialize_run_context + ) + return {name: tool.tool_def for name, tool in (await original_get_tools(run_context)).items()} @activity.defn(name=f'mcp_server__{id}__call_tool') async def call_tool_activity(params: _CallToolParams) -> ToolResult: - return await original_direct_call_tool(params.name, params.tool_args, params.metadata) + run_context = TemporalRunContext.deserialize_run_context( + params.serialized_run_context, settings.deserialize_run_context + ) + return await original_call_tool( + params.name, + params.tool_args, + run_context, + server._toolset_tool_for_tool_def(params.tool_def), # pyright: ignore[reportPrivateUsage] + ) - async def list_tools() -> list[mcp_types.Tool]: - return await workflow.execute_activity( # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType] - activity=list_tools_activity, + async def get_tools(ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: + serialized_run_context = TemporalRunContext.serialize_run_context(ctx, settings.serialize_run_context) + tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] + activity=get_tools_activity, + arg=_GetToolsParams(serialized_run_context=serialized_run_context), **settings.execute_activity_options, ) + return { + name: server._toolset_tool_for_tool_def(tool_def) # pyright: ignore[reportPrivateUsage] + for name, tool_def in tool_defs.items() + } - async def direct_call_tool( + async def call_tool( name: str, - args: dict[str, Any], - metadata: dict[str, Any] | None = None, + tool_args: dict[str, Any], + ctx: RunContext[Any], + tool: ToolsetTool[Any], ) -> ToolResult: + serialized_run_context = TemporalRunContext.serialize_run_context(ctx, settings.serialize_run_context) return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=call_tool_activity, - arg=_CallToolParams(name=name, tool_args=args, metadata=metadata), + arg=_CallToolParams( + name=name, + tool_args=tool_args, + serialized_run_context=serialized_run_context, + tool_def=tool.tool_def, + ), **tool_settings.get(name, settings).execute_activity_options, ) - server.list_tools = list_tools - server.direct_call_tool = direct_call_tool + server.get_tools = get_tools + server.call_tool = call_tool - activities = [list_tools_activity, call_tool_activity] + activities = [get_tools_activity, call_tool_activity] setattr(server, '__temporal_activities', activities) return activities @@ -87,9 +121,9 @@ def untemporalize_mcp_server(server: MCPServer) -> None: if not hasattr(server, '__temporal_activities'): return - server.list_tools = getattr(server, '__original_list_tools') - server.direct_call_tool = getattr(server, '__original_direct_call_tool') - delattr(server, '__original_list_tools') - delattr(server, '__original_direct_call_tool') + server.get_tools = getattr(server, '__original_get_tools') + server.call_tool = getattr(server, '__original_call_tool') + delattr(server, '__original_get_tools') + delattr(server, '__original_call_tool') delattr(server, '__temporal_activities') diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index b0d45399c..82e9976c8 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -207,20 +207,25 @@ async def call_tool( async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: return { - name: ToolsetTool( - toolset=self, - tool_def=ToolDefinition( + name: self._toolset_tool_for_tool_def( + ToolDefinition( name=name, description=mcp_tool.description, parameters_json_schema=mcp_tool.inputSchema, ), - max_retries=self.max_retries, - args_validator=TOOL_SCHEMA_VALIDATOR, ) for mcp_tool in await self.list_tools() if (name := f'{self.tool_prefix}_{mcp_tool.name}' if self.tool_prefix else mcp_tool.name) } + def _toolset_tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]: + return ToolsetTool( + toolset=self, + tool_def=tool_def, + max_retries=self.max_retries, + args_validator=TOOL_SCHEMA_VALIDATOR, + ) + async def __aenter__(self) -> Self: """Enter the MCP server context. diff --git a/temporal.py b/temporal.py index 677e831ba..6a0321a62 100644 --- a/temporal.py +++ b/temporal.py @@ -59,6 +59,9 @@ def get_country(ctx: RunContext[Deps]) -> str: }, ) +with workflow.unsafe.imports_passed_through(): + import pandas # noqa: F401 + @workflow.defn class MyAgentWorkflow: From 8eb677ba968163aee30674e774a6d59d9cefaec8 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 31 Jul 2025 15:17:02 +0000 Subject: [PATCH 14/41] Add ID to model activity names --- pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py index 9f2240e58..c1a831ba1 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py @@ -89,11 +89,13 @@ def temporalize_model( # noqa: C901 setattr(model, '__original_request', original_request) setattr(model, '__original_request_stream', original_request_stream) - @activity.defn(name='model_request') + id = '_'.join([model.system, model.model_name]) + + @activity.defn(name=f'model__{id}__request') async def request_activity(params: _RequestParams) -> ModelResponse: return await original_request(params.messages, params.model_settings, params.model_request_parameters) - @activity.defn(name='model_request_stream') + @activity.defn(name=f'model__{id}__request_stream') async def request_stream_activity(params: _RequestParams) -> ModelResponse: run_context = TemporalRunContext.deserialize_run_context( params.serialized_run_context, settings.deserialize_run_context From 6682e9762dfbe88a23e00efe52fa64e5024250b8 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 31 Jul 2025 16:28:47 +0000 Subject: [PATCH 15/41] Use temporal wrapper classes instead of monkeypatching --- .../pydantic_ai/ext/temporal/__init__.py | 7 +- .../pydantic_ai/ext/temporal/_agent.py | 62 ++---- .../ext/temporal/_function_toolset.py | 120 +++++------ .../pydantic_ai/ext/temporal/_mcp_server.py | 133 ++++++------ .../pydantic_ai/ext/temporal/_model.py | 189 ++++++++---------- .../pydantic_ai/ext/temporal/_toolset.py | 26 +-- .../pydantic_ai/toolsets/abstract.py | 5 + .../pydantic_ai/toolsets/combined.py | 5 + .../pydantic_ai/toolsets/wrapper.py | 7 +- 9 files changed, 237 insertions(+), 317 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py index f1ec642a5..4240cff6e 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py @@ -11,7 +11,7 @@ from pydantic_ai.agent import Agent -from ._agent import temporalize_agent, untemporalize_agent +from ._agent import temporalize_agent from ._logfire import LogfirePlugin from ._run_context import TemporalRunContext from ._settings import TemporalSettings @@ -23,7 +23,6 @@ 'LogfirePlugin', 'AgentPlugin', 'temporalize_agent', - 'untemporalize_agent', ] @@ -61,9 +60,7 @@ def __init__(self, agent: Agent[Any, Any]): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: agent_activities = getattr(self.agent, '__temporal_activities', None) if agent_activities is None: - raise ValueError( - 'The agent has not been temporalized yet, call `temporalize_agent(agent)` (or `with temporalized_agent(agent): ...`) first.' - ) + raise ValueError('The agent has not been temporalized yet, call `temporalize_agent(agent)` first.') activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType] config['activities'] = [*activities, *agent_activities] diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py index f3717b6da..b9cd9be1e 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -1,16 +1,15 @@ from __future__ import annotations -from collections.abc import Generator -from contextlib import contextmanager -from typing import Any, Callable +from typing import Any, Callable, cast from pydantic_ai.agent import Agent from pydantic_ai.models import Model from pydantic_ai.toolsets.abstract import AbstractToolset +from pydantic_ai.toolsets.function import FunctionToolset -from ._model import temporalize_model, untemporalize_model +from ._model import TemporalModel from ._settings import TemporalSettings -from ._toolset import temporalize_toolset, untemporalize_toolset +from ._toolset import temporalize_toolset def temporalize_agent( @@ -19,7 +18,7 @@ def temporalize_agent( toolset_settings: dict[str, TemporalSettings] = {}, tool_settings: dict[str, dict[str, TemporalSettings]] = {}, temporalize_toolset_func: Callable[ - [AbstractToolset, TemporalSettings | None, dict[str, TemporalSettings]], list[Callable[..., Any]] + [AbstractToolset, TemporalSettings | None, dict[str, TemporalSettings]], AbstractToolset ] = temporalize_toolset, ) -> list[Callable[..., Any]]: """Temporalize an agent. @@ -38,19 +37,27 @@ def temporalize_agent( activities: list[Callable[..., Any]] = [] if isinstance(agent.model, Model): - activities.extend(temporalize_model(agent.model, settings, agent._event_stream_handler)) # pyright: ignore[reportPrivateUsage] + model = TemporalModel(agent.model, settings, agent._event_stream_handler) # pyright: ignore[reportPrivateUsage] + activities.extend(model.activities) + agent.model = model + else: + raise ValueError( + 'Model cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + ) - def temporalize_toolset(toolset: AbstractToolset) -> None: + def temporalize_toolset(toolset: AbstractToolset) -> AbstractToolset: id = toolset.id if not id: raise ValueError( "A toolset needs to have an ID in order to be used with Temporal. The ID will be used to identify the toolset's activities within the workflow." ) - activities.extend( - temporalize_toolset_func(toolset, settings.merge(toolset_settings.get(id)), tool_settings.get(id, {})) - ) + toolset = temporalize_toolset_func(toolset, settings.merge(toolset_settings.get(id)), tool_settings.get(id, {})) + if hasattr(toolset, 'activities'): + activities.extend(getattr(toolset, 'activities')) + return toolset - agent.toolset.apply(temporalize_toolset) + agent._function_toolset = cast(FunctionToolset, temporalize_toolset(agent._function_toolset)) # pyright: ignore[reportPrivateUsage] + agent._user_toolsets = [temporalize_toolset(toolset) for toolset in agent._user_toolsets] # pyright: ignore[reportPrivateUsage] original_iter = agent.iter original_override = agent.override @@ -87,34 +94,3 @@ def override(*args: Any, **kwargs: Any) -> Any: setattr(agent, '__temporal_activities', activities) return activities - - -def untemporalize_agent(agent: Agent[Any, Any]) -> None: - """Untemporalize an agent. - - Args: - agent: The agent to untemporalize. - """ - if not hasattr(agent, '__temporal_activities'): - return - - if isinstance(agent.model, Model): - untemporalize_model(agent.model) - - agent.toolset.apply(untemporalize_toolset) - - agent.iter = getattr(agent, '__original_iter') - agent.override = getattr(agent, '__original_override') - delattr(agent, '__original_iter') - delattr(agent, '__original_override') - - delattr(agent, '__temporal_activities') - - -@contextmanager -def temporalized_agent(agent: Agent[Any, Any], settings: TemporalSettings | None = None) -> Generator[None, None, None]: - temporalize_agent(agent, settings) - try: - yield - finally: - untemporalize_agent(agent) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py index 054420453..9f403c6e6 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py @@ -1,13 +1,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable +from typing import Any, Callable, cast from pydantic import ConfigDict, with_config from temporalio import activity, workflow from pydantic_ai._run_context import RunContext from pydantic_ai.toolsets import FunctionToolset, ToolsetTool +from pydantic_ai.toolsets.wrapper import WrapperToolset from ._run_context import TemporalRunContext from ._settings import TemporalSettings @@ -21,72 +22,61 @@ class _CallToolParams: serialized_run_context: Any -def temporalize_function_toolset( - toolset: FunctionToolset, - settings: TemporalSettings | None = None, - tool_settings: dict[str, TemporalSettings] = {}, -) -> list[Callable[..., Any]]: - """Temporalize a function toolset. - - Args: - toolset: The function toolset to temporalize. - settings: The temporal settings to use. - tool_settings: The temporal settings to use for specific tools identified by tool name. - """ - if activities := getattr(toolset, '__temporal_activities', None): - return activities - - id = toolset.id - assert id is not None - - settings = settings or TemporalSettings() - - original_call_tool = toolset.call_tool - setattr(toolset, '__original_call_tool', original_call_tool) - - @activity.defn(name=f'function_toolset__{id}__call_tool') - async def call_tool_activity(params: _CallToolParams) -> Any: - name = params.name - settings_for_tool = settings.merge(tool_settings.get(name)) - ctx = TemporalRunContext.deserialize_run_context( - params.serialized_run_context, settings_for_tool.deserialize_run_context - ) - try: - tool = (await toolset.get_tools(ctx))[name] - except KeyError as e: - raise ValueError( - f'Tool {name!r} not found in toolset {toolset.id!r}. ' - 'Removing or renaming tools during an agent run is not supported with Temporal.' - ) from e - - return await original_call_tool(name, params.tool_args, ctx, tool) - - async def call_tool(name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: - settings_for_tool = settings.merge(tool_settings.get(name)) +class TemporalFunctionToolset(WrapperToolset[Any]): + def __init__( + self, + toolset: FunctionToolset, + settings: TemporalSettings | None = None, + tool_settings: dict[str, TemporalSettings] = {}, + ): + super().__init__(toolset) + self.settings = settings or TemporalSettings() + self.tool_settings = tool_settings + + id = toolset.id + assert id is not None + + @activity.defn(name=f'function_toolset__{id}__call_tool') + async def call_tool_activity(params: _CallToolParams) -> Any: + name = params.name + settings_for_tool = self.settings.merge(self.tool_settings.get(name)) + ctx = TemporalRunContext.deserialize_run_context( + params.serialized_run_context, settings_for_tool.deserialize_run_context + ) + try: + tool = (await toolset.get_tools(ctx))[name] + except KeyError as e: + raise ValueError( + f'Tool {name!r} not found in toolset {toolset.id!r}. ' + 'Removing or renaming tools during an agent run is not supported with Temporal.' + ) from e + + return await self.wrapped.call_tool(name, params.tool_args, ctx, tool) + + self.call_tool_activity = call_tool_activity + + @property + def wrapped_function_toolset(self) -> FunctionToolset: + return cast(FunctionToolset, self.wrapped) + + @property + def activities(self) -> list[Callable[..., Any]]: + return [self.call_tool_activity] + + def tool(self, *args: Any, **kwargs: Any) -> Any: + return self.wrapped_function_toolset.tool(*args, **kwargs) + + def add_function(self, *args: Any, **kwargs: Any) -> None: + return self.wrapped_function_toolset.add_function(*args, **kwargs) + + def add_tool(self, *args: Any, **kwargs: Any) -> None: + return self.wrapped_function_toolset.add_tool(*args, **kwargs) + + async def call_tool(self, name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: + settings_for_tool = self.settings.merge(self.tool_settings.get(name)) serialized_run_context = TemporalRunContext.serialize_run_context(ctx, settings_for_tool.serialize_run_context) return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] - activity=call_tool_activity, + activity=self.call_tool_activity, arg=_CallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), **settings_for_tool.execute_activity_options, ) - - toolset.call_tool = call_tool - - activities = [call_tool_activity] - setattr(toolset, '__temporal_activities', activities) - return activities - - -def untemporalize_function_toolset(toolset: FunctionToolset) -> None: - """Untemporalize a function toolset. - - Args: - toolset: The function toolset to untemporalize. - """ - if not hasattr(toolset, '__temporal_activities'): - return - - toolset.call_tool = getattr(toolset, '__original_call_tool') - delattr(toolset, '__original_call_tool') - - delattr(toolset, '__temporal_activities') diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py index 5805639cd..c3502b520 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable +from typing import Any, Callable, cast from pydantic import ConfigDict, with_config from temporalio import activity, workflow @@ -10,6 +10,7 @@ from pydantic_ai.mcp import MCPServer, ToolResult from pydantic_ai.tools import ToolDefinition from pydantic_ai.toolsets.abstract import ToolsetTool +from pydantic_ai.toolsets.wrapper import WrapperToolset from ._run_context import TemporalRunContext from ._settings import TemporalSettings @@ -30,100 +31,78 @@ class _CallToolParams: tool_def: ToolDefinition -def temporalize_mcp_server( - server: MCPServer, - settings: TemporalSettings | None = None, - tool_settings: dict[str, TemporalSettings] = {}, -) -> list[Callable[..., Any]]: - """Temporalize an MCP server. - - Args: - server: The MCP server to temporalize. - settings: The temporal settings to use. - tool_settings: The temporal settings to use for each tool. - """ - if activities := getattr(server, '__temporal_activities', None): - return activities - - id = server.id - assert id is not None - - settings = settings or TemporalSettings() - - original_get_tools = server.get_tools - original_call_tool = server.call_tool - setattr(server, '__original_get_tools', original_get_tools) - setattr(server, '__original_call_tool', original_call_tool) - - @activity.defn(name=f'mcp_server__{id}__get_tools') - async def get_tools_activity(params: _GetToolsParams) -> dict[str, ToolDefinition]: - run_context = TemporalRunContext.deserialize_run_context( - params.serialized_run_context, settings.deserialize_run_context - ) - return {name: tool.tool_def for name, tool in (await original_get_tools(run_context)).items()} - - @activity.defn(name=f'mcp_server__{id}__call_tool') - async def call_tool_activity(params: _CallToolParams) -> ToolResult: - run_context = TemporalRunContext.deserialize_run_context( - params.serialized_run_context, settings.deserialize_run_context - ) - return await original_call_tool( - params.name, - params.tool_args, - run_context, - server._toolset_tool_for_tool_def(params.tool_def), # pyright: ignore[reportPrivateUsage] - ) - - async def get_tools(ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: - serialized_run_context = TemporalRunContext.serialize_run_context(ctx, settings.serialize_run_context) +class TemporalMCPServer(WrapperToolset[Any]): + def __init__( + self, + server: MCPServer, + settings: TemporalSettings | None = None, + tool_settings: dict[str, TemporalSettings] = {}, + ): + super().__init__(server) + self.settings = settings or TemporalSettings() + self.tool_settings = tool_settings + + id = server.id + assert id is not None + + @activity.defn(name=f'mcp_server__{id}__get_tools') + async def get_tools_activity(params: _GetToolsParams) -> dict[str, ToolDefinition]: + run_context = TemporalRunContext.deserialize_run_context( + params.serialized_run_context, self.settings.deserialize_run_context + ) + return {name: tool.tool_def for name, tool in (await self.wrapped.get_tools(run_context)).items()} + + self.get_tools_activity = get_tools_activity + + @activity.defn(name=f'mcp_server__{id}__call_tool') + async def call_tool_activity(params: _CallToolParams) -> ToolResult: + run_context = TemporalRunContext.deserialize_run_context( + params.serialized_run_context, self.settings.deserialize_run_context + ) + return await self.wrapped.call_tool( + params.name, + params.tool_args, + run_context, + self.wrapped_server._toolset_tool_for_tool_def(params.tool_def), # pyright: ignore[reportPrivateUsage] + ) + + self.call_tool_activity = call_tool_activity + + @property + def wrapped_server(self) -> MCPServer: + return cast(MCPServer, self.wrapped) + + @property + def activities(self) -> list[Callable[..., Any]]: + return [self.get_tools_activity, self.call_tool_activity] + + async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: + serialized_run_context = TemporalRunContext.serialize_run_context(ctx, self.settings.serialize_run_context) tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] - activity=get_tools_activity, + activity=self.get_tools_activity, arg=_GetToolsParams(serialized_run_context=serialized_run_context), - **settings.execute_activity_options, + **self.settings.execute_activity_options, ) return { - name: server._toolset_tool_for_tool_def(tool_def) # pyright: ignore[reportPrivateUsage] + name: self.wrapped_server._toolset_tool_for_tool_def(tool_def) # pyright: ignore[reportPrivateUsage] for name, tool_def in tool_defs.items() } async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[Any], tool: ToolsetTool[Any], ) -> ToolResult: - serialized_run_context = TemporalRunContext.serialize_run_context(ctx, settings.serialize_run_context) + serialized_run_context = TemporalRunContext.serialize_run_context(ctx, self.settings.serialize_run_context) return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] - activity=call_tool_activity, + activity=self.call_tool_activity, arg=_CallToolParams( name=name, tool_args=tool_args, serialized_run_context=serialized_run_context, tool_def=tool.tool_def, ), - **tool_settings.get(name, settings).execute_activity_options, + **self.tool_settings.get(name, self.settings).execute_activity_options, ) - - server.get_tools = get_tools - server.call_tool = call_tool - - activities = [get_tools_activity, call_tool_activity] - setattr(server, '__temporal_activities', activities) - return activities - - -def untemporalize_mcp_server(server: MCPServer) -> None: - """Untemporalize an MCP server. - - Args: - server: The MCP server to untemporalize. - """ - if not hasattr(server, '__temporal_activities'): - return - - server.get_tools = getattr(server, '__original_get_tools') - server.call_tool = getattr(server, '__original_call_tool') - delattr(server, '__original_get_tools') - delattr(server, '__original_call_tool') - - delattr(server, '__temporal_activities') diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py index c1a831ba1..7adeb2172 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py @@ -22,6 +22,7 @@ ToolCallPart, ) from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse +from pydantic_ai.models.wrapper import WrapperModel from pydantic_ai.settings import ModelSettings from pydantic_ai.usage import Usage @@ -66,151 +67,127 @@ def timestamp(self) -> datetime: return self.response.timestamp -def temporalize_model( # noqa: C901 - model: Model, - settings: TemporalSettings | None = None, - event_stream_handler: EventStreamHandler | None = None, -) -> list[Callable[..., Any]]: - """Temporalize a model. - - Args: - model: The model to temporalize. - settings: The temporal settings to use. - event_stream_handler: The event stream handler to use. - """ - if activities := getattr(model, '__temporal_activities', None): - return activities - - settings = settings or TemporalSettings() - - original_request = model.request - original_request_stream = model.request_stream - - setattr(model, '__original_request', original_request) - setattr(model, '__original_request_stream', original_request_stream) - - id = '_'.join([model.system, model.model_name]) +class TemporalModel(WrapperModel): + def __init__( + self, + model: Model, + settings: TemporalSettings | None = None, + event_stream_handler: EventStreamHandler | None = None, + ): + super().__init__(model) + self.temporal_settings = settings or TemporalSettings() + self.event_stream_handler = event_stream_handler + + id = '_'.join([model.system, model.model_name]) + + @activity.defn(name=f'model__{id}__request') + async def request_activity(params: _RequestParams) -> ModelResponse: + return await self.wrapped.request(params.messages, params.model_settings, params.model_request_parameters) + + self.request_activity = request_activity + + @activity.defn(name=f'model__{id}__request_stream') + async def request_stream_activity(params: _RequestParams) -> ModelResponse: + run_context = TemporalRunContext.deserialize_run_context( + params.serialized_run_context, self.temporal_settings.deserialize_run_context + ) + async with self.wrapped.request_stream( + params.messages, params.model_settings, params.model_request_parameters, run_context + ) as streamed_response: + tool_defs = { + tool_def.name: tool_def + for tool_def in [ + *params.model_request_parameters.output_tools, + *params.model_request_parameters.function_tools, + ] + } + + # Keep in sync with `AgentStream.__aiter__` + async def aiter(): + def _get_final_result_event(e: ModelResponseStreamEvent) -> FinalResultEvent | None: + """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" + if isinstance(e, PartStartEvent): + new_part = e.part + if ( + isinstance(new_part, TextPart) and params.model_request_parameters.allow_text_output + ): # pragma: no branch + return FinalResultEvent(tool_name=None, tool_call_id=None) + elif isinstance(new_part, ToolCallPart) and (tool_def := tool_defs.get(new_part.tool_name)): + if tool_def.kind == 'output': + return FinalResultEvent( + tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id + ) + elif tool_def.kind == 'deferred': + return FinalResultEvent(tool_name=None, tool_call_id=None) - @activity.defn(name=f'model__{id}__request') - async def request_activity(params: _RequestParams) -> ModelResponse: - return await original_request(params.messages, params.model_settings, params.model_request_parameters) + # `AgentStream.__aiter__`, which this is based on, calls `_get_usage_checking_stream_response` here, + # but we don't have access to the `_usage_limits`. - @activity.defn(name=f'model__{id}__request_stream') - async def request_stream_activity(params: _RequestParams) -> ModelResponse: - run_context = TemporalRunContext.deserialize_run_context( - params.serialized_run_context, settings.deserialize_run_context - ) - async with original_request_stream( - params.messages, params.model_settings, params.model_request_parameters, run_context - ) as streamed_response: - tool_defs = { - tool_def.name: tool_def - for tool_def in [ - *params.model_request_parameters.output_tools, - *params.model_request_parameters.function_tools, - ] - } - - # Keep in sync with `AgentStream.__aiter__` - async def aiter(): - def _get_final_result_event(e: ModelResponseStreamEvent) -> FinalResultEvent | None: - """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" - if isinstance(e, PartStartEvent): - new_part = e.part - if ( - isinstance(new_part, TextPart) and params.model_request_parameters.allow_text_output - ): # pragma: no branch - return FinalResultEvent(tool_name=None, tool_call_id=None) - elif isinstance(new_part, ToolCallPart) and (tool_def := tool_defs.get(new_part.tool_name)): - if tool_def.kind == 'output': - return FinalResultEvent( - tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id - ) - elif tool_def.kind == 'deferred': - return FinalResultEvent(tool_name=None, tool_call_id=None) + async for event in streamed_response: + yield event + if (final_result_event := _get_final_result_event(event)) is not None: + yield final_result_event + break - # `AgentStream.__aiter__`, which this is based on, calls `_get_usage_checking_stream_response` here, - # but we don't have access to the `_usage_limits`. + # If we broke out of the above loop, we need to yield the rest of the events + # If we didn't, this will just be a no-op + async for event in streamed_response: + yield event - async for event in streamed_response: - yield event - if (final_result_event := _get_final_result_event(event)) is not None: - yield final_result_event - break + assert event_stream_handler is not None + await event_stream_handler(run_context, aiter()) - # If we broke out of the above loop, we need to yield the rest of the events - # If we didn't, this will just be a no-op - async for event in streamed_response: - yield event + async for _ in streamed_response: + pass + return streamed_response.get() - assert event_stream_handler is not None - await event_stream_handler(run_context, aiter()) + self.request_stream_activity = request_stream_activity - async for _ in streamed_response: - pass - return streamed_response.get() + @property + def activities(self) -> list[Callable[..., Any]]: + return [self.request_activity, self.request_stream_activity] async def request( + self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] - activity=request_activity, + activity=self.request_activity, arg=_RequestParams( messages=messages, model_settings=model_settings, model_request_parameters=model_request_parameters, serialized_run_context=None, ), - **settings.execute_activity_options, + **self.temporal_settings.execute_activity_options, ) @asynccontextmanager async def request_stream( + self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: - if event_stream_handler is None: + if self.event_stream_handler is None: raise UserError('Streaming with Temporal requires `Agent` to have an `event_stream_handler`') if run_context is None: raise UserError('Streaming with Temporal requires `request_stream` to be called with a `run_context`') - serialized_run_context = TemporalRunContext.serialize_run_context(run_context, settings.serialize_run_context) + serialized_run_context = TemporalRunContext.serialize_run_context( + run_context, self.temporal_settings.serialize_run_context + ) response = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] - activity=request_stream_activity, + activity=self.request_stream_activity, arg=_RequestParams( messages=messages, model_settings=model_settings, model_request_parameters=model_request_parameters, serialized_run_context=serialized_run_context, ), - **settings.execute_activity_options, + **self.temporal_settings.execute_activity_options, ) yield _TemporalStreamedResponse(response) - - model.request = request - model.request_stream = request_stream - - activities = [request_activity, request_stream_activity] - setattr(model, '__temporal_activities', activities) - return activities - - -def untemporalize_model(model: Model) -> None: - """Untemporalize a model. - - Args: - model: The model to untemporalize. - """ - if not hasattr(model, '__temporal_activities'): - return - - model.request = getattr(model, '__original_request') - model.request_stream = getattr(model, '__original_request_stream') - - delattr(model, '__original_request') - delattr(model, '__original_request_stream') - delattr(model, '__temporal_activities') diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py index bf8b8281d..00f7f9249 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py @@ -1,19 +1,17 @@ from __future__ import annotations -from typing import Any, Callable - from pydantic_ai.mcp import MCPServer from pydantic_ai.toolsets.abstract import AbstractToolset from pydantic_ai.toolsets.function import FunctionToolset -from ._function_toolset import temporalize_function_toolset, untemporalize_function_toolset -from ._mcp_server import temporalize_mcp_server, untemporalize_mcp_server +from ._function_toolset import TemporalFunctionToolset +from ._mcp_server import TemporalMCPServer from ._settings import TemporalSettings def temporalize_toolset( toolset: AbstractToolset, settings: TemporalSettings | None, tool_settings: dict[str, TemporalSettings] = {} -) -> list[Callable[..., Any]]: +) -> AbstractToolset: """Temporalize a toolset. Args: @@ -22,20 +20,8 @@ def temporalize_toolset( tool_settings: The temporal settings to use for specific tools identified by tool name. """ if isinstance(toolset, FunctionToolset): - return temporalize_function_toolset(toolset, settings, tool_settings) + return TemporalFunctionToolset(toolset, settings, tool_settings) elif isinstance(toolset, MCPServer): - return temporalize_mcp_server(toolset, settings, tool_settings) + return TemporalMCPServer(toolset, settings, tool_settings) else: - return [] - - -def untemporalize_toolset(toolset: AbstractToolset) -> None: - """Untemporalize a toolset. - - Args: - toolset: The toolset to untemporalize. - """ - if isinstance(toolset, FunctionToolset): - untemporalize_function_toolset(toolset) - elif isinstance(toolset, MCPServer): - untemporalize_mcp_server(toolset) + return toolset diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py index d73119e58..7f44c2be6 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py @@ -130,6 +130,11 @@ def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: """Run a visitor function on all concrete toolsets that are not wrappers (i.e. they implement their own tool listing and calling).""" visitor(self) + def visit_and_replace( + self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]] + ) -> AbstractToolset[AgentDepsT]: + return visitor(self) + def filtered( self, filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool] ) -> FilteredToolset[AgentDepsT]: diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index 750c54b8e..43e2d0557 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -94,3 +94,8 @@ async def call_tool( def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: for toolset in self.toolsets: toolset.apply(visitor) + + def visit_and_replace( + self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]] + ) -> AbstractToolset[AgentDepsT]: + return CombinedToolset(toolsets=[visitor(toolset) for toolset in self.toolsets]) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py index 6d5a409a5..fdbde6baa 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import Any, Callable from typing_extensions import Self @@ -43,3 +43,8 @@ async def call_tool( def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: self.wrapped.apply(visitor) + + def visit_and_replace( + self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]] + ) -> AbstractToolset[AgentDepsT]: + return replace(self, wrapped=visitor(self.wrapped)) From fb5259c8d42b16c73070b88c8e59aeb56a9c9b00 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 31 Jul 2025 18:24:19 +0000 Subject: [PATCH 16/41] Let running a tool in a Temporal activity be disabled --- .../pydantic_ai/ext/temporal/_agent.py | 10 ++++--- .../ext/temporal/_function_toolset.py | 24 +++++++++++---- .../pydantic_ai/ext/temporal/_mcp_server.py | 17 +++++++---- .../pydantic_ai/ext/temporal/_toolset.py | 6 +++- .../pydantic_ai/toolsets/function.py | 2 ++ temporal.py | 29 ++++++++++++------- 6 files changed, 63 insertions(+), 25 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py index b9cd9be1e..91ab8d947 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Callable, cast +from typing import Any, Callable, Literal, cast from pydantic_ai.agent import Agent from pydantic_ai.models import Model @@ -16,9 +16,9 @@ def temporalize_agent( agent: Agent[Any, Any], settings: TemporalSettings | None = None, toolset_settings: dict[str, TemporalSettings] = {}, - tool_settings: dict[str, dict[str, TemporalSettings]] = {}, + tool_settings: dict[str, dict[str, TemporalSettings | Literal[False]]] = {}, temporalize_toolset_func: Callable[ - [AbstractToolset, TemporalSettings | None, dict[str, TemporalSettings]], AbstractToolset + [AbstractToolset, TemporalSettings | None, dict[str, TemporalSettings | Literal[False]]], AbstractToolset ] = temporalize_toolset, ) -> list[Callable[..., Any]]: """Temporalize an agent. @@ -73,7 +73,9 @@ def iter(*args: Any, **kwargs: Any) -> Any: raise ValueError( 'Toolsets cannot be set at agent run time when using Temporal, it must be set at agent creation time.' ) - if kwargs.get('event_stream_handler') is not None: + if ( + kwargs.get('event_stream_handler') is not None + ): # TODO: iter won't have event_stream_handler, run/_sync/_stream will raise ValueError( 'Event stream handler cannot be set at agent run time when using Temporal, it must be set at agent creation time.' ) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py index 9f403c6e6..828aec6d5 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py @@ -1,13 +1,15 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, cast +from typing import Any, Callable, Literal from pydantic import ConfigDict, with_config from temporalio import activity, workflow from pydantic_ai._run_context import RunContext +from pydantic_ai.exceptions import UserError from pydantic_ai.toolsets import FunctionToolset, ToolsetTool +from pydantic_ai.toolsets.function import _FunctionToolsetTool # pyright: ignore[reportPrivateUsage] from pydantic_ai.toolsets.wrapper import WrapperToolset from ._run_context import TemporalRunContext @@ -27,7 +29,7 @@ def __init__( self, toolset: FunctionToolset, settings: TemporalSettings | None = None, - tool_settings: dict[str, TemporalSettings] = {}, + tool_settings: dict[str, TemporalSettings | Literal[False]] = {}, ): super().__init__(toolset) self.settings = settings or TemporalSettings() @@ -39,7 +41,9 @@ def __init__( @activity.defn(name=f'function_toolset__{id}__call_tool') async def call_tool_activity(params: _CallToolParams) -> Any: name = params.name - settings_for_tool = self.settings.merge(self.tool_settings.get(name)) + settings_for_tool = self.tool_settings.get(name) + assert isinstance(settings_for_tool, TemporalSettings) + settings_for_tool = self.settings.merge(settings_for_tool) ctx = TemporalRunContext.deserialize_run_context( params.serialized_run_context, settings_for_tool.deserialize_run_context ) @@ -57,7 +61,8 @@ async def call_tool_activity(params: _CallToolParams) -> Any: @property def wrapped_function_toolset(self) -> FunctionToolset: - return cast(FunctionToolset, self.wrapped) + assert isinstance(self.wrapped, FunctionToolset) + return self.wrapped @property def activities(self) -> list[Callable[..., Any]]: @@ -73,7 +78,16 @@ def add_tool(self, *args: Any, **kwargs: Any) -> None: return self.wrapped_function_toolset.add_tool(*args, **kwargs) async def call_tool(self, name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: - settings_for_tool = self.settings.merge(self.tool_settings.get(name)) + settings_for_tool = self.tool_settings.get(name) + if settings_for_tool is False: + assert isinstance(tool, _FunctionToolsetTool) + if not tool.is_async: + raise UserError( + 'Disabling running a non-async tool in a Temporal activity is not possible. Make the tool function async instead.' + ) + return await super().call_tool(name, tool_args, ctx, tool) + + settings_for_tool = self.settings.merge(settings_for_tool) serialized_run_context = TemporalRunContext.serialize_run_context(ctx, settings_for_tool.serialize_run_context) return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=self.call_tool_activity, diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py index c3502b520..dbca04420 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py @@ -1,12 +1,13 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, cast +from typing import Any, Callable, Literal from pydantic import ConfigDict, with_config from temporalio import activity, workflow from pydantic_ai._run_context import RunContext +from pydantic_ai.exceptions import UserError from pydantic_ai.mcp import MCPServer, ToolResult from pydantic_ai.tools import ToolDefinition from pydantic_ai.toolsets.abstract import ToolsetTool @@ -36,7 +37,7 @@ def __init__( self, server: MCPServer, settings: TemporalSettings | None = None, - tool_settings: dict[str, TemporalSettings] = {}, + tool_settings: dict[str, TemporalSettings | Literal[False]] = {}, ): super().__init__(server) self.settings = settings or TemporalSettings() @@ -70,7 +71,8 @@ async def call_tool_activity(params: _CallToolParams) -> ToolResult: @property def wrapped_server(self) -> MCPServer: - return cast(MCPServer, self.wrapped) + assert isinstance(self.wrapped, MCPServer) + return self.wrapped @property def activities(self) -> list[Callable[..., Any]]: @@ -95,7 +97,12 @@ async def call_tool( ctx: RunContext[Any], tool: ToolsetTool[Any], ) -> ToolResult: - serialized_run_context = TemporalRunContext.serialize_run_context(ctx, self.settings.serialize_run_context) + settings_for_tool = self.tool_settings.get(name) + if settings_for_tool is False: + raise UserError('Disabling running an MCP tool in a Temporal activity is not possible.') + + settings_for_tool = self.settings.merge(settings_for_tool) + serialized_run_context = TemporalRunContext.serialize_run_context(ctx, settings_for_tool.serialize_run_context) return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=self.call_tool_activity, arg=_CallToolParams( @@ -104,5 +111,5 @@ async def call_tool( serialized_run_context=serialized_run_context, tool_def=tool.tool_def, ), - **self.tool_settings.get(name, self.settings).execute_activity_options, + **settings_for_tool.execute_activity_options, ) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py index 00f7f9249..8a17aea9e 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Literal + from pydantic_ai.mcp import MCPServer from pydantic_ai.toolsets.abstract import AbstractToolset from pydantic_ai.toolsets.function import FunctionToolset @@ -10,7 +12,9 @@ def temporalize_toolset( - toolset: AbstractToolset, settings: TemporalSettings | None, tool_settings: dict[str, TemporalSettings] = {} + toolset: AbstractToolset, + settings: TemporalSettings | None, + tool_settings: dict[str, TemporalSettings | Literal[False]] = {}, ) -> AbstractToolset: """Temporalize a toolset. diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py index 81c667a9e..c206b22c5 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/function.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -24,6 +24,7 @@ class _FunctionToolsetTool(ToolsetTool[AgentDepsT]): """A tool definition for a function toolset tool that keeps track of the function to call.""" call_func: Callable[[dict[str, Any], RunContext[AgentDepsT]], Awaitable[Any]] + is_async: bool @dataclass(init=False) @@ -240,6 +241,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[ max_retries=tool.max_retries if tool.max_retries is not None else self.max_retries, args_validator=tool.function_schema.validator, call_func=tool.function_schema.call, + is_async=tool.function_schema.is_async, ) return tools diff --git a/temporal.py b/temporal.py index 6a0321a62..ab1b8dabd 100644 --- a/temporal.py +++ b/temporal.py @@ -19,6 +19,7 @@ ) from pydantic_ai.mcp import MCPServerStdio from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent +from pydantic_ai.toolsets.function import FunctionToolset class Deps(TypedDict): @@ -31,19 +32,26 @@ async def event_stream_handler(ctx: RunContext[Deps], stream: AsyncIterable[Agen logfire.info(f'{event=}') +toolset = FunctionToolset[Deps](id='toolset') + + +@toolset.tool +async def get_country(ctx: RunContext[Deps]) -> str: + return ctx.deps['country'] + + +@toolset.tool +def get_weather(city: str) -> str: + return 'sunny' + + agent = Agent( 'openai:gpt-4o', deps_type=Deps, - toolsets=[MCPServerStdio('python', ['-m', 'tests.mcp_server'], timeout=20, id='test')], + toolsets=[toolset, MCPServerStdio('python', ['-m', 'tests.mcp_server'], timeout=20, id='mcp')], event_stream_handler=event_stream_handler, ) - -@agent.tool -def get_country(ctx: RunContext[Deps]) -> str: - return ctx.deps['country'] - - # This needs to be called in the same scope where the `agent` is bound to the workflow, # as it modifies the `agent` object in place to swap out methods that use IO for ones that use Temporal activities. temporalize_agent( @@ -53,8 +61,9 @@ def get_country(ctx: RunContext[Deps]) -> str: 'country': TemporalSettings(start_to_close_timeout=timedelta(seconds=120)), }, tool_settings={ - 'country': { - 'get_country': TemporalSettings(start_to_close_timeout=timedelta(seconds=180)), + 'toolset': { + 'get_country': False, + 'get_weather': TemporalSettings(start_to_close_timeout=timedelta(seconds=180)), }, }, ) @@ -95,7 +104,7 @@ async def main(): output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] MyAgentWorkflow.run, args=[ - 'what is the capital of the capital of the country? and what is the product name?', + 'what is the capital of the country? what is the weather there? what is the product name?', Deps(country='Mexico'), ], id=f'my-agent-workflow-id-{random.random()}', From 00002d315d737f1a5b2edcde0a300db0a988e960 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 31 Jul 2025 18:56:11 +0000 Subject: [PATCH 17/41] Use temporal ActivityConfig instead of our TemporalSettings --- .../pydantic_ai/ext/temporal/__init__.py | 7 +-- .../pydantic_ai/ext/temporal/_agent.py | 32 +++++++---- .../ext/temporal/_function_toolset.py | 29 +++++----- .../pydantic_ai/ext/temporal/_mcp_server.py | 34 ++++++------ .../pydantic_ai/ext/temporal/_model.py | 20 +++---- .../pydantic_ai/ext/temporal/_run_context.py | 46 +++++----------- .../pydantic_ai/ext/temporal/_settings.py | 55 ------------------- .../pydantic_ai/ext/temporal/_toolset.py | 18 +++--- temporal.py | 14 +++-- 9 files changed, 95 insertions(+), 160 deletions(-) delete mode 100644 pydantic_ai_slim/pydantic_ai/ext/temporal/_settings.py diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py index 4240cff6e..795560587 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py @@ -13,12 +13,11 @@ from ._agent import temporalize_agent from ._logfire import LogfirePlugin -from ._run_context import TemporalRunContext -from ._settings import TemporalSettings +from ._run_context import TemporalRunContext, TemporalRunContextWithDeps __all__ = [ - 'TemporalSettings', 'TemporalRunContext', + 'TemporalRunContextWithDeps', 'PydanticAIPlugin', 'LogfirePlugin', 'AgentPlugin', @@ -60,7 +59,7 @@ def __init__(self, agent: Agent[Any, Any]): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: agent_activities = getattr(self.agent, '__temporal_activities', None) if agent_activities is None: - raise ValueError('The agent has not been temporalized yet, call `temporalize_agent(agent)` first.') + raise ValueError('The agent has not been prepared for Temporal yet, call `temporalize_agent(agent)` first.') activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType] config['activities'] = [*activities, *agent_activities] diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py index 91ab8d947..45a0a594b 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -2,42 +2,45 @@ from typing import Any, Callable, Literal, cast +from temporalio.workflow import ActivityConfig + from pydantic_ai.agent import Agent +from pydantic_ai.ext.temporal._run_context import TemporalRunContext from pydantic_ai.models import Model from pydantic_ai.toolsets.abstract import AbstractToolset from pydantic_ai.toolsets.function import FunctionToolset from ._model import TemporalModel -from ._settings import TemporalSettings from ._toolset import temporalize_toolset def temporalize_agent( agent: Agent[Any, Any], - settings: TemporalSettings | None = None, - toolset_settings: dict[str, TemporalSettings] = {}, - tool_settings: dict[str, dict[str, TemporalSettings | Literal[False]]] = {}, + activity_config: ActivityConfig = {}, + toolset_activity_config: dict[str, ActivityConfig] = {}, + tool_activity_config: dict[str, dict[str, ActivityConfig | Literal[False]]] = {}, + run_context_type: type[TemporalRunContext] = TemporalRunContext, temporalize_toolset_func: Callable[ - [AbstractToolset, TemporalSettings | None, dict[str, TemporalSettings | Literal[False]]], AbstractToolset + [AbstractToolset, ActivityConfig, dict[str, ActivityConfig | Literal[False]], type[TemporalRunContext]], + AbstractToolset, ] = temporalize_toolset, ) -> list[Callable[..., Any]]: """Temporalize an agent. Args: agent: The agent to temporalize. - settings: The temporal settings to use. - toolset_settings: The temporal settings to use for specific toolsets identified by ID. - tool_settings: The temporal settings to use for specific tools identified by toolset ID and tool name. + activity_config: The Temporal activity config to use. + toolset_activity_config: The Temporal activity config to use for specific toolsets identified by ID. + tool_activity_config: The Temporal activity config to use for specific tools identified by toolset ID and tool name. + run_context_type: The type of run context to use to serialize and deserialize the run context. temporalize_toolset_func: The function to use to temporalize the toolsets. """ if existing_activities := getattr(agent, '__temporal_activities', None): return existing_activities - settings = settings or TemporalSettings() - activities: list[Callable[..., Any]] = [] if isinstance(agent.model, Model): - model = TemporalModel(agent.model, settings, agent._event_stream_handler) # pyright: ignore[reportPrivateUsage] + model = TemporalModel(agent.model, activity_config, agent._event_stream_handler, run_context_type) # pyright: ignore[reportPrivateUsage] activities.extend(model.activities) agent.model = model else: @@ -51,7 +54,12 @@ def temporalize_toolset(toolset: AbstractToolset) -> AbstractToolset: raise ValueError( "A toolset needs to have an ID in order to be used with Temporal. The ID will be used to identify the toolset's activities within the workflow." ) - toolset = temporalize_toolset_func(toolset, settings.merge(toolset_settings.get(id)), tool_settings.get(id, {})) + toolset = temporalize_toolset_func( + toolset, + activity_config | toolset_activity_config.get(id, {}), + tool_activity_config.get(id, {}), + run_context_type, + ) if hasattr(toolset, 'activities'): activities.extend(getattr(toolset, 'activities')) return toolset diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py index 828aec6d5..b4add7e32 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py @@ -5,6 +5,7 @@ from pydantic import ConfigDict, with_config from temporalio import activity, workflow +from temporalio.workflow import ActivityConfig from pydantic_ai._run_context import RunContext from pydantic_ai.exceptions import UserError @@ -13,7 +14,6 @@ from pydantic_ai.toolsets.wrapper import WrapperToolset from ._run_context import TemporalRunContext -from ._settings import TemporalSettings @dataclass @@ -28,12 +28,14 @@ class TemporalFunctionToolset(WrapperToolset[Any]): def __init__( self, toolset: FunctionToolset, - settings: TemporalSettings | None = None, - tool_settings: dict[str, TemporalSettings | Literal[False]] = {}, + activity_config: ActivityConfig = {}, + tool_activity_config: dict[str, ActivityConfig | Literal[False]] = {}, + run_context_type: type[TemporalRunContext] = TemporalRunContext, ): super().__init__(toolset) - self.settings = settings or TemporalSettings() - self.tool_settings = tool_settings + self.activity_config = activity_config + self.tool_activity_config = tool_activity_config + self.run_context_type = run_context_type id = toolset.id assert id is not None @@ -41,12 +43,7 @@ def __init__( @activity.defn(name=f'function_toolset__{id}__call_tool') async def call_tool_activity(params: _CallToolParams) -> Any: name = params.name - settings_for_tool = self.tool_settings.get(name) - assert isinstance(settings_for_tool, TemporalSettings) - settings_for_tool = self.settings.merge(settings_for_tool) - ctx = TemporalRunContext.deserialize_run_context( - params.serialized_run_context, settings_for_tool.deserialize_run_context - ) + ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context) try: tool = (await toolset.get_tools(ctx))[name] except KeyError as e: @@ -78,8 +75,8 @@ def add_tool(self, *args: Any, **kwargs: Any) -> None: return self.wrapped_function_toolset.add_tool(*args, **kwargs) async def call_tool(self, name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: - settings_for_tool = self.tool_settings.get(name) - if settings_for_tool is False: + tool_activity_config = self.tool_activity_config.get(name, {}) + if tool_activity_config is False: assert isinstance(tool, _FunctionToolsetTool) if not tool.is_async: raise UserError( @@ -87,10 +84,10 @@ async def call_tool(self, name: str, tool_args: dict[str, Any], ctx: RunContext, ) return await super().call_tool(name, tool_args, ctx, tool) - settings_for_tool = self.settings.merge(settings_for_tool) - serialized_run_context = TemporalRunContext.serialize_run_context(ctx, settings_for_tool.serialize_run_context) + tool_activity_config = self.activity_config | tool_activity_config + serialized_run_context = self.run_context_type.serialize_run_context(ctx) return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=self.call_tool_activity, arg=_CallToolParams(name=name, tool_args=tool_args, serialized_run_context=serialized_run_context), - **settings_for_tool.execute_activity_options, + **tool_activity_config, ) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py index dbca04420..3a4eca0a7 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py @@ -5,6 +5,7 @@ from pydantic import ConfigDict, with_config from temporalio import activity, workflow +from temporalio.workflow import ActivityConfig from pydantic_ai._run_context import RunContext from pydantic_ai.exceptions import UserError @@ -14,7 +15,6 @@ from pydantic_ai.toolsets.wrapper import WrapperToolset from ._run_context import TemporalRunContext -from ._settings import TemporalSettings @dataclass @@ -36,30 +36,28 @@ class TemporalMCPServer(WrapperToolset[Any]): def __init__( self, server: MCPServer, - settings: TemporalSettings | None = None, - tool_settings: dict[str, TemporalSettings | Literal[False]] = {}, + activity_config: ActivityConfig = {}, + tool_activity_config: dict[str, ActivityConfig | Literal[False]] = {}, + run_context_type: type[TemporalRunContext] = TemporalRunContext, ): super().__init__(server) - self.settings = settings or TemporalSettings() - self.tool_settings = tool_settings + self.activity_config = activity_config + self.tool_activity_config = tool_activity_config + self.run_context_type = run_context_type id = server.id assert id is not None @activity.defn(name=f'mcp_server__{id}__get_tools') async def get_tools_activity(params: _GetToolsParams) -> dict[str, ToolDefinition]: - run_context = TemporalRunContext.deserialize_run_context( - params.serialized_run_context, self.settings.deserialize_run_context - ) + run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context) return {name: tool.tool_def for name, tool in (await self.wrapped.get_tools(run_context)).items()} self.get_tools_activity = get_tools_activity @activity.defn(name=f'mcp_server__{id}__call_tool') async def call_tool_activity(params: _CallToolParams) -> ToolResult: - run_context = TemporalRunContext.deserialize_run_context( - params.serialized_run_context, self.settings.deserialize_run_context - ) + run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context) return await self.wrapped.call_tool( params.name, params.tool_args, @@ -79,11 +77,11 @@ def activities(self) -> list[Callable[..., Any]]: return [self.get_tools_activity, self.call_tool_activity] async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: - serialized_run_context = TemporalRunContext.serialize_run_context(ctx, self.settings.serialize_run_context) + serialized_run_context = self.run_context_type.serialize_run_context(ctx) tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=self.get_tools_activity, arg=_GetToolsParams(serialized_run_context=serialized_run_context), - **self.settings.execute_activity_options, + **self.activity_config, ) return { name: self.wrapped_server._toolset_tool_for_tool_def(tool_def) # pyright: ignore[reportPrivateUsage] @@ -97,12 +95,12 @@ async def call_tool( ctx: RunContext[Any], tool: ToolsetTool[Any], ) -> ToolResult: - settings_for_tool = self.tool_settings.get(name) - if settings_for_tool is False: + tool_activity_config = self.tool_activity_config.get(name, {}) + if tool_activity_config is False: raise UserError('Disabling running an MCP tool in a Temporal activity is not possible.') - settings_for_tool = self.settings.merge(settings_for_tool) - serialized_run_context = TemporalRunContext.serialize_run_context(ctx, settings_for_tool.serialize_run_context) + tool_activity_config = self.activity_config | tool_activity_config + serialized_run_context = self.run_context_type.serialize_run_context(ctx) return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=self.call_tool_activity, arg=_CallToolParams( @@ -111,5 +109,5 @@ async def call_tool( serialized_run_context=serialized_run_context, tool_def=tool.tool_def, ), - **settings_for_tool.execute_activity_options, + **tool_activity_config, ) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py index 7adeb2172..5282412eb 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py @@ -8,6 +8,7 @@ from pydantic import ConfigDict, with_config from temporalio import activity, workflow +from temporalio.workflow import ActivityConfig from pydantic_ai._run_context import RunContext from pydantic_ai.agent import EventStreamHandler @@ -27,7 +28,6 @@ from pydantic_ai.usage import Usage from ._run_context import TemporalRunContext -from ._settings import TemporalSettings @dataclass @@ -71,12 +71,14 @@ class TemporalModel(WrapperModel): def __init__( self, model: Model, - settings: TemporalSettings | None = None, + activity_config: ActivityConfig = {}, event_stream_handler: EventStreamHandler | None = None, + run_context_type: type[TemporalRunContext] = TemporalRunContext, ): super().__init__(model) - self.temporal_settings = settings or TemporalSettings() + self.activity_config = activity_config self.event_stream_handler = event_stream_handler + self.run_context_type = run_context_type id = '_'.join([model.system, model.model_name]) @@ -88,9 +90,7 @@ async def request_activity(params: _RequestParams) -> ModelResponse: @activity.defn(name=f'model__{id}__request_stream') async def request_stream_activity(params: _RequestParams) -> ModelResponse: - run_context = TemporalRunContext.deserialize_run_context( - params.serialized_run_context, self.temporal_settings.deserialize_run_context - ) + run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context) async with self.wrapped.request_stream( params.messages, params.model_settings, params.model_request_parameters, run_context ) as streamed_response: @@ -161,7 +161,7 @@ async def request( model_request_parameters=model_request_parameters, serialized_run_context=None, ), - **self.temporal_settings.execute_activity_options, + **self.activity_config, ) @asynccontextmanager @@ -177,9 +177,7 @@ async def request_stream( if run_context is None: raise UserError('Streaming with Temporal requires `request_stream` to be called with a `run_context`') - serialized_run_context = TemporalRunContext.serialize_run_context( - run_context, self.temporal_settings.serialize_run_context - ) + serialized_run_context = self.run_context_type.serialize_run_context(run_context) response = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=self.request_stream_activity, arg=_RequestParams( @@ -188,6 +186,6 @@ async def request_stream( model_request_parameters=model_request_parameters, serialized_run_context=serialized_run_context, ), - **self.temporal_settings.execute_activity_options, + **self.activity_config, ) yield _TemporalStreamedResponse(response) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py index c1e896508..d172a0db0 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Callable +from typing import Any from pydantic_ai._run_context import RunContext @@ -20,49 +20,33 @@ def __getattribute__(self, name: str) -> Any: except AttributeError as e: if name in RunContext.__dataclass_fields__: raise AttributeError( - f'Temporalized {RunContext.__name__!r} object has no attribute {name!r}. ' - 'To make the attribute available, pass a `TemporalSettings` object to `temporalize_agent` with a custom `serialize_run_context` function that returns a dictionary that includes the attribute.' + f'{self.__class__.__name__!r} object has no attribute {name!r}. ' + 'To make the attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to `temporalize_agent`.' ) else: raise e @classmethod - def serialize_run_context( - cls, - ctx: RunContext[Any], - extra_serializer: Callable[[RunContext[Any]], dict[str, Any]] | None = None, - ) -> dict[str, Any]: + def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]: return { 'retries': ctx.retries, 'tool_call_id': ctx.tool_call_id, 'tool_name': ctx.tool_name, 'retry': ctx.retry, 'run_step': ctx.run_step, - **(extra_serializer(ctx) if extra_serializer else {}), } @classmethod - def deserialize_run_context( - cls, ctx: dict[str, Any], extra_deserializer: Callable[[dict[str, Any]], dict[str, Any]] | None = None - ) -> RunContext[Any]: - return cls( - retries=ctx['retries'], - tool_call_id=ctx['tool_call_id'], - tool_name=ctx['tool_name'], - retry=ctx['retry'], - run_step=ctx['run_step'], - **(extra_deserializer(ctx) if extra_deserializer else {}), - ) - + def deserialize_run_context(cls, ctx: dict[str, Any]) -> RunContext[Any]: + return cls(**ctx) -def serialize_run_context_deps(ctx: RunContext[Any]) -> dict[str, Any]: - if not isinstance(ctx.deps, dict): - raise ValueError( - 'The `deps` object must be a JSON-serializable dictionary in order to be used with Temporal. ' - 'To use a different type, pass a `TemporalSettings` object to `temporalize_agent` with custom `serialize_run_context` and `deserialize_run_context` functions.' - ) - return {'deps': ctx.deps} # pyright: ignore[reportUnknownMemberType] - -def deserialize_run_context_deps(ctx: dict[str, Any]) -> dict[str, Any]: - return {'deps': ctx['deps']} +class TemporalRunContextWithDeps(TemporalRunContext): + @classmethod + def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]: + if not isinstance(ctx.deps, dict): + raise ValueError( + 'The `deps` object must be a JSON-serializable dictionary in order to be used with Temporal. ' + 'To use a different type, pass a `TemporalRunContext` subclass to `temporalize_agent` with custom `serialize_run_context` and `deserialize_run_context` class methods.' + ) + return {**super().serialize_run_context(ctx), 'deps': ctx.deps} # pyright: ignore[reportUnknownMemberType] diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_settings.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_settings.py deleted file mode 100644 index 4340f1863..000000000 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_settings.py +++ /dev/null @@ -1,55 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, fields, replace -from datetime import timedelta -from typing import Any, Callable - -from temporalio.common import Priority, RetryPolicy -from temporalio.workflow import ActivityCancellationType, VersioningIntent - -from pydantic_ai._run_context import RunContext - -from ._run_context import deserialize_run_context_deps, serialize_run_context_deps - - -@dataclass -class TemporalSettings: - """Settings for Temporal `execute_activity` and Pydantic AI-specific Temporal activity behavior.""" - - # Temporal settings - task_queue: str | None = None - schedule_to_close_timeout: timedelta | None = None - schedule_to_start_timeout: timedelta | None = None - start_to_close_timeout: timedelta | None = None - heartbeat_timeout: timedelta | None = None - retry_policy: RetryPolicy | None = None - cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL - activity_id: str | None = None - versioning_intent: VersioningIntent | None = None - summary: str | None = None - priority: Priority = Priority.default - - serialize_run_context: Callable[[RunContext], dict[str, Any]] = serialize_run_context_deps - deserialize_run_context: Callable[[dict[str, Any]], dict[str, Any]] = deserialize_run_context_deps - - @property - def execute_activity_options(self) -> dict[str, Any]: - return { - 'task_queue': self.task_queue, - 'schedule_to_close_timeout': self.schedule_to_close_timeout, - 'schedule_to_start_timeout': self.schedule_to_start_timeout, - 'start_to_close_timeout': self.start_to_close_timeout, - 'heartbeat_timeout': self.heartbeat_timeout, - 'retry_policy': self.retry_policy, - 'cancellation_type': self.cancellation_type, - 'activity_id': self.activity_id, - 'versioning_intent': self.versioning_intent, - 'summary': self.summary, - 'priority': self.priority, - } - - def merge(self, other: TemporalSettings | None) -> TemporalSettings: - """Merge non-default values from another TemporalSettings instance into this one, returning a new instance.""" - if not other: - return self - return replace(self, **{f.name: value for f in fields(other) if (value := getattr(other, f.name)) != f.default}) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py index 8a17aea9e..e4cccbbb9 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py @@ -2,30 +2,34 @@ from typing import Literal +from temporalio.workflow import ActivityConfig + +from pydantic_ai.ext.temporal._run_context import TemporalRunContext from pydantic_ai.mcp import MCPServer from pydantic_ai.toolsets.abstract import AbstractToolset from pydantic_ai.toolsets.function import FunctionToolset from ._function_toolset import TemporalFunctionToolset from ._mcp_server import TemporalMCPServer -from ._settings import TemporalSettings def temporalize_toolset( toolset: AbstractToolset, - settings: TemporalSettings | None, - tool_settings: dict[str, TemporalSettings | Literal[False]] = {}, + activity_config: ActivityConfig = {}, + tool_activity_config: dict[str, ActivityConfig | Literal[False]] = {}, + run_context_type: type[TemporalRunContext] = TemporalRunContext, ) -> AbstractToolset: """Temporalize a toolset. Args: toolset: The toolset to temporalize. - settings: The temporal settings to use. - tool_settings: The temporal settings to use for specific tools identified by tool name. + activity_config: The Temporal activity config to use. + tool_activity_config: The Temporal activity config to use for specific tools identified by tool name. + run_context_type: The type of run context to use to serialize and deserialize the run context. """ if isinstance(toolset, FunctionToolset): - return TemporalFunctionToolset(toolset, settings, tool_settings) + return TemporalFunctionToolset(toolset, activity_config, tool_activity_config, run_context_type) elif isinstance(toolset, MCPServer): - return TemporalMCPServer(toolset, settings, tool_settings) + return TemporalMCPServer(toolset, activity_config, tool_activity_config, run_context_type) else: return toolset diff --git a/temporal.py b/temporal.py index ab1b8dabd..fc873de55 100644 --- a/temporal.py +++ b/temporal.py @@ -7,6 +7,7 @@ from temporalio import workflow from temporalio.client import Client from temporalio.worker import Worker +from temporalio.workflow import ActivityConfig from typing_extensions import TypedDict from pydantic_ai import Agent, RunContext @@ -14,7 +15,7 @@ AgentPlugin, LogfirePlugin, PydanticAIPlugin, - TemporalSettings, + TemporalRunContextWithDeps, temporalize_agent, ) from pydantic_ai.mcp import MCPServerStdio @@ -56,16 +57,17 @@ def get_weather(city: str) -> str: # as it modifies the `agent` object in place to swap out methods that use IO for ones that use Temporal activities. temporalize_agent( agent, - settings=TemporalSettings(start_to_close_timeout=timedelta(seconds=60)), - toolset_settings={ - 'country': TemporalSettings(start_to_close_timeout=timedelta(seconds=120)), + activity_config=ActivityConfig(start_to_close_timeout=timedelta(seconds=60)), + toolset_activity_config={ + 'country': ActivityConfig(start_to_close_timeout=timedelta(seconds=120)), }, - tool_settings={ + tool_activity_config={ 'toolset': { 'get_country': False, - 'get_weather': TemporalSettings(start_to_close_timeout=timedelta(seconds=180)), + 'get_weather': ActivityConfig(start_to_close_timeout=timedelta(seconds=180)), }, }, + run_context_type=TemporalRunContextWithDeps, ) with workflow.unsafe.imports_passed_through(): From f91547ddb5efce091c4f120309cd1cb9b180d911 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 31 Jul 2025 19:04:19 +0000 Subject: [PATCH 18/41] Warn when non-default Temporal data converter was swapped out --- .../pydantic_ai/ext/temporal/__init__.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py index 795560587..7ea4cf02b 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py @@ -1,11 +1,13 @@ from __future__ import annotations +import warnings from collections.abc import Sequence from dataclasses import replace from typing import Any, Callable from temporalio.client import ClientConfig, Plugin as ClientPlugin -from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter +from temporalio.converter import DefaultPayloadConverter from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner @@ -29,6 +31,14 @@ class PydanticAIPlugin(ClientPlugin, WorkerPlugin): """Temporal client and worker plugin for Pydantic AI.""" def configure_client(self, config: ClientConfig) -> ClientConfig: + if (data_converter := config.get('data_converter')) and data_converter.payload_converter_class not in ( + DefaultPayloadConverter, + PydanticPayloadConverter, + ): + warnings.warn( + 'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.' + ) + config['data_converter'] = pydantic_data_converter return super().configure_client(config) From 6aeb078140464dd1d72c325cd4508db8305c55dd Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 31 Jul 2025 20:46:37 +0000 Subject: [PATCH 19/41] Use agent.override inside temporalize_agent instead of directly setting fields --- pydantic_ai_slim/pydantic_ai/agent.py | 31 ++++++- .../pydantic_ai/ext/temporal/__init__.py | 3 +- .../pydantic_ai/ext/temporal/_agent.py | 91 ++++++++++++------- .../ext/temporal/_function_toolset.py | 11 ++- .../pydantic_ai/ext/temporal/_mcp_server.py | 12 ++- .../pydantic_ai/ext/temporal/_model.py | 5 +- .../pydantic_ai/ext/temporal/_run_context.py | 3 +- .../pydantic_ai/ext/temporal/_toolset.py | 16 +++- temporal.py | 11 +-- 9 files changed, 124 insertions(+), 59 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index be9adef30..61d749ad2 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -156,6 +156,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]): instrument: InstrumentationSettings | bool | None """Options to automatically instrument with OpenTelemetry.""" + event_stream_handler: EventStreamHandler[AgentDepsT] | None + """Optional handler for events from the agent stream.""" + _instrument_default: ClassVar[InstrumentationSettings | bool] = False _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) @@ -176,7 +179,6 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) - _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) _enter_lock: Lock = dataclasses.field(repr=False) _entered_count: int = dataclasses.field(repr=False) @@ -438,13 +440,16 @@ def __init__( self.history_processors = history_processors or [] - self._event_stream_handler = event_stream_handler + self.event_stream_handler = event_stream_handler self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar( '_override_toolsets', default=None ) + self._override_tools: ContextVar[ + _utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]] + ] = ContextVar('_override_tools', default=None) self._enter_lock = _utils.get_async_lock() self._entered_count = 0 @@ -575,11 +580,11 @@ async def main(): toolsets=toolsets, ) as agent_run: async for node in agent_run: - if self._event_stream_handler is not None and ( + if self.event_stream_handler is not None and ( self.is_model_request_node(node) or self.is_call_tools_node(node) ): async with node.stream(agent_run.ctx) as stream: - await self._event_stream_handler(_agent_graph.build_run_context(agent_run.ctx), stream) + await self.event_stream_handler(_agent_graph.build_run_context(agent_run.ctx), stream) assert agent_run.result is not None, 'The graph run did not finish properly' return agent_run.result @@ -1210,6 +1215,7 @@ def override( deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent dependencies, model, or toolsets. @@ -1220,6 +1226,7 @@ def override( deps: The dependencies to use instead of the dependencies passed to the agent run. model: The model to use instead of the model passed to the agent run. toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. + tools: The tools to use instead of the tools registered with the agent. """ if _utils.is_set(deps): deps_token = self._override_deps.set(_utils.Some(deps)) @@ -1236,6 +1243,11 @@ def override( else: toolsets_token = None + if _utils.is_set(tools): + tools_token = self._override_tools.set(_utils.Some(tools)) + else: + tools_token = None + try: yield finally: @@ -1245,6 +1257,8 @@ def override( self._override_model.reset(model_token) if toolsets_token is not None: self._override_toolsets.reset(toolsets_token) + if tools_token is not None: + self._override_tools.reset(tools_token) @overload def instructions( @@ -1697,6 +1711,13 @@ def _get_toolset( output_toolset: The output toolset to use instead of the one built at agent construction time. additional_toolsets: Additional toolsets to add. """ + if some_tools := self._override_tools.get(): + function_toolset = FunctionToolset( + some_tools.value, max_retries=self._function_toolset.max_retries, id='agent' + ) + else: + function_toolset = self._function_toolset + if some_user_toolsets := self._override_toolsets.get(): user_toolsets = some_user_toolsets.value elif additional_toolsets is not None: @@ -1704,7 +1725,7 @@ def _get_toolset( else: user_toolsets = self._user_toolsets - all_toolsets = [self._function_toolset, *user_toolsets] + all_toolsets = [function_toolset, *user_toolsets] if self._prepare_tools: all_toolsets = [PreparedToolset(CombinedToolset(all_toolsets), self._prepare_tools)] diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py index 7ea4cf02b..aa6e03e43 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py @@ -12,6 +12,7 @@ from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner from pydantic_ai.agent import Agent +from pydantic_ai.exceptions import UserError from ._agent import temporalize_agent from ._logfire import LogfirePlugin @@ -69,7 +70,7 @@ def __init__(self, agent: Agent[Any, Any]): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: agent_activities = getattr(self.agent, '__temporal_activities', None) if agent_activities is None: - raise ValueError('The agent has not been prepared for Temporal yet, call `temporalize_agent(agent)` first.') + raise UserError('The agent has not been prepared for Temporal yet, call `temporalize_agent(agent)` first.') activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType] config['activities'] = [*activities, *agent_activities] diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py index 45a0a594b..ef7f857a4 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -1,20 +1,22 @@ from __future__ import annotations -from typing import Any, Callable, Literal, cast +from contextlib import asynccontextmanager +from typing import Any, Callable, Literal +from temporalio import workflow from temporalio.workflow import ActivityConfig from pydantic_ai.agent import Agent +from pydantic_ai.exceptions import UserError from pydantic_ai.ext.temporal._run_context import TemporalRunContext from pydantic_ai.models import Model from pydantic_ai.toolsets.abstract import AbstractToolset -from pydantic_ai.toolsets.function import FunctionToolset from ._model import TemporalModel -from ._toolset import temporalize_toolset +from ._toolset import TemporalWrapperToolset, temporalize_toolset -def temporalize_agent( +def temporalize_agent( # noqa: C901 agent: Agent[Any, Any], activity_config: ActivityConfig = {}, toolset_activity_config: dict[str, ActivityConfig] = {}, @@ -24,7 +26,7 @@ def temporalize_agent( [AbstractToolset, ActivityConfig, dict[str, ActivityConfig | Literal[False]], type[TemporalRunContext]], AbstractToolset, ] = temporalize_toolset, -) -> list[Callable[..., Any]]: +) -> Agent[Any, Any]: """Temporalize an agent. Args: @@ -33,26 +35,25 @@ def temporalize_agent( toolset_activity_config: The Temporal activity config to use for specific toolsets identified by ID. tool_activity_config: The Temporal activity config to use for specific tools identified by toolset ID and tool name. run_context_type: The type of run context to use to serialize and deserialize the run context. - temporalize_toolset_func: The function to use to temporalize the toolsets. + temporalize_toolset_func: The function to use to prepare the toolsets for Temporal. """ - if existing_activities := getattr(agent, '__temporal_activities', None): - return existing_activities + if getattr(agent, '__temporal_activities', None): + return agent activities: list[Callable[..., Any]] = [] - if isinstance(agent.model, Model): - model = TemporalModel(agent.model, activity_config, agent._event_stream_handler, run_context_type) # pyright: ignore[reportPrivateUsage] - activities.extend(model.activities) - agent.model = model - else: - raise ValueError( + if not isinstance(agent.model, Model): + raise UserError( 'Model cannot be set at agent run time when using Temporal, it must be set at agent creation time.' ) + temporal_model = TemporalModel(agent.model, activity_config, agent.event_stream_handler, run_context_type) + activities.extend(temporal_model.temporal_activities) + def temporalize_toolset(toolset: AbstractToolset) -> AbstractToolset: id = toolset.id if not id: - raise ValueError( - "A toolset needs to have an ID in order to be used with Temporal. The ID will be used to identify the toolset's activities within the workflow." + raise UserError( + "Toolsets that implement their own tool calling need to have an ID in order to be used with Temporal. The ID will be used to identify the toolset's activities within the workflow." ) toolset = temporalize_toolset_func( toolset, @@ -60,47 +61,67 @@ def temporalize_toolset(toolset: AbstractToolset) -> AbstractToolset: tool_activity_config.get(id, {}), run_context_type, ) - if hasattr(toolset, 'activities'): - activities.extend(getattr(toolset, 'activities')) + if isinstance(toolset, TemporalWrapperToolset): + activities.extend(toolset.temporal_activities) return toolset - agent._function_toolset = cast(FunctionToolset, temporalize_toolset(agent._function_toolset)) # pyright: ignore[reportPrivateUsage] - agent._user_toolsets = [temporalize_toolset(toolset) for toolset in agent._user_toolsets] # pyright: ignore[reportPrivateUsage] + temporal_toolsets = [temporalize_toolset(toolset) for toolset in [agent._function_toolset, *agent._user_toolsets]] # pyright: ignore[reportPrivateUsage] original_iter = agent.iter original_override = agent.override - setattr(agent, '__original_iter', original_iter) - setattr(agent, '__original_override', original_override) def iter(*args: Any, **kwargs: Any) -> Any: + if not workflow.in_workflow(): + return original_iter(*args, **kwargs) + if kwargs.get('model') is not None: - raise ValueError( + raise UserError( 'Model cannot be set at agent run time when using Temporal, it must be set at agent creation time.' ) if kwargs.get('toolsets') is not None: - raise ValueError( + raise UserError( 'Toolsets cannot be set at agent run time when using Temporal, it must be set at agent creation time.' ) - if ( - kwargs.get('event_stream_handler') is not None - ): # TODO: iter won't have event_stream_handler, run/_sync/_stream will - raise ValueError( + if kwargs.get('event_stream_handler') is not None: + # TODO: iter won't have event_stream_handler, run/_sync/_stream will + raise UserError( 'Event stream handler cannot be set at agent run time when using Temporal, it must be set at agent creation time.' ) - return original_iter(*args, **kwargs) + @asynccontextmanager + async def async_iter(): + # We reset tools here as the temporalized function toolset is already in temporal_toolsets. + with agent.override(model=temporal_model, toolsets=temporal_toolsets, tools=[]): + async with original_iter(*args, **kwargs) as result: + yield result + + return async_iter() def override(*args: Any, **kwargs: Any) -> Any: - if kwargs.get('model') is not None: - raise ValueError('Model cannot be overridden when using Temporal, it must be set at agent creation time.') - if kwargs.get('toolsets') is not None: - raise ValueError( - 'Toolsets cannot be overridden when using Temporal, it must be set at agent creation time.' + if not workflow.in_workflow(): + return original_override(*args, **kwargs) + + if kwargs.get('model') not in (None, temporal_model): + raise UserError( + 'Model cannot be contextually overridden when using Temporal, it must be set at agent creation time.' + ) + if kwargs.get('toolsets') not in (None, temporal_toolsets): + raise UserError( + 'Toolsets cannot be contextually overridden when using Temporal, they must be set at agent creation time.' + ) + if kwargs.get('tools') not in (None, []): + raise UserError( + 'Tools cannot be contextually overridden when using Temporal, they must be set at agent creation time.' ) return original_override(*args, **kwargs) + def tool(*args: Any, **kwargs: Any) -> Any: + raise UserError('New tools cannot be registered after an agent has been prepared for Temporal.') + agent.iter = iter agent.override = override + agent.tool = tool + agent.tool_plain = tool setattr(agent, '__temporal_activities', activities) - return activities + return agent diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py index b4add7e32..0f397c8cb 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py @@ -11,9 +11,9 @@ from pydantic_ai.exceptions import UserError from pydantic_ai.toolsets import FunctionToolset, ToolsetTool from pydantic_ai.toolsets.function import _FunctionToolsetTool # pyright: ignore[reportPrivateUsage] -from pydantic_ai.toolsets.wrapper import WrapperToolset from ._run_context import TemporalRunContext +from ._toolset import TemporalWrapperToolset @dataclass @@ -24,7 +24,7 @@ class _CallToolParams: serialized_run_context: Any -class TemporalFunctionToolset(WrapperToolset[Any]): +class TemporalFunctionToolset(TemporalWrapperToolset): def __init__( self, toolset: FunctionToolset, @@ -47,7 +47,7 @@ async def call_tool_activity(params: _CallToolParams) -> Any: try: tool = (await toolset.get_tools(ctx))[name] except KeyError as e: - raise ValueError( + raise UserError( f'Tool {name!r} not found in toolset {toolset.id!r}. ' 'Removing or renaming tools during an agent run is not supported with Temporal.' ) from e @@ -62,7 +62,7 @@ def wrapped_function_toolset(self) -> FunctionToolset: return self.wrapped @property - def activities(self) -> list[Callable[..., Any]]: + def temporal_activities(self) -> list[Callable[..., Any]]: return [self.call_tool_activity] def tool(self, *args: Any, **kwargs: Any) -> Any: @@ -75,6 +75,9 @@ def add_tool(self, *args: Any, **kwargs: Any) -> None: return self.wrapped_function_toolset.add_tool(*args, **kwargs) async def call_tool(self, name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: + if not workflow.in_workflow(): + return await super().call_tool(name, tool_args, ctx, tool) + tool_activity_config = self.tool_activity_config.get(name, {}) if tool_activity_config is False: assert isinstance(tool, _FunctionToolsetTool) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py index 3a4eca0a7..25575b7a5 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py @@ -12,9 +12,9 @@ from pydantic_ai.mcp import MCPServer, ToolResult from pydantic_ai.tools import ToolDefinition from pydantic_ai.toolsets.abstract import ToolsetTool -from pydantic_ai.toolsets.wrapper import WrapperToolset from ._run_context import TemporalRunContext +from ._toolset import TemporalWrapperToolset @dataclass @@ -32,7 +32,7 @@ class _CallToolParams: tool_def: ToolDefinition -class TemporalMCPServer(WrapperToolset[Any]): +class TemporalMCPServer(TemporalWrapperToolset): def __init__( self, server: MCPServer, @@ -73,10 +73,13 @@ def wrapped_server(self) -> MCPServer: return self.wrapped @property - def activities(self) -> list[Callable[..., Any]]: + def temporal_activities(self) -> list[Callable[..., Any]]: return [self.get_tools_activity, self.call_tool_activity] async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: + if not workflow.in_workflow(): + return await super().get_tools(ctx) + serialized_run_context = self.run_context_type.serialize_run_context(ctx) tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=self.get_tools_activity, @@ -95,6 +98,9 @@ async def call_tool( ctx: RunContext[Any], tool: ToolsetTool[Any], ) -> ToolResult: + if not workflow.in_workflow(): + return await super().call_tool(name, tool_args, ctx, tool) + tool_activity_config = self.tool_activity_config.get(name, {}) if tool_activity_config is False: raise UserError('Disabling running an MCP tool in a Temporal activity is not possible.') diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py index 5282412eb..7b7f249d1 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py @@ -144,7 +144,7 @@ def _get_final_result_event(e: ModelResponseStreamEvent) -> FinalResultEvent | N self.request_stream_activity = request_stream_activity @property - def activities(self) -> list[Callable[..., Any]]: + def temporal_activities(self) -> list[Callable[..., Any]]: return [self.request_activity, self.request_stream_activity] async def request( @@ -153,6 +153,9 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: + if not workflow.in_workflow(): + return await super().request(messages, model_settings, model_request_parameters) + return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] activity=self.request_activity, arg=_RequestParams( diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py index d172a0db0..e6223195c 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py @@ -3,6 +3,7 @@ from typing import Any from pydantic_ai._run_context import RunContext +from pydantic_ai.exceptions import UserError class TemporalRunContext(RunContext[Any]): @@ -45,7 +46,7 @@ class TemporalRunContextWithDeps(TemporalRunContext): @classmethod def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]: if not isinstance(ctx.deps, dict): - raise ValueError( + raise UserError( 'The `deps` object must be a JSON-serializable dictionary in order to be used with Temporal. ' 'To use a different type, pass a `TemporalRunContext` subclass to `temporalize_agent` with custom `serialize_run_context` and `deserialize_run_context` class methods.' ) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py index e4cccbbb9..32f9ef584 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Literal +from abc import ABC, abstractmethod +from typing import Any, Callable, Literal from temporalio.workflow import ActivityConfig @@ -8,9 +9,14 @@ from pydantic_ai.mcp import MCPServer from pydantic_ai.toolsets.abstract import AbstractToolset from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.toolsets.wrapper import WrapperToolset -from ._function_toolset import TemporalFunctionToolset -from ._mcp_server import TemporalMCPServer + +class TemporalWrapperToolset(WrapperToolset[Any], ABC): + @property + @abstractmethod + def temporal_activities(self) -> list[Callable[..., Any]]: + raise NotImplementedError def temporalize_toolset( @@ -28,8 +34,12 @@ def temporalize_toolset( run_context_type: The type of run context to use to serialize and deserialize the run context. """ if isinstance(toolset, FunctionToolset): + from ._function_toolset import TemporalFunctionToolset + return TemporalFunctionToolset(toolset, activity_config, tool_activity_config, run_context_type) elif isinstance(toolset, MCPServer): + from ._mcp_server import TemporalMCPServer + return TemporalMCPServer(toolset, activity_config, tool_activity_config, run_context_type) else: return toolset diff --git a/temporal.py b/temporal.py index fc873de55..34432629c 100644 --- a/temporal.py +++ b/temporal.py @@ -33,15 +33,10 @@ async def event_stream_handler(ctx: RunContext[Deps], stream: AsyncIterable[Agen logfire.info(f'{event=}') -toolset = FunctionToolset[Deps](id='toolset') - - -@toolset.tool async def get_country(ctx: RunContext[Deps]) -> str: return ctx.deps['country'] -@toolset.tool def get_weather(city: str) -> str: return 'sunny' @@ -49,7 +44,11 @@ def get_weather(city: str) -> str: agent = Agent( 'openai:gpt-4o', deps_type=Deps, - toolsets=[toolset, MCPServerStdio('python', ['-m', 'tests.mcp_server'], timeout=20, id='mcp')], + toolsets=[ + FunctionToolset[Deps](tools=[get_weather], id='toolset'), + MCPServerStdio('python', ['-m', 'tests.mcp_server'], timeout=20, id='mcp'), + ], + tools=[get_country], event_stream_handler=event_stream_handler, ) From 2f3965b7c23635e94b60544010f0cfd25338026f Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 31 Jul 2025 22:31:52 +0000 Subject: [PATCH 20/41] Remove duplication between AgentStream and TemporalModel get_final_result_event --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 28 +++++---------- .../pydantic_ai/ext/temporal/_model.py | 34 +++---------------- .../pydantic_ai/models/__init__.py | 29 +++++++++++++++- .../pydantic_ai/models/anthropic.py | 5 +-- .../pydantic_ai/models/bedrock.py | 5 +-- pydantic_ai_slim/pydantic_ai/models/cohere.py | 5 +-- pydantic_ai_slim/pydantic_ai/models/gemini.py | 4 +-- pydantic_ai_slim/pydantic_ai/models/google.py | 10 ++---- pydantic_ai_slim/pydantic_ai/models/groq.py | 5 +-- .../pydantic_ai/models/huggingface.py | 5 +-- .../pydantic_ai/models/mistral.py | 11 ++---- pydantic_ai_slim/pydantic_ai/models/openai.py | 10 ++---- pydantic_ai_slim/pydantic_ai/result.py | 23 ++----------- 13 files changed, 55 insertions(+), 119 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 0dea48dab..16df8f9dc 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -303,10 +303,18 @@ async def stream( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], ) -> AsyncIterator[result.AgentStream[DepsT, T]]: - async with self._stream(ctx) as streamed_response: + assert not self._did_stream, 'stream() should only be called once per node' + + model_settings, model_request_parameters, message_history, run_context = await self._prepare_request(ctx) + async with ctx.deps.model.request_stream( + message_history, model_settings, model_request_parameters, run_context + ) as streamed_response: + self._did_stream = True + ctx.state.usage.requests += 1 agent_stream = result.AgentStream[DepsT, T]( streamed_response, ctx.deps.output_schema, + model_request_parameters, ctx.deps.output_validators, build_run_context(ctx), ctx.deps.usage_limits, @@ -318,24 +326,6 @@ async def stream( async for _ in agent_stream: pass - @asynccontextmanager - async def _stream( - self, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], - ) -> AsyncIterator[models.StreamedResponse]: - assert not self._did_stream, 'stream() should only be called once per node' - - model_settings, model_request_parameters, message_history, run_context = await self._prepare_request(ctx) - async with ctx.deps.model.request_stream( - message_history, model_settings, model_request_parameters, run_context - ) as streamed_response: - self._did_stream = True - ctx.state.usage.requests += 1 - yield streamed_response - # In case the user didn't manually consume the full stream, ensure it is fully consumed here, - # otherwise usage won't be properly counted: - async for _ in streamed_response: - pass model_response = streamed_response.get() self._finish_handling(ctx, model_response) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py index 7b7f249d1..b87e3a6e8 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py @@ -14,13 +14,9 @@ from pydantic_ai.agent import EventStreamHandler from pydantic_ai.exceptions import UserError from pydantic_ai.messages import ( - FinalResultEvent, ModelMessage, ModelResponse, ModelResponseStreamEvent, - PartStartEvent, - TextPart, - ToolCallPart, ) from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse from pydantic_ai.models.wrapper import WrapperModel @@ -72,7 +68,7 @@ def __init__( self, model: Model, activity_config: ActivityConfig = {}, - event_stream_handler: EventStreamHandler | None = None, + event_stream_handler: EventStreamHandler[Any] | None = None, run_context_type: type[TemporalRunContext] = TemporalRunContext, ): super().__init__(model) @@ -94,38 +90,16 @@ async def request_stream_activity(params: _RequestParams) -> ModelResponse: async with self.wrapped.request_stream( params.messages, params.model_settings, params.model_request_parameters, run_context ) as streamed_response: - tool_defs = { - tool_def.name: tool_def - for tool_def in [ - *params.model_request_parameters.output_tools, - *params.model_request_parameters.function_tools, - ] - } - # Keep in sync with `AgentStream.__aiter__` async def aiter(): - def _get_final_result_event(e: ModelResponseStreamEvent) -> FinalResultEvent | None: - """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" - if isinstance(e, PartStartEvent): - new_part = e.part - if ( - isinstance(new_part, TextPart) and params.model_request_parameters.allow_text_output - ): # pragma: no branch - return FinalResultEvent(tool_name=None, tool_call_id=None) - elif isinstance(new_part, ToolCallPart) and (tool_def := tool_defs.get(new_part.tool_name)): - if tool_def.kind == 'output': - return FinalResultEvent( - tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id - ) - elif tool_def.kind == 'deferred': - return FinalResultEvent(tool_name=None, tool_call_id=None) - # `AgentStream.__aiter__`, which this is based on, calls `_get_usage_checking_stream_response` here, # but we don't have access to the `_usage_limits`. async for event in streamed_response: yield event - if (final_result_event := _get_final_result_event(event)) is not None: + if ( + final_result_event := params.model_request_parameters.get_final_result_event(event) + ) is not None: yield final_result_event break diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 3e457de57..fea76fd75 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -23,7 +23,18 @@ from .._parts_manager import ModelResponsePartsManager from .._run_context import RunContext from ..exceptions import UserError -from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl +from ..messages import ( + FileUrl, + FinalResultEvent, + ModelMessage, + ModelRequest, + ModelResponse, + ModelResponseStreamEvent, + PartStartEvent, + TextPart, + ToolCallPart, + VideoUrl, +) from ..output import OutputMode from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec from ..profiles._json_schema import JsonSchemaTransformer @@ -343,6 +354,22 @@ class ModelRequestParameters: output_tools: list[ToolDefinition] = field(default_factory=list) allow_text_output: bool = True + @cached_property + def tool_defs(self) -> dict[str, ToolDefinition]: + return {tool_def.name: tool_def for tool_def in [*self.function_tools, *self.output_tools]} + + def get_final_result_event(self, e: ModelResponseStreamEvent) -> FinalResultEvent | None: + """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" + if isinstance(e, PartStartEvent): + new_part = e.part + if isinstance(new_part, TextPart) and self.allow_text_output: # pragma: no branch + return FinalResultEvent(tool_name=None, tool_call_id=None) + elif isinstance(new_part, ToolCallPart) and (tool_def := self.tool_defs.get(new_part.tool_name)): + if tool_def.kind == 'output': + return FinalResultEvent(tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id) + elif tool_def.kind == 'deferred': + return FinalResultEvent(tool_name=None, tool_call_id=None) + __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 3006cfde6..5ea1f5347 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -299,10 +299,7 @@ async def _process_streamed_response(self, response: AsyncStream[BetaRawMessageS ) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolParam]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901 """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 3fde8d252..00eeed4ae 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -226,10 +226,7 @@ def __init__( super().__init__(settings=settings, profile=profile or provider.model_profile) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] @staticmethod def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef: diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 4243ef492..652b6740f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -238,10 +238,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]: return cohere_messages def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] @staticmethod def _map_tool_call(t: ToolCallPart) -> ToolCallV2: diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index feb6cf10e..94ed3255f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -183,9 +183,7 @@ def system(self) -> str: return self._system def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None: - tools = [_function_from_abstract_tool(t) for t in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [_function_from_abstract_tool(t) for t in model_request_parameters.output_tools] + tools = [_function_from_abstract_tool(t) for t in model_request_parameters.tool_defs.values()] return _GeminiTools(function_declarations=tools) if tools else None def _get_tool_config( diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index b25ea8233..d6dd82b4a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -206,16 +206,10 @@ def system(self) -> str: return self._system def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None: - tools: list[ToolDict] = [ + return [ ToolDict(function_declarations=[_function_declaration_from_tool(t)]) - for t in model_request_parameters.function_tools + for t in model_request_parameters.tool_defs.values() ] - if model_request_parameters.output_tools: - tools += [ - ToolDict(function_declarations=[_function_declaration_from_tool(t)]) - for t in model_request_parameters.output_tools - ] - return tools or None def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 056dc64d5..22dc69d56 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -286,10 +286,7 @@ async def _process_streamed_response(self, response: AsyncStream[chat.ChatComple ) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index d32970546..ba20f552e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -273,10 +273,7 @@ async def _process_streamed_response(self, response: AsyncIterable[ChatCompletio ) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] async def _map_messages( self, messages: list[ModelMessage] diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 2957ba16f..c5e6211b2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -234,11 +234,7 @@ async def _stream_completions_create( response: MistralEventStreamAsync[MistralCompletionEvent] | None mistral_messages = self._map_messages(messages) - if ( - model_request_parameters.output_tools - and model_request_parameters.function_tools - or model_request_parameters.function_tools - ): + if model_request_parameters.function_tools: # Function Calling response = await self.client.chat.stream_async( model=str(self._model_name), @@ -306,16 +302,13 @@ def _map_function_and_output_tools_definition( Returns None if both function_tools and output_tools are empty. """ - all_tools: list[ToolDefinition] = ( - model_request_parameters.function_tools + model_request_parameters.output_tools - ) tools = [ MistralTool( function=MistralFunction( name=r.name, parameters=r.parameters_json_schema, description=r.description or '' ) ) - for r in all_tools + for r in model_request_parameters.tool_defs.values() ] return tools if tools else None diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 3e4f5d2b9..6124bb964 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -439,10 +439,7 @@ async def _process_streamed_response(self, response: AsyncStream[ChatCompletionC ) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`.""" @@ -836,10 +833,7 @@ def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reason return Reasoning(effort=reasoning_effort, summary=reasoning_summary) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.FunctionToolParam]: - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] - if model_request_parameters.output_tools: - tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] - return tools + return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] def _map_tool_definition(self, f: ToolDefinition) -> responses.FunctionToolParam: return { diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index e640302b2..5261f75c8 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -46,6 +46,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _raw_stream_response: models.StreamedResponse _output_schema: OutputSchema[OutputDataT] + _model_request_parameters: models.ModelRequestParameters _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] _usage_limits: UsageLimits | None @@ -230,32 +231,12 @@ def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: return self._agent_stream_iterator async def aiter(): - output_schema = self._output_schema - - def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None: - """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" - if isinstance(e, _messages.PartStartEvent): - new_part = e.part - if isinstance(new_part, _messages.TextPart) and isinstance( - output_schema, TextOutputSchema - ): # pragma: no branch - return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) - elif isinstance(new_part, _messages.ToolCallPart) and ( - tool_def := self._tool_manager.get_tool_def(new_part.tool_name) - ): - if tool_def.kind == 'output': - return _messages.FinalResultEvent( - tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id - ) - elif tool_def.kind == 'deferred': - return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) - usage_checking_stream = _get_usage_checking_stream_response( self._raw_stream_response, self._usage_limits, self.usage ) async for event in usage_checking_stream: yield event - if (final_result_event := _get_final_result_event(event)) is not None: + if (final_result_event := self._model_request_parameters.get_final_result_event(event)) is not None: self._final_result_event = final_result_event yield final_result_event break From 529f4c898c270749693750f42d574397b72e9dbb Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 31 Jul 2025 23:16:42 +0000 Subject: [PATCH 21/41] Some polish --- pydantic_ai_slim/pydantic_ai/agent.py | 7 +++---- pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py | 1 + .../pydantic_ai/ext/temporal/_function_toolset.py | 9 --------- .../pydantic_ai/ext/temporal/_mcp_server.py | 12 ++++++------ .../pydantic_ai/ext/temporal/_toolset.py | 4 ++-- pydantic_ai_slim/pydantic_ai/mcp.py | 4 ++-- 6 files changed, 14 insertions(+), 23 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 61d749ad2..b4c144fe1 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -10,7 +10,7 @@ from contextvars import ContextVar from copy import deepcopy from types import FrameType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, overload from opentelemetry.trace import NoOpTracer, use_span from pydantic.json_schema import GenerateJsonSchema @@ -105,7 +105,6 @@ ] -@final @dataclasses.dataclass(init=False) class Agent(Generic[AgentDepsT, OutputDataT]): """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. @@ -798,7 +797,7 @@ async def main(): toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) # This will raise errors for any name conflicts - run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context) + tool_manager = await ToolManager[AgentDepsT].build(toolset, run_context) # Merge model settings in order of precedence: run > agent > model merged_settings = merge_model_settings(model_used.settings, self.model_settings) @@ -842,7 +841,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: output_schema=output_schema, output_validators=output_validators, history_processors=self.history_processors, - tool_manager=run_toolset, + tool_manager=tool_manager, tracer=tracer, get_instructions=get_instructions, instrumentation_settings=instrumentation_settings, diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py index ef7f857a4..22454d7f6 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -65,6 +65,7 @@ def temporalize_toolset(toolset: AbstractToolset) -> AbstractToolset: activities.extend(toolset.temporal_activities) return toolset + # TODO: Use public methods so others can replicate this temporal_toolsets = [temporalize_toolset(toolset) for toolset in [agent._function_toolset, *agent._user_toolsets]] # pyright: ignore[reportPrivateUsage] original_iter = agent.iter diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py index 0f397c8cb..e46477541 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py @@ -65,15 +65,6 @@ def wrapped_function_toolset(self) -> FunctionToolset: def temporal_activities(self) -> list[Callable[..., Any]]: return [self.call_tool_activity] - def tool(self, *args: Any, **kwargs: Any) -> Any: - return self.wrapped_function_toolset.tool(*args, **kwargs) - - def add_function(self, *args: Any, **kwargs: Any) -> None: - return self.wrapped_function_toolset.add_function(*args, **kwargs) - - def add_tool(self, *args: Any, **kwargs: Any) -> None: - return self.wrapped_function_toolset.add_tool(*args, **kwargs) - async def call_tool(self, name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: if not workflow.in_workflow(): return await super().call_tool(name, tool_args, ctx, tool) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py index 25575b7a5..b89f66ba3 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py @@ -51,7 +51,10 @@ def __init__( @activity.defn(name=f'mcp_server__{id}__get_tools') async def get_tools_activity(params: _GetToolsParams) -> dict[str, ToolDefinition]: run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context) - return {name: tool.tool_def for name, tool in (await self.wrapped.get_tools(run_context)).items()} + tools = await self.wrapped.get_tools(run_context) + # ToolsetTool is not serializable as its holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time), + # so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity. + return {name: tool.tool_def for name, tool in tools.items()} self.get_tools_activity = get_tools_activity @@ -62,7 +65,7 @@ async def call_tool_activity(params: _CallToolParams) -> ToolResult: params.name, params.tool_args, run_context, - self.wrapped_server._toolset_tool_for_tool_def(params.tool_def), # pyright: ignore[reportPrivateUsage] + self.wrapped_server.tool_for_tool_def(params.tool_def), ) self.call_tool_activity = call_tool_activity @@ -86,10 +89,7 @@ async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: arg=_GetToolsParams(serialized_run_context=serialized_run_context), **self.activity_config, ) - return { - name: self.wrapped_server._toolset_tool_for_tool_def(tool_def) # pyright: ignore[reportPrivateUsage] - for name, tool_def in tool_defs.items() - } + return {name: self.wrapped_server.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()} async def call_tool( self, diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py index 32f9ef584..c568ef793 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_toolset.py @@ -20,11 +20,11 @@ def temporal_activities(self) -> list[Callable[..., Any]]: def temporalize_toolset( - toolset: AbstractToolset, + toolset: AbstractToolset[Any], activity_config: ActivityConfig = {}, tool_activity_config: dict[str, ActivityConfig | Literal[False]] = {}, run_context_type: type[TemporalRunContext] = TemporalRunContext, -) -> AbstractToolset: +) -> AbstractToolset[Any]: """Temporalize a toolset. Args: diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 82e9976c8..b3426ed36 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -207,7 +207,7 @@ async def call_tool( async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: return { - name: self._toolset_tool_for_tool_def( + name: self.tool_for_tool_def( ToolDefinition( name=name, description=mcp_tool.description, @@ -218,7 +218,7 @@ async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: if (name := f'{self.tool_prefix}_{mcp_tool.name}' if self.tool_prefix else mcp_tool.name) } - def _toolset_tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]: + def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[Any]: return ToolsetTool( toolset=self, tool_def=tool_def, From 295a69b4e711cc68e3caebfbe2fa348c38665e66 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 1 Aug 2025 00:11:16 +0000 Subject: [PATCH 22/41] Add AbstractAgent, WrapperAgent, TemporalAgent instead of temporalize_agent with monkeypatching --- pydantic_ai_slim/pydantic_ai/_a2a.py | 10 +- pydantic_ai_slim/pydantic_ai/_cli.py | 6 +- pydantic_ai_slim/pydantic_ai/ag_ui.py | 6 +- pydantic_ai_slim/pydantic_ai/agent.py | 2391 ++++++++++------- .../pydantic_ai/ext/temporal/__init__.py | 15 +- .../pydantic_ai/ext/temporal/_agent.py | 396 ++- .../pydantic_ai/ext/temporal/_run_context.py | 4 +- temporal.py | 8 +- 8 files changed, 1741 insertions(+), 1095 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 8a916b5cc..101d3f2e9 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -27,7 +27,7 @@ VideoUrl, ) -from .agent import Agent, AgentDepsT, OutputDataT +from .agent import AbstractAgent, AgentDepsT, OutputDataT # AgentWorker output type needs to be invariant for use in both parameter and return positions WorkerOutputT = TypeVar('WorkerOutputT') @@ -59,7 +59,9 @@ @asynccontextmanager -async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT, OutputDataT]) -> AsyncIterator[None]: +async def worker_lifespan( + app: FastA2A, worker: Worker, agent: AbstractAgent[AgentDepsT, OutputDataT] +) -> AsyncIterator[None]: """Custom lifespan that runs the worker during application startup. This ensures the worker is started and ready to process tasks as soon as the application starts. @@ -70,7 +72,7 @@ async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT, def agent_to_a2a( - agent: Agent[AgentDepsT, OutputDataT], + agent: AbstractAgent[AgentDepsT, OutputDataT], *, storage: Storage | None = None, broker: Broker | None = None, @@ -116,7 +118,7 @@ def agent_to_a2a( class AgentWorker(Worker[list[ModelMessage]], Generic[WorkerOutputT, AgentDepsT]): """A worker that uses an agent to execute tasks.""" - agent: Agent[AgentDepsT, WorkerOutputT] + agent: AbstractAgent[AgentDepsT, WorkerOutputT] async def run_task(self, params: TaskSendParams) -> None: task = await self.storage.load_task(params['id']) diff --git a/pydantic_ai_slim/pydantic_ai/_cli.py b/pydantic_ai_slim/pydantic_ai/_cli.py index ae4f6ff6f..3f596405a 100644 --- a/pydantic_ai_slim/pydantic_ai/_cli.py +++ b/pydantic_ai_slim/pydantic_ai/_cli.py @@ -16,7 +16,7 @@ from . import __version__ from ._run_context import AgentDepsT -from .agent import Agent +from .agent import AbstractAgent, Agent from .exceptions import UserError from .messages import ModelMessage from .models import KnownModelName, infer_model @@ -220,7 +220,7 @@ def cli( # noqa: C901 async def run_chat( stream: bool, - agent: Agent[AgentDepsT, OutputDataT], + agent: AbstractAgent[AgentDepsT, OutputDataT], console: Console, code_theme: str, prog_name: str, @@ -263,7 +263,7 @@ async def run_chat( async def ask_agent( - agent: Agent[AgentDepsT, OutputDataT], + agent: AbstractAgent[AgentDepsT, OutputDataT], prompt: str, stream: bool, console: Console, diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index d0a23baa6..e4fcb4d81 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -72,7 +72,7 @@ from pydantic import BaseModel, ValidationError from ._agent_graph import CallToolsNode, ModelRequestNode -from .agent import Agent, AgentRun, RunOutputDataT +from .agent import AbstractAgent, AgentRun, RunOutputDataT from .messages import ( AgentStreamEvent, FunctionToolResultEvent, @@ -115,7 +115,7 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette): def __init__( self, - agent: Agent[AgentDepsT, OutputDataT], + agent: AbstractAgent[AgentDepsT, OutputDataT], *, # Agent.iter parameters. output_type: OutputSpec[OutputDataT] | None = None, @@ -223,7 +223,7 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]): agent: The Pydantic AI `Agent` to adapt. """ - agent: Agent[AgentDepsT, OutputDataT] = field(repr=False) + agent: AbstractAgent[AgentDepsT, OutputDataT] = field(repr=False) async def run( self, diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index b4c144fe1..c80a328ed 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -4,6 +4,7 @@ import inspect import json import warnings +from abc import ABC, abstractmethod from asyncio import Lock from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator, Mapping, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager @@ -105,397 +106,197 @@ ] -@dataclasses.dataclass(init=False) -class Agent(Generic[AgentDepsT, OutputDataT]): - """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. +class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC): + @property + @abstractmethod + def model(self) -> models.Model | models.KnownModelName | str | None: + raise NotImplementedError - Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT] - and the output type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT]. + @property + @abstractmethod + def name(self) -> str | None: + raise NotImplementedError - By default, if neither generic parameter is customised, agents have type `Agent[None, str]`. + @name.setter + @abstractmethod + def name(self, value: str | None) -> None: + raise NotImplementedError - Minimal usage example: + @property + @abstractmethod + def output_type(self) -> OutputSpec[OutputDataT]: + raise NotImplementedError - ```python - from pydantic_ai import Agent + @property + @abstractmethod + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + raise NotImplementedError - agent = Agent('openai:gpt-4o') - result = agent.run_sync('What is the capital of France?') - print(result.output) - #> Paris - ``` - """ + @overload + async def run( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AgentRunResult[OutputDataT]: ... - model: models.Model | models.KnownModelName | str | None - """The default model configured for this agent. + @overload + async def run( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AgentRunResult[RunOutputDataT]: ... - We allow `str` here since the actual list of allowed models changes frequently. - """ + @overload + @deprecated('`result_type` is deprecated, use `output_type` instead.') + async def run( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + result_type: type[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AgentRunResult[RunOutputDataT]: ... - name: str | None - """The name of the agent, used for logging. + async def run( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AgentRunResult[Any]: + """Run the agent with a user prompt in async mode. - If `None`, we try to infer the agent name from the call frame when the agent is first run. - """ - end_strategy: EndStrategy - """Strategy for handling tool calls when a final result is found.""" + This method builds an internal agent graph (using system prompts, tools and result schemas) and then + runs the graph to completion. The result of the run is returned. - model_settings: ModelSettings | None - """Optional model request settings to use for this agents's runs, by default. + Example: + ```python + from pydantic_ai import Agent - Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will - be merged with this value, with the runtime argument taking priority. - """ + agent = Agent('openai:gpt-4o') - output_type: OutputSpec[OutputDataT] - """ - The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`. - """ + async def main(): + agent_run = await agent.run('What is the capital of France?') + print(agent_run.output) + #> Paris + ``` - instrument: InstrumentationSettings | bool | None - """Options to automatically instrument with OpenTelemetry.""" + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. - event_stream_handler: EventStreamHandler[AgentDepsT] | None - """Optional handler for events from the agent stream.""" + Returns: + The result of the run. + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) - _instrument_default: ClassVar[InstrumentationSettings | bool] = False + if 'result_type' in _deprecated_kwargs: # pragma: no cover + if output_type is not str: + raise TypeError('`result_type` and `output_type` cannot be set at the same time.') + warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) + output_type = _deprecated_kwargs.pop('result_type') - _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) - _deprecated_result_tool_name: str | None = dataclasses.field(repr=False) - _deprecated_result_tool_description: str | None = dataclasses.field(repr=False) - _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False) - _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False) - _instructions: str | None = dataclasses.field(repr=False) - _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) - _system_prompts: tuple[str, ...] = dataclasses.field(repr=False) - _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) - _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field( - repr=False - ) - _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False) - _output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False) - _user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False) - _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) - _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) - _max_result_retries: int = dataclasses.field(repr=False) + _utils.validate_empty_kwargs(_deprecated_kwargs) - _enter_lock: Lock = dataclasses.field(repr=False) - _entered_count: int = dataclasses.field(repr=False) - _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) + async with self.iter( + user_prompt=user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + toolsets=toolsets, + ) as agent_run: + async for node in agent_run: + if self.event_stream_handler is not None and ( + self.is_model_request_node(node) or self.is_call_tools_node(node) + ): + async with node.stream(agent_run.ctx) as stream: + await self.event_stream_handler(_agent_graph.build_run_context(agent_run.ctx), stream) + + assert agent_run.result is not None, 'The graph run did not finish properly' + return agent_run.result @overload - def __init__( + def run_sync( self, - model: models.Model | models.KnownModelName | str | None = None, + user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[OutputDataT] = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, model_settings: ModelSettings | None = None, - retries: int = 1, - output_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, - ) -> None: ... + ) -> AgentRunResult[OutputDataT]: ... @overload - @deprecated( - '`result_type`, `result_tool_name` & `result_tool_description` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' - ) - def __init__( + def run_sync( self, - model: models.Model | models.KnownModelName | str | None = None, + user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - result_type: type[OutputDataT] = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, model_settings: ModelSettings | None = None, - retries: int = 1, - result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, - result_tool_description: str | None = None, - result_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - ) -> None: ... + ) -> AgentRunResult[RunOutputDataT]: ... @overload - @deprecated('`mcp_servers` is deprecated, use `toolsets` instead.') - def __init__( + @deprecated('`result_type` is deprecated, use `output_type` instead.') + def run_sync( self, - model: models.Model | models.KnownModelName | str | None = None, - *, - result_type: type[OutputDataT] = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, - model_settings: ModelSettings | None = None, - retries: int = 1, - result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, - result_tool_description: str | None = None, - result_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[MCPServer] = (), - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, - ) -> None: ... - - def __init__( - self, - model: models.Model | models.KnownModelName | str | None = None, - *, - # TODO change this back to `output_type: _output.OutputType[OutputDataT] = str,` when we remove the overloads - output_type: Any = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, - model_settings: ModelSettings | None = None, - retries: int = 1, - output_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, - **_deprecated_kwargs: Any, - ): - """Create an agent. - - Args: - model: The default model to use for this agent, if not provide, - you must provide the model when calling it. We allow `str` here since the actual list of allowed models changes frequently. - output_type: The type of the output data, used to validate the data returned by the model, - defaults to `str`. - instructions: Instructions to use for this agent, you can also register instructions via a function with - [`instructions`][pydantic_ai.Agent.instructions]. - system_prompt: Static system prompts to use for this agent, you can also register system - prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt]. - deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully - parameterize the agent, and therefore get the best out of static type checking. - If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright - or add a type hint `: Agent[None, ]`. - name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame - when the agent is first run. - model_settings: Optional model request settings to use for this agent's runs, by default. - retries: The default number of retries to allow before raising an error. - output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. - tools: Tools to register with the agent, you can also register tools via the decorators - [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. - prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools. - This is useful if you want to customize the definition of multiple tools or you want to register - a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] - prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step. - This is useful if you want to customize the definition of multiple output tools or you want to register - a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] - toolsets: Toolsets to register with the agent, including MCP servers. - defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, - it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, - which checks for the necessary environment variables. Set this to `false` - to defer the evaluation until the first run. Useful if you want to - [override the model][pydantic_ai.Agent.override] for testing. - end_strategy: Strategy for handling tool calls that are requested alongside a final result. - See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information. - instrument: Set to True to automatically instrument with OpenTelemetry, - which will use Logfire if it's configured. - Set to an instance of [`InstrumentationSettings`][pydantic_ai.agent.InstrumentationSettings] to customize. - If this isn't set, then the last value set by - [`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all] - will be used, which defaults to False. - See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. - history_processors: Optional list of callables to process the message history before sending it to the model. - Each processor takes a list of messages and returns a modified list of messages. - Processors can be sync or async and are applied in sequence. - event_stream_handler: TODO: Optional handler for events from the agent stream. - """ - if model is None or defer_model_check: - self.model = model - else: - self.model = models.infer_model(model) - - self.end_strategy = end_strategy - self.name = name - self.model_settings = model_settings - - if 'result_type' in _deprecated_kwargs: - if output_type is not str: # pragma: no cover - raise TypeError('`result_type` and `output_type` cannot be set at the same time.') - warnings.warn('`result_type` is deprecated, use `output_type` instead', DeprecationWarning, stacklevel=2) - output_type = _deprecated_kwargs.pop('result_type') - - self.output_type = output_type - - self.instrument = instrument - - self._deps_type = deps_type - - self._deprecated_result_tool_name = _deprecated_kwargs.pop('result_tool_name', None) - if self._deprecated_result_tool_name is not None: - warnings.warn( - '`result_tool_name` is deprecated, use `output_type` with `ToolOutput` instead', - DeprecationWarning, - stacklevel=2, - ) - - self._deprecated_result_tool_description = _deprecated_kwargs.pop('result_tool_description', None) - if self._deprecated_result_tool_description is not None: - warnings.warn( - '`result_tool_description` is deprecated, use `output_type` with `ToolOutput` instead', - DeprecationWarning, - stacklevel=2, - ) - result_retries = _deprecated_kwargs.pop('result_retries', None) - if result_retries is not None: - if output_retries is not None: # pragma: no cover - raise TypeError('`output_retries` and `result_retries` cannot be set at the same time.') - warnings.warn( - '`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning, stacklevel=2 - ) - output_retries = result_retries - - if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None): - if toolsets is not None: # pragma: no cover - raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.') - warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning) - toolsets = mcp_servers - - _utils.validate_empty_kwargs(_deprecated_kwargs) - - default_output_mode = ( - self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None - ) - - self._output_schema = _output.OutputSchema[OutputDataT].build( - output_type, - default_mode=default_output_mode, - name=self._deprecated_result_tool_name, - description=self._deprecated_result_tool_description, - ) - self._output_validators = [] - - self._instructions = '' - self._instructions_functions = [] - if isinstance(instructions, (str, Callable)): - instructions = [instructions] - for instruction in instructions or []: - if isinstance(instruction, str): - self._instructions += instruction + '\n' - else: - self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction)) - self._instructions = self._instructions.strip() or None - - self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) - self._system_prompt_functions = [] - self._system_prompt_dynamic_functions = {} - - self._max_result_retries = output_retries if output_retries is not None else retries - self._prepare_tools = prepare_tools - self._prepare_output_tools = prepare_output_tools - - self._output_toolset = self._output_schema.toolset - if self._output_toolset: - self._output_toolset.max_retries = self._max_result_retries - - self._function_toolset = FunctionToolset(tools, max_retries=retries, id='agent') - self._user_toolsets = toolsets or () - - self.history_processors = history_processors or [] - - self.event_stream_handler = event_stream_handler - - self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) - self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) - self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar( - '_override_toolsets', default=None - ) - self._override_tools: ContextVar[ - _utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]] - ] = ContextVar('_override_tools', default=None) - - self._enter_lock = _utils.get_async_lock() - self._entered_count = 0 - self._exit_stack = None - - @staticmethod - def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: - """Set the instrumentation options for all agents where `instrument` is not set.""" - Agent._instrument_default = instrument - - @overload - async def run( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: None = None, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[OutputDataT]: ... - - @overload - async def run( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: OutputSpec[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[RunOutputDataT]: ... - - @overload - @deprecated('`result_type` is deprecated, use `output_type` instead.') - async def run( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, + user_prompt: str | Sequence[_messages.UserContent] | None = None, *, result_type: type[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, @@ -508,7 +309,7 @@ async def run( toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... - async def run( + def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, @@ -523,10 +324,10 @@ async def run( toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: - """Run the agent with a user prompt in async mode. + """Synchronously run the agent with a user prompt. - This method builds an internal agent graph (using system prompts, tools and result schemas) and then - runs the graph to completion. The result of the run is returned. + This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. + You therefore can't use this method inside async code or if there's an active event loop. Example: ```python @@ -534,10 +335,9 @@ async def run( agent = Agent('openai:gpt-4o') - async def main(): - agent_run = await agent.run('What is the capital of France?') - print(agent_run.output) - #> Paris + result_sync = agent.run_sync('What is the capital of Italy?') + print(result_sync.output) + #> Rome ``` Args: @@ -567,26 +367,205 @@ async def main(): _utils.validate_empty_kwargs(_deprecated_kwargs) - async with self.iter( - user_prompt=user_prompt, - output_type=output_type, - message_history=message_history, - model=model, - deps=deps, - model_settings=model_settings, + return get_event_loop().run_until_complete( + self.run( + user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=False, + toolsets=toolsets, + ) + ) + + @overload + def run_stream( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... + + @overload + def run_stream( + self, + user_prompt: str | Sequence[_messages.UserContent], + *, + output_type: OutputSpec[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... + + @overload + @deprecated('`result_type` is deprecated, use `output_type` instead.') + def run_stream( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + result_type: type[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... + + @asynccontextmanager + async def run_stream( # noqa C901 + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: + """Run the agent with a user prompt in async mode, returning a streamed response. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + async with agent.run_stream('What is the capital of the UK?') as response: + print(await response.get_output()) + #> London + ``` + + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + + Returns: + The result of the run. + """ + # TODO: We need to deprecate this now that we have the `iter` method. + # Before that, though, we should add an event for when we reach the final result of the stream. + if infer_name and self.name is None: + # f_back because `asynccontextmanager` adds one frame + if frame := inspect.currentframe(): # pragma: no branch + self._infer_name(frame.f_back) + + if 'result_type' in _deprecated_kwargs: # pragma: no cover + if output_type is not str: + raise TypeError('`result_type` and `output_type` cannot be set at the same time.') + warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) + output_type = _deprecated_kwargs.pop('result_type') + + _utils.validate_empty_kwargs(_deprecated_kwargs) + + yielded = False + async with self.iter( + user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, usage_limits=usage_limits, usage=usage, + infer_name=False, toolsets=toolsets, ) as agent_run: - async for node in agent_run: - if self.event_stream_handler is not None and ( - self.is_model_request_node(node) or self.is_call_tools_node(node) - ): - async with node.stream(agent_run.ctx) as stream: - await self.event_stream_handler(_agent_graph.build_run_context(agent_run.ctx), stream) + first_node = agent_run.next_node # start with the first node + assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node + node = first_node + while True: + if self.is_model_request_node(node): + graph_ctx = agent_run.ctx + async with node.stream(graph_ctx) as stream: - assert agent_run.result is not None, 'The graph run did not finish properly' - return agent_run.result + async def stream_to_final(s: AgentStream) -> FinalResult[AgentStream] | None: + async for event in stream: + if isinstance(event, _messages.FinalResultEvent): + return FinalResult(s, event.tool_name, event.tool_call_id) + return None + + final_result = await stream_to_final(stream) + if final_result is not None: + if yielded: + raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover + yielded = True + + messages = graph_ctx.state.message_history.copy() + + async def on_complete() -> None: + """Called when the stream has completed. + + The model response will have been added to messages by now + by `StreamedRunResult._marked_completed`. + """ + last_message = messages[-1] + assert isinstance(last_message, _messages.ModelResponse) + tool_calls = [ + part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) + ] + + parts: list[_messages.ModelRequestPart] = [] + async for _event in _agent_graph.process_function_tools( + graph_ctx.deps.tool_manager, + tool_calls, + final_result, + graph_ctx, + parts, + ): + pass + if parts: + messages.append(_messages.ModelRequest(parts)) + + yield StreamedRunResult( + messages, + graph_ctx.deps.new_message_index, + stream, + on_complete, + ) + break + next_node = await agent_run.next(node) + if not isinstance(next_node, _agent_graph.AgentNode): + raise exceptions.AgentRunError( # pragma: no cover + 'Should have produced a StreamedRunResult before getting here' + ) + node = cast(_agent_graph.AgentNode[Any, Any], next_node) + + if not yielded: + raise exceptions.AgentRunError('Agent run finished without producing a final result') # pragma: no cover @overload def iter( @@ -640,6 +619,7 @@ def iter( ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... @asynccontextmanager + @abstractmethod async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, @@ -732,315 +712,912 @@ async def main(): Returns: The result of the run. """ - if infer_name and self.name is None: - self._infer_name(inspect.currentframe()) - model_used = self._get_model(model) - del model + raise NotImplementedError + yield - if 'result_type' in _deprecated_kwargs: # pragma: no cover - if output_type is not str: - raise TypeError('`result_type` and `output_type` cannot be set at the same time.') - warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) - output_type = _deprecated_kwargs.pop('result_type') + @contextmanager + @abstractmethod + def override( + self, + *, + deps: AgentDepsT | _utils.Unset = _utils.UNSET, + model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + ) -> Iterator[None]: + """Context manager to temporarily override agent dependencies, model, or toolsets. - _utils.validate_empty_kwargs(_deprecated_kwargs) + This is particularly useful when testing. + You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). - deps = self._get_deps(deps) - new_message_index = len(message_history) if message_history else 0 - output_schema = self._prepare_output_schema(output_type, model_used.profile) + Args: + deps: The dependencies to use instead of the dependencies passed to the agent run. + model: The model to use instead of the model passed to the agent run. + toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. + tools: The tools to use instead of the tools registered with the agent. + """ + raise NotImplementedError + yield - output_type_ = output_type or self.output_type + def _infer_name(self, function_frame: FrameType | None) -> None: + """Infer the agent name from the call frame. - # We consider it a user error if a user tries to restrict the result type while having an output validator that - # may change the result type from the restricted type to something else. Therefore, we consider the following - # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. - output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) + Usage should be `self._infer_name(inspect.currentframe())`. + """ + assert self.name is None, 'Name already set' + if function_frame is not None: # pragma: no branch + if parent_frame := function_frame.f_back: # pragma: no branch + for name, item in parent_frame.f_locals.items(): + if item is self: + self.name = name + return + if parent_frame.f_locals != parent_frame.f_globals: # pragma: no branch + # if we couldn't find the agent in locals and globals are a different dict, try globals + for name, item in parent_frame.f_globals.items(): + if item is self: + self.name = name + return - output_toolset = self._output_toolset - if output_schema != self._output_schema or output_validators: - output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset) - if output_toolset: - output_toolset.max_retries = self._max_result_retries - output_toolset.output_validators = output_validators + @staticmethod + def is_model_request_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeIs[_agent_graph.ModelRequestNode[T, S]]: + """Check if the node is a `ModelRequestNode`, narrowing the type if it is. - # Build the graph - graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( - _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) - ) + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, _agent_graph.ModelRequestNode) - # Build the initial state - usage = usage or _usage.Usage() - state = _agent_graph.GraphAgentState( - message_history=message_history[:] if message_history else [], - usage=usage, - retries=0, - run_step=0, - ) + @staticmethod + def is_call_tools_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeIs[_agent_graph.CallToolsNode[T, S]]: + """Check if the node is a `CallToolsNode`, narrowing the type if it is. - if isinstance(model_used, InstrumentedModel): - instrumentation_settings = model_used.instrumentation_settings - tracer = model_used.instrumentation_settings.tracer - else: - instrumentation_settings = None - tracer = NoOpTracer() + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, _agent_graph.CallToolsNode) - run_context = RunContext[AgentDepsT]( + @staticmethod + def is_user_prompt_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeIs[_agent_graph.UserPromptNode[T, S]]: + """Check if the node is a `UserPromptNode`, narrowing the type if it is. + + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, _agent_graph.UserPromptNode) + + @staticmethod + def is_end_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeIs[End[result.FinalResult[S]]]: + """Check if the node is a `End`, narrowing the type if it is. + + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, End) + + @abstractmethod + async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]: + raise NotImplementedError + + @abstractmethod + async def __aexit__(self, *args: Any) -> bool | None: + raise NotImplementedError + + def to_ag_ui( + self, + *, + # Agent.iter parameters + output_type: OutputSpec[OutputDataT] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + # Starlette + debug: bool = False, + routes: Sequence[BaseRoute] | None = None, + middleware: Sequence[Middleware] | None = None, + exception_handlers: Mapping[Any, ExceptionHandler] | None = None, + on_startup: Sequence[Callable[[], Any]] | None = None, + on_shutdown: Sequence[Callable[[], Any]] | None = None, + lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None, + ) -> AGUIApp[AgentDepsT, OutputDataT]: + """Convert the agent to an AG-UI application. + + This allows you to use the agent with a compatible AG-UI frontend. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + app = agent.to_ag_ui() + ``` + + The `app` is an ASGI application that can be used with any ASGI server. + + To run the application, you can use the following command: + + ```bash + uvicorn app:app --host 0.0.0.0 --port 8000 + ``` + + See [AG-UI docs](../ag-ui.md) for more information. + + Args: + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has + no output validators since output validators would expect an argument that matches the agent's + output type. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. + + debug: Boolean indicating if debug tracebacks should be returned on errors. + routes: A list of routes to serve incoming HTTP and WebSocket requests. + middleware: A list of middleware to run for every request. A starlette application will always + automatically include two middleware classes. `ServerErrorMiddleware` is added as the very + outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. + `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled + exception cases occurring in the routing or endpoints. + exception_handlers: A mapping of either integer status codes, or exception class types onto + callables which handle the exceptions. Exception handler callables should be of the form + `handler(request, exc) -> response` and may be either standard functions, or async functions. + on_startup: A list of callables to run on application startup. Startup handler callables do not + take any arguments, and may be either standard functions, or async functions. + on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do + not take any arguments, and may be either standard functions, or async functions. + lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks. + This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or + the other, not both. + + Returns: + An ASGI application for running Pydantic AI agents with AG-UI protocol support. + """ + from .ag_ui import AGUIApp + + return AGUIApp( + agent=self, + # Agent.iter parameters + output_type=output_type, + model=model, deps=deps, - model=model_used, + model_settings=model_settings, + usage_limits=usage_limits, usage=usage, - prompt=user_prompt, - messages=state.message_history, - tracer=tracer, - trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content, - run_step=state.run_step, + infer_name=infer_name, + toolsets=toolsets, + # Starlette + debug=debug, + routes=routes, + middleware=middleware, + exception_handlers=exception_handlers, + on_startup=on_startup, + on_shutdown=on_shutdown, + lifespan=lifespan, ) - toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) - # This will raise errors for any name conflicts - tool_manager = await ToolManager[AgentDepsT].build(toolset, run_context) + def to_a2a( + self, + *, + storage: Storage | None = None, + broker: Broker | None = None, + # Agent card + name: str | None = None, + url: str = 'http://localhost:8000', + version: str = '1.0.0', + description: str | None = None, + provider: AgentProvider | None = None, + skills: list[Skill] | None = None, + # Starlette + debug: bool = False, + routes: Sequence[Route] | None = None, + middleware: Sequence[Middleware] | None = None, + exception_handlers: dict[Any, ExceptionHandler] | None = None, + lifespan: Lifespan[FastA2A] | None = None, + ) -> FastA2A: + """Convert the agent to a FastA2A application. - # Merge model settings in order of precedence: run > agent > model - merged_settings = merge_model_settings(model_used.settings, self.model_settings) - model_settings = merge_model_settings(merged_settings, model_settings) - usage_limits = usage_limits or _usage.UsageLimits() - agent_name = self.name or 'agent' - run_span = tracer.start_span( - 'agent run', - attributes={ - 'model_name': model_used.model_name if model_used else 'no-model', - 'agent_name': agent_name, - 'logfire.msg': f'{agent_name} run', - }, + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + app = agent.to_a2a() + ``` + + The `app` is an ASGI application that can be used with any ASGI server. + + To run the application, you can use the following command: + + ```bash + uvicorn app:app --host 0.0.0.0 --port 8000 + ``` + """ + from ._a2a import agent_to_a2a + + return agent_to_a2a( + self, + storage=storage, + broker=broker, + name=name, + url=url, + version=version, + description=description, + provider=provider, + skills=skills, + debug=debug, + routes=routes, + middleware=middleware, + exception_handlers=exception_handlers, + lifespan=lifespan, ) - async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: - parts = [ - self._instructions, - *[await func.run(run_context) for func in self._instructions_functions], - ] + async def to_cli(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: + """Run the agent in a CLI chat interface. - model_profile = model_used.profile - if isinstance(output_schema, _output.PromptedOutputSchema): - instructions = output_schema.instructions(model_profile.prompted_output_template) - parts.append(instructions) + Args: + deps: The dependencies to pass to the agent. + prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. - parts = [p for p in parts if p] - if not parts: - return None - return '\n\n'.join(parts).strip() + Example: + ```python {title="agent_to_cli.py" test="skip"} + from pydantic_ai import Agent - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( - user_deps=deps, - prompt=user_prompt, - new_message_index=new_message_index, - model=model_used, + agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') + + async def main(): + await agent.to_cli() + ``` + """ + from rich.console import Console + + from pydantic_ai._cli import run_chat + + await run_chat(stream=True, agent=self, deps=deps, console=Console(), code_theme='monokai', prog_name=prog_name) + + def to_cli_sync(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: + """Run the agent in a CLI chat interface with the non-async interface. + + Args: + deps: The dependencies to pass to the agent. + prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. + + ```python {title="agent_to_cli_sync.py" test="skip"} + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') + agent.to_cli_sync() + agent.to_cli_sync(prog_name='assistant') + ``` + """ + return get_event_loop().run_until_complete(self.to_cli(deps=deps, prog_name=prog_name)) + + +class WrapperAgent(AbstractAgent[AgentDepsT, OutputDataT]): + def __init__(self, wrapped: AbstractAgent[AgentDepsT, OutputDataT]): + self.wrapped = wrapped + + @property + def model(self) -> models.Model | models.KnownModelName | str | None: + return self.wrapped.model + + @property + def name(self) -> str | None: + return self.wrapped.name + + @name.setter + def name(self, value: str | None) -> None: + self.wrapped.name = value + + @property + def output_type(self) -> OutputSpec[OutputDataT]: + return self.wrapped.output_type + + @property + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + return self.wrapped.event_stream_handler + + async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]: + return await self.wrapped.__aenter__() + + async def __aexit__(self, *args: Any) -> bool | None: + return await self.wrapped.__aexit__(*args) + + def __getattr__(self, name: str) -> Any: + return getattr(self._agent, name) + + @overload + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... + + @overload + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... + + @overload + @deprecated('`result_type` is deprecated, use `output_type` instead.') + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + result_type: type[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... + + @asynccontextmanager + async def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: + """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. + + This method builds an internal agent graph (using system prompts, tools and output schemas) and then returns an + `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are + executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the + stream of events coming from the execution of tools. + + The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics, + and the final result of the run once it has completed. + + For more details, see the documentation of `AgentRun`. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + nodes = [] + async with agent.iter('What is the capital of France?') as agent_run: + async for node in agent_run: + nodes.append(node) + print(nodes) + ''' + [ + UserPromptNode( + user_prompt='What is the capital of France?', + instructions=None, + instructions_functions=[], + system_prompts=(), + system_prompt_functions=[], + system_prompt_dynamic_functions={}, + ), + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + ) + ] + ) + ), + CallToolsNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris')], + usage=Usage( + requests=1, request_tokens=56, response_tokens=1, total_tokens=57 + ), + model_name='gpt-4o', + timestamp=datetime.datetime(...), + ) + ), + End(data=FinalResult(output='Paris')), + ] + ''' + print(agent_run.result.output) + #> Paris + ``` + + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + + Returns: + The result of the run. + """ + async with self.wrapped.iter( + user_prompt=user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, model_settings=model_settings, usage_limits=usage_limits, - max_result_retries=self._max_result_retries, - end_strategy=self.end_strategy, - output_schema=output_schema, - output_validators=output_validators, - history_processors=self.history_processors, - tool_manager=tool_manager, - tracer=tracer, - get_instructions=get_instructions, - instrumentation_settings=instrumentation_settings, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + ) as result: + yield result + + @contextmanager + def override( + self, + *, + deps: AgentDepsT | _utils.Unset = _utils.UNSET, + model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + ) -> Iterator[None]: + """Context manager to temporarily override agent dependencies, model, or toolsets. + + This is particularly useful when testing. + You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). + + Args: + deps: The dependencies to use instead of the dependencies passed to the agent run. + model: The model to use instead of the model passed to the agent run. + toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. + tools: The tools to use instead of the tools registered with the agent. + """ + with self.wrapped.override(deps=deps, model=model, toolsets=toolsets, tools=tools): + yield + + +@dataclasses.dataclass(init=False) +class Agent(AbstractAgent[AgentDepsT, OutputDataT]): + """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. + + Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT] + and the output type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT]. + + By default, if neither generic parameter is customised, agents have type `Agent[None, str]`. + + Minimal usage example: + + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + result = agent.run_sync('What is the capital of France?') + print(result.output) + #> Paris + ``` + """ + + _model: models.Model | models.KnownModelName | str | None + """The default model configured for this agent. + + We allow `str` here since the actual list of allowed models changes frequently. + """ + + _name: str | None + """The name of the agent, used for logging. + + If `None`, we try to infer the agent name from the call frame when the agent is first run. + """ + end_strategy: EndStrategy + """Strategy for handling tool calls when a final result is found.""" + + model_settings: ModelSettings | None + """Optional model request settings to use for this agents's runs, by default. + + Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will + be merged with this value, with the runtime argument taking priority. + """ + + _output_type: OutputSpec[OutputDataT] + """ + The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`. + """ + + instrument: InstrumentationSettings | bool | None + """Options to automatically instrument with OpenTelemetry.""" + + _instrument_default: ClassVar[InstrumentationSettings | bool] = False + + _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) + _deprecated_result_tool_name: str | None = dataclasses.field(repr=False) + _deprecated_result_tool_description: str | None = dataclasses.field(repr=False) + _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False) + _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False) + _instructions: str | None = dataclasses.field(repr=False) + _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) + _system_prompts: tuple[str, ...] = dataclasses.field(repr=False) + _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) + _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field( + repr=False + ) + _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False) + _output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False) + _user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False) + _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) + _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) + _max_result_retries: int = dataclasses.field(repr=False) + + _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) + + _enter_lock: Lock = dataclasses.field(repr=False) + _entered_count: int = dataclasses.field(repr=False) + _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) + + @overload + def __init__( + self, + model: models.Model | models.KnownModelName | str | None = None, + *, + output_type: OutputSpec[OutputDataT] = str, + instructions: str + | _system_prompt.SystemPromptFunc[AgentDepsT] + | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] + | None = None, + system_prompt: str | Sequence[str] = (), + deps_type: type[AgentDepsT] = NoneType, + name: str | None = None, + model_settings: ModelSettings | None = None, + retries: int = 1, + output_retries: int | None = None, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), + prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + defer_model_check: bool = False, + end_strategy: EndStrategy = 'early', + instrument: InstrumentationSettings | bool | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ) -> None: ... + + @overload + @deprecated( + '`result_type`, `result_tool_name` & `result_tool_description` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' + ) + def __init__( + self, + model: models.Model | models.KnownModelName | str | None = None, + *, + result_type: type[OutputDataT] = str, + instructions: str + | _system_prompt.SystemPromptFunc[AgentDepsT] + | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] + | None = None, + system_prompt: str | Sequence[str] = (), + deps_type: type[AgentDepsT] = NoneType, + name: str | None = None, + model_settings: ModelSettings | None = None, + retries: int = 1, + result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, + result_tool_description: str | None = None, + result_retries: int | None = None, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), + prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + defer_model_check: bool = False, + end_strategy: EndStrategy = 'early', + instrument: InstrumentationSettings | bool | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + ) -> None: ... + + @overload + @deprecated('`mcp_servers` is deprecated, use `toolsets` instead.') + def __init__( + self, + model: models.Model | models.KnownModelName | str | None = None, + *, + result_type: type[OutputDataT] = str, + instructions: str + | _system_prompt.SystemPromptFunc[AgentDepsT] + | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] + | None = None, + system_prompt: str | Sequence[str] = (), + deps_type: type[AgentDepsT] = NoneType, + name: str | None = None, + model_settings: ModelSettings | None = None, + retries: int = 1, + result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, + result_tool_description: str | None = None, + result_retries: int | None = None, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), + prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + mcp_servers: Sequence[MCPServer] = (), + defer_model_check: bool = False, + end_strategy: EndStrategy = 'early', + instrument: InstrumentationSettings | bool | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ) -> None: ... + + def __init__( + self, + model: models.Model | models.KnownModelName | str | None = None, + *, + # TODO change this back to `output_type: _output.OutputType[OutputDataT] = str,` when we remove the overloads + output_type: Any = str, + instructions: str + | _system_prompt.SystemPromptFunc[AgentDepsT] + | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] + | None = None, + system_prompt: str | Sequence[str] = (), + deps_type: type[AgentDepsT] = NoneType, + name: str | None = None, + model_settings: ModelSettings | None = None, + retries: int = 1, + output_retries: int | None = None, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), + prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + defer_model_check: bool = False, + end_strategy: EndStrategy = 'early', + instrument: InstrumentationSettings | bool | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + **_deprecated_kwargs: Any, + ): + """Create an agent. + + Args: + model: The default model to use for this agent, if not provide, + you must provide the model when calling it. We allow `str` here since the actual list of allowed models changes frequently. + output_type: The type of the output data, used to validate the data returned by the model, + defaults to `str`. + instructions: Instructions to use for this agent, you can also register instructions via a function with + [`instructions`][pydantic_ai.Agent.instructions]. + system_prompt: Static system prompts to use for this agent, you can also register system + prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt]. + deps_type: The type used for dependency injection, this parameter exists solely to allow you to fully + parameterize the agent, and therefore get the best out of static type checking. + If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright + or add a type hint `: Agent[None, ]`. + name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame + when the agent is first run. + model_settings: Optional model request settings to use for this agent's runs, by default. + retries: The default number of retries to allow before raising an error. + output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. + tools: Tools to register with the agent, you can also register tools via the decorators + [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. + prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools. + This is useful if you want to customize the definition of multiple tools or you want to register + a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] + prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step. + This is useful if you want to customize the definition of multiple output tools or you want to register + a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] + toolsets: Toolsets to register with the agent, including MCP servers. + defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, + it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, + which checks for the necessary environment variables. Set this to `false` + to defer the evaluation until the first run. Useful if you want to + [override the model][pydantic_ai.Agent.override] for testing. + end_strategy: Strategy for handling tool calls that are requested alongside a final result. + See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information. + instrument: Set to True to automatically instrument with OpenTelemetry, + which will use Logfire if it's configured. + Set to an instance of [`InstrumentationSettings`][pydantic_ai.agent.InstrumentationSettings] to customize. + If this isn't set, then the last value set by + [`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all] + will be used, which defaults to False. + See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. + history_processors: Optional list of callables to process the message history before sending it to the model. + Each processor takes a list of messages and returns a modified list of messages. + Processors can be sync or async and are applied in sequence. + event_stream_handler: TODO: Optional handler for events from the agent stream. + """ + if model is None or defer_model_check: + self._model = model + else: + self._model = models.infer_model(model) + + self._name = name + self.end_strategy = end_strategy + self.model_settings = model_settings + + if 'result_type' in _deprecated_kwargs: + if output_type is not str: # pragma: no cover + raise TypeError('`result_type` and `output_type` cannot be set at the same time.') + warnings.warn('`result_type` is deprecated, use `output_type` instead', DeprecationWarning, stacklevel=2) + output_type = _deprecated_kwargs.pop('result_type') + + self._output_type = output_type + + self.instrument = instrument + + self._deps_type = deps_type + + self._deprecated_result_tool_name = _deprecated_kwargs.pop('result_tool_name', None) + if self._deprecated_result_tool_name is not None: + warnings.warn( + '`result_tool_name` is deprecated, use `output_type` with `ToolOutput` instead', + DeprecationWarning, + stacklevel=2, + ) + + self._deprecated_result_tool_description = _deprecated_kwargs.pop('result_tool_description', None) + if self._deprecated_result_tool_description is not None: + warnings.warn( + '`result_tool_description` is deprecated, use `output_type` with `ToolOutput` instead', + DeprecationWarning, + stacklevel=2, + ) + result_retries = _deprecated_kwargs.pop('result_retries', None) + if result_retries is not None: + if output_retries is not None: # pragma: no cover + raise TypeError('`output_retries` and `result_retries` cannot be set at the same time.') + warnings.warn( + '`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning, stacklevel=2 + ) + output_retries = result_retries + + if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None): + if toolsets is not None: # pragma: no cover + raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.') + warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning) + toolsets = mcp_servers + + _utils.validate_empty_kwargs(_deprecated_kwargs) + + default_output_mode = ( + self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None ) - start_node = _agent_graph.UserPromptNode[AgentDepsT]( - user_prompt=user_prompt, - instructions=self._instructions, - instructions_functions=self._instructions_functions, - system_prompts=self._system_prompts, - system_prompt_functions=self._system_prompt_functions, - system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, + + self._output_schema = _output.OutputSchema[OutputDataT].build( + output_type, + default_mode=default_output_mode, + name=self._deprecated_result_tool_name, + description=self._deprecated_result_tool_description, ) + self._output_validators = [] - try: - async with graph.iter( - start_node, - state=state, - deps=graph_deps, - span=use_span(run_span) if run_span.is_recording() else None, - infer_name=False, - ) as graph_run: - agent_run = AgentRun(graph_run) - yield agent_run - if (final_result := agent_run.result) is not None and run_span.is_recording(): - if instrumentation_settings and instrumentation_settings.include_content: - run_span.set_attribute( - 'final_result', - ( - final_result.output - if isinstance(final_result.output, str) - else json.dumps(InstrumentedModel.serialize_any(final_result.output)) - ), - ) - finally: - try: - if instrumentation_settings and run_span.is_recording(): - run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings)) - finally: - run_span.end() + self._instructions = '' + self._instructions_functions = [] + if isinstance(instructions, (str, Callable)): + instructions = [instructions] + for instruction in instructions or []: + if isinstance(instruction, str): + self._instructions += instruction + '\n' + else: + self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction)) + self._instructions = self._instructions.strip() or None - def _run_span_end_attributes( - self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings - ): - return { - **usage.opentelemetry_attributes(), - 'all_messages_events': json.dumps( - [InstrumentedModel.event_to_dict(e) for e in settings.messages_to_otel_events(state.message_history)] - ), - 'logfire.json_schema': json.dumps( - { - 'type': 'object', - 'properties': { - 'all_messages_events': {'type': 'array'}, - 'final_result': {'type': 'object'}, - }, - } - ), - } + self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) + self._system_prompt_functions = [] + self._system_prompt_dynamic_functions = {} - @overload - def run_sync( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[OutputDataT]: ... + self._max_result_retries = output_retries if output_retries is not None else retries + self._prepare_tools = prepare_tools + self._prepare_output_tools = prepare_output_tools - @overload - def run_sync( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: OutputSpec[RunOutputDataT] | None = None, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[RunOutputDataT]: ... + self._output_toolset = self._output_schema.toolset + if self._output_toolset: + self._output_toolset.max_retries = self._max_result_retries - @overload - @deprecated('`result_type` is deprecated, use `output_type` instead.') - def run_sync( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - result_type: type[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[RunOutputDataT]: ... + self._function_toolset = FunctionToolset(tools, max_retries=retries, id='agent') + self._user_toolsets = toolsets or () - def run_sync( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - output_type: OutputSpec[RunOutputDataT] | None = None, - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - **_deprecated_kwargs: Never, - ) -> AgentRunResult[Any]: - """Synchronously run the agent with a user prompt. + self.history_processors = history_processors or [] - This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. - You therefore can't use this method inside async code or if there's an active event loop. + self._event_stream_handler = event_stream_handler - Example: - ```python - from pydantic_ai import Agent + self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) + self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) + self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar( + '_override_toolsets', default=None + ) + self._override_tools: ContextVar[ + _utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]] + ] = ContextVar('_override_tools', default=None) - agent = Agent('openai:gpt-4o') + self._enter_lock = _utils.get_async_lock() + self._entered_count = 0 + self._exit_stack = None - result_sync = agent.run_sync('What is the capital of Italy?') - print(result_sync.output) - #> Rome - ``` + @staticmethod + def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: + """Set the instrumentation options for all agents where `instrument` is not set.""" + Agent._instrument_default = instrument - Args: - user_prompt: User input to start/continue the conversation. - output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no - output validators since output validators would expect an argument that matches the agent's output type. - message_history: History of the conversation so far. - model: Optional model to use for this run, required if `model` was not set when creating the agent. - deps: Optional dependencies to use for this run. - model_settings: Optional settings to use for this model's request. - usage_limits: Optional limits on model request count or token usage. - usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional additional toolsets for this run. + @property + def model(self) -> models.Model | models.KnownModelName | str | None: + return self._model - Returns: - The result of the run. - """ - if infer_name and self.name is None: - self._infer_name(inspect.currentframe()) + @model.setter + def model(self, value: models.Model | models.KnownModelName | str | None) -> None: + self._model = value - if 'result_type' in _deprecated_kwargs: # pragma: no cover - if output_type is not str: - raise TypeError('`result_type` and `output_type` cannot be set at the same time.') - warnings.warn('`result_type` is deprecated, use `output_type` instead.', DeprecationWarning, stacklevel=2) - output_type = _deprecated_kwargs.pop('result_type') + @property + def name(self) -> str | None: + return self._name - _utils.validate_empty_kwargs(_deprecated_kwargs) + @name.setter + def name(self, value: str | None) -> None: + self._name = value - return get_event_loop().run_until_complete( - self.run( - user_prompt, - output_type=output_type, - message_history=message_history, - model=model, - deps=deps, - model_settings=model_settings, - usage_limits=usage_limits, - usage=usage, - infer_name=False, - toolsets=toolsets, - ) - ) + @property + def output_type(self) -> OutputSpec[OutputDataT]: + return self._output_type + + @property + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + return self._event_stream_handler + + def __repr__(self) -> str: # pragma: no cover + return f'{type(self).__name__}(model={self.model!r}, name={self.name!r}, end_strategy={self.end_strategy!r}, model_settings={self.model_settings!r}, output_type={self.output_type!r}, instrument={self.instrument!r})' @overload - def run_stream( + def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, + output_type: None = None, message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | None = None, + model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @overload - def run_stream( + def iter( self, - user_prompt: str | Sequence[_messages.UserContent], + user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, @@ -1051,11 +1628,12 @@ def run_stream( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @overload @deprecated('`result_type` is deprecated, use `output_type` instead.') - def run_stream( + def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, @@ -1068,10 +1646,10 @@ def run_stream( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... @asynccontextmanager - async def run_stream( # noqa C901 + async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, @@ -1085,8 +1663,18 @@ async def run_stream( # noqa C901 infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, - ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: - """Run the agent with a user prompt in async mode, returning a streamed response. + ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: + """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. + + This method builds an internal agent graph (using system prompts, tools and output schemas) and then returns an + `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are + executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the + stream of events coming from the execution of tools. + + The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics, + and the final result of the run once it has completed. + + For more details, see the documentation of `AgentRun`. Example: ```python @@ -1095,9 +1683,46 @@ async def run_stream( # noqa C901 agent = Agent('openai:gpt-4o') async def main(): - async with agent.run_stream('What is the capital of the UK?') as response: - print(await response.get_output()) - #> London + nodes = [] + async with agent.iter('What is the capital of France?') as agent_run: + async for node in agent_run: + nodes.append(node) + print(nodes) + ''' + [ + UserPromptNode( + user_prompt='What is the capital of France?', + instructions=None, + instructions_functions=[], + system_prompts=(), + system_prompt_functions=[], + system_prompt_dynamic_functions={}, + ), + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + ) + ] + ) + ), + CallToolsNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris')], + usage=Usage( + requests=1, request_tokens=56, response_tokens=1, total_tokens=57 + ), + model_name='gpt-4o', + timestamp=datetime.datetime(...), + ) + ), + End(data=FinalResult(output='Paris')), + ] + ''' + print(agent_run.result.output) + #> Paris ``` Args: @@ -1116,12 +1741,10 @@ async def main(): Returns: The result of the run. """ - # TODO: We need to deprecate this now that we have the `iter` method. - # Before that, though, we should add an event for when we reach the final result of the stream. if infer_name and self.name is None: - # f_back because `asynccontextmanager` adds one frame - if frame := inspect.currentframe(): # pragma: no branch - self._infer_name(frame.f_back) + self._infer_name(inspect.currentframe()) + model_used = self._get_model(model) + del model if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: @@ -1131,81 +1754,161 @@ async def main(): _utils.validate_empty_kwargs(_deprecated_kwargs) - yielded = False - async with self.iter( - user_prompt, - output_type=output_type, - message_history=message_history, - model=model, + deps = self._get_deps(deps) + new_message_index = len(message_history) if message_history else 0 + output_schema = self._prepare_output_schema(output_type, model_used.profile) + + output_type_ = output_type or self.output_type + + # We consider it a user error if a user tries to restrict the result type while having an output validator that + # may change the result type from the restricted type to something else. Therefore, we consider the following + # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. + output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) + + output_toolset = self._output_toolset + if output_schema != self._output_schema or output_validators: + output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset) + if output_toolset: + output_toolset.max_retries = self._max_result_retries + output_toolset.output_validators = output_validators + + # Build the graph + graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( + _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) + ) + + # Build the initial state + usage = usage or _usage.Usage() + state = _agent_graph.GraphAgentState( + message_history=message_history[:] if message_history else [], + usage=usage, + retries=0, + run_step=0, + ) + + if isinstance(model_used, InstrumentedModel): + instrumentation_settings = model_used.instrumentation_settings + tracer = model_used.instrumentation_settings.tracer + else: + instrumentation_settings = None + tracer = NoOpTracer() + + run_context = RunContext[AgentDepsT]( deps=deps, - model_settings=model_settings, - usage_limits=usage_limits, + model=model_used, usage=usage, - infer_name=False, - toolsets=toolsets, - ) as agent_run: - first_node = agent_run.next_node # start with the first node - assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node - node = first_node - while True: - if self.is_model_request_node(node): - graph_ctx = agent_run.ctx - async with node.stream(graph_ctx) as stream: + prompt=user_prompt, + messages=state.message_history, + tracer=tracer, + trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content, + run_step=state.run_step, + ) - async def stream_to_final(s: AgentStream) -> FinalResult[AgentStream] | None: - async for event in stream: - if isinstance(event, _messages.FinalResultEvent): - return FinalResult(s, event.tool_name, event.tool_call_id) - return None + toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) + # This will raise errors for any name conflicts + tool_manager = await ToolManager[AgentDepsT].build(toolset, run_context) - final_result = await stream_to_final(stream) - if final_result is not None: - if yielded: - raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover - yielded = True + # Merge model settings in order of precedence: run > agent > model + merged_settings = merge_model_settings(model_used.settings, self.model_settings) + model_settings = merge_model_settings(merged_settings, model_settings) + usage_limits = usage_limits or _usage.UsageLimits() + agent_name = self.name or 'agent' + run_span = tracer.start_span( + 'agent run', + attributes={ + 'model_name': model_used.model_name if model_used else 'no-model', + 'agent_name': agent_name, + 'logfire.msg': f'{agent_name} run', + }, + ) - messages = graph_ctx.state.message_history.copy() + async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: + parts = [ + self._instructions, + *[await func.run(run_context) for func in self._instructions_functions], + ] - async def on_complete() -> None: - """Called when the stream has completed. + model_profile = model_used.profile + if isinstance(output_schema, _output.PromptedOutputSchema): + instructions = output_schema.instructions(model_profile.prompted_output_template) + parts.append(instructions) - The model response will have been added to messages by now - by `StreamedRunResult._marked_completed`. - """ - last_message = messages[-1] - assert isinstance(last_message, _messages.ModelResponse) - tool_calls = [ - part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) - ] + parts = [p for p in parts if p] + if not parts: + return None + return '\n\n'.join(parts).strip() - parts: list[_messages.ModelRequestPart] = [] - async for _event in _agent_graph.process_function_tools( - graph_ctx.deps.tool_manager, - tool_calls, - final_result, - graph_ctx, - parts, - ): - pass - if parts: - messages.append(_messages.ModelRequest(parts)) + graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( + user_deps=deps, + prompt=user_prompt, + new_message_index=new_message_index, + model=model_used, + model_settings=model_settings, + usage_limits=usage_limits, + max_result_retries=self._max_result_retries, + end_strategy=self.end_strategy, + output_schema=output_schema, + output_validators=output_validators, + history_processors=self.history_processors, + tool_manager=tool_manager, + tracer=tracer, + get_instructions=get_instructions, + instrumentation_settings=instrumentation_settings, + ) + start_node = _agent_graph.UserPromptNode[AgentDepsT]( + user_prompt=user_prompt, + instructions=self._instructions, + instructions_functions=self._instructions_functions, + system_prompts=self._system_prompts, + system_prompt_functions=self._system_prompt_functions, + system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, + ) - yield StreamedRunResult( - messages, - graph_ctx.deps.new_message_index, - stream, - on_complete, - ) - break - next_node = await agent_run.next(node) - if not isinstance(next_node, _agent_graph.AgentNode): - raise exceptions.AgentRunError( # pragma: no cover - 'Should have produced a StreamedRunResult before getting here' - ) - node = cast(_agent_graph.AgentNode[Any, Any], next_node) + try: + async with graph.iter( + start_node, + state=state, + deps=graph_deps, + span=use_span(run_span) if run_span.is_recording() else None, + infer_name=False, + ) as graph_run: + agent_run = AgentRun(graph_run) + yield agent_run + if (final_result := agent_run.result) is not None and run_span.is_recording(): + if instrumentation_settings and instrumentation_settings.include_content: + run_span.set_attribute( + 'final_result', + ( + final_result.output + if isinstance(final_result.output, str) + else json.dumps(InstrumentedModel.serialize_any(final_result.output)) + ), + ) + finally: + try: + if instrumentation_settings and run_span.is_recording(): + run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings)) + finally: + run_span.end() - if not yielded: - raise exceptions.AgentRunError('Agent run finished without producing a final result') # pragma: no cover + def _run_span_end_attributes( + self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings + ): + return { + **usage.opentelemetry_attributes(), + 'all_messages_events': json.dumps( + [InstrumentedModel.event_to_dict(e) for e in settings.messages_to_otel_events(state.message_history)] + ), + 'logfire.json_schema': json.dumps( + { + 'type': 'object', + 'properties': { + 'all_messages_events': {'type': 'array'}, + 'final_result': {'type': 'object'}, + }, + } + ), + } @contextmanager def override( @@ -1745,25 +2448,6 @@ def toolset(self) -> AbstractToolset[AgentDepsT]: """ return self._get_toolset() - def _infer_name(self, function_frame: FrameType | None) -> None: - """Infer the agent name from the call frame. - - Usage should be `self._infer_name(inspect.currentframe())`. - """ - assert self.name is None, 'Name already set' - if function_frame is not None: # pragma: no branch - if parent_frame := function_frame.f_back: # pragma: no branch - for name, item in parent_frame.f_locals.items(): - if item is self: - self.name = name - return - if parent_frame.f_locals != parent_frame.f_globals: # pragma: no branch - # if we couldn't find the agent in locals and globals are a different dict, try globals - for name, item in parent_frame.f_globals.items(): - if item is self: - self.name = name - return - @property @deprecated( 'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None @@ -1790,47 +2474,7 @@ def _prepare_output_schema( return schema # pyright: ignore[reportReturnType] - @staticmethod - def is_model_request_node( - node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeIs[_agent_graph.ModelRequestNode[T, S]]: - """Check if the node is a `ModelRequestNode`, narrowing the type if it is. - - This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. - """ - return isinstance(node, _agent_graph.ModelRequestNode) - - @staticmethod - def is_call_tools_node( - node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeIs[_agent_graph.CallToolsNode[T, S]]: - """Check if the node is a `CallToolsNode`, narrowing the type if it is. - - This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. - """ - return isinstance(node, _agent_graph.CallToolsNode) - - @staticmethod - def is_user_prompt_node( - node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeIs[_agent_graph.UserPromptNode[T, S]]: - """Check if the node is a `UserPromptNode`, narrowing the type if it is. - - This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. - """ - return isinstance(node, _agent_graph.UserPromptNode) - - @staticmethod - def is_end_node( - node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeIs[End[result.FinalResult[S]]]: - """Check if the node is a `End`, narrowing the type if it is. - - This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. - """ - return isinstance(node, End) - - async def __aenter__(self) -> Self: + async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]: """Enter the agent context. This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used. @@ -1893,201 +2537,6 @@ async def run_mcp_servers( async with self: yield - def to_ag_ui( - self, - *, - # Agent.iter parameters - output_type: OutputSpec[OutputDataT] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: UsageLimits | None = None, - usage: Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - # Starlette - debug: bool = False, - routes: Sequence[BaseRoute] | None = None, - middleware: Sequence[Middleware] | None = None, - exception_handlers: Mapping[Any, ExceptionHandler] | None = None, - on_startup: Sequence[Callable[[], Any]] | None = None, - on_shutdown: Sequence[Callable[[], Any]] | None = None, - lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None, - ) -> AGUIApp[AgentDepsT, OutputDataT]: - """Convert the agent to an AG-UI application. - - This allows you to use the agent with a compatible AG-UI frontend. - - Example: - ```python - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o') - app = agent.to_ag_ui() - ``` - - The `app` is an ASGI application that can be used with any ASGI server. - - To run the application, you can use the following command: - - ```bash - uvicorn app:app --host 0.0.0.0 --port 8000 - ``` - - See [AG-UI docs](../ag-ui.md) for more information. - - Args: - output_type: Custom output type to use for this run, `output_type` may only be used if the agent has - no output validators since output validators would expect an argument that matches the agent's - output type. - model: Optional model to use for this run, required if `model` was not set when creating the agent. - deps: Optional dependencies to use for this run. - model_settings: Optional settings to use for this model's request. - usage_limits: Optional limits on model request count or token usage. - usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. - - debug: Boolean indicating if debug tracebacks should be returned on errors. - routes: A list of routes to serve incoming HTTP and WebSocket requests. - middleware: A list of middleware to run for every request. A starlette application will always - automatically include two middleware classes. `ServerErrorMiddleware` is added as the very - outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. - `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled - exception cases occurring in the routing or endpoints. - exception_handlers: A mapping of either integer status codes, or exception class types onto - callables which handle the exceptions. Exception handler callables should be of the form - `handler(request, exc) -> response` and may be either standard functions, or async functions. - on_startup: A list of callables to run on application startup. Startup handler callables do not - take any arguments, and may be either standard functions, or async functions. - on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do - not take any arguments, and may be either standard functions, or async functions. - lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks. - This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or - the other, not both. - - Returns: - An ASGI application for running Pydantic AI agents with AG-UI protocol support. - """ - from .ag_ui import AGUIApp - - return AGUIApp( - agent=self, - # Agent.iter parameters - output_type=output_type, - model=model, - deps=deps, - model_settings=model_settings, - usage_limits=usage_limits, - usage=usage, - infer_name=infer_name, - toolsets=toolsets, - # Starlette - debug=debug, - routes=routes, - middleware=middleware, - exception_handlers=exception_handlers, - on_startup=on_startup, - on_shutdown=on_shutdown, - lifespan=lifespan, - ) - - def to_a2a( - self, - *, - storage: Storage | None = None, - broker: Broker | None = None, - # Agent card - name: str | None = None, - url: str = 'http://localhost:8000', - version: str = '1.0.0', - description: str | None = None, - provider: AgentProvider | None = None, - skills: list[Skill] | None = None, - # Starlette - debug: bool = False, - routes: Sequence[Route] | None = None, - middleware: Sequence[Middleware] | None = None, - exception_handlers: dict[Any, ExceptionHandler] | None = None, - lifespan: Lifespan[FastA2A] | None = None, - ) -> FastA2A: - """Convert the agent to a FastA2A application. - - Example: - ```python - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o') - app = agent.to_a2a() - ``` - - The `app` is an ASGI application that can be used with any ASGI server. - - To run the application, you can use the following command: - - ```bash - uvicorn app:app --host 0.0.0.0 --port 8000 - ``` - """ - from ._a2a import agent_to_a2a - - return agent_to_a2a( - self, - storage=storage, - broker=broker, - name=name, - url=url, - version=version, - description=description, - provider=provider, - skills=skills, - debug=debug, - routes=routes, - middleware=middleware, - exception_handlers=exception_handlers, - lifespan=lifespan, - ) - - async def to_cli(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: - """Run the agent in a CLI chat interface. - - Args: - deps: The dependencies to pass to the agent. - prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. - - Example: - ```python {title="agent_to_cli.py" test="skip"} - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') - - async def main(): - await agent.to_cli() - ``` - """ - from rich.console import Console - - from pydantic_ai._cli import run_chat - - await run_chat(stream=True, agent=self, deps=deps, console=Console(), code_theme='monokai', prog_name=prog_name) - - def to_cli_sync(self: Self, deps: AgentDepsT = None, prog_name: str = 'pydantic-ai') -> None: - """Run the agent in a CLI chat interface with the non-async interface. - - Args: - deps: The dependencies to pass to the agent. - prog_name: The name of the program to use for the CLI. Defaults to 'pydantic-ai'. - - ```python {title="agent_to_cli_sync.py" test="skip"} - from pydantic_ai import Agent - - agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') - agent.to_cli_sync() - agent.to_cli_sync(prog_name='assistant') - ``` - """ - return get_event_loop().run_until_complete(self.to_cli(deps=deps, prog_name=prog_name)) - @dataclasses.dataclass(repr=False) class AgentRun(Generic[AgentDepsT, OutputDataT]): diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py index aa6e03e43..cd7ba738c 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py @@ -11,10 +11,7 @@ from temporalio.worker import Plugin as WorkerPlugin, WorkerConfig from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner -from pydantic_ai.agent import Agent -from pydantic_ai.exceptions import UserError - -from ._agent import temporalize_agent +from ._agent import TemporalAgent from ._logfire import LogfirePlugin from ._run_context import TemporalRunContext, TemporalRunContextWithDeps @@ -24,7 +21,7 @@ 'PydanticAIPlugin', 'LogfirePlugin', 'AgentPlugin', - 'temporalize_agent', + 'TemporalAgent', ] @@ -64,14 +61,10 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: class AgentPlugin(WorkerPlugin): """Temporal worker plugin for a specific Pydantic AI agent.""" - def __init__(self, agent: Agent[Any, Any]): + def __init__(self, agent: TemporalAgent[Any, Any]): self.agent = agent def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - agent_activities = getattr(self.agent, '__temporal_activities', None) - if agent_activities is None: - raise UserError('The agent has not been prepared for Temporal yet, call `temporalize_agent(agent)` first.') - activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType] - config['activities'] = [*activities, *agent_activities] + config['activities'] = [*activities, *self.agent.temporal_activities] return super().configure_worker(config) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py index 22454d7f6..7ef9717a1 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -1,128 +1,330 @@ from __future__ import annotations -from contextlib import asynccontextmanager -from typing import Any, Callable, Literal +from collections.abc import AsyncIterator, Iterator, Sequence +from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager +from typing import Any, Callable, Literal, overload from temporalio import workflow from temporalio.workflow import ActivityConfig +from typing_extensions import Never, deprecated -from pydantic_ai.agent import Agent +from pydantic_ai import ( + _utils, + messages as _messages, + models, + usage as _usage, +) +from pydantic_ai._run_context import AgentDepsT +from pydantic_ai.agent import AbstractAgent, Agent, AgentRun, RunOutputDataT, WrapperAgent from pydantic_ai.exceptions import UserError from pydantic_ai.ext.temporal._run_context import TemporalRunContext from pydantic_ai.models import Model -from pydantic_ai.toolsets.abstract import AbstractToolset +from pydantic_ai.output import OutputDataT, OutputSpec +from pydantic_ai.settings import ModelSettings +from pydantic_ai.tools import ( + Tool, + ToolFuncEither, +) +from pydantic_ai.toolsets import AbstractToolset from ._model import TemporalModel from ._toolset import TemporalWrapperToolset, temporalize_toolset -def temporalize_agent( # noqa: C901 - agent: Agent[Any, Any], - activity_config: ActivityConfig = {}, - toolset_activity_config: dict[str, ActivityConfig] = {}, - tool_activity_config: dict[str, dict[str, ActivityConfig | Literal[False]]] = {}, - run_context_type: type[TemporalRunContext] = TemporalRunContext, - temporalize_toolset_func: Callable[ - [AbstractToolset, ActivityConfig, dict[str, ActivityConfig | Literal[False]], type[TemporalRunContext]], - AbstractToolset, - ] = temporalize_toolset, -) -> Agent[Any, Any]: - """Temporalize an agent. - - Args: - agent: The agent to temporalize. - activity_config: The Temporal activity config to use. - toolset_activity_config: The Temporal activity config to use for specific toolsets identified by ID. - tool_activity_config: The Temporal activity config to use for specific tools identified by toolset ID and tool name. - run_context_type: The type of run context to use to serialize and deserialize the run context. - temporalize_toolset_func: The function to use to prepare the toolsets for Temporal. - """ - if getattr(agent, '__temporal_activities', None): - return agent - - activities: list[Callable[..., Any]] = [] - if not isinstance(agent.model, Model): - raise UserError( - 'Model cannot be set at agent run time when using Temporal, it must be set at agent creation time.' - ) - - temporal_model = TemporalModel(agent.model, activity_config, agent.event_stream_handler, run_context_type) - activities.extend(temporal_model.temporal_activities) - - def temporalize_toolset(toolset: AbstractToolset) -> AbstractToolset: - id = toolset.id - if not id: - raise UserError( - "Toolsets that implement their own tool calling need to have an ID in order to be used with Temporal. The ID will be used to identify the toolset's activities within the workflow." - ) - toolset = temporalize_toolset_func( - toolset, - activity_config | toolset_activity_config.get(id, {}), - tool_activity_config.get(id, {}), - run_context_type, - ) - if isinstance(toolset, TemporalWrapperToolset): - activities.extend(toolset.temporal_activities) - return toolset - - # TODO: Use public methods so others can replicate this - temporal_toolsets = [temporalize_toolset(toolset) for toolset in [agent._function_toolset, *agent._user_toolsets]] # pyright: ignore[reportPrivateUsage] - - original_iter = agent.iter - original_override = agent.override - - def iter(*args: Any, **kwargs: Any) -> Any: - if not workflow.in_workflow(): - return original_iter(*args, **kwargs) +class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]): + def __init__( + self, + wrapped: AbstractAgent[AgentDepsT, OutputDataT], + activity_config: ActivityConfig = {}, + toolset_activity_config: dict[str, ActivityConfig] = {}, + tool_activity_config: dict[str, dict[str, ActivityConfig | Literal[False]]] = {}, + run_context_type: type[TemporalRunContext] = TemporalRunContext, + temporalize_toolset_func: Callable[ + [ + AbstractToolset[Any], + ActivityConfig, + dict[str, ActivityConfig | Literal[False]], + type[TemporalRunContext], + ], + AbstractToolset[Any], + ] = temporalize_toolset, + ): + """Wrap an agent to make it compatible with Temporal. + + Args: + wrapped: The agent to wrap. + activity_config: The Temporal activity config to use. + toolset_activity_config: The Temporal activity config to use for specific toolsets identified by ID. + tool_activity_config: The Temporal activity config to use for specific tools identified by toolset ID and tool name. + run_context_type: The type of run context to use to serialize and deserialize the run context. + temporalize_toolset_func: The function to use to prepare the toolsets for Temporal. + """ + super().__init__(wrapped) - if kwargs.get('model') is not None: + # TODO: Make this work with any AbstractAgent + assert isinstance(wrapped, Agent) + agent = wrapped + + activities: list[Callable[..., Any]] = [] + if not isinstance(agent.model, Model): raise UserError( 'Model cannot be set at agent run time when using Temporal, it must be set at agent creation time.' ) - if kwargs.get('toolsets') is not None: - raise UserError( - 'Toolsets cannot be set at agent run time when using Temporal, it must be set at agent creation time.' - ) - if kwargs.get('event_stream_handler') is not None: - # TODO: iter won't have event_stream_handler, run/_sync/_stream will - raise UserError( - 'Event stream handler cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + + temporal_model = TemporalModel(agent.model, activity_config, agent.event_stream_handler, run_context_type) + activities.extend(temporal_model.temporal_activities) + + def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]: + id = toolset.id + if not id: + raise UserError( + "Toolsets that implement their own tool calling need to have an ID in order to be used with Temporal. The ID will be used to identify the toolset's activities within the workflow." + ) + toolset = temporalize_toolset_func( + toolset, + activity_config | toolset_activity_config.get(id, {}), + tool_activity_config.get(id, {}), + run_context_type, ) + if isinstance(toolset, TemporalWrapperToolset): + activities.extend(toolset.temporal_activities) + return toolset + + # TODO: Use public methods so others can replicate this + temporal_toolsets = [ + temporalize_toolset(toolset) + for toolset in [agent._function_toolset, *agent._user_toolsets] # pyright: ignore[reportPrivateUsage] + ] + + self._model = temporal_model + self._toolsets = temporal_toolsets + self._temporal_activities = activities + + @property + def model(self) -> Model: + return self._model + + @property + def temporal_activities(self) -> list[Callable[..., Any]]: + return self._temporal_activities - @asynccontextmanager - async def async_iter(): - # We reset tools here as the temporalized function toolset is already in temporal_toolsets. - with agent.override(model=temporal_model, toolsets=temporal_toolsets, tools=[]): - async with original_iter(*args, **kwargs) as result: - yield result + @overload + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... - return async_iter() + @overload + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... - def override(*args: Any, **kwargs: Any) -> Any: + @overload + @deprecated('`result_type` is deprecated, use `output_type` instead.') + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + result_type: type[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... + + @asynccontextmanager + async def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + **_deprecated_kwargs: Never, + ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: + """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. + + This method builds an internal agent graph (using system prompts, tools and output schemas) and then returns an + `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are + executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the + stream of events coming from the execution of tools. + + The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics, + and the final result of the run once it has completed. + + For more details, see the documentation of `AgentRun`. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + nodes = [] + async with agent.iter('What is the capital of France?') as agent_run: + async for node in agent_run: + nodes.append(node) + print(nodes) + ''' + [ + UserPromptNode( + user_prompt='What is the capital of France?', + instructions=None, + instructions_functions=[], + system_prompts=(), + system_prompt_functions=[], + system_prompt_dynamic_functions={}, + ), + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + ) + ] + ) + ), + CallToolsNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris')], + usage=Usage( + requests=1, request_tokens=56, response_tokens=1, total_tokens=57 + ), + model_name='gpt-4o', + timestamp=datetime.datetime(...), + ) + ), + End(data=FinalResult(output='Paris')), + ] + ''' + print(agent_run.result.output) + #> Paris + ``` + + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + + Returns: + The result of the run. + """ if not workflow.in_workflow(): - return original_override(*args, **kwargs) + async with super().iter( + user_prompt=user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + ) as result: + yield result - if kwargs.get('model') not in (None, temporal_model): - raise UserError( - 'Model cannot be contextually overridden when using Temporal, it must be set at agent creation time.' - ) - if kwargs.get('toolsets') not in (None, temporal_toolsets): + if model is not None: raise UserError( - 'Toolsets cannot be contextually overridden when using Temporal, they must be set at agent creation time.' + 'Model cannot be set at agent run time when using Temporal, it must be set at agent creation time.' ) - if kwargs.get('tools') not in (None, []): + if toolsets is not None: raise UserError( - 'Tools cannot be contextually overridden when using Temporal, they must be set at agent creation time.' + 'Toolsets cannot be set at agent run time when using Temporal, it must be set at agent creation time.' ) - return original_override(*args, **kwargs) - def tool(*args: Any, **kwargs: Any) -> Any: - raise UserError('New tools cannot be registered after an agent has been prepared for Temporal.') + # We reset tools here as the temporalized function toolset is already in self._toolsets. + with super().override(model=self._model, toolsets=self._toolsets, tools=[]): + async with super().iter( + user_prompt=user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + ) as result: + yield result + + @contextmanager + def override( + self, + *, + deps: AgentDepsT | _utils.Unset = _utils.UNSET, + model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + ) -> Iterator[None]: + """Context manager to temporarily override agent dependencies, model, or toolsets. + + This is particularly useful when testing. + You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). - agent.iter = iter - agent.override = override - agent.tool = tool - agent.tool_plain = tool + Args: + deps: The dependencies to use instead of the dependencies passed to the agent run. + model: The model to use instead of the model passed to the agent run. + toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. + tools: The tools to use instead of the tools registered with the agent. + """ + if workflow.in_workflow(): + if _utils.is_set(model): + raise UserError( + 'Model cannot be contextually overridden when using Temporal, it must be set at agent creation time.' + ) + if _utils.is_set(toolsets): + raise UserError( + 'Toolsets cannot be contextually overridden when using Temporal, they must be set at agent creation time.' + ) + if _utils.is_set(tools): + raise UserError( + 'Tools cannot be contextually overridden when using Temporal, they must be set at agent creation time.' + ) - setattr(agent, '__temporal_activities', activities) - return agent + with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools): + yield diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py index e6223195c..cec240b3d 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_run_context.py @@ -22,7 +22,7 @@ def __getattribute__(self, name: str) -> Any: if name in RunContext.__dataclass_fields__: raise AttributeError( f'{self.__class__.__name__!r} object has no attribute {name!r}. ' - 'To make the attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to `temporalize_agent`.' + 'To make the attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to `TemporalAgent`.' ) else: raise e @@ -48,6 +48,6 @@ def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]: if not isinstance(ctx.deps, dict): raise UserError( 'The `deps` object must be a JSON-serializable dictionary in order to be used with Temporal. ' - 'To use a different type, pass a `TemporalRunContext` subclass to `temporalize_agent` with custom `serialize_run_context` and `deserialize_run_context` class methods.' + 'To use a different type, pass a `TemporalRunContext` subclass to `TemporalAgent` with custom `serialize_run_context` and `deserialize_run_context` class methods.' ) return {**super().serialize_run_context(ctx), 'deps': ctx.deps} # pyright: ignore[reportUnknownMemberType] diff --git a/temporal.py b/temporal.py index 34432629c..acb6d476c 100644 --- a/temporal.py +++ b/temporal.py @@ -15,8 +15,8 @@ AgentPlugin, LogfirePlugin, PydanticAIPlugin, + TemporalAgent, TemporalRunContextWithDeps, - temporalize_agent, ) from pydantic_ai.mcp import MCPServerStdio from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent @@ -54,7 +54,7 @@ def get_weather(city: str) -> str: # This needs to be called in the same scope where the `agent` is bound to the workflow, # as it modifies the `agent` object in place to swap out methods that use IO for ones that use Temporal activities. -temporalize_agent( +temporal_agent = TemporalAgent( agent, activity_config=ActivityConfig(start_to_close_timeout=timedelta(seconds=60)), toolset_activity_config={ @@ -77,7 +77,7 @@ def get_weather(city: str) -> str: class MyAgentWorkflow: @workflow.run async def run(self, prompt: str, deps: Deps) -> str: - result = await agent.run(prompt, deps=deps) + result = await temporal_agent.run(prompt, deps=deps) return result.output @@ -100,7 +100,7 @@ async def main(): client, task_queue=TASK_QUEUE, workflows=[MyAgentWorkflow], - plugins=[AgentPlugin(agent)], + plugins=[AgentPlugin(temporal_agent)], ): output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] MyAgentWorkflow.run, From 8d3e04d45fda7ea9d3a87e2c7f00afd320ddeadf Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 1 Aug 2025 16:40:05 +0000 Subject: [PATCH 23/41] Add a todo --- pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py index b87e3a6e8..97f0a7041 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py @@ -95,6 +95,7 @@ async def aiter(): # `AgentStream.__aiter__`, which this is based on, calls `_get_usage_checking_stream_response` here, # but we don't have access to the `_usage_limits`. + # TODO: Create new stream wrapper that does this async for event in streamed_response: yield event if ( From 13511e7c9477932246a9e929d42058f699d00e3a Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 15:55:03 +0000 Subject: [PATCH 24/41] Properly set up OpenTelemetry metrics --- .../pydantic_ai/ext/temporal/_logfire.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py index bb307b990..7bb969bb7 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py @@ -30,5 +30,18 @@ def configure_client(self, config: ClientConfig) -> ClientConfig: async def connect_service_client(self, config: ConnectConfig) -> ServiceClient: self.setup_logfire() - config.runtime = Runtime(telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url='http://localhost:4318'))) + import logfire + + logfire_instance = logfire.configure() + logfire_config = logfire_instance.config + token = logfire_config.token + if token is not None: + base_url = logfire_config.advanced.generate_base_url(token) + metrics_url = base_url + '/v1/metrics' + headers = {'Authorization': f'Bearer {token}'} + + config.runtime = Runtime( + telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url, headers=headers)) + ) + return await super().connect_service_client(config) From d4325f4ea389bd701723047c39ca196006af7358 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 16:10:09 +0000 Subject: [PATCH 25/41] Address some feedback --- pydantic_ai_slim/pydantic_ai/agent.py | 86 +------------------ pydantic_ai_slim/pydantic_ai/ext/aci.py | 2 +- pydantic_ai_slim/pydantic_ai/ext/langchain.py | 2 +- .../pydantic_ai/ext/temporal/_agent.py | 20 +---- .../pydantic_ai/ext/temporal/_mcp_server.py | 6 +- pydantic_ai_slim/pydantic_ai/mcp.py | 10 ++- .../pydantic_ai/toolsets/deferred.py | 2 +- .../pydantic_ai/toolsets/function.py | 1 + 8 files changed, 17 insertions(+), 112 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index c80a328ed..1565986f4 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -164,23 +164,6 @@ async def run( toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... - @overload - @deprecated('`result_type` is deprecated, use `output_type` instead.') - async def run( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - result_type: type[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[RunOutputDataT]: ... - async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, @@ -292,23 +275,6 @@ def run_sync( toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... - @overload - @deprecated('`result_type` is deprecated, use `output_type` instead.') - def run_sync( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - result_type: type[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AgentRunResult[RunOutputDataT]: ... - def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, @@ -413,23 +379,6 @@ def run_stream( toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... - @overload - @deprecated('`result_type` is deprecated, use `output_type` instead.') - def run_stream( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - result_type: type[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... - @asynccontextmanager async def run_stream( # noqa C901 self, @@ -601,23 +550,6 @@ def iter( **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... - @overload - @deprecated('`result_type` is deprecated, use `output_type` instead.') - def iter( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - result_type: type[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... - @asynccontextmanager @abstractmethod async def iter( @@ -1069,23 +1001,6 @@ def iter( **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... - @overload - @deprecated('`result_type` is deprecated, use `output_type` instead.') - def iter( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - result_type: type[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... - @asynccontextmanager async def iter( self, @@ -1190,6 +1105,7 @@ async def main(): usage=usage, infer_name=infer_name, toolsets=toolsets, + **_deprecated_kwargs, ) as result: yield result diff --git a/pydantic_ai_slim/pydantic_ai/ext/aci.py b/pydantic_ai_slim/pydantic_ai/ext/aci.py index ef686d134..f9db595de 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/aci.py +++ b/pydantic_ai_slim/pydantic_ai/ext/aci.py @@ -71,7 +71,7 @@ def implementation(*args: Any, **kwargs: Any) -> str: class ACIToolset(FunctionToolset): """A toolset that wraps ACI.dev tools.""" - def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str, id: str | None = None): + def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str, *, id: str | None = None): super().__init__( [tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions], id=id ) diff --git a/pydantic_ai_slim/pydantic_ai/ext/langchain.py b/pydantic_ai_slim/pydantic_ai/ext/langchain.py index 3782c0b9d..88c4c8cc3 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/langchain.py +++ b/pydantic_ai_slim/pydantic_ai/ext/langchain.py @@ -65,5 +65,5 @@ def proxy(*args: Any, **kwargs: Any) -> str: class LangChainToolset(FunctionToolset): """A toolset that wraps LangChain tools.""" - def __init__(self, tools: list[LangChainTool], id: str | None = None): + def __init__(self, tools: list[LangChainTool], *, id: str | None = None): super().__init__([tool_from_langchain(tool) for tool in tools], id=id) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py index 7ef9717a1..83cebec30 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -6,7 +6,7 @@ from temporalio import workflow from temporalio.workflow import ActivityConfig -from typing_extensions import Never, deprecated +from typing_extensions import Never from pydantic_ai import ( _utils, @@ -142,23 +142,6 @@ def iter( **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... - @overload - @deprecated('`result_type` is deprecated, use `output_type` instead.') - def iter( - self, - user_prompt: str | Sequence[_messages.UserContent] | None = None, - *, - result_type: type[RunOutputDataT], - message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: _usage.UsageLimits | None = None, - usage: _usage.Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... - @asynccontextmanager async def iter( self, @@ -264,6 +247,7 @@ async def main(): usage=usage, infer_name=infer_name, toolsets=toolsets, + **_deprecated_kwargs, ) as result: yield result diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py index b89f66ba3..6b7242df7 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py @@ -52,7 +52,7 @@ def __init__( async def get_tools_activity(params: _GetToolsParams) -> dict[str, ToolDefinition]: run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context) tools = await self.wrapped.get_tools(run_context) - # ToolsetTool is not serializable as its holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time), + # ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time), # so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity. return {name: tool.tool_def for name, tool in tools.items()} @@ -103,7 +103,9 @@ async def call_tool( tool_activity_config = self.tool_activity_config.get(name, {}) if tool_activity_config is False: - raise UserError('Disabling running an MCP tool in a Temporal activity is not possible.') + raise UserError( + f'Temporal activity config for tool {name!r} is `False`, but MCP tools cannot be run out side of an activity.' + ) tool_activity_config = self.activity_config | tool_activity_config serialized_run_context = self.run_context_type.serialize_run_context(ctx) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index b3426ed36..5cc99a078 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -89,6 +89,7 @@ def __init__( allow_sampling: bool = True, sampling_model: models.Model | None = None, max_retries: int = 1, + *, id: str | None = None, ): self.tool_prefix = tool_prefix @@ -448,7 +449,6 @@ def __init__( args: Sequence[str], env: dict[str, str] | None = None, cwd: str | Path | None = None, - id: str | None = None, tool_prefix: str | None = None, log_level: mcp_types.LoggingLevel | None = None, log_handler: LoggingFnT | None = None, @@ -458,6 +458,8 @@ def __init__( allow_sampling: bool = True, sampling_model: models.Model | None = None, max_retries: int = 1, + *, + id: str | None = None, ): """Build a new MCP server. @@ -466,7 +468,6 @@ def __init__( args: The arguments to pass to the command. env: The environment variables to set in the subprocess. cwd: The working directory to use when spawning the process. - id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. tool_prefix: A prefix to add to all tools that are registered with the server. log_level: The log level to set when connecting to the server, if any. log_handler: A handler for logging messages from the server. @@ -476,6 +477,7 @@ def __init__( allow_sampling: Whether to allow MCP sampling through this client. sampling_model: The model to use for sampling. max_retries: The maximum number of times to retry a tool call. + id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. """ self.command = command self.args = args @@ -492,7 +494,7 @@ def __init__( allow_sampling, sampling_model, max_retries, - id, + id=id, ) @asynccontextmanager @@ -662,7 +664,7 @@ def __init__( allow_sampling, sampling_model, max_retries, - id, + id=id, ) @property diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py index a67c3b0ad..91091ea45 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py @@ -22,7 +22,7 @@ class DeferredToolset(AbstractToolset[AgentDepsT]): tool_defs: list[ToolDefinition] _id: str | None = field(init=False, default=None) - def __init__(self, tool_defs: list[ToolDefinition], id: str | None = None): + def __init__(self, tool_defs: list[ToolDefinition], *, id: str | None = None): self._id = id self.tool_defs = tool_defs diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py index c206b22c5..1b94c57eb 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/function.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -42,6 +42,7 @@ def __init__( self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1, + *, id: str | None = None, ): """Build a new function toolset. From 99d5664d7dcab3a6bb14ad7466770453f0fb1e11 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 16:46:15 +0000 Subject: [PATCH 26/41] Fix for 3.9 --- pydantic_ai_slim/pydantic_ai/agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index e841f8ca3..87e58b6d5 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -11,7 +11,7 @@ from contextvars import ContextVar from copy import deepcopy from types import FrameType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, overload +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Union, cast, overload from opentelemetry.trace import NoOpTracer, use_span from pydantic.json_schema import GenerateJsonSchema @@ -100,7 +100,7 @@ EventStreamHandler: TypeAlias = Callable[ [ RunContext[AgentDepsT], - AsyncIterable[_messages.AgentStreamEvent | _messages.HandleResponseEvent], + AsyncIterable[Union[_messages.AgentStreamEvent, _messages.HandleResponseEvent]], ], Awaitable[None], ] From 0507ebf0fd3de1fe618f1af926632744e1a478e9 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 17:32:11 +0000 Subject: [PATCH 27/41] Move FinalResultEvent injection to StreamedResponse --- pydantic_ai_slim/pydantic_ai/direct.py | 8 +-- pydantic_ai_slim/pydantic_ai/ext/aci.py | 13 ++-- pydantic_ai_slim/pydantic_ai/ext/langchain.py | 2 + .../pydantic_ai/ext/temporal/_model.py | 26 ++------ .../pydantic_ai/models/__init__.py | 64 ++++++++++++++----- .../pydantic_ai/models/anthropic.py | 11 +++- .../pydantic_ai/models/bedrock.py | 6 +- .../pydantic_ai/models/function.py | 6 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 13 +++- pydantic_ai_slim/pydantic_ai/models/google.py | 7 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 7 +- .../pydantic_ai/models/huggingface.py | 7 +- .../pydantic_ai/models/mistral.py | 12 ++-- pydantic_ai_slim/pydantic_ai/models/openai.py | 14 ++-- pydantic_ai_slim/pydantic_ai/models/test.py | 1 + pydantic_ai_slim/pydantic_ai/result.py | 53 ++++++--------- tests/models/test_instrumented.py | 2 +- 17 files changed, 143 insertions(+), 109 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/direct.py b/pydantic_ai_slim/pydantic_ai/direct.py index 6735c928e..60f1c059e 100644 --- a/pydantic_ai_slim/pydantic_ai/direct.py +++ b/pydantic_ai_slim/pydantic_ai/direct.py @@ -273,9 +273,7 @@ class StreamedResponseSync: """ _async_stream_cm: AbstractAsyncContextManager[StreamedResponse] - _queue: queue.Queue[messages.ModelResponseStreamEvent | Exception | None] = field( - default_factory=queue.Queue, init=False - ) + _queue: queue.Queue[messages.AgentStreamEvent | Exception | None] = field(default_factory=queue.Queue, init=False) _thread: threading.Thread | None = field(default=None, init=False) _stream_response: StreamedResponse | None = field(default=None, init=False) _exception: Exception | None = field(default=None, init=False) @@ -295,8 +293,8 @@ def __exit__( ) -> None: self._cleanup() - def __iter__(self) -> Iterator[messages.ModelResponseStreamEvent]: - """Stream the response as an iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" + def __iter__(self) -> Iterator[messages.AgentStreamEvent]: + """Stream the response as an iterable of [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.""" self._check_context_manager_usage() while True: diff --git a/pydantic_ai_slim/pydantic_ai/ext/aci.py b/pydantic_ai_slim/pydantic_ai/ext/aci.py index f9db595de..24754deef 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/aci.py +++ b/pydantic_ai_slim/pydantic_ai/ext/aci.py @@ -1,17 +1,16 @@ -# Checking whether aci-sdk is installed -try: - from aci import ACI -except ImportError as _import_error: - raise ImportError('Please install `aci-sdk` to use ACI.dev tools') from _import_error +from __future__ import annotations from collections.abc import Sequence from typing import Any -from aci import ACI - from pydantic_ai.tools import Tool from pydantic_ai.toolsets.function import FunctionToolset +try: + from aci import ACI +except ImportError as _import_error: + raise ImportError('Please install `aci-sdk` to use ACI.dev tools') from _import_error + def _clean_schema(schema): if isinstance(schema, dict): diff --git a/pydantic_ai_slim/pydantic_ai/ext/langchain.py b/pydantic_ai_slim/pydantic_ai/ext/langchain.py index 88c4c8cc3..9390ab8cb 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/langchain.py +++ b/pydantic_ai_slim/pydantic_ai/ext/langchain.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Protocol from pydantic.json_schema import JsonSchemaValue diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py index 97f0a7041..e179f06fc 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py @@ -36,7 +36,8 @@ class _RequestParams: class _TemporalStreamedResponse(StreamedResponse): - def __init__(self, response: ModelResponse): + def __init__(self, model_request_parameters: ModelRequestParameters, response: ModelResponse): + super().__init__(model_request_parameters) self.response = response async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: @@ -90,27 +91,8 @@ async def request_stream_activity(params: _RequestParams) -> ModelResponse: async with self.wrapped.request_stream( params.messages, params.model_settings, params.model_request_parameters, run_context ) as streamed_response: - # Keep in sync with `AgentStream.__aiter__` - async def aiter(): - # `AgentStream.__aiter__`, which this is based on, calls `_get_usage_checking_stream_response` here, - # but we don't have access to the `_usage_limits`. - - # TODO: Create new stream wrapper that does this - async for event in streamed_response: - yield event - if ( - final_result_event := params.model_request_parameters.get_final_result_event(event) - ) is not None: - yield final_result_event - break - - # If we broke out of the above loop, we need to yield the rest of the events - # If we didn't, this will just be a no-op - async for event in streamed_response: - yield event - assert event_stream_handler is not None - await event_stream_handler(run_context, aiter()) + await event_stream_handler(run_context, streamed_response) async for _ in streamed_response: pass @@ -166,4 +148,4 @@ async def request_stream( ), **self.activity_config, ) - yield _TemporalStreamedResponse(response) + yield _TemporalStreamedResponse(model_request_parameters, response) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index f7c33eee3..a3a3994b4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -24,6 +24,7 @@ from .._run_context import RunContext from ..exceptions import UserError from ..messages import ( + AgentStreamEvent, FileUrl, FinalResultEvent, ModelMessage, @@ -358,18 +359,6 @@ class ModelRequestParameters: def tool_defs(self) -> dict[str, ToolDefinition]: return {tool_def.name: tool_def for tool_def in [*self.function_tools, *self.output_tools]} - def get_final_result_event(self, e: ModelResponseStreamEvent) -> FinalResultEvent | None: - """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" - if isinstance(e, PartStartEvent): - new_part = e.part - if isinstance(new_part, TextPart) and self.allow_text_output: # pragma: no branch - return FinalResultEvent(tool_name=None, tool_call_id=None) - elif isinstance(new_part, ToolCallPart) and (tool_def := self.tool_defs.get(new_part.tool_name)): - if tool_def.kind == 'output': - return FinalResultEvent(tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id) - elif tool_def.kind == 'deferred': - return FinalResultEvent(tool_name=None, tool_call_id=None) - __repr__ = _utils.dataclasses_no_defaults_repr @@ -528,14 +517,40 @@ def _get_instructions(messages: list[ModelMessage]) -> str | None: class StreamedResponse(ABC): """Streamed response from an LLM when calling a tool.""" + model_request_parameters: ModelRequestParameters + _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) - _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) + _event_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) _usage: Usage = field(default_factory=Usage, init=False) + _final_result_event: FinalResultEvent | None = field(default=None, init=False) - def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: - """Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" + def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: + """Stream the response as an async iterable of [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. + + This proxies the `_event_iterator()` and emits all events, while also checking for matches + on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the + first match is found. + """ if self._event_iterator is None: - self._event_iterator = self._get_event_iterator() + + async def iterator_with_final_event( + iterator: AsyncIterator[ModelResponseStreamEvent], + ) -> AsyncIterator[AgentStreamEvent]: + async for event in iterator: + yield event + if ( + final_result_event := _get_final_result_event(event, self.model_request_parameters) + ) is not None: + self._final_result_event = final_result_event + yield final_result_event + break + + # If we broke out of the above loop, we need to yield the rest of the events + # If we didn't, this will just be a no-op + async for event in iterator: + yield event + + self._event_iterator = iterator_with_final_event(self._get_event_iterator()) return self._event_iterator @abstractmethod @@ -560,6 +575,10 @@ def get(self) -> ModelResponse: usage=self.usage(), ) + def get_final_result_event(self) -> FinalResultEvent | None: + """Get the final result event for the response.""" + return self._final_result_event + def usage(self) -> Usage: """Get the usage of the response so far. This will not be the final usage until the stream is exhausted.""" return self._usage @@ -831,3 +850,16 @@ def _customize_output_object(transformer: type[JsonSchemaTransformer], o: Output schema_transformer = transformer(o.json_schema, strict=True) son_schema = schema_transformer.walk() return replace(o, json_schema=son_schema) + + +def _get_final_result_event(e: ModelResponseStreamEvent, params: ModelRequestParameters) -> FinalResultEvent | None: + """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" + if isinstance(e, PartStartEvent): + new_part = e.part + if isinstance(new_part, TextPart) and params.allow_text_output: # pragma: no branch + return FinalResultEvent(tool_name=None, tool_call_id=None) + elif isinstance(new_part, ToolCallPart) and (tool_def := params.tool_defs.get(new_part.tool_name)): + if tool_def.kind == 'output': + return FinalResultEvent(tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id) + elif tool_def.kind == 'deferred': + return FinalResultEvent(tool_name=None, tool_call_id=None) diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index e98806ebe..d8b03139e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -179,7 +179,7 @@ async def request_stream( messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters ) async with response: - yield await self._process_streamed_response(response) + yield await self._process_streamed_response(response, model_request_parameters) @property def model_name(self) -> AnthropicModelName: @@ -286,7 +286,9 @@ def _process_response(self, response: BetaMessage) -> ModelResponse: return ModelResponse(items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id) - async def _process_streamed_response(self, response: AsyncStream[BetaRawMessageStreamEvent]) -> StreamedResponse: + async def _process_streamed_response( + self, response: AsyncStream[BetaRawMessageStreamEvent], model_request_parameters: ModelRequestParameters + ) -> StreamedResponse: peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() if isinstance(first_chunk, _utils.Unset): @@ -295,7 +297,10 @@ async def _process_streamed_response(self, response: AsyncStream[BetaRawMessageS # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time timestamp = datetime.now(tz=timezone.utc) return AnthropicStreamedResponse( - _model_name=self._model_name, _response=peekable_response, _timestamp=timestamp + model_request_parameters=model_request_parameters, + _model_name=self._model_name, + _response=peekable_response, + _timestamp=timestamp, ) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolParam]: diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 64ce31615..f1b09083e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -266,7 +266,11 @@ async def request_stream( ) -> AsyncIterator[StreamedResponse]: settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, True, settings, model_request_parameters) - yield BedrockStreamedResponse(_model_name=self.model_name, _event_stream=response) + yield BedrockStreamedResponse( + model_request_parameters=model_request_parameters, + _model_name=self.model_name, + _event_stream=response, + ) async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse: items: list[ModelResponsePart] = [] diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index f23feafb1..615f2ea8f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -164,7 +164,11 @@ async def request_stream( if isinstance(first, _utils.Unset): raise ValueError('Stream function must return at least one item') - yield FunctionStreamedResponse(_model_name=self._model_name, _iter=response_stream) + yield FunctionStreamedResponse( + model_request_parameters=model_request_parameters, + _model_name=self._model_name, + _iter=response_stream, + ) @property def model_name(self) -> str: diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index bd0369bb3..16f2a0f2a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -170,7 +170,7 @@ async def request_stream( async with self._make_request( messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters ) as http_response: - yield await self._process_streamed_response(http_response) + yield await self._process_streamed_response(http_response, model_request_parameters) @property def model_name(self) -> GeminiModelName: @@ -284,7 +284,9 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse: vendor_details=vendor_details, ) - async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse: + async def _process_streamed_response( + self, http_response: HTTPResponse, model_request_parameters: ModelRequestParameters + ) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" aiter_bytes = http_response.aiter_bytes() start_response: _GeminiResponse | None = None @@ -305,7 +307,12 @@ async def _process_streamed_response(self, http_response: HTTPResponse) -> Strea if start_response is None: raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') - return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes) + return GeminiStreamedResponse( + model_request_parameters=model_request_parameters, + _model_name=self._model_name, + _content=content, + _stream=aiter_bytes, + ) async def _message_to_gemini_content( self, messages: list[ModelMessage] diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index b4f0d6a2d..ccdc5fb00 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -193,7 +193,7 @@ async def request_stream( check_allow_model_requests() model_settings = cast(GoogleModelSettings, model_settings or {}) response = await self._generate_content(messages, True, model_settings, model_request_parameters) - yield await self._process_streamed_response(response) # type: ignore + yield await self._process_streamed_response(response, model_request_parameters) # type: ignore @property def model_name(self) -> GoogleModelName: @@ -321,7 +321,9 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse: parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details ) - async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse: + async def _process_streamed_response( + self, response: AsyncIterator[GenerateContentResponse], model_request_parameters: ModelRequestParameters + ) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() @@ -329,6 +331,7 @@ async def _process_streamed_response(self, response: AsyncIterator[GenerateConte raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') # pragma: no cover return GeminiStreamedResponse( + model_request_parameters=model_request_parameters, _model_name=self._model_name, _response=peekable_response, _timestamp=first_chunk.create_time or _utils.now_utc(), diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 2d7cdff4f..f7c857f18 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -173,7 +173,7 @@ async def request_stream( messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters ) async with response: - yield await self._process_streamed_response(response) + yield await self._process_streamed_response(response, model_request_parameters) @property def model_name(self) -> GroqModelName: @@ -270,7 +270,9 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id ) - async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse: + async def _process_streamed_response( + self, response: AsyncStream[chat.ChatCompletionChunk], model_request_parameters: ModelRequestParameters + ) -> GroqStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() @@ -280,6 +282,7 @@ async def _process_streamed_response(self, response: AsyncStream[chat.ChatComple ) return GroqStreamedResponse( + model_request_parameters=model_request_parameters, _response=peekable_response, _model_name=self._model_name, _model_profile=self.profile, diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 0de234cd5..6f1daeddc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -158,7 +158,7 @@ async def request_stream( response = await self._completions_create( messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters ) - yield await self._process_streamed_response(response) + yield await self._process_streamed_response(response, model_request_parameters) @property def model_name(self) -> HuggingFaceModelName: @@ -263,7 +263,9 @@ def _process_response(self, response: ChatCompletionOutput) -> ModelResponse: vendor_id=response.id, ) - async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse: + async def _process_streamed_response( + self, response: AsyncIterable[ChatCompletionStreamOutput], model_request_parameters: ModelRequestParameters + ) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() @@ -273,6 +275,7 @@ async def _process_streamed_response(self, response: AsyncIterable[ChatCompletio ) return HuggingFaceStreamedResponse( + model_request_parameters=model_request_parameters, _model_name=self._model_name, _model_profile=self.profile, _response=peekable_response, diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 527baae74..d83f853a0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -181,7 +181,7 @@ async def request_stream( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) async with response: - yield await self._process_streamed_response(model_request_parameters.output_tools, response) + yield await self._process_streamed_response(response, model_request_parameters) @property def model_name(self) -> MistralModelName: @@ -340,8 +340,8 @@ def _process_response(self, response: MistralChatCompletionResponse) -> ModelRes async def _process_streamed_response( self, - output_tools: list[ToolDefinition], response: MistralEventStreamAsync[MistralCompletionEvent], + model_request_parameters: ModelRequestParameters, ) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" peekable_response = _utils.PeekableAsyncStream(response) @@ -357,10 +357,10 @@ async def _process_streamed_response( timestamp = _now_utc() return MistralStreamedResponse( + model_request_parameters=model_request_parameters, _response=peekable_response, _model_name=self._model_name, _timestamp=timestamp, - _output_tools={c.name: c for c in output_tools}, ) @staticmethod @@ -564,7 +564,6 @@ class MistralStreamedResponse(StreamedResponse): _model_name: MistralModelName _response: AsyncIterable[MistralCompletionEvent] _timestamp: datetime - _output_tools: dict[str, ToolDefinition] _delta_content: str = field(default='', init=False) @@ -583,10 +582,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: text = _map_content(content) if text: # Attempt to produce an output tool call from the received text - if self._output_tools: + output_tools = {c.name: c for c in self.model_request_parameters.output_tools} + if output_tools: self._delta_content += text # TODO: Port to native "manual JSON" mode - maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, self._output_tools) + maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, output_tools) if maybe_tool_call_part: yield self._parts_manager.handle_tool_call_part( vendor_part_id='output', diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index a2a9c9518..5b99ba8d8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -261,7 +261,7 @@ async def request_stream( messages, True, cast(OpenAIModelSettings, model_settings or {}), model_request_parameters ) async with response: - yield await self._process_streamed_response(response) + yield await self._process_streamed_response(response, model_request_parameters) @property def model_name(self) -> OpenAIModelName: @@ -423,7 +423,9 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons vendor_id=response.id, ) - async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse: + async def _process_streamed_response( + self, response: AsyncStream[ChatCompletionChunk], model_request_parameters: ModelRequestParameters + ) -> OpenAIStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() @@ -433,6 +435,7 @@ async def _process_streamed_response(self, response: AsyncStream[ChatCompletionC ) return OpenAIStreamedResponse( + model_request_parameters=model_request_parameters, _model_name=self._model_name, _model_profile=self.profile, _response=peekable_response, @@ -684,7 +687,7 @@ async def request_stream( messages, True, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters ) async with response: - yield await self._process_streamed_response(response) + yield await self._process_streamed_response(response, model_request_parameters) def _process_response(self, response: responses.Response) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" @@ -711,7 +714,9 @@ def _process_response(self, response: responses.Response) -> ModelResponse: ) async def _process_streamed_response( - self, response: AsyncStream[responses.ResponseStreamEvent] + self, + response: AsyncStream[responses.ResponseStreamEvent], + model_request_parameters: ModelRequestParameters, ) -> OpenAIResponsesStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" peekable_response = _utils.PeekableAsyncStream(response) @@ -721,6 +726,7 @@ async def _process_streamed_response( assert isinstance(first_chunk, responses.ResponseCreatedEvent) return OpenAIResponsesStreamedResponse( + model_request_parameters=model_request_parameters, _model_name=self._model_name, _response=peekable_response, _timestamp=number_to_datetime(first_chunk.response.created_at), diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 1348ec28f..473b00c56 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -125,6 +125,7 @@ async def request_stream( model_response = self._request(messages, model_settings, model_request_parameters) yield TestStreamedResponse( + model_request_parameters=model_request_parameters, _model_name=self._model_name, _structured_response=model_response, _messages=messages, diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 7e0bffcfd..ca2bef3be 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations import warnings -from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable +from collections.abc import AsyncIterator, Awaitable, Callable from copy import copy from dataclasses import dataclass, field from datetime import datetime @@ -23,7 +23,7 @@ ToolOutputSchema, ) from ._run_context import AgentDepsT, RunContext -from .messages import AgentStreamEvent, FinalResultEvent +from .messages import AgentStreamEvent from .output import ( OutputDataT, ToolOutput, @@ -53,7 +53,6 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _tool_manager: ToolManager[AgentDepsT] _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) - _final_result_event: FinalResultEvent | None = field(default=None, init=False) _initial_run_ctx_usage: Usage = field(init=False) def __post_init__(self): @@ -62,12 +61,12 @@ def __post_init__(self): async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[OutputDataT]: """Asynchronously stream the (validated) agent outputs.""" async for response in self.stream_responses(debounce_by=debounce_by): - if self._final_result_event is not None: + if self._raw_stream_response.get_final_result_event() is not None: try: yield await self._validate_response(response, allow_partial=True) except ValidationError: pass - if self._final_result_event is not None: # pragma: no branch + if self._raw_stream_response.get_final_result_event() is not None: # pragma: no branch yield await self._validate_response(self._raw_stream_response.get()) async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]: @@ -133,10 +132,11 @@ async def get_output(self) -> OutputDataT: async def _validate_response(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT: """Validate a structured result message.""" - if self._final_result_event is None: + final_result_event = self._raw_stream_response.get_final_result_event() + if final_result_event is None: raise exceptions.UnexpectedModelBehavior('Invalid response, unable to find output') # pragma: no cover - output_tool_name = self._final_result_event.tool_name + output_tool_name = final_result_event.tool_name if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None: tool_call = next( @@ -223,32 +223,12 @@ async def _stream_text_deltas() -> AsyncIterator[str]: yield ''.join(deltas) def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: - """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. - - This proxies the _raw_stream_response and sends all events to the agent stream, while also checking for matches - on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the - first match is found. - """ - if self._agent_stream_iterator is not None: - return self._agent_stream_iterator - - async def aiter(): - usage_checking_stream = _get_usage_checking_stream_response( + """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.""" + if self._agent_stream_iterator is None: + self._agent_stream_iterator = _get_usage_checking_stream_response( self._raw_stream_response, self._usage_limits, self.usage ) - async for event in usage_checking_stream: - yield event - if (final_result_event := self._model_request_parameters.get_final_result_event(event)) is not None: - self._final_result_event = final_result_event - yield final_result_event - break - - # If we broke out of the above loop, we need to yield the rest of the events - # If we didn't, this will just be a no-op - async for event in usage_checking_stream: - yield event - - self._agent_stream_iterator = aiter() + return self._agent_stream_iterator @@ -497,10 +477,10 @@ def data(self) -> OutputDataT: def _get_usage_checking_stream_response( - stream_response: AsyncIterable[_messages.ModelResponseStreamEvent], + stream_response: models.StreamedResponse, limits: UsageLimits | None, get_usage: Callable[[], Usage], -) -> AsyncIterable[_messages.ModelResponseStreamEvent]: +) -> AsyncIterator[AgentStreamEvent]: if limits is not None and limits.has_token_limits(): async def _usage_checking_iterator(): @@ -510,7 +490,12 @@ async def _usage_checking_iterator(): return _usage_checking_iterator() else: - return stream_response + + async def _iterator(): + async for item in stream_response: + yield item + + return _iterator() def coalesce_deprecated_return_content( diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index 79e373a20..ae67704be 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -93,7 +93,7 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext | None = None, ) -> AsyncIterator[StreamedResponse]: - yield MyResponseStream() + yield MyResponseStream(model_request_parameters=model_request_parameters) class MyResponseStream(StreamedResponse): From 7cd6f15ebc8bd025ad081e4e4515eea9bb7829d7 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 17:44:06 +0000 Subject: [PATCH 28/41] Fix direct test now that FinalResultEvent is included --- tests/test_direct.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_direct.py b/tests/test_direct.py index e9a131ea3..3e5767917 100644 --- a/tests/test_direct.py +++ b/tests/test_direct.py @@ -17,6 +17,7 @@ model_request_sync, ) from pydantic_ai.messages import ( + FinalResultEvent, ModelMessage, ModelRequest, ModelResponse, @@ -86,6 +87,7 @@ def test_model_request_stream_sync(): assert chunks == snapshot( [ PartStartEvent(index=0, part=TextPart(content='')), + FinalResultEvent(tool_name=None, tool_call_id=None), PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='success ')), PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='(no ')), PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='tool ')), @@ -104,6 +106,7 @@ async def test_model_request_stream(): assert chunks == snapshot( [ PartStartEvent(index=0, part=TextPart(content='')), + FinalResultEvent(tool_name=None, tool_call_id=None), PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='success ')), PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='(no ')), PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='tool ')), From b06173b8e539a7ff84c667bb26de1e002bc25890 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 18:36:03 +0000 Subject: [PATCH 29/41] Add public AbstractAgent.toolset method --- pydantic_ai_slim/pydantic_ai/_output.py | 2 +- pydantic_ai_slim/pydantic_ai/ag_ui.py | 2 +- pydantic_ai_slim/pydantic_ai/agent.py | 54 ++++++++++------- .../pydantic_ai/ext/temporal/__init__.py | 1 + .../pydantic_ai/ext/temporal/_agent.py | 60 ++++++++++--------- .../pydantic_ai/ext/temporal/_logfire.py | 16 +++-- .../pydantic_ai/ext/temporal/_mcp_server.py | 9 +++ .../pydantic_ai/toolsets/combined.py | 4 +- temporal.py | 18 +++++- 9 files changed, 101 insertions(+), 65 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 71e7251fd..ffec1df02 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -979,7 +979,7 @@ def __init__( @property def id(self) -> str | None: - return 'output' + return '' async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: return { diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index 8c0311e8b..572de1385 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -314,7 +314,7 @@ async def run_ag_ui( ) for tool in run_input.tools ], - id='ag_ui_frontend', + id='', ) toolsets = [*toolsets, toolset] if toolsets else [toolset] diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 87e58b6d5..d9e6e7480 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -132,6 +132,11 @@ def output_type(self) -> OutputSpec[OutputDataT]: def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: raise NotImplementedError + @property + @abstractmethod + def toolset(self) -> AbstractToolset[AgentDepsT]: + raise NotImplementedError + @overload async def run( self, @@ -239,6 +244,8 @@ async def main(): self.is_model_request_node(node) or self.is_call_tools_node(node) ): async with node.stream(agent_run.ctx) as stream: + # TODO: Use actual run context so we get retry counts etc + # Pass node? Or some indicaton of the agent if this is forwarded up from a nested agent? await self.event_stream_handler(_agent_graph.build_run_context(agent_run.ctx), stream) assert agent_run.result is not None, 'The graph run did not finish properly' @@ -962,6 +969,10 @@ def output_type(self) -> OutputSpec[OutputDataT]: def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: return self.wrapped.event_stream_handler + @property + def toolset(self) -> AbstractToolset[AgentDepsT]: + return self.wrapped.toolset + async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]: return await self.wrapped.__aenter__() @@ -1465,7 +1476,7 @@ def __init__( if self._output_toolset: self._output_toolset.max_retries = self._max_result_retries - self._function_toolset = FunctionToolset(tools, max_retries=retries, id='agent') + self._function_toolset = FunctionToolset(tools, max_retries=retries, id='') self._user_toolsets = toolsets or () self.history_processors = history_processors or [] @@ -1724,9 +1735,15 @@ async def main(): run_step=state.run_step, ) - toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) - # This will raise errors for any name conflicts + toolset = self._get_toolset(additional=toolsets) + + if output_toolset is not None: + if self._prepare_output_tools: + output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) + toolset = CombinedToolset([output_toolset, toolset]) + async with toolset: + # This will raise errors for any name conflicts tool_manager = await ToolManager[AgentDepsT].build(toolset, run_context) # Merge model settings in order of precedence: run > agent > model @@ -2324,48 +2341,39 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T: return deps def _get_toolset( - self, - output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET, - additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + self, additional: Sequence[AbstractToolset[AgentDepsT]] | None = None ) -> AbstractToolset[AgentDepsT]: - """Get the complete toolset. + """Get the combined toolset containing function tools registered directly to the agent and user-provided toolsets including MCP servers. Args: - output_toolset: The output toolset to use instead of the one built at agent construction time. - additional_toolsets: Additional toolsets to add. + additional: Additional toolsets to add. """ if some_tools := self._override_tools.get(): function_toolset = FunctionToolset( - some_tools.value, max_retries=self._function_toolset.max_retries, id='agent' + some_tools.value, max_retries=self._function_toolset.max_retries, id='' ) else: function_toolset = self._function_toolset if some_user_toolsets := self._override_toolsets.get(): user_toolsets = some_user_toolsets.value - elif additional_toolsets is not None: - user_toolsets = [*self._user_toolsets, *additional_toolsets] + elif additional is not None: + user_toolsets = [*self._user_toolsets, *additional] else: user_toolsets = self._user_toolsets - all_toolsets = [function_toolset, *user_toolsets] + toolset = CombinedToolset([function_toolset, *user_toolsets]) if self._prepare_tools: - all_toolsets = [PreparedToolset(CombinedToolset(all_toolsets), self._prepare_tools)] - - output_toolset = output_toolset if _utils.is_set(output_toolset) else self._output_toolset - if output_toolset is not None: - if self._prepare_output_tools: - output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) - all_toolsets = [output_toolset, *all_toolsets] + toolset = PreparedToolset(toolset, self._prepare_tools) - return CombinedToolset(all_toolsets) + return toolset @property def toolset(self) -> AbstractToolset[AgentDepsT]: - """The complete toolset that will be available to the model during an agent run. + """The complete toolset combining function tools and toolsets registered to the agent. - This will include function tools registered directly to the agent, output tools, and user-provided toolsets including MCP servers. + Output tools are not included. """ return self._get_toolset() diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py index cd7ba738c..49a1da7b0 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py @@ -48,6 +48,7 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: restrictions=runner.restrictions.with_passthrough_modules( 'pydantic_ai', 'logfire', + 'rich', # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize 'attrs', # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py index 83cebec30..9017ec6a7 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -15,7 +15,7 @@ usage as _usage, ) from pydantic_ai._run_context import AgentDepsT -from pydantic_ai.agent import AbstractAgent, Agent, AgentRun, RunOutputDataT, WrapperAgent +from pydantic_ai.agent import AbstractAgent, AgentRun, RunOutputDataT, WrapperAgent from pydantic_ai.exceptions import UserError from pydantic_ai.ext.temporal._run_context import TemporalRunContext from pydantic_ai.models import Model @@ -61,8 +61,6 @@ def __init__( """ super().__init__(wrapped) - # TODO: Make this work with any AbstractAgent - assert isinstance(wrapped, Agent) agent = wrapped activities: list[Callable[..., Any]] = [] @@ -90,20 +88,21 @@ def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset activities.extend(toolset.temporal_activities) return toolset - # TODO: Use public methods so others can replicate this - temporal_toolsets = [ - temporalize_toolset(toolset) - for toolset in [agent._function_toolset, *agent._user_toolsets] # pyright: ignore[reportPrivateUsage] - ] + temporal_toolset = agent.toolset.visit_and_replace(temporalize_toolset) self._model = temporal_model - self._toolsets = temporal_toolsets + self._toolset = temporal_toolset self._temporal_activities = activities @property def model(self) -> Model: return self._model + @property + def toolset(self) -> AbstractToolset[AgentDepsT]: + with self._temporal_overrides(): + return super().toolset + @property def temporal_activities(self) -> list[Callable[..., Any]]: return self._temporal_activities @@ -260,22 +259,26 @@ async def main(): 'Toolsets cannot be set at agent run time when using Temporal, it must be set at agent creation time.' ) - # We reset tools here as the temporalized function toolset is already in self._toolsets. - with super().override(model=self._model, toolsets=self._toolsets, tools=[]): + with self._temporal_overrides(): async with super().iter( user_prompt=user_prompt, output_type=output_type, message_history=message_history, - model=model, deps=deps, model_settings=model_settings, usage_limits=usage_limits, usage=usage, infer_name=infer_name, - toolsets=toolsets, + **_deprecated_kwargs, ) as result: yield result + @contextmanager + def _temporal_overrides(self) -> Iterator[None]: + # We reset tools here as the temporalized function toolset is already in self._toolset. + with super().override(model=self._model, toolsets=[self._toolset], tools=[]): + yield + @contextmanager def override( self, @@ -296,19 +299,22 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. """ - if workflow.in_workflow(): - if _utils.is_set(model): - raise UserError( - 'Model cannot be contextually overridden when using Temporal, it must be set at agent creation time.' - ) - if _utils.is_set(toolsets): - raise UserError( - 'Toolsets cannot be contextually overridden when using Temporal, they must be set at agent creation time.' - ) - if _utils.is_set(tools): - raise UserError( - 'Tools cannot be contextually overridden when using Temporal, they must be set at agent creation time.' - ) + if not workflow.in_workflow(): + with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools): + yield + + if _utils.is_set(model): + raise UserError( + 'Model cannot be contextually overridden when using Temporal, it must be set at agent creation time.' + ) + if _utils.is_set(toolsets): + raise UserError( + 'Toolsets cannot be contextually overridden when using Temporal, they must be set at agent creation time.' + ) + if _utils.is_set(tools): + raise UserError( + 'Tools cannot be contextually overridden when using Temporal, they must be set at agent creation time.' + ) - with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools): + with super().override(deps=deps): yield diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py index 7bb969bb7..c665c8d41 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_logfire.py @@ -2,6 +2,7 @@ from typing import Callable +from logfire import Logfire from opentelemetry.trace import get_tracer from temporalio.client import ClientConfig, Plugin as ClientPlugin from temporalio.contrib.opentelemetry import TracingInterceptor @@ -9,17 +10,18 @@ from temporalio.service import ConnectConfig, ServiceClient -def _default_setup_logfire(): +def _default_setup_logfire() -> Logfire: import logfire - logfire.configure(console=False) + instance = logfire.configure() logfire.instrument_pydantic_ai() + return instance class LogfirePlugin(ClientPlugin): """Temporal client plugin for Logfire.""" - def __init__(self, setup_logfire: Callable[[], None] = _default_setup_logfire): + def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire): self.setup_logfire = setup_logfire def configure_client(self, config: ClientConfig) -> ClientConfig: @@ -28,12 +30,8 @@ def configure_client(self, config: ClientConfig) -> ClientConfig: return super().configure_client(config) async def connect_service_client(self, config: ConnectConfig) -> ServiceClient: - self.setup_logfire() - - import logfire - - logfire_instance = logfire.configure() - logfire_config = logfire_instance.config + logfire = self.setup_logfire() + logfire_config = logfire.config token = logfire_config.token if token is not None: base_url = logfire_config.advanced.generate_base_url(token) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py index 6b7242df7..2cc88e377 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py @@ -6,6 +6,7 @@ from pydantic import ConfigDict, with_config from temporalio import activity, workflow from temporalio.workflow import ActivityConfig +from typing_extensions import Self from pydantic_ai._run_context import RunContext from pydantic_ai.exceptions import UserError @@ -79,6 +80,14 @@ def wrapped_server(self) -> MCPServer: def temporal_activities(self) -> list[Callable[..., Any]]: return [self.get_tools_activity, self.call_tool_activity] + async def __aenter__(self) -> Self: + # The wrapped MCPServer enters itself around listing and calling tools + # so we don't need to enter it here (nor could we because we're not inside a Temporal activity). + return self + + async def __aexit__(self, *args: Any) -> bool | None: + return None + async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: if not workflow.in_workflow(): return await super().get_tools(ctx) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index d52f435d5..f02c32b3c 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Sequence from contextlib import AsyncExitStack -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from typing import Any, Callable from typing_extensions import Self @@ -99,4 +99,4 @@ def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: def visit_and_replace( self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]] ) -> AbstractToolset[AgentDepsT]: - return CombinedToolset(toolsets=[visitor(toolset) for toolset in self.toolsets]) + return replace(self, toolsets=[visitor(toolset) for toolset in self.toolsets]) diff --git a/temporal.py b/temporal.py index acb6d476c..5d3533a10 100644 --- a/temporal.py +++ b/temporal.py @@ -1,6 +1,7 @@ import asyncio import random from collections.abc import AsyncIterable +from dataclasses import dataclass from datetime import timedelta import logfire @@ -41,9 +42,21 @@ def get_weather(city: str) -> str: return 'sunny' +@dataclass +class Answer: + label: str + answer: str + + +@dataclass +class Response: + answers: list[Answer] + + agent = Agent( 'openai:gpt-4o', deps_type=Deps, + output_type=Response, toolsets=[ FunctionToolset[Deps](tools=[get_weather], id='toolset'), MCPServerStdio('python', ['-m', 'tests.mcp_server'], timeout=20, id='mcp'), @@ -76,7 +89,7 @@ def get_weather(city: str) -> str: @workflow.defn class MyAgentWorkflow: @workflow.run - async def run(self, prompt: str, deps: Deps) -> str: + async def run(self, prompt: str, deps: Deps) -> Response: result = await temporal_agent.run(prompt, deps=deps) return result.output @@ -85,9 +98,10 @@ async def run(self, prompt: str, deps: Deps) -> str: def setup_logfire(): - logfire.configure(console=False) + instance = logfire.configure() logfire.instrument_pydantic_ai() logfire.instrument_httpx(capture_all=True) + return instance async def main(): From fe0a75e7ee8b5792b74623e610f6a7d4d245da93 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 18:38:52 +0000 Subject: [PATCH 30/41] Update tests --- pydantic_ai_slim/pydantic_ai/direct.py | 2 ++ tests/models/test_instrumented.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/pydantic_ai_slim/pydantic_ai/direct.py b/pydantic_ai_slim/pydantic_ai/direct.py index 60f1c059e..daba62933 100644 --- a/pydantic_ai_slim/pydantic_ai/direct.py +++ b/pydantic_ai_slim/pydantic_ai/direct.py @@ -167,6 +167,7 @@ async def main(): ''' [ PartStartEvent(index=0, part=TextPart(content='Albert Einstein was ')), + FinalResultEvent(tool_name=None, tool_call_id=None), PartDeltaEvent( index=0, delta=TextPartDelta(content_delta='a German-born theoretical ') ), @@ -223,6 +224,7 @@ def model_request_stream_sync( ''' [ PartStartEvent(index=0, part=TextPart(content='Albert Einstein was ')), + FinalResultEvent(tool_name=None, tool_call_id=None), PartDeltaEvent( index=0, delta=TextPartDelta(content_delta='a German-born theoretical ') ), diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index ae67704be..6c6e0c7ad 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -16,6 +16,7 @@ AudioUrl, BinaryContent, DocumentUrl, + FinalResultEvent, ImageUrl, ModelMessage, ModelRequest, @@ -360,6 +361,7 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): assert [event async for event in response_stream] == snapshot( [ PartStartEvent(index=0, part=TextPart(content='text1')), + FinalResultEvent(tool_name=None, tool_call_id=None), PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='text2')), ] ) From fc83519eb7a1ff8b29abf70d6a968bf2a4d10fae Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 19:45:08 +0000 Subject: [PATCH 31/41] Add event_stream_handler to run, run_sync and run_stream --- pydantic_ai_slim/pydantic_ai/agent.py | 76 +++- .../pydantic_ai/ext/temporal/_agent.py | 358 ++++++++++++++++-- .../pydantic_ai/ext/temporal/_model.py | 15 +- temporal.py | 9 +- tests/test_streaming.py | 144 ++++++- 5 files changed, 546 insertions(+), 56 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index d9e6e7480..c2bbb4811 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -99,11 +99,13 @@ EventStreamHandler: TypeAlias = Callable[ [ + 'AbstractAgent[AgentDepsT, OutputDataT]', RunContext[AgentDepsT], AsyncIterable[Union[_messages.AgentStreamEvent, _messages.HandleResponseEvent]], ], Awaitable[None], ] +"""TODO: Docstring""" class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC): @@ -129,7 +131,7 @@ def output_type(self) -> OutputSpec[OutputDataT]: @property @abstractmethod - def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT, OutputDataT] | None: raise NotImplementedError @property @@ -151,6 +153,7 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, OutputDataT] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -167,6 +170,7 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, RunOutputDataT] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( @@ -182,6 +186,7 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, Any] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -213,6 +218,7 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional event stream handler to use for this run. Returns: The result of the run. @@ -228,6 +234,8 @@ async def main(): _utils.validate_empty_kwargs(_deprecated_kwargs) + event_stream_handler = event_stream_handler or self.event_stream_handler + async with self.iter( user_prompt=user_prompt, output_type=output_type, @@ -240,13 +248,11 @@ async def main(): toolsets=toolsets, ) as agent_run: async for node in agent_run: - if self.event_stream_handler is not None and ( + if event_stream_handler is not None and ( self.is_model_request_node(node) or self.is_call_tools_node(node) ): async with node.stream(agent_run.ctx) as stream: - # TODO: Use actual run context so we get retry counts etc - # Pass node? Or some indicaton of the agent if this is forwarded up from a nested agent? - await self.event_stream_handler(_agent_graph.build_run_context(agent_run.ctx), stream) + await event_stream_handler(self, _agent_graph.build_run_context(agent_run.ctx), stream) assert agent_run.result is not None, 'The graph run did not finish properly' return agent_run.result @@ -256,6 +262,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, + output_type: None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -264,6 +271,7 @@ def run_sync( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, OutputDataT] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -271,7 +279,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -280,6 +288,7 @@ def run_sync( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, RunOutputDataT] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( @@ -295,6 +304,7 @@ def run_sync( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, Any] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -325,6 +335,7 @@ def run_sync( usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional event stream handler to use for this run. Returns: The result of the run. @@ -352,6 +363,7 @@ def run_sync( usage=usage, infer_name=False, toolsets=toolsets, + event_stream_handler=event_stream_handler, ) ) @@ -360,20 +372,22 @@ def run_stream( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, + output_type: None = None, message_history: list[_messages.ModelMessage] | None = None, - model: models.Model | models.KnownModelName | None = None, + model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, OutputDataT] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload def run_stream( self, - user_prompt: str | Sequence[_messages.UserContent], + user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, @@ -384,6 +398,7 @@ def run_stream( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, RunOutputDataT] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -400,6 +415,7 @@ async def run_stream( # noqa C901 usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, Any] | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -428,12 +444,11 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional event stream handler to use for this run. It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager. Returns: The result of the run. """ - # TODO: We need to deprecate this now that we have the `iter` method. - # Before that, though, we should add an event for when we reach the final result of the stream. if infer_name and self.name is None: # f_back because `asynccontextmanager` adds one frame if frame := inspect.currentframe(): # pragma: no branch @@ -447,6 +462,8 @@ async def main(): _utils.validate_empty_kwargs(_deprecated_kwargs) + event_stream_handler = event_stream_handler or self.event_stream_handler + yielded = False async with self.iter( user_prompt, @@ -467,15 +484,28 @@ async def main(): if self.is_model_request_node(node): graph_ctx = agent_run.ctx async with node.stream(graph_ctx) as stream: + final_result_event = None - async def stream_to_final(s: AgentStream) -> FinalResult[AgentStream] | None: + async def stream_to_final(stream: AgentStream) -> AsyncIterator[_messages.AgentStreamEvent]: + nonlocal final_result_event async for event in stream: + yield event if isinstance(event, _messages.FinalResultEvent): - return FinalResult(s, event.tool_name, event.tool_call_id) - return None + final_result_event = event + break + + if event_stream_handler is not None: + await event_stream_handler( + self, _agent_graph.build_run_context(graph_ctx), stream_to_final(stream) + ) + else: + async for _ in stream_to_final(stream): + pass - final_result = await stream_to_final(stream) - if final_result is not None: + if final_result_event is not None: + final_result = FinalResult( + stream, final_result_event.tool_name, final_result_event.tool_call_id + ) if yielded: raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover yielded = True @@ -513,6 +543,10 @@ async def on_complete() -> None: on_complete, ) break + elif self.is_call_tools_node(node) and event_stream_handler is not None: + async with node.stream(agent_run.ctx) as stream: + await event_stream_handler(self, _agent_graph.build_run_context(agent_run.ctx), stream) + next_node = await agent_run.next(node) if not isinstance(next_node, _agent_graph.AgentNode): raise exceptions.AgentRunError( # pragma: no cover @@ -966,7 +1000,7 @@ def output_type(self) -> OutputSpec[OutputDataT]: return self.wrapped.output_type @property - def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT, OutputDataT] | None: return self.wrapped.event_stream_handler @property @@ -1219,7 +1253,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) - _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) + _event_stream_handler: EventStreamHandler[AgentDepsT, OutputDataT] | None = dataclasses.field(repr=False) _enter_lock: Lock = dataclasses.field(repr=False) _entered_count: int = dataclasses.field(repr=False) @@ -1249,7 +1283,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, OutputDataT] | None = None, ) -> None: ... @overload @@ -1310,7 +1344,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, OutputDataT] | None = None, ) -> None: ... def __init__( @@ -1337,7 +1371,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, OutputDataT] | None = None, **_deprecated_kwargs: Any, ): """Create an agent. @@ -1522,7 +1556,7 @@ def output_type(self) -> OutputSpec[OutputDataT]: return self._output_type @property - def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT, OutputDataT] | None: return self._event_stream_handler def __repr__(self) -> str: # pragma: no cover diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py index 9017ec6a7..e05c7c67d 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -15,11 +15,12 @@ usage as _usage, ) from pydantic_ai._run_context import AgentDepsT -from pydantic_ai.agent import AbstractAgent, AgentRun, RunOutputDataT, WrapperAgent +from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent from pydantic_ai.exceptions import UserError from pydantic_ai.ext.temporal._run_context import TemporalRunContext from pydantic_ai.models import Model from pydantic_ai.output import OutputDataT, OutputSpec +from pydantic_ai.result import StreamedRunResult from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import ( Tool, @@ -69,7 +70,7 @@ def __init__( 'Model cannot be set at agent run time when using Temporal, it must be set at agent creation time.' ) - temporal_model = TemporalModel(agent.model, activity_config, agent.event_stream_handler, run_context_type) + temporal_model = TemporalModel(agent.model, agent, activity_config, run_context_type) activities.extend(temporal_model.temporal_activities) def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]: @@ -107,6 +108,322 @@ def toolset(self) -> AbstractToolset[AgentDepsT]: def temporal_activities(self) -> list[Callable[..., Any]]: return self._temporal_activities + @contextmanager + def _temporal_overrides(self) -> Iterator[None]: + # We reset tools here as the temporalized function toolset is already in self._toolset. + with super().override(model=self._model, toolsets=[self._toolset], tools=[]): + yield + + @overload + async def run( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, OutputDataT] | None = None, + ) -> AgentRunResult[OutputDataT]: ... + + @overload + async def run( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, RunOutputDataT] | None = None, + ) -> AgentRunResult[RunOutputDataT]: ... + + async def run( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, Any] | None = None, + **_deprecated_kwargs: Never, + ) -> AgentRunResult[Any]: + """Run the agent with a user prompt in async mode. + + This method builds an internal agent graph (using system prompts, tools and result schemas) and then + runs the graph to completion. The result of the run is returned. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + agent_run = await agent.run('What is the capital of France?') + print(agent_run.output) + #> Paris + ``` + + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional event stream handler to use for this run. + + Returns: + The result of the run. + """ + if workflow.in_workflow() and event_stream_handler is not None: + raise UserError( + 'Event stream handler cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + ) + + return await super().run( + user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + event_stream_handler=event_stream_handler, + **_deprecated_kwargs, + ) + + @overload + def run_sync( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, OutputDataT] | None = None, + ) -> AgentRunResult[OutputDataT]: ... + + @overload + def run_sync( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, RunOutputDataT] | None = None, + ) -> AgentRunResult[RunOutputDataT]: ... + + def run_sync( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, Any] | None = None, + **_deprecated_kwargs: Never, + ) -> AgentRunResult[Any]: + """Synchronously run the agent with a user prompt. + + This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. + You therefore can't use this method inside async code or if there's an active event loop. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + result_sync = agent.run_sync('What is the capital of Italy?') + print(result_sync.output) + #> Rome + ``` + + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional event stream handler to use for this run. + + Returns: + The result of the run. + """ + if workflow.in_workflow() and event_stream_handler is not None: + raise UserError( + 'Event stream handler cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + ) + + return super().run_sync( + user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + event_stream_handler=event_stream_handler, + **_deprecated_kwargs, + ) + + @overload + def run_stream( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, OutputDataT] | None = None, + ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ... + + @overload + def run_stream( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, RunOutputDataT] | None = None, + ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... + + @asynccontextmanager + async def run_stream( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, Any] | None = None, + **_deprecated_kwargs: Never, + ) -> AsyncIterator[StreamedRunResult[AgentDepsT, Any]]: + """Run the agent with a user prompt in async mode, returning a streamed response. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + async with agent.run_stream('What is the capital of the UK?') as response: + print(await response.get_output()) + #> London + ``` + + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional event stream handler to use for this run. It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager. + + Returns: + The result of the run. + """ + if workflow.in_workflow() and event_stream_handler is not None: + raise UserError( + 'Event stream handler cannot be set at agent run time when using Temporal, it must be set at agent creation time.' + ) + + async with super().run_stream( + user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + event_stream_handler=event_stream_handler, + **_deprecated_kwargs, + ) as result: + yield result + @overload def iter( self, @@ -273,12 +590,6 @@ async def main(): ) as result: yield result - @contextmanager - def _temporal_overrides(self) -> Iterator[None]: - # We reset tools here as the temporalized function toolset is already in self._toolset. - with super().override(model=self._model, toolsets=[self._toolset], tools=[]): - yield - @contextmanager def override( self, @@ -299,22 +610,19 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. """ - if not workflow.in_workflow(): - with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools): - yield - - if _utils.is_set(model): - raise UserError( - 'Model cannot be contextually overridden when using Temporal, it must be set at agent creation time.' - ) - if _utils.is_set(toolsets): - raise UserError( - 'Toolsets cannot be contextually overridden when using Temporal, they must be set at agent creation time.' - ) - if _utils.is_set(tools): - raise UserError( - 'Tools cannot be contextually overridden when using Temporal, they must be set at agent creation time.' - ) + if workflow.in_workflow(): + if _utils.is_set(model): + raise UserError( + 'Model cannot be contextually overridden when using Temporal, it must be set at agent creation time.' + ) + if _utils.is_set(toolsets): + raise UserError( + 'Toolsets cannot be contextually overridden when using Temporal, they must be set at agent creation time.' + ) + if _utils.is_set(tools): + raise UserError( + 'Tools cannot be contextually overridden when using Temporal, they must be set at agent creation time.' + ) - with super().override(deps=deps): + with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools): yield diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py index e179f06fc..3d6d84f4f 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_model.py @@ -11,7 +11,7 @@ from temporalio.workflow import ActivityConfig from pydantic_ai._run_context import RunContext -from pydantic_ai.agent import EventStreamHandler +from pydantic_ai.agent import AbstractAgent from pydantic_ai.exceptions import UserError from pydantic_ai.messages import ( ModelMessage, @@ -68,13 +68,13 @@ class TemporalModel(WrapperModel): def __init__( self, model: Model, + agent: AbstractAgent[Any, Any], activity_config: ActivityConfig = {}, - event_stream_handler: EventStreamHandler[Any] | None = None, run_context_type: type[TemporalRunContext] = TemporalRunContext, ): super().__init__(model) self.activity_config = activity_config - self.event_stream_handler = event_stream_handler + self.agent = agent self.run_context_type = run_context_type id = '_'.join([model.system, model.model_name]) @@ -91,8 +91,11 @@ async def request_stream_activity(params: _RequestParams) -> ModelResponse: async with self.wrapped.request_stream( params.messages, params.model_settings, params.model_request_parameters, run_context ) as streamed_response: - assert event_stream_handler is not None - await event_stream_handler(run_context, streamed_response) + event_stream_handler = self.agent.event_stream_handler + if event_stream_handler is None: + raise UserError('Streaming with Temporal requires `Agent` to have an `event_stream_handler` set.') + + await event_stream_handler(self.agent, run_context, streamed_response) async for _ in streamed_response: pass @@ -132,8 +135,6 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: - if self.event_stream_handler is None: - raise UserError('Streaming with Temporal requires `Agent` to have an `event_stream_handler`') if run_context is None: raise UserError('Streaming with Temporal requires `request_stream` to be called with a `run_context`') diff --git a/temporal.py b/temporal.py index 5d3533a10..0ff6593be 100644 --- a/temporal.py +++ b/temporal.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import random from collections.abc import AsyncIterable @@ -12,6 +14,7 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, RunContext +from pydantic_ai.agent import AbstractAgent from pydantic_ai.ext.temporal import ( AgentPlugin, LogfirePlugin, @@ -28,7 +31,11 @@ class Deps(TypedDict): country: str -async def event_stream_handler(ctx: RunContext[Deps], stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent]): +async def event_stream_handler( + agent: AbstractAgent[Deps, Response], + ctx: RunContext[Deps], + stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], +): logfire.info(f'{ctx.run_step=}') async for event in stream: logfire.info(f'{event=}') diff --git a/tests/test_streaming.py b/tests/test_streaming.py index e8861a0e0..711eb990b 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -3,7 +3,7 @@ import datetime import json import re -from collections.abc import AsyncIterator +from collections.abc import AsyncIterable, AsyncIterator from copy import deepcopy from dataclasses import replace from datetime import timezone @@ -14,17 +14,21 @@ from pydantic import BaseModel from pydantic_ai import Agent, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages -from pydantic_ai.agent import AgentRun +from pydantic_ai.agent import AbstractAgent, AgentRun from pydantic_ai.messages import ( + AgentStreamEvent, FinalResultEvent, FunctionToolCallEvent, FunctionToolResultEvent, + HandleResponseEvent, ModelMessage, ModelRequest, ModelResponse, + PartDeltaEvent, PartStartEvent, RetryPromptPart, TextPart, + TextPartDelta, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -1249,3 +1253,139 @@ def my_tool(x: int) -> int: FunctionToolCallEvent(part=ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())), ] ) + + +async def test_run_event_stream_handler(): + m = TestModel() + + test_agent = Agent(m) + assert test_agent.name is None + + @test_agent.tool_plain + async def ret_a(x: str) -> str: + return f'{x}-apple' + + events: list[AgentStreamEvent | HandleResponseEvent] = [] + + async def event_stream_handler( + agent: AbstractAgent[None, str], + ctx: RunContext[None], + stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], + ): + async for event in stream: + events.append(event) + + result = await test_agent.run('Hello', event_stream_handler=event_stream_handler) + assert result.output == snapshot('{"ret_a":"a-apple"}') + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr()), + ), + FunctionToolCallEvent(part=ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())), + FunctionToolResultEvent( + result=ToolReturnPart( + tool_name='ret_a', + content='a-apple', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ) + ), + PartStartEvent(index=0, part=TextPart(content='')), + FinalResultEvent(tool_name=None, tool_call_id=None), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='{"ret_a":')), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='"a-apple"}')), + ] + ) + + +def test_run_sync_event_stream_handler(): + m = TestModel() + + test_agent = Agent(m) + assert test_agent.name is None + + @test_agent.tool_plain + async def ret_a(x: str) -> str: + return f'{x}-apple' + + events: list[AgentStreamEvent | HandleResponseEvent] = [] + + async def event_stream_handler( + agent: AbstractAgent[None, str], + ctx: RunContext[None], + stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], + ): + async for event in stream: + events.append(event) + + result = test_agent.run_sync('Hello', event_stream_handler=event_stream_handler) + assert result.output == snapshot('{"ret_a":"a-apple"}') + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr()), + ), + FunctionToolCallEvent(part=ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())), + FunctionToolResultEvent( + result=ToolReturnPart( + tool_name='ret_a', + content='a-apple', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ) + ), + PartStartEvent(index=0, part=TextPart(content='')), + FinalResultEvent(tool_name=None, tool_call_id=None), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='{"ret_a":')), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='"a-apple"}')), + ] + ) + + +async def test_run_stream_event_stream_handler(): + m = TestModel() + + test_agent = Agent(m) + assert test_agent.name is None + + @test_agent.tool_plain + async def ret_a(x: str) -> str: + return f'{x}-apple' + + events: list[AgentStreamEvent | HandleResponseEvent] = [] + + async def event_stream_handler( + agent: AbstractAgent[None, str], + ctx: RunContext[None], + stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], + ): + async for event in stream: + events.append(event) + + async with test_agent.run_stream('Hello', event_stream_handler=event_stream_handler) as result: + assert [c async for c in result.stream(debounce_by=None)] == snapshot( + ['{"ret_a":', '{"ret_a":"a-apple"}', '{"ret_a":"a-apple"}'] + ) + + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr()), + ), + FunctionToolCallEvent(part=ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())), + FunctionToolResultEvent( + result=ToolReturnPart( + tool_name='ret_a', + content='a-apple', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ) + ), + PartStartEvent(index=0, part=TextPart(content='')), + FinalResultEvent(tool_name=None, tool_call_id=None), + ] + ) From 653b84f47a9659e7d817c3d0490683c3a354a732 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 20:12:39 +0000 Subject: [PATCH 32/41] Improve tool name conflict error messages --- pydantic_ai_slim/pydantic_ai/_output.py | 4 +++ pydantic_ai_slim/pydantic_ai/ag_ui.py | 35 +++++++++++++------ pydantic_ai_slim/pydantic_ai/agent.py | 30 +++++++++++----- .../pydantic_ai/ext/temporal/_agent.py | 2 +- .../pydantic_ai/toolsets/combined.py | 8 +++-- tests/test_mcp.py | 2 +- tests/test_tools.py | 4 +-- 7 files changed, 60 insertions(+), 25 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index ffec1df02..c467bc7b4 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -981,6 +981,10 @@ def __init__( def id(self) -> str | None: return '' + @property + def label(self) -> str: + return "the agent's output tools" + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: return { tool_def.name: ToolsetTool( diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index 572de1385..7f7f8ec68 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -72,6 +72,7 @@ ThinkingTextMessageContentEvent, ThinkingTextMessageEndEvent, ThinkingTextMessageStartEvent, + Tool as AGUITool, ToolCallArgsEvent, ToolCallEndEvent, ToolCallResultEvent, @@ -305,17 +306,7 @@ async def run_ag_ui( # AG-UI tools can't be prefixed as that would result in a mismatch between the tool names in the # Pydantic AI events and actual AG-UI tool names, preventing the tool from being called. If any # conflicts arise, the AG-UI tool should be renamed or a `PrefixedToolset` used for local toolsets. - toolset = DeferredToolset[AgentDepsT]( - [ - ToolDefinition( - name=tool.name, - description=tool.description, - parameters_json_schema=tool.parameters, - ) - for tool in run_input.tools - ], - id='', - ) + toolset = _AGUIFrontendToolset[AgentDepsT](run_input.tools) toolsets = [*toolsets, toolset] if toolsets else [toolset] try: @@ -688,3 +679,25 @@ def __init__(self, tool_call_id: str) -> None: message=f'Tool call with ID {tool_call_id} not found in the history.', code='tool_call_not_found', ) + + +class _AGUIFrontendToolset(DeferredToolset[AgentDepsT]): + def __init__(self, tools: list[AGUITool]): + super().__init__( + [ + ToolDefinition( + name=tool.name, + description=tool.description, + parameters_json_schema=tool.parameters, + ) + for tool in tools + ] + ) + + @property + def id(self) -> str: + return '' + + @property + def label(self) -> str: + return 'the AG-UI frontend tools' diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index c2bbb4811..a31cf4dab 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -698,7 +698,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: - """Context manager to temporarily override agent dependencies, model, or toolsets. + """Context manager to temporarily override agent dependencies, model, toolsets, or tools. This is particularly useful when testing. You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). @@ -1167,7 +1167,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: - """Context manager to temporarily override agent dependencies, model, or toolsets. + """Context manager to temporarily override agent dependencies, model, toolsets, or tools. This is particularly useful when testing. You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). @@ -1252,6 +1252,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) + _max_tool_retries: int = dataclasses.field(repr=False) _event_stream_handler: EventStreamHandler[AgentDepsT, OutputDataT] | None = dataclasses.field(repr=False) @@ -1503,6 +1504,7 @@ def __init__( self._system_prompt_dynamic_functions = {} self._max_result_retries = output_retries if output_retries is not None else retries + self._max_tool_retries = retries self._prepare_tools = prepare_tools self._prepare_output_tools = prepare_output_tools @@ -1510,7 +1512,7 @@ def __init__( if self._output_toolset: self._output_toolset.max_retries = self._max_result_retries - self._function_toolset = FunctionToolset(tools, max_retries=retries, id='') + self._function_toolset = _AgentFunctionToolset(tools, max_retries=self._max_tool_retries) self._user_toolsets = toolsets or () self.history_processors = history_processors or [] @@ -1891,7 +1893,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: - """Context manager to temporarily override agent dependencies, model, or toolsets. + """Context manager to temporarily override agent dependencies, model, toolsets, or tools. This is particularly useful when testing. You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). @@ -2383,9 +2385,7 @@ def _get_toolset( additional: Additional toolsets to add. """ if some_tools := self._override_tools.get(): - function_toolset = FunctionToolset( - some_tools.value, max_retries=self._function_toolset.max_retries, id='' - ) + function_toolset = _AgentFunctionToolset(some_tools.value, max_retries=self._max_tool_retries) else: function_toolset = self._function_toolset @@ -2396,7 +2396,10 @@ def _get_toolset( else: user_toolsets = self._user_toolsets - toolset = CombinedToolset([function_toolset, *user_toolsets]) + if user_toolsets: + toolset = CombinedToolset([function_toolset, *user_toolsets]) + else: + toolset = function_toolset if self._prepare_tools: toolset = PreparedToolset(toolset, self._prepare_tools) @@ -2883,3 +2886,14 @@ def new_messages_json( def usage(self) -> _usage.Usage: """Return the usage of the whole run.""" return self._state.usage + + +@dataclasses.dataclass(init=False) +class _AgentFunctionToolset(FunctionToolset[AgentDepsT]): + @property + def id(self) -> str: + return '' + + @property + def label(self) -> str: + return 'the agent' diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py index e05c7c67d..2ac043310 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -599,7 +599,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: - """Context manager to temporarily override agent dependencies, model, or toolsets. + """Context manager to temporarily override agent dependencies, model, toolsets, or tools. This is particularly useful when testing. You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index f02c32b3c..dfbe917a4 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -46,7 +46,10 @@ def id(self) -> str | None: @property def label(self) -> str: - return f'{self.__class__.__name__}({", ".join(toolset.label for toolset in self.toolsets)})' + if len(self.toolsets) == 1: + return self.toolsets[0].label + else: + return f'{self.__class__.__name__}({", ".join(toolset.label for toolset in self.toolsets)})' async def __aenter__(self) -> Self: async with self._enter_lock: @@ -72,8 +75,9 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[ for toolset, tools in zip(self.toolsets, toolsets_tools): for name, tool in tools.items(): if existing_tools := all_tools.get(name): + capitalized_toolset_label = toolset.label[0].upper() + toolset.label[1:] raise UserError( - f'{toolset.label} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.label}: {name!r}. {toolset.tool_name_conflict_hint}' + f'{capitalized_toolset_label} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.label}: {name!r}. {toolset.tool_name_conflict_hint}' ) all_tools[name] = _CombinedToolsetTool( diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 9f4e5e5f0..178d0b362 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -261,7 +261,7 @@ def get_none() -> None: # pragma: no cover with pytest.raises( UserError, match=re.escape( - "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server']) defines a tool whose name conflicts with existing tool from FunctionToolset 'agent': 'get_none'. Set the `tool_prefix` attribute to avoid name conflicts." + "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server']) defines a tool whose name conflicts with existing tool from the agent: 'get_none'. Set the `tool_prefix` attribute to avoid name conflicts." ), ): await agent.run('Get me a conflict') diff --git a/tests/test_tools.py b/tests/test_tools.py index 94336c307..c88ce659c 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -618,7 +618,7 @@ def test_tool_return_conflict(): with pytest.raises( UserError, match=re.escape( - "FunctionToolset 'agent' defines a tool whose name conflicts with existing tool from OutputToolset 'output': 'ctx_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts." + "The agent defines a tool whose name conflicts with existing tool from the agent's output tools: 'ctx_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts." ), ): Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')).run_sync( @@ -630,7 +630,7 @@ def test_tool_name_conflict_hint(): with pytest.raises( UserError, match=re.escape( - "PrefixedToolset(FunctionToolset 'tool') defines a tool whose name conflicts with existing tool from FunctionToolset 'agent': 'foo_tool'. Change the `prefix` attribute to avoid name conflicts." + "PrefixedToolset(FunctionToolset 'tool') defines a tool whose name conflicts with existing tool from the agent: 'foo_tool'. Change the `prefix` attribute to avoid name conflicts." ), ): From 1e6312150ea290246d334719430ea860c214e649 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 22:41:02 +0000 Subject: [PATCH 33/41] Add Temporal test --- .../pydantic_ai/ext/temporal/__init__.py | 1 + temporal.py | 139 --- .../test_temporal/test_temporal.yaml | 929 ++++++++++++++++++ tests/test_temporal.py | 289 ++++++ 4 files changed, 1219 insertions(+), 139 deletions(-) delete mode 100644 temporal.py create mode 100644 tests/cassettes/test_temporal/test_temporal.yaml create mode 100644 tests/test_temporal.py diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py index 49a1da7b0..abe9a63ed 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/__init__.py @@ -49,6 +49,7 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: 'pydantic_ai', 'logfire', 'rich', + 'httpx', # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize 'attrs', # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize diff --git a/temporal.py b/temporal.py deleted file mode 100644 index 0ff6593be..000000000 --- a/temporal.py +++ /dev/null @@ -1,139 +0,0 @@ -from __future__ import annotations - -import asyncio -import random -from collections.abc import AsyncIterable -from dataclasses import dataclass -from datetime import timedelta - -import logfire -from temporalio import workflow -from temporalio.client import Client -from temporalio.worker import Worker -from temporalio.workflow import ActivityConfig -from typing_extensions import TypedDict - -from pydantic_ai import Agent, RunContext -from pydantic_ai.agent import AbstractAgent -from pydantic_ai.ext.temporal import ( - AgentPlugin, - LogfirePlugin, - PydanticAIPlugin, - TemporalAgent, - TemporalRunContextWithDeps, -) -from pydantic_ai.mcp import MCPServerStdio -from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent -from pydantic_ai.toolsets.function import FunctionToolset - - -class Deps(TypedDict): - country: str - - -async def event_stream_handler( - agent: AbstractAgent[Deps, Response], - ctx: RunContext[Deps], - stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], -): - logfire.info(f'{ctx.run_step=}') - async for event in stream: - logfire.info(f'{event=}') - - -async def get_country(ctx: RunContext[Deps]) -> str: - return ctx.deps['country'] - - -def get_weather(city: str) -> str: - return 'sunny' - - -@dataclass -class Answer: - label: str - answer: str - - -@dataclass -class Response: - answers: list[Answer] - - -agent = Agent( - 'openai:gpt-4o', - deps_type=Deps, - output_type=Response, - toolsets=[ - FunctionToolset[Deps](tools=[get_weather], id='toolset'), - MCPServerStdio('python', ['-m', 'tests.mcp_server'], timeout=20, id='mcp'), - ], - tools=[get_country], - event_stream_handler=event_stream_handler, -) - -# This needs to be called in the same scope where the `agent` is bound to the workflow, -# as it modifies the `agent` object in place to swap out methods that use IO for ones that use Temporal activities. -temporal_agent = TemporalAgent( - agent, - activity_config=ActivityConfig(start_to_close_timeout=timedelta(seconds=60)), - toolset_activity_config={ - 'country': ActivityConfig(start_to_close_timeout=timedelta(seconds=120)), - }, - tool_activity_config={ - 'toolset': { - 'get_country': False, - 'get_weather': ActivityConfig(start_to_close_timeout=timedelta(seconds=180)), - }, - }, - run_context_type=TemporalRunContextWithDeps, -) - -with workflow.unsafe.imports_passed_through(): - import pandas # noqa: F401 - - -@workflow.defn -class MyAgentWorkflow: - @workflow.run - async def run(self, prompt: str, deps: Deps) -> Response: - result = await temporal_agent.run(prompt, deps=deps) - return result.output - - -TASK_QUEUE = 'pydantic-ai-agent-task-queue' - - -def setup_logfire(): - instance = logfire.configure() - logfire.instrument_pydantic_ai() - logfire.instrument_httpx(capture_all=True) - return instance - - -async def main(): - client = await Client.connect( - 'localhost:7233', - plugins=[PydanticAIPlugin(), LogfirePlugin(setup_logfire)], - ) - - async with Worker( - client, - task_queue=TASK_QUEUE, - workflows=[MyAgentWorkflow], - plugins=[AgentPlugin(temporal_agent)], - ): - output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] - MyAgentWorkflow.run, - args=[ - 'what is the capital of the country? what is the weather there? what is the product name?', - Deps(country='Mexico'), - ], - id=f'my-agent-workflow-id-{random.random()}', - task_queue=TASK_QUEUE, - ) - print(output) - - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/tests/cassettes/test_temporal/test_temporal.yaml b/tests/cassettes/test_temporal/test_temporal.yaml new file mode 100644 index 000000000..cf5092580 --- /dev/null +++ b/tests/cassettes/test_temporal/test_temporal.yaml @@ -0,0 +1,929 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '4294' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: 'Tell me: the capital of the country; the weather there; the product name' + role: user + model: gpt-4o + stream: true + stream_options: + include_usage: true + tool_choice: required + tools: + - function: + description: '' + name: get_weather + parameters: + additionalProperties: false + properties: + city: + type: string + required: + - city + type: object + strict: true + type: function + - function: + description: '' + name: get_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + additionalProperties: false + properties: + celsius: + type: number + required: + - celsius + type: object + strict: true + type: function + - function: + description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + additionalProperties: false + properties: + location: + type: string + required: + - location + type: object + strict: true + type: function + - function: + description: '' + name: get_image_resource + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_image_resource_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + additionalProperties: false + properties: + value: + default: false + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: "Echo the run context.\n\n Args:\n ctx: Context object containing request and session information.\n\n + \ Returns:\n Dictionary with an echo message and the deps.\n " + name: echo_deps + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: Use sampling callback. + name: use_sampling + parameters: + additionalProperties: false + properties: + foo: + type: string + required: + - foo + type: object + strict: true + type: function + - function: + description: The final response which ends this conversation + name: final_result + parameters: + $defs: + Answer: + additionalProperties: false + properties: + answer: + type: string + label: + type: string + required: + - label + - answer + type: object + additionalProperties: false + properties: + answers: + items: + $ref: '#/$defs/Answer' + type: array + required: + - answers + type: object + strict: true + type: function + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: |+ + data: {"id":"chatcmpl-C1KMEUDb1vVwsROQUCZTgG6A6vtWo","object":"chat.completion.chunk","created":1754432618,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","usage":null,"choices":[{"index":0,"delta":{"role":"assistant","content":null},"logprobs":null,"finish_reason":null}],"obfuscation":"jP9abrn9XF3"} + + data: {"id":"chatcmpl-C1KMEUDb1vVwsROQUCZTgG6A6vtWo","object":"chat.completion.chunk","created":1754432618,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_3rqTYrA6H21AYUaRGP4F66oq","type":"function","function":{"name":"get_country","arguments":""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"8RXlE4Z5NT"} + + data: {"id":"chatcmpl-C1KMEUDb1vVwsROQUCZTgG6A6vtWo","object":"chat.completion.chunk","created":1754432618,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{}"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"CNtv"} + + data: {"id":"chatcmpl-C1KMEUDb1vVwsROQUCZTgG6A6vtWo","object":"chat.completion.chunk","created":1754432618,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"id":"call_Xw9XMKBJU48kAAd78WgIswDx","type":"function","function":{"name":"get_product_name","arguments":""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"TomlO"} + + data: {"id":"chatcmpl-C1KMEUDb1vVwsROQUCZTgG6A6vtWo","object":"chat.completion.chunk","created":1754432618,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"{}"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"4Gko"} + + data: {"id":"chatcmpl-C1KMEUDb1vVwsROQUCZTgG6A6vtWo","object":"chat.completion.chunk","created":1754432618,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null,"obfuscation":"GtnJ"} + + data: {"id":"chatcmpl-C1KMEUDb1vVwsROQUCZTgG6A6vtWo","object":"chat.completion.chunk","created":1754432618,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[],"usage":{"prompt_tokens":364,"completion_tokens":40,"total_tokens":404,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}},"obfuscation":"Ohcj4NkgRRFLL"} + + data: [DONE] + + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-type: + - text/event-stream; charset=utf-8 + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '756' + openai-project: + - proj_dKobscVY9YJxeEaDJen54e3d + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '4720' + content-type: + - application/json + cookie: + - __cf_bm=glnXI8zJVAJ0K_QdEGsOKKRgWNfzACGXNWXiwREzaLg-1754432619-1.0.1.1-_Ef07EWA.dA.ieqsA1ZV5wshb3Z4zXVgZ6bbQCLkVEXqzEUQ4cPSApZhDGjVWQMg9aEywh0CfTkaZwvW0rjDh_nKfoZ5Cc8fMVgN5gAyMNc; + _cfuvid=Ckz8lfebgV0n8QtvIIYuQIvlcwwiwc67I0Aw.L8t4rM-1754432619011-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: 'Tell me: the capital of the country; the weather there; the product name' + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_country + id: call_3rqTYrA6H21AYUaRGP4F66oq + type: function + - function: + arguments: '{}' + name: get_product_name + id: call_Xw9XMKBJU48kAAd78WgIswDx + type: function + - content: Mexico + role: tool + tool_call_id: call_3rqTYrA6H21AYUaRGP4F66oq + - content: Pydantic AI + role: tool + tool_call_id: call_Xw9XMKBJU48kAAd78WgIswDx + model: gpt-4o + stream: true + stream_options: + include_usage: true + tool_choice: required + tools: + - function: + description: '' + name: get_weather + parameters: + additionalProperties: false + properties: + city: + type: string + required: + - city + type: object + strict: true + type: function + - function: + description: '' + name: get_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + additionalProperties: false + properties: + celsius: + type: number + required: + - celsius + type: object + strict: true + type: function + - function: + description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + additionalProperties: false + properties: + location: + type: string + required: + - location + type: object + strict: true + type: function + - function: + description: '' + name: get_image_resource + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_image_resource_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + additionalProperties: false + properties: + value: + default: false + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: "Echo the run context.\n\n Args:\n ctx: Context object containing request and session information.\n\n + \ Returns:\n Dictionary with an echo message and the deps.\n " + name: echo_deps + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: Use sampling callback. + name: use_sampling + parameters: + additionalProperties: false + properties: + foo: + type: string + required: + - foo + type: object + strict: true + type: function + - function: + description: The final response which ends this conversation + name: final_result + parameters: + $defs: + Answer: + additionalProperties: false + properties: + answer: + type: string + label: + type: string + required: + - label + - answer + type: object + additionalProperties: false + properties: + answers: + items: + $ref: '#/$defs/Answer' + type: array + required: + - answers + type: object + strict: true + type: function + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: |+ + data: {"id":"chatcmpl-C1KMJC4uUHgeJ4A0e8jM8wufrmdxX","object":"chat.completion.chunk","created":1754432623,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"pJt0pVt5b"} + + data: {"id":"chatcmpl-C1KMJC4uUHgeJ4A0e8jM8wufrmdxX","object":"chat.completion.chunk","created":1754432623,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"u17"} + + data: {"id":"chatcmpl-C1KMJC4uUHgeJ4A0e8jM8wufrmdxX","object":"chat.completion.chunk","created":1754432623,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"city"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"W5"} + + data: {"id":"chatcmpl-C1KMJC4uUHgeJ4A0e8jM8wufrmdxX","object":"chat.completion.chunk","created":1754432623,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"R"} + + data: {"id":"chatcmpl-C1KMJC4uUHgeJ4A0e8jM8wufrmdxX","object":"chat.completion.chunk","created":1754432623,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Mexico"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":""} + + data: {"id":"chatcmpl-C1KMJC4uUHgeJ4A0e8jM8wufrmdxX","object":"chat.completion.chunk","created":1754432623,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" City"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"d"} + + data: {"id":"chatcmpl-C1KMJC4uUHgeJ4A0e8jM8wufrmdxX","object":"chat.completion.chunk","created":1754432623,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"y7y"} + + data: {"id":"chatcmpl-C1KMJC4uUHgeJ4A0e8jM8wufrmdxX","object":"chat.completion.chunk","created":1754432623,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null,"obfuscation":"Wj82"} + + data: {"id":"chatcmpl-C1KMJC4uUHgeJ4A0e8jM8wufrmdxX","object":"chat.completion.chunk","created":1754432623,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[],"usage":{"prompt_tokens":423,"completion_tokens":15,"total_tokens":438,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}},"obfuscation":"Umu4rjZrtKjmq"} + + data: [DONE] + + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-type: + - text/event-stream; charset=utf-8 + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '535' + openai-project: + - proj_dKobscVY9YJxeEaDJen54e3d + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '4969' + content-type: + - application/json + cookie: + - __cf_bm=glnXI8zJVAJ0K_QdEGsOKKRgWNfzACGXNWXiwREzaLg-1754432619-1.0.1.1-_Ef07EWA.dA.ieqsA1ZV5wshb3Z4zXVgZ6bbQCLkVEXqzEUQ4cPSApZhDGjVWQMg9aEywh0CfTkaZwvW0rjDh_nKfoZ5Cc8fMVgN5gAyMNc; + _cfuvid=Ckz8lfebgV0n8QtvIIYuQIvlcwwiwc67I0Aw.L8t4rM-1754432619011-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: 'Tell me: the capital of the country; the weather there; the product name' + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_country + id: call_3rqTYrA6H21AYUaRGP4F66oq + type: function + - function: + arguments: '{}' + name: get_product_name + id: call_Xw9XMKBJU48kAAd78WgIswDx + type: function + - content: Mexico + role: tool + tool_call_id: call_3rqTYrA6H21AYUaRGP4F66oq + - content: Pydantic AI + role: tool + tool_call_id: call_Xw9XMKBJU48kAAd78WgIswDx + - role: assistant + tool_calls: + - function: + arguments: '{"city":"Mexico City"}' + name: get_weather + id: call_Vz0Sie91Ap56nH0ThKGrZXT7 + type: function + - content: sunny + role: tool + tool_call_id: call_Vz0Sie91Ap56nH0ThKGrZXT7 + model: gpt-4o + stream: true + stream_options: + include_usage: true + tool_choice: required + tools: + - function: + description: '' + name: get_weather + parameters: + additionalProperties: false + properties: + city: + type: string + required: + - city + type: object + strict: true + type: function + - function: + description: '' + name: get_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + additionalProperties: false + properties: + celsius: + type: number + required: + - celsius + type: object + strict: true + type: function + - function: + description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + additionalProperties: false + properties: + location: + type: string + required: + - location + type: object + strict: true + type: function + - function: + description: '' + name: get_image_resource + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_image_resource_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name_link + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + additionalProperties: false + properties: + value: + default: false + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: "Echo the run context.\n\n Args:\n ctx: Context object containing request and session information.\n\n + \ Returns:\n Dictionary with an echo message and the deps.\n " + name: echo_deps + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: Use sampling callback. + name: use_sampling + parameters: + additionalProperties: false + properties: + foo: + type: string + required: + - foo + type: object + strict: true + type: function + - function: + description: The final response which ends this conversation + name: final_result + parameters: + $defs: + Answer: + additionalProperties: false + properties: + answer: + type: string + label: + type: string + required: + - label + - answer + type: object + additionalProperties: false + properties: + answers: + items: + $ref: '#/$defs/Answer' + type: array + required: + - answers + type: object + strict: true + type: function + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: |+ + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_4kc6691zCzjPnOuEtbEGUvz2","type":"function","function":{"name":"final_result","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"fJykmX6H"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"MvE"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"answers"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"dktqfAJehiLPjyt"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":["}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"1O"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"Lvw"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"label"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"f"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"U"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Capital"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"dfpOaGxthacFGXR"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" of"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"ofN"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" the"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"Lw"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" country"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"dTcI51M3iiQGmY"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\",\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"8"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"answer"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":""} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"5"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Mexico"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":""} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" City"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"l"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"},{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"joxN8cSzo5Vtw0V"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"label"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"H"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"x"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Weather"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"naC5I0l5UxNvni5"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" in"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"Osv"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" the"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"Hz"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" capital"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"XpkABYJn503NY0"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\",\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"D"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"answer"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":""} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"u"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Sunny"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"Y"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"},{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"q6lCElEngpao86s"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"label"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"l"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"p"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Product"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"xxsXQcQiDZz87WR"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Name"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"z"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\",\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"g"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"answer"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":""} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"i"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"P"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"4sQOS"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"yd"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"cEjx"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"antic"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"l"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" AI"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"1Fy"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"bId"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"]}"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"YeZF"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null,"obfuscation":"T9AO"} + + data: {"id":"chatcmpl-C1KMMrEA9QLIX25pFjKjoRdNkO0nN","object":"chat.completion.chunk","created":1754432626,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_ff25b2783a","choices":[],"usage":{"prompt_tokens":448,"completion_tokens":49,"total_tokens":497,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}},"obfuscation":"0pSy41lq4PYDj"} + + data: [DONE] + + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-type: + - text/event-stream; charset=utf-8 + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '585' + openai-project: + - proj_dKobscVY9YJxeEaDJen54e3d + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + status: + code: 200 + message: OK +version: 1 +... diff --git a/tests/test_temporal.py b/tests/test_temporal.py new file mode 100644 index 000000000..2c2e16d5f --- /dev/null +++ b/tests/test_temporal.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +from collections.abc import AsyncIterable, AsyncIterator +from dataclasses import dataclass +from datetime import timedelta + +import logfire +from inline_snapshot import snapshot +from typing_extensions import TypedDict + +from pydantic_ai import Agent, RunContext +from pydantic_ai.agent import AbstractAgent +from pydantic_ai.mcp import MCPServerStdio +from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent +from pydantic_ai.toolsets.function import FunctionToolset + +try: + from temporalio import workflow + from temporalio.client import Client + from temporalio.testing import WorkflowEnvironment + from temporalio.worker import Worker + from temporalio.workflow import ActivityConfig + + from pydantic_ai.ext.temporal import ( + AgentPlugin, + LogfirePlugin, + PydanticAIPlugin, + TemporalAgent, + TemporalRunContextWithDeps, + ) +except ImportError: + import pytest + + pytest.skip('temporal not installed') + +try: + from logfire.testing import CaptureLogfire +except ImportError: + import pytest + + pytest.skip('logfire not installed') + +with workflow.unsafe.imports_passed_through(): + # Workaround for a race condition when running `logfire.info` inside an activity with attributes to serialize and pandas importable: + # AttributeError: partially initialized module 'pandas' has no attribute '_pandas_parser_CAPI' (most likely due to a circular import) + import pandas # pyright: ignore[reportUnusedImport] # noqa: F401 + + # https://github.com/temporalio/sdk-python/blob/3244f8bffebee05e0e7efefb1240a75039903dda/tests/test_client.py#L112C1-L113C1 + import pytest + + # Loads `vcr`, which Temporal doesn't like without passing through the import + from .conftest import IsStr + +pytestmark = [ + pytest.mark.anyio, + pytest.mark.vcr, +] + +TEMPORAL_PORT = 7243 + + +@pytest.fixture +async def env() -> AsyncIterator[WorkflowEnvironment]: + async with await WorkflowEnvironment.start_local(port=TEMPORAL_PORT) as env: # pyright: ignore[reportUnknownMemberType] + yield env + + +@pytest.fixture +async def client(env: WorkflowEnvironment) -> Client: + return await Client.connect( + f'localhost:{TEMPORAL_PORT}', + plugins=[PydanticAIPlugin(), LogfirePlugin()], + ) + + +class Deps(TypedDict): + country: str + + +async def event_stream_handler( + agent: AbstractAgent[Deps, Response], + ctx: RunContext[Deps], + stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent], +): + logfire.info(f'{ctx.run_step=}') + async for event in stream: + logfire.info(f'{event=}') + + +async def get_country(ctx: RunContext[Deps]) -> str: + return ctx.deps['country'] + + +def get_weather(city: str) -> str: + return 'sunny' + + +@dataclass +class Answer: + label: str + answer: str + + +@dataclass +class Response: + answers: list[Answer] + + +agent = Agent( + 'openai:gpt-4o', + deps_type=Deps, + output_type=Response, + toolsets=[ + FunctionToolset[Deps](tools=[get_country], id='country'), + MCPServerStdio('python', ['-m', 'tests.mcp_server'], timeout=20, id='mcp'), + ], + tools=[get_weather], + event_stream_handler=event_stream_handler, +) + +# This needs to be done before the `agent` is bound to the workflow. +temporal_agent = TemporalAgent( + agent, + activity_config=ActivityConfig(start_to_close_timeout=timedelta(seconds=60)), + toolset_activity_config={ + 'country': ActivityConfig(start_to_close_timeout=timedelta(seconds=120)), + }, + tool_activity_config={ + 'country': { + 'get_country': False, + }, + }, + run_context_type=TemporalRunContextWithDeps, +) + + +@workflow.defn +class AgentWorkflow: + @workflow.run + async def run(self, prompt: str, deps: Deps) -> Response: + result = await temporal_agent.run(prompt, deps=deps) + return result.output + + +async def test_temporal(allow_model_requests: None, client: Client, capfire: CaptureLogfire): + task_queue = 'pydantic-ai-agent-task-queue' + + async with Worker( + client, + task_queue=task_queue, + workflows=[AgentWorkflow], + plugins=[AgentPlugin(temporal_agent)], + ): + output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] + AgentWorkflow.run, + args=[ + 'Tell me: the capital of the country; the weather there; the product name', + Deps(country='Mexico'), + ], + id='pydantic-ai-agent-workflow', + task_queue=task_queue, + ) + assert output == snapshot( + Response( + answers=[ + Answer(label='Capital of the country', answer='Mexico City'), + Answer(label='Weather in the capital', answer='Sunny'), + Answer(label='Product Name', answer='Pydantic AI'), + ] + ) + ) + exporter = capfire.exporter + + parsed_spans: list[str | AgentStreamEvent | HandleResponseEvent] = [] + for span in exporter.exported_spans_as_dict(): + attributes = span['attributes'] + if event := attributes.get('event'): + parsed_spans.append(event) + else: + parsed_spans.append(attributes['logfire.msg']) + + assert parsed_spans == snapshot( + [ + 'StartWorkflow:AgentWorkflow', + 'RunWorkflow:AgentWorkflow', + 'StartActivity:mcp_server__mcp__get_tools', + 'RunActivity:mcp_server__mcp__get_tools', + 'StartActivity:mcp_server__mcp__get_tools', + 'RunActivity:mcp_server__mcp__get_tools', + 'StartActivity:model__openai_gpt-4o__request_stream', + 'ctx.run_step=1', + '{"index":0,"part":{"tool_name":"get_country","args":"","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","part_kind":"tool-call"},"event_kind":"part_start"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"{}","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":1,"part":{"tool_name":"get_product_name","args":"","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","part_kind":"tool-call"},"event_kind":"part_start"}', + '{"index":1,"delta":{"tool_name_delta":null,"args_delta":"{}","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + 'RunActivity:model__openai_gpt-4o__request_stream', + 'ctx.run_step=1', + 'chat gpt-4o', + 'ctx.run_step=1', + '{"part":{"tool_name":"get_country","args":"{}","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","part_kind":"tool-call"},"event_kind":"function_tool_call"}', + '{"part":{"tool_name":"get_product_name","args":"{}","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","part_kind":"tool-call"},"event_kind":"function_tool_call"}', + 'running tool: get_country', + 'StartActivity:mcp_server__mcp__call_tool', + IsStr( + regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}' + ), + 'RunActivity:mcp_server__mcp__call_tool', + 'running tool: get_product_name', + IsStr( + regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}' + ), + 'running 2 tools', + 'StartActivity:mcp_server__mcp__get_tools', + 'RunActivity:mcp_server__mcp__get_tools', + 'StartActivity:model__openai_gpt-4o__request_stream', + 'ctx.run_step=2', + '{"index":0,"part":{"tool_name":"get_weather","args":"","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","part_kind":"tool-call"},"event_kind":"part_start"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"{\\"","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"city","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\":\\"","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"Mexico","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":" City","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\"}","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + 'RunActivity:model__openai_gpt-4o__request_stream', + 'ctx.run_step=2', + 'chat gpt-4o', + 'ctx.run_step=2', + '{"part":{"tool_name":"get_weather","args":"{\\"city\\":\\"Mexico City\\"}","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","part_kind":"tool-call"},"event_kind":"function_tool_call"}', + 'StartActivity:function_toolset____call_tool', + 'RunActivity:function_toolset____call_tool', + 'running tool: get_weather', + IsStr( + regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}' + ), + 'running 1 tool', + 'StartActivity:mcp_server__mcp__get_tools', + 'RunActivity:mcp_server__mcp__get_tools', + 'StartActivity:model__openai_gpt-4o__request_stream', + 'ctx.run_step=3', + '{"index":0,"part":{"tool_name":"final_result","args":"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_kind":"tool-call"},"event_kind":"part_start"}', + '{"tool_name":"final_result","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","event_kind":"final_result"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"{\\"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"answers","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\":[","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"{\\"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"label","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\":\\"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"Capital","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":" of","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":" the","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":" country","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\",\\"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"answer","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\":\\"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"Mexico","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":" City","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\"},{\\"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"label","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\":\\"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"Weather","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":" in","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":" the","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":" capital","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\",\\"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"answer","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\":\\"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"Sunny","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\"},{\\"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"label","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\":\\"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"Product","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":" Name","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\",\\"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"answer","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\":\\"","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"P","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"yd","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"antic","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":" AI","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"\\"}","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + '{"index":0,"delta":{"tool_name_delta":null,"args_delta":"]}","tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","part_delta_kind":"tool_call"},"event_kind":"part_delta"}', + 'RunActivity:model__openai_gpt-4o__request_stream', + 'ctx.run_step=3', + 'chat gpt-4o', + 'ctx.run_step=3', + 'self run', + 'CompleteWorkflow:AgentWorkflow', + ] + ) From 5c06aaa1fe25c28b47e5260e173f6e3390b69de3 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 22:43:10 +0000 Subject: [PATCH 34/41] Fix API docs crosslinks --- docs/ag-ui.md | 6 +++--- docs/agents.md | 8 ++++---- docs/api/agent.md | 2 ++ docs/message-history.md | 8 ++++---- docs/multi-agent-applications.md | 2 +- docs/output.md | 2 +- docs/toolsets.md | 4 ++-- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 +- pydantic_ai_slim/pydantic_ai/agent.py | 2 +- .../pydantic_ai/ext/temporal/_agent.py | 2 +- pydantic_ai_slim/pydantic_ai/messages.py | 4 ++-- tests/evals/test_llm_as_a_judge.py | 20 +++++++++---------- 12 files changed, 32 insertions(+), 30 deletions(-) diff --git a/docs/ag-ui.md b/docs/ag-ui.md index b19f0baf7..e19bb0fa3 100644 --- a/docs/ag-ui.md +++ b/docs/ag-ui.md @@ -35,7 +35,7 @@ There are three ways to run a Pydantic AI agent based on AG-UI run input with st 1. [`run_ag_ui()`][pydantic_ai.ag_ui.run_ag_ui] takes an agent and an AG-UI [`RunAgentInput`](https://docs.ag-ui.com/sdk/python/core/types#runagentinput) object, and returns a stream of AG-UI events encoded as strings. It also takes optional [`Agent.iter()`][pydantic_ai.Agent.iter] arguments including `deps`. Use this if you're using a web framework not based on Starlette (e.g. Django or Flask) or want to modify the input or output some way. 2. [`handle_ag_ui_request()`][pydantic_ai.ag_ui.handle_ag_ui_request] takes an agent and a Starlette request (e.g. from FastAPI) coming from an AG-UI frontend, and returns a streaming Starlette response of AG-UI events that you can return directly from your endpoint. It also takes optional [`Agent.iter()`][pydantic_ai.Agent.iter] arguments including `deps`, that you can vary for each request (e.g. based on the authenticated user). -3. [`Agent.to_ag_ui()`][pydantic_ai.Agent.to_ag_ui] returns an ASGI application that handles every AG-UI request by running the agent. It also takes optional [`Agent.iter()`][pydantic_ai.Agent.iter] arguments including `deps`, but these will be the same for each request, with the exception of the AG-UI state that's injected as described under [state management](#state-management). This ASGI app can be [mounted](https://fastapi.tiangolo.com/advanced/sub-applications/) at a given path in an existing FastAPI app. +3. [`Agent.to_ag_ui()`][pydantic_ai.AbstractAgent.to_ag_ui] returns an ASGI application that handles every AG-UI request by running the agent. It also takes optional [`Agent.iter()`][pydantic_ai.Agent.iter] arguments including `deps`, but these will be the same for each request, with the exception of the AG-UI state that's injected as described under [state management](#state-management). This ASGI app can be [mounted](https://fastapi.tiangolo.com/advanced/sub-applications/) at a given path in an existing FastAPI app. ### Handle run input and output directly @@ -117,7 +117,7 @@ This will expose the agent as an AG-UI server, and your frontend can start sendi ### Stand-alone ASGI app -This example uses [`Agent.to_ag_ui()`][pydantic_ai.Agent.to_ag_ui] to turn the agent into a stand-alone ASGI application: +This example uses [`Agent.to_ag_ui()`][pydantic_ai.AbstractAgent.to_ag_ui] to turn the agent into a stand-alone ASGI application: ```py {title="agent_to_ag_ui.py" py="3.10" hl_lines="4"} from pydantic_ai import Agent @@ -265,7 +265,7 @@ uvicorn ag_ui_tool_events:app --host 0.0.0.0 --port 9000 ## Examples -For more examples of how to use [`to_ag_ui()`][pydantic_ai.Agent.to_ag_ui] see +For more examples of how to use [`to_ag_ui()`][pydantic_ai.AbstractAgent.to_ag_ui] see [`pydantic_ai_examples.ag_ui`](https://github.com/pydantic/pydantic-ai/tree/main/examples/pydantic_ai_examples/ag_ui), which includes a server for use with the [AG-UI Dojo](https://docs.ag-ui.com/tutorials/debugging#the-ag-ui-dojo). diff --git a/docs/agents.md b/docs/agents.md index 91e2602e3..3612d0269 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -63,9 +63,9 @@ print(result.output) There are four ways to run an agent: -1. [`agent.run()`][pydantic_ai.Agent.run] — a coroutine which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response. -2. [`agent.run_sync()`][pydantic_ai.Agent.run_sync] — a plain, synchronous function which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`). -3. [`agent.run_stream()`][pydantic_ai.Agent.run_stream] — a coroutine which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream a response as an async iterable. +1. [`agent.run()`][pydantic_ai.AbstractAgent.run] — a coroutine which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response. +2. [`agent.run_sync()`][pydantic_ai.AbstractAgent.run_sync] — a plain, synchronous function which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`). +3. [`agent.run_stream()`][pydantic_ai.AbstractAgent.run_stream] — a coroutine which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream a response as an async iterable. 4. [`agent.iter()`][pydantic_ai.Agent.iter] — a context manager which returns an [`AgentRun`][pydantic_ai.agent.AgentRun], an async-iterable over the nodes of the agent's underlying [`Graph`][pydantic_graph.graph.Graph]. Here's a simple example demonstrating the first three: @@ -890,4 +890,4 @@ with capture_run_messages() as messages: # (2)! _(This example is complete, it can be run "as is")_ !!! note - If you call [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] more than once within a single `capture_run_messages` context, `messages` will represent the messages exchanged during the first call only. + If you call [`run`][pydantic_ai.AbstractAgent.run], [`run_sync`][pydantic_ai.AbstractAgent.run_sync], or [`run_stream`][pydantic_ai.AbstractAgent.run_stream] more than once within a single `capture_run_messages` context, `messages` will represent the messages exchanged during the first call only. diff --git a/docs/api/agent.md b/docs/api/agent.md index 2f10bb7a8..bfbfe2a3a 100644 --- a/docs/api/agent.md +++ b/docs/api/agent.md @@ -4,6 +4,8 @@ options: members: - Agent + - AbstractAgent + - WrapperAgent - AgentRun - AgentRunResult - EndStrategy diff --git a/docs/message-history.md b/docs/message-history.md index 3e8cedbd0..3059b9922 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -7,8 +7,8 @@ Pydantic AI provides access to messages exchanged during an agent run. These mes After running an agent, you can access the messages exchanged during that run from the `result` object. Both [`RunResult`][pydantic_ai.agent.AgentRunResult] -(returned by [`Agent.run`][pydantic_ai.Agent.run], [`Agent.run_sync`][pydantic_ai.Agent.run_sync]) -and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] (returned by [`Agent.run_stream`][pydantic_ai.Agent.run_stream]) have the following methods: +(returned by [`Agent.run`][pydantic_ai.AbstractAgent.run], [`Agent.run_sync`][pydantic_ai.AbstractAgent.run_sync]) +and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] (returned by [`Agent.run_stream`][pydantic_ai.AbstractAgent.run_stream]) have the following methods: - [`all_messages()`][pydantic_ai.agent.AgentRunResult.all_messages]: returns all messages, including messages from prior runs. There's also a variant that returns JSON bytes, [`all_messages_json()`][pydantic_ai.agent.AgentRunResult.all_messages_json]. - [`new_messages()`][pydantic_ai.agent.AgentRunResult.new_messages]: returns only the messages from the current run. There's also a variant that returns JSON bytes, [`new_messages_json()`][pydantic_ai.agent.AgentRunResult.new_messages_json]. @@ -141,8 +141,8 @@ _(This example is complete, it can be run "as is" — you'll need to add `asynci The primary use of message histories in Pydantic AI is to maintain context across multiple agent runs. To use existing messages in a run, pass them to the `message_history` parameter of -[`Agent.run`][pydantic_ai.Agent.run], [`Agent.run_sync`][pydantic_ai.Agent.run_sync] or -[`Agent.run_stream`][pydantic_ai.Agent.run_stream]. +[`Agent.run`][pydantic_ai.AbstractAgent.run], [`Agent.run_sync`][pydantic_ai.AbstractAgent.run_sync] or +[`Agent.run_stream`][pydantic_ai.AbstractAgent.run_stream]. If `message_history` is set and not empty, a new system prompt is not generated — we assume the existing message history includes a system prompt. diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index a97cc9813..cbeeb2e80 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -16,7 +16,7 @@ If you want to hand off control to another agent completely, without coming back Since agents are stateless and designed to be global, you do not need to include the agent itself in agent [dependencies](dependencies.md). -You'll generally want to pass [`ctx.usage`][pydantic_ai.RunContext.usage] to the [`usage`][pydantic_ai.Agent.run] keyword argument of the delegate agent run so usage within that run counts towards the total usage of the parent agent run. +You'll generally want to pass [`ctx.usage`][pydantic_ai.RunContext.usage] to the [`usage`][pydantic_ai.AbstractAgent.run] keyword argument of the delegate agent run so usage within that run counts towards the total usage of the parent agent run. !!! note "Multiple models" Agent delegation doesn't need to use the same model for each agent. If you choose to use different models within a run, calculating the monetary cost from the final [`result.usage()`][pydantic_ai.agent.AgentRunResult.usage] of the run will not be possible, but you can still use [`UsageLimits`][pydantic_ai.usage.UsageLimits] to avoid unexpected costs. diff --git a/docs/output.md b/docs/output.md index 7e1eb16ed..96dae51ce 100644 --- a/docs/output.md +++ b/docs/output.md @@ -505,7 +505,7 @@ async def main(): ``` 1. Streaming works with the standard [`Agent`][pydantic_ai.Agent] class, and doesn't require any special setup, just a model that supports streaming (currently all models support streaming). -2. The [`Agent.run_stream()`][pydantic_ai.Agent.run_stream] method is used to start a streamed run, this method returns a context manager so the connection can be closed when the stream completes. +2. The [`Agent.run_stream()`][pydantic_ai.AbstractAgent.run_stream] method is used to start a streamed run, this method returns a context manager so the connection can be closed when the stream completes. 3. Each item yield by [`StreamedRunResult.stream_text()`][pydantic_ai.result.StreamedRunResult.stream_text] is the complete text response, extended as new data is received. _(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ diff --git a/docs/toolsets.md b/docs/toolsets.md index 56e4a843e..ba64eb1e2 100644 --- a/docs/toolsets.md +++ b/docs/toolsets.md @@ -8,7 +8,7 @@ Toolsets are used (among many other things) to define [MCP servers](mcp/client.m The toolsets that will be available during an agent run can be specified in three different ways: * at agent construction time, via the [`toolsets`][pydantic_ai.Agent.__init__] keyword argument to `Agent` -* at agent run time, via the `toolsets` keyword argument to [`agent.run()`][pydantic_ai.Agent.run], [`agent.run_sync()`][pydantic_ai.Agent.run_sync], [`agent.run_stream()`][pydantic_ai.Agent.run_stream], or [`agent.iter()`][pydantic_ai.Agent.iter]. These toolsets will be additional to those provided to the `Agent` constructor +* at agent run time, via the `toolsets` keyword argument to [`agent.run()`][pydantic_ai.AbstractAgent.run], [`agent.run_sync()`][pydantic_ai.AbstractAgent.run_sync], [`agent.run_stream()`][pydantic_ai.AbstractAgent.run_stream], or [`agent.iter()`][pydantic_ai.Agent.iter]. These toolsets will be additional to those provided to the `Agent` constructor * as a contextual override, via the `toolsets` keyword argument to the [`agent.override()`][pydantic_ai.Agent.iter] context manager. These toolsets will replace those provided at agent construction or run time during the life of the context manager ```python {title="toolsets.py"} @@ -457,7 +457,7 @@ When the model calls a deferred tool, the agent run ends with a [`DeferredToolCa To enable an agent to call deferred tools, you create a [`DeferredToolset`][pydantic_ai.toolsets.DeferredToolset], pass it a list of [`ToolDefinition`s][pydantic_ai.tools.ToolDefinition], and provide it to the agent using one of the methods described above. Additionally, you need to add `DeferredToolCalls` to the `Agent`'s [`output_type`](output.md#structured-output) so that the possible types of the agent run output are correctly inferred. Finally, you should handle the possible `DeferredToolCalls` output by passing it to the service that will produce the results. -If your agent can also be used in a context where no deferred tools are available, you will not want to include `DeferredToolCalls` in the `output_type` passed to the `Agent` constructor as you'd have to deal with that type everywhere you use the agent. Instead, you can pass the `toolsets` and `output_type` keyword arguments when you run the agent using [`agent.run()`][pydantic_ai.Agent.run], [`agent.run_sync()`][pydantic_ai.Agent.run_sync], [`agent.run_stream()`][pydantic_ai.Agent.run_stream], or [`agent.iter()`][pydantic_ai.Agent.iter]. Note that while `toolsets` provided at this stage are additional to the toolsets provided to the constructor, the `output_type` overrides the one specified at construction time (for type inference reasons), so you'll need to include the original output types explicitly. +If your agent can also be used in a context where no deferred tools are available, you will not want to include `DeferredToolCalls` in the `output_type` passed to the `Agent` constructor as you'd have to deal with that type everywhere you use the agent. Instead, you can pass the `toolsets` and `output_type` keyword arguments when you run the agent using [`agent.run()`][pydantic_ai.AbstractAgent.run], [`agent.run_sync()`][pydantic_ai.AbstractAgent.run_sync], [`agent.run_stream()`][pydantic_ai.AbstractAgent.run_stream], or [`agent.iter()`][pydantic_ai.Agent.iter]. Note that while `toolsets` provided at this stage are additional to the toolsets provided to the constructor, the `output_type` overrides the one specified at construction time (for type inference reasons), so you'll need to include the original output types explicitly. To demonstrate, let us first define a simple agent _without_ deferred tools: diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index b8184381c..a46114b31 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -786,7 +786,7 @@ class _RunMessages: @contextmanager def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]: - """Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call. + """Context manager to access the messages used in a [`run`][pydantic_ai.AbstractAgent.run], [`run_sync`][pydantic_ai.AbstractAgent.run_sync], or [`run_stream`][pydantic_ai.AbstractAgent.run_stream] call. Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information. diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index a31cf4dab..e0d8da9f5 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -309,7 +309,7 @@ def run_sync( ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. - This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. + This is a convenience method that wraps [`self.run`][pydantic_ai.AbstractAgent.run] with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. Example: diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py index 2ac043310..d7910a1a4 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -270,7 +270,7 @@ def run_sync( ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. - This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. + This is a convenience method that wraps [`self.run`][pydantic_ai.AbstractAgent.run] with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. Example: diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 195c4b791..765d073f9 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -490,8 +490,8 @@ class ToolReturn: class UserPromptPart: """A user prompt, generally written by the end user. - Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.Agent.run], - [`Agent.run_sync`][pydantic_ai.Agent.run_sync], and [`Agent.run_stream`][pydantic_ai.Agent.run_stream]. + Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.AbstractAgent.run], + [`Agent.run_sync`][pydantic_ai.AbstractAgent.run_sync], and [`Agent.run_stream`][pydantic_ai.AbstractAgent.run_stream]. """ content: str | Sequence[UserContent] diff --git a/tests/evals/test_llm_as_a_judge.py b/tests/evals/test_llm_as_a_judge.py index 404c1f81a..8fa270f39 100644 --- a/tests/evals/test_llm_as_a_judge.py +++ b/tests/evals/test_llm_as_a_judge.py @@ -75,7 +75,7 @@ async def test_judge_output_mock(mocker: MockerFixture): # Mock the agent run method mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) # Test with string output grading_output = await judge_output('Hello world', 'Content contains a greeting') @@ -96,7 +96,7 @@ async def test_judge_output_with_model_settings_mock(mocker: MockerFixture): """Test judge_output function with model_settings and mocked agent.""" mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed with settings', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) test_model_settings = ModelSettings(temperature=1) @@ -125,7 +125,7 @@ async def test_judge_input_output_mock(mocker: MockerFixture): # Mock the agent run method mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) # Test with string input and output result = await judge_input_output('Hello', 'Hello world', 'Output contains input') @@ -147,7 +147,7 @@ async def test_judge_input_output_binary_content_list_mock(mocker: MockerFixture # Mock the agent run method mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) result = await judge_input_output([image_content, image_content], 'Hello world', 'Output contains input') assert isinstance(result, GradingOutput) @@ -171,7 +171,7 @@ async def test_judge_input_output_binary_content_mock(mocker: MockerFixture, ima # Mock the agent run method mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) result = await judge_input_output(image_content, 'Hello world', 'Output contains input') assert isinstance(result, GradingOutput) @@ -195,7 +195,7 @@ async def test_judge_input_output_with_model_settings_mock(mocker: MockerFixture """Test judge_input_output function with model_settings and mocked agent.""" mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed with settings', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) test_model_settings = ModelSettings(temperature=1) @@ -226,7 +226,7 @@ async def test_judge_input_output_expected_mock(mocker: MockerFixture, image_con # Mock the agent run method mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) # Test with string input and output result = await judge_input_output_expected('Hello', 'Hello world', 'Hello', 'Output contains input') @@ -262,7 +262,7 @@ async def test_judge_input_output_expected_with_model_settings_mock( """Test judge_input_output_expected function with model_settings and mocked agent.""" mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed with settings', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) test_model_settings = ModelSettings(temperature=1) @@ -396,7 +396,7 @@ async def test_judge_output_expected_mock(mocker: MockerFixture): # Mock the agent run method mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) # Test with string output and expected output result = await judge_output_expected('Hello world', 'Hello', 'Output contains input') @@ -418,7 +418,7 @@ async def test_judge_output_expected_with_model_settings_mock(mocker: MockerFixt """Test judge_output_expected function with model_settings and mocked agent.""" mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed with settings', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) test_model_settings = ModelSettings(temperature=1) From 38baf2206badaf3e3f86c57f49389031c80425ab Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 22:46:32 +0000 Subject: [PATCH 35/41] Don't import logfire unless available --- tests/test_temporal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 2c2e16d5f..9b18f79d1 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from datetime import timedelta -import logfire from inline_snapshot import snapshot from typing_extensions import TypedDict @@ -34,6 +33,7 @@ pytest.skip('temporal not installed') try: + import logfire from logfire.testing import CaptureLogfire except ImportError: import pytest From f16fbbea2eca8b2cf2b99e4d9a21ad51dcca6c7c Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 23:00:04 +0000 Subject: [PATCH 36/41] Make FunctionToolsetTool public so the Temporal integration doesn't use anything private anymore --- .../pydantic_ai/ext/temporal/_function_toolset.py | 6 +++--- pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py | 2 +- pydantic_ai_slim/pydantic_ai/toolsets/function.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py index e46477541..d0965ab7d 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_function_toolset.py @@ -10,7 +10,7 @@ from pydantic_ai._run_context import RunContext from pydantic_ai.exceptions import UserError from pydantic_ai.toolsets import FunctionToolset, ToolsetTool -from pydantic_ai.toolsets.function import _FunctionToolsetTool # pyright: ignore[reportPrivateUsage] +from pydantic_ai.toolsets.function import FunctionToolsetTool from ._run_context import TemporalRunContext from ._toolset import TemporalWrapperToolset @@ -71,10 +71,10 @@ async def call_tool(self, name: str, tool_args: dict[str, Any], ctx: RunContext, tool_activity_config = self.tool_activity_config.get(name, {}) if tool_activity_config is False: - assert isinstance(tool, _FunctionToolsetTool) + assert isinstance(tool, FunctionToolsetTool) if not tool.is_async: raise UserError( - 'Disabling running a non-async tool in a Temporal activity is not possible. Make the tool function async instead.' + f'Temporal activity config for non-async tool {name!r} is `False` (activity disabled), but only async tools can be run outside of an activity. Make the tool function async instead.' ) return await super().call_tool(name, tool_args, ctx, tool) diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py index 2cc88e377..1d0c156ee 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_mcp_server.py @@ -113,7 +113,7 @@ async def call_tool( tool_activity_config = self.tool_activity_config.get(name, {}) if tool_activity_config is False: raise UserError( - f'Temporal activity config for tool {name!r} is `False`, but MCP tools cannot be run out side of an activity.' + f'Temporal activity config for MCP tool {name!r} is `False` (activity disabled), but MCP tools cannot be run outside of an activity.' ) tool_activity_config = self.activity_config | tool_activity_config diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py index 1b94c57eb..f756ab276 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/function.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -20,7 +20,7 @@ @dataclass -class _FunctionToolsetTool(ToolsetTool[AgentDepsT]): +class FunctionToolsetTool(ToolsetTool[AgentDepsT]): """A tool definition for a function toolset tool that keeps track of the function to call.""" call_func: Callable[[dict[str, Any], RunContext[AgentDepsT]], Awaitable[Any]] @@ -236,7 +236,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[ else: raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.') - tools[new_name] = _FunctionToolsetTool( + tools[new_name] = FunctionToolsetTool( toolset=self, tool_def=tool_def, max_retries=tool.max_retries if tool.max_retries is not None else self.max_retries, @@ -249,5 +249,5 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[ async def call_tool( self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] ) -> Any: - assert isinstance(tool, _FunctionToolsetTool) + assert isinstance(tool, FunctionToolsetTool) return await tool.call_func(tool_args, ctx) From e69229bdf0eadb18476b1a9bc81b038e95221951 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 23:03:12 +0000 Subject: [PATCH 37/41] Guard against mcp not being installed --- tests/test_temporal.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 9b18f79d1..f6d5ed9ee 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -9,9 +9,8 @@ from pydantic_ai import Agent, RunContext from pydantic_ai.agent import AbstractAgent -from pydantic_ai.mcp import MCPServerStdio from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent -from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.toolsets import FunctionToolset try: from temporalio import workflow @@ -40,6 +39,13 @@ pytest.skip('logfire not installed') +try: + from pydantic_ai.mcp import MCPServerStdio +except ImportError: + import pytest + + pytest.skip('mcp not installed') + with workflow.unsafe.imports_passed_through(): # Workaround for a race condition when running `logfire.info` inside an activity with attributes to serialize and pandas importable: # AttributeError: partially initialized module 'pandas' has no attribute '_pandas_parser_CAPI' (most likely due to a circular import) From c7f4b00d921fb0acc8ec244c4bd0c5ca683e9d82 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 23:05:44 +0000 Subject: [PATCH 38/41] Allow model level pytest skip --- tests/test_temporal.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index f6d5ed9ee..88e0b73e8 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -29,7 +29,7 @@ except ImportError: import pytest - pytest.skip('temporal not installed') + pytest.skip('temporal not installed', allow_module_level=True) try: import logfire @@ -37,14 +37,14 @@ except ImportError: import pytest - pytest.skip('logfire not installed') + pytest.skip('logfire not installed', allow_module_level=True) try: from pydantic_ai.mcp import MCPServerStdio except ImportError: import pytest - pytest.skip('mcp not installed') + pytest.skip('mcp not installed', allow_module_level=True) with workflow.unsafe.imports_passed_through(): # Workaround for a race condition when running `logfire.info` inside an activity with attributes to serialize and pandas importable: From d93d57e49781f18f714e15954e535d4e9d54ff8c Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 23:11:14 +0000 Subject: [PATCH 39/41] Fix pydantic_ai.AbstractAgent references to be pydantic_ai.agent.AbstractAgent --- docs/ag-ui.md | 6 +++--- docs/agents.md | 8 ++++---- docs/message-history.md | 8 ++++---- docs/multi-agent-applications.md | 2 +- docs/output.md | 2 +- docs/toolsets.md | 4 ++-- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 +- pydantic_ai_slim/pydantic_ai/agent.py | 2 +- .../pydantic_ai/ext/temporal/_agent.py | 2 +- pydantic_ai_slim/pydantic_ai/messages.py | 4 ++-- tests/evals/test_llm_as_a_judge.py | 20 +++++++++---------- 11 files changed, 30 insertions(+), 30 deletions(-) diff --git a/docs/ag-ui.md b/docs/ag-ui.md index e19bb0fa3..7259c4fe3 100644 --- a/docs/ag-ui.md +++ b/docs/ag-ui.md @@ -35,7 +35,7 @@ There are three ways to run a Pydantic AI agent based on AG-UI run input with st 1. [`run_ag_ui()`][pydantic_ai.ag_ui.run_ag_ui] takes an agent and an AG-UI [`RunAgentInput`](https://docs.ag-ui.com/sdk/python/core/types#runagentinput) object, and returns a stream of AG-UI events encoded as strings. It also takes optional [`Agent.iter()`][pydantic_ai.Agent.iter] arguments including `deps`. Use this if you're using a web framework not based on Starlette (e.g. Django or Flask) or want to modify the input or output some way. 2. [`handle_ag_ui_request()`][pydantic_ai.ag_ui.handle_ag_ui_request] takes an agent and a Starlette request (e.g. from FastAPI) coming from an AG-UI frontend, and returns a streaming Starlette response of AG-UI events that you can return directly from your endpoint. It also takes optional [`Agent.iter()`][pydantic_ai.Agent.iter] arguments including `deps`, that you can vary for each request (e.g. based on the authenticated user). -3. [`Agent.to_ag_ui()`][pydantic_ai.AbstractAgent.to_ag_ui] returns an ASGI application that handles every AG-UI request by running the agent. It also takes optional [`Agent.iter()`][pydantic_ai.Agent.iter] arguments including `deps`, but these will be the same for each request, with the exception of the AG-UI state that's injected as described under [state management](#state-management). This ASGI app can be [mounted](https://fastapi.tiangolo.com/advanced/sub-applications/) at a given path in an existing FastAPI app. +3. [`Agent.to_ag_ui()`][pydantic_ai.agent.AbstractAgent.to_ag_ui] returns an ASGI application that handles every AG-UI request by running the agent. It also takes optional [`Agent.iter()`][pydantic_ai.Agent.iter] arguments including `deps`, but these will be the same for each request, with the exception of the AG-UI state that's injected as described under [state management](#state-management). This ASGI app can be [mounted](https://fastapi.tiangolo.com/advanced/sub-applications/) at a given path in an existing FastAPI app. ### Handle run input and output directly @@ -117,7 +117,7 @@ This will expose the agent as an AG-UI server, and your frontend can start sendi ### Stand-alone ASGI app -This example uses [`Agent.to_ag_ui()`][pydantic_ai.AbstractAgent.to_ag_ui] to turn the agent into a stand-alone ASGI application: +This example uses [`Agent.to_ag_ui()`][pydantic_ai.agent.AbstractAgent.to_ag_ui] to turn the agent into a stand-alone ASGI application: ```py {title="agent_to_ag_ui.py" py="3.10" hl_lines="4"} from pydantic_ai import Agent @@ -265,7 +265,7 @@ uvicorn ag_ui_tool_events:app --host 0.0.0.0 --port 9000 ## Examples -For more examples of how to use [`to_ag_ui()`][pydantic_ai.AbstractAgent.to_ag_ui] see +For more examples of how to use [`to_ag_ui()`][pydantic_ai.agent.AbstractAgent.to_ag_ui] see [`pydantic_ai_examples.ag_ui`](https://github.com/pydantic/pydantic-ai/tree/main/examples/pydantic_ai_examples/ag_ui), which includes a server for use with the [AG-UI Dojo](https://docs.ag-ui.com/tutorials/debugging#the-ag-ui-dojo). diff --git a/docs/agents.md b/docs/agents.md index 3612d0269..b890adbef 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -63,9 +63,9 @@ print(result.output) There are four ways to run an agent: -1. [`agent.run()`][pydantic_ai.AbstractAgent.run] — a coroutine which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response. -2. [`agent.run_sync()`][pydantic_ai.AbstractAgent.run_sync] — a plain, synchronous function which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`). -3. [`agent.run_stream()`][pydantic_ai.AbstractAgent.run_stream] — a coroutine which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream a response as an async iterable. +1. [`agent.run()`][pydantic_ai.agent.AbstractAgent.run] — a coroutine which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response. +2. [`agent.run_sync()`][pydantic_ai.agent.AbstractAgent.run_sync] — a plain, synchronous function which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`). +3. [`agent.run_stream()`][pydantic_ai.agent.AbstractAgent.run_stream] — a coroutine which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream a response as an async iterable. 4. [`agent.iter()`][pydantic_ai.Agent.iter] — a context manager which returns an [`AgentRun`][pydantic_ai.agent.AgentRun], an async-iterable over the nodes of the agent's underlying [`Graph`][pydantic_graph.graph.Graph]. Here's a simple example demonstrating the first three: @@ -890,4 +890,4 @@ with capture_run_messages() as messages: # (2)! _(This example is complete, it can be run "as is")_ !!! note - If you call [`run`][pydantic_ai.AbstractAgent.run], [`run_sync`][pydantic_ai.AbstractAgent.run_sync], or [`run_stream`][pydantic_ai.AbstractAgent.run_stream] more than once within a single `capture_run_messages` context, `messages` will represent the messages exchanged during the first call only. + If you call [`run`][pydantic_ai.agent.AbstractAgent.run], [`run_sync`][pydantic_ai.agent.AbstractAgent.run_sync], or [`run_stream`][pydantic_ai.agent.AbstractAgent.run_stream] more than once within a single `capture_run_messages` context, `messages` will represent the messages exchanged during the first call only. diff --git a/docs/message-history.md b/docs/message-history.md index 3059b9922..f82263ac4 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -7,8 +7,8 @@ Pydantic AI provides access to messages exchanged during an agent run. These mes After running an agent, you can access the messages exchanged during that run from the `result` object. Both [`RunResult`][pydantic_ai.agent.AgentRunResult] -(returned by [`Agent.run`][pydantic_ai.AbstractAgent.run], [`Agent.run_sync`][pydantic_ai.AbstractAgent.run_sync]) -and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] (returned by [`Agent.run_stream`][pydantic_ai.AbstractAgent.run_stream]) have the following methods: +(returned by [`Agent.run`][pydantic_ai.agent.AbstractAgent.run], [`Agent.run_sync`][pydantic_ai.agent.AbstractAgent.run_sync]) +and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] (returned by [`Agent.run_stream`][pydantic_ai.agent.AbstractAgent.run_stream]) have the following methods: - [`all_messages()`][pydantic_ai.agent.AgentRunResult.all_messages]: returns all messages, including messages from prior runs. There's also a variant that returns JSON bytes, [`all_messages_json()`][pydantic_ai.agent.AgentRunResult.all_messages_json]. - [`new_messages()`][pydantic_ai.agent.AgentRunResult.new_messages]: returns only the messages from the current run. There's also a variant that returns JSON bytes, [`new_messages_json()`][pydantic_ai.agent.AgentRunResult.new_messages_json]. @@ -141,8 +141,8 @@ _(This example is complete, it can be run "as is" — you'll need to add `asynci The primary use of message histories in Pydantic AI is to maintain context across multiple agent runs. To use existing messages in a run, pass them to the `message_history` parameter of -[`Agent.run`][pydantic_ai.AbstractAgent.run], [`Agent.run_sync`][pydantic_ai.AbstractAgent.run_sync] or -[`Agent.run_stream`][pydantic_ai.AbstractAgent.run_stream]. +[`Agent.run`][pydantic_ai.agent.AbstractAgent.run], [`Agent.run_sync`][pydantic_ai.agent.AbstractAgent.run_sync] or +[`Agent.run_stream`][pydantic_ai.agent.AbstractAgent.run_stream]. If `message_history` is set and not empty, a new system prompt is not generated — we assume the existing message history includes a system prompt. diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index cbeeb2e80..a531f87b8 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -16,7 +16,7 @@ If you want to hand off control to another agent completely, without coming back Since agents are stateless and designed to be global, you do not need to include the agent itself in agent [dependencies](dependencies.md). -You'll generally want to pass [`ctx.usage`][pydantic_ai.RunContext.usage] to the [`usage`][pydantic_ai.AbstractAgent.run] keyword argument of the delegate agent run so usage within that run counts towards the total usage of the parent agent run. +You'll generally want to pass [`ctx.usage`][pydantic_ai.RunContext.usage] to the [`usage`][pydantic_ai.agent.AbstractAgent.run] keyword argument of the delegate agent run so usage within that run counts towards the total usage of the parent agent run. !!! note "Multiple models" Agent delegation doesn't need to use the same model for each agent. If you choose to use different models within a run, calculating the monetary cost from the final [`result.usage()`][pydantic_ai.agent.AgentRunResult.usage] of the run will not be possible, but you can still use [`UsageLimits`][pydantic_ai.usage.UsageLimits] to avoid unexpected costs. diff --git a/docs/output.md b/docs/output.md index 96dae51ce..3a9f221be 100644 --- a/docs/output.md +++ b/docs/output.md @@ -505,7 +505,7 @@ async def main(): ``` 1. Streaming works with the standard [`Agent`][pydantic_ai.Agent] class, and doesn't require any special setup, just a model that supports streaming (currently all models support streaming). -2. The [`Agent.run_stream()`][pydantic_ai.AbstractAgent.run_stream] method is used to start a streamed run, this method returns a context manager so the connection can be closed when the stream completes. +2. The [`Agent.run_stream()`][pydantic_ai.agent.AbstractAgent.run_stream] method is used to start a streamed run, this method returns a context manager so the connection can be closed when the stream completes. 3. Each item yield by [`StreamedRunResult.stream_text()`][pydantic_ai.result.StreamedRunResult.stream_text] is the complete text response, extended as new data is received. _(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ diff --git a/docs/toolsets.md b/docs/toolsets.md index ba64eb1e2..6a0bfc362 100644 --- a/docs/toolsets.md +++ b/docs/toolsets.md @@ -8,7 +8,7 @@ Toolsets are used (among many other things) to define [MCP servers](mcp/client.m The toolsets that will be available during an agent run can be specified in three different ways: * at agent construction time, via the [`toolsets`][pydantic_ai.Agent.__init__] keyword argument to `Agent` -* at agent run time, via the `toolsets` keyword argument to [`agent.run()`][pydantic_ai.AbstractAgent.run], [`agent.run_sync()`][pydantic_ai.AbstractAgent.run_sync], [`agent.run_stream()`][pydantic_ai.AbstractAgent.run_stream], or [`agent.iter()`][pydantic_ai.Agent.iter]. These toolsets will be additional to those provided to the `Agent` constructor +* at agent run time, via the `toolsets` keyword argument to [`agent.run()`][pydantic_ai.agent.AbstractAgent.run], [`agent.run_sync()`][pydantic_ai.agent.AbstractAgent.run_sync], [`agent.run_stream()`][pydantic_ai.agent.AbstractAgent.run_stream], or [`agent.iter()`][pydantic_ai.Agent.iter]. These toolsets will be additional to those provided to the `Agent` constructor * as a contextual override, via the `toolsets` keyword argument to the [`agent.override()`][pydantic_ai.Agent.iter] context manager. These toolsets will replace those provided at agent construction or run time during the life of the context manager ```python {title="toolsets.py"} @@ -457,7 +457,7 @@ When the model calls a deferred tool, the agent run ends with a [`DeferredToolCa To enable an agent to call deferred tools, you create a [`DeferredToolset`][pydantic_ai.toolsets.DeferredToolset], pass it a list of [`ToolDefinition`s][pydantic_ai.tools.ToolDefinition], and provide it to the agent using one of the methods described above. Additionally, you need to add `DeferredToolCalls` to the `Agent`'s [`output_type`](output.md#structured-output) so that the possible types of the agent run output are correctly inferred. Finally, you should handle the possible `DeferredToolCalls` output by passing it to the service that will produce the results. -If your agent can also be used in a context where no deferred tools are available, you will not want to include `DeferredToolCalls` in the `output_type` passed to the `Agent` constructor as you'd have to deal with that type everywhere you use the agent. Instead, you can pass the `toolsets` and `output_type` keyword arguments when you run the agent using [`agent.run()`][pydantic_ai.AbstractAgent.run], [`agent.run_sync()`][pydantic_ai.AbstractAgent.run_sync], [`agent.run_stream()`][pydantic_ai.AbstractAgent.run_stream], or [`agent.iter()`][pydantic_ai.Agent.iter]. Note that while `toolsets` provided at this stage are additional to the toolsets provided to the constructor, the `output_type` overrides the one specified at construction time (for type inference reasons), so you'll need to include the original output types explicitly. +If your agent can also be used in a context where no deferred tools are available, you will not want to include `DeferredToolCalls` in the `output_type` passed to the `Agent` constructor as you'd have to deal with that type everywhere you use the agent. Instead, you can pass the `toolsets` and `output_type` keyword arguments when you run the agent using [`agent.run()`][pydantic_ai.agent.AbstractAgent.run], [`agent.run_sync()`][pydantic_ai.agent.AbstractAgent.run_sync], [`agent.run_stream()`][pydantic_ai.agent.AbstractAgent.run_stream], or [`agent.iter()`][pydantic_ai.Agent.iter]. Note that while `toolsets` provided at this stage are additional to the toolsets provided to the constructor, the `output_type` overrides the one specified at construction time (for type inference reasons), so you'll need to include the original output types explicitly. To demonstrate, let us first define a simple agent _without_ deferred tools: diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index a46114b31..95511c330 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -786,7 +786,7 @@ class _RunMessages: @contextmanager def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]: - """Context manager to access the messages used in a [`run`][pydantic_ai.AbstractAgent.run], [`run_sync`][pydantic_ai.AbstractAgent.run_sync], or [`run_stream`][pydantic_ai.AbstractAgent.run_stream] call. + """Context manager to access the messages used in a [`run`][pydantic_ai.agent.AbstractAgent.run], [`run_sync`][pydantic_ai.agent.AbstractAgent.run_sync], or [`run_stream`][pydantic_ai.agent.AbstractAgent.run_stream] call. Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information. diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index e0d8da9f5..4bbdb65e3 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -309,7 +309,7 @@ def run_sync( ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. - This is a convenience method that wraps [`self.run`][pydantic_ai.AbstractAgent.run] with `loop.run_until_complete(...)`. + This is a convenience method that wraps [`self.run`][pydantic_ai.agent.AbstractAgent.run] with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. Example: diff --git a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py index d7910a1a4..0438c28d8 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/ext/temporal/_agent.py @@ -270,7 +270,7 @@ def run_sync( ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. - This is a convenience method that wraps [`self.run`][pydantic_ai.AbstractAgent.run] with `loop.run_until_complete(...)`. + This is a convenience method that wraps [`self.run`][pydantic_ai.agent.AbstractAgent.run] with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. Example: diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 765d073f9..ee889bff0 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -490,8 +490,8 @@ class ToolReturn: class UserPromptPart: """A user prompt, generally written by the end user. - Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.AbstractAgent.run], - [`Agent.run_sync`][pydantic_ai.AbstractAgent.run_sync], and [`Agent.run_stream`][pydantic_ai.AbstractAgent.run_stream]. + Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.agent.AbstractAgent.run], + [`Agent.run_sync`][pydantic_ai.agent.AbstractAgent.run_sync], and [`Agent.run_stream`][pydantic_ai.agent.AbstractAgent.run_stream]. """ content: str | Sequence[UserContent] diff --git a/tests/evals/test_llm_as_a_judge.py b/tests/evals/test_llm_as_a_judge.py index 8fa270f39..2fb8c51a8 100644 --- a/tests/evals/test_llm_as_a_judge.py +++ b/tests/evals/test_llm_as_a_judge.py @@ -75,7 +75,7 @@ async def test_judge_output_mock(mocker: MockerFixture): # Mock the agent run method mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.agent.AbstractAgent.run', return_value=mock_result) # Test with string output grading_output = await judge_output('Hello world', 'Content contains a greeting') @@ -96,7 +96,7 @@ async def test_judge_output_with_model_settings_mock(mocker: MockerFixture): """Test judge_output function with model_settings and mocked agent.""" mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed with settings', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.agent.AbstractAgent.run', return_value=mock_result) test_model_settings = ModelSettings(temperature=1) @@ -125,7 +125,7 @@ async def test_judge_input_output_mock(mocker: MockerFixture): # Mock the agent run method mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.agent.AbstractAgent.run', return_value=mock_result) # Test with string input and output result = await judge_input_output('Hello', 'Hello world', 'Output contains input') @@ -147,7 +147,7 @@ async def test_judge_input_output_binary_content_list_mock(mocker: MockerFixture # Mock the agent run method mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.agent.AbstractAgent.run', return_value=mock_result) result = await judge_input_output([image_content, image_content], 'Hello world', 'Output contains input') assert isinstance(result, GradingOutput) @@ -171,7 +171,7 @@ async def test_judge_input_output_binary_content_mock(mocker: MockerFixture, ima # Mock the agent run method mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.agent.AbstractAgent.run', return_value=mock_result) result = await judge_input_output(image_content, 'Hello world', 'Output contains input') assert isinstance(result, GradingOutput) @@ -195,7 +195,7 @@ async def test_judge_input_output_with_model_settings_mock(mocker: MockerFixture """Test judge_input_output function with model_settings and mocked agent.""" mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed with settings', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.agent.AbstractAgent.run', return_value=mock_result) test_model_settings = ModelSettings(temperature=1) @@ -226,7 +226,7 @@ async def test_judge_input_output_expected_mock(mocker: MockerFixture, image_con # Mock the agent run method mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.agent.AbstractAgent.run', return_value=mock_result) # Test with string input and output result = await judge_input_output_expected('Hello', 'Hello world', 'Hello', 'Output contains input') @@ -262,7 +262,7 @@ async def test_judge_input_output_expected_with_model_settings_mock( """Test judge_input_output_expected function with model_settings and mocked agent.""" mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed with settings', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.agent.AbstractAgent.run', return_value=mock_result) test_model_settings = ModelSettings(temperature=1) @@ -396,7 +396,7 @@ async def test_judge_output_expected_mock(mocker: MockerFixture): # Mock the agent run method mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.agent.AbstractAgent.run', return_value=mock_result) # Test with string output and expected output result = await judge_output_expected('Hello world', 'Hello', 'Output contains input') @@ -418,7 +418,7 @@ async def test_judge_output_expected_with_model_settings_mock(mocker: MockerFixt """Test judge_output_expected function with model_settings and mocked agent.""" mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed with settings', pass_=True, score=1.0) - mock_run = mocker.patch('pydantic_ai.AbstractAgent.run', return_value=mock_result) + mock_run = mocker.patch('pydantic_ai.agent.AbstractAgent.run', return_value=mock_result) test_model_settings = ModelSettings(temperature=1) From dc9bcbb3e528303c7954be7e2ad9d9ce96f81230 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 23:19:39 +0000 Subject: [PATCH 40/41] Use mock OpenAI API key in CI --- tests/test_temporal.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 88e0b73e8..6b6f02301 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from collections.abc import AsyncIterable, AsyncIterator from dataclasses import dataclass from datetime import timedelta @@ -10,6 +11,8 @@ from pydantic_ai import Agent, RunContext from pydantic_ai.agent import AbstractAgent from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent +from pydantic_ai.models.openai import OpenAIModel +from pydantic_ai.providers.openai import OpenAIProvider from pydantic_ai.toolsets import FunctionToolset try: @@ -113,7 +116,8 @@ class Response: agent = Agent( - 'openai:gpt-4o', + # Can't use the `openai_api_key` fixture here because the workflow needs to be defined at the top level of the file. + OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=os.getenv('OPENAI_API_KEY', 'mock-api-key'))), deps_type=Deps, output_type=Response, toolsets=[ From 99781e3ef1c8efc9b61c1e3b7b271e7ac8361466 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 5 Aug 2025 23:30:23 +0000 Subject: [PATCH 41/41] Guard against openai not being installed --- tests/test_temporal.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 6b6f02301..9beb85025 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -11,8 +11,6 @@ from pydantic_ai import Agent, RunContext from pydantic_ai.agent import AbstractAgent from pydantic_ai.messages import AgentStreamEvent, HandleResponseEvent -from pydantic_ai.models.openai import OpenAIModel -from pydantic_ai.providers.openai import OpenAIProvider from pydantic_ai.toolsets import FunctionToolset try: @@ -49,6 +47,15 @@ pytest.skip('mcp not installed', allow_module_level=True) +try: + from pydantic_ai.models.openai import OpenAIModel + from pydantic_ai.providers.openai import OpenAIProvider +except ImportError: + import pytest + + pytest.skip('openai not installed', allow_module_level=True) + + with workflow.unsafe.imports_passed_through(): # Workaround for a race condition when running `logfire.info` inside an activity with attributes to serialize and pandas importable: # AttributeError: partially initialized module 'pandas' has no attribute '_pandas_parser_CAPI' (most likely due to a circular import)