diff --git a/docs/durable_execution/temporal.md b/docs/durable_execution/temporal.md index f52626caff..bb1450c587 100644 --- a/docs/durable_execution/temporal.md +++ b/docs/durable_execution/temporal.md @@ -86,8 +86,8 @@ from temporalio.worker import Worker from pydantic_ai import Agent from pydantic_ai.durable_exec.temporal import ( - AgentPlugin, PydanticAIPlugin, + PydanticAIWorkflow, TemporalAgent, ) @@ -101,26 +101,27 @@ temporal_agent = TemporalAgent(agent) # (1)! @workflow.defn -class GeographyWorkflow: # (2)! +class GeographyWorkflow(PydanticAIWorkflow): # (2)! + __pydantic_ai_agents__ = [temporal_agent] # (3)! + @workflow.run async def run(self, prompt: str) -> str: - result = await temporal_agent.run(prompt) # (3)! + result = await temporal_agent.run(prompt) # (4)! return result.output async def main(): - client = await Client.connect( # (4)! - 'localhost:7233', # (5)! - plugins=[PydanticAIPlugin()], # (6)! + client = await Client.connect( # (5)! + 'localhost:7233', # (6)! + plugins=[PydanticAIPlugin()], # (7)! ) - async with Worker( # (7)! + async with Worker( # (8)! client, task_queue='geography', workflows=[GeographyWorkflow], - plugins=[AgentPlugin(temporal_agent)], # (8)! ): - output = await client.execute_workflow( # (9)! + output = await client.execute_workflow( # (10)! GeographyWorkflow.run, args=['What is the capital of Mexico?'], id=f'geography-{uuid.uuid4()}', @@ -131,15 +132,15 @@ async def main(): ``` 1. The original `Agent` cannot be used inside a deterministic Temporal workflow, but the `TemporalAgent` can. -2. As explained above, the workflow represents a deterministic piece of code that can use non-deterministic activities for operations that require I/O. -3. [`TemporalAgent.run()`][pydantic_ai.durable_exec.temporal.TemporalAgent.run] works just like [`Agent.run()`][pydantic_ai.Agent.run], but it will automatically offload model requests, tool calls, and MCP server communication to Temporal activities. -4. We connect to the Temporal server which keeps track of workflow and activity execution. -5. This assumes the Temporal server is [running locally](https://github.com/temporalio/temporal#download-and-start-temporal-server-locally). -6. The [`PydanticAIPlugin`][pydantic_ai.durable_exec.temporal.PydanticAIPlugin] tells Temporal to use Pydantic for serialization and deserialization, and to treat [`UserError`][pydantic_ai.exceptions.UserError] exceptions as non-retryable. -7. We start the worker that will listen on the specified task queue and run workflows and activities. In a real world application, this might be run in a separate service. -8. The [`AgentPlugin`][pydantic_ai.durable_exec.temporal.AgentPlugin] registers the `TemporalAgent`'s activities with the worker. -9. We call on the server to execute the workflow on a worker that's listening on the specified task queue. -10. The agent's `name` is used to uniquely identify its activities. +2. As explained above, the workflow represents a deterministic piece of code that can use non-deterministic activities for operations that require I/O. Subclassing [`PydanticAIWorkflow`][pydantic_ai.durable_exec.temporal.PydanticAIWorkflow] is optional but provides proper typing for the `__pydantic_ai_agents__` class variable. +3. List the `TemporalAgent`s used by this workflow. The [`PydanticAIPlugin`][pydantic_ai.durable_exec.temporal.PydanticAIPlugin] will automatically register their activities with the worker. Alternatively, if modifying the worker initialization is easier than the workflow class, you can use [`AgentPlugin`][pydantic_ai.durable_exec.temporal.AgentPlugin] to register agents directly on the worker. +4. [`TemporalAgent.run()`][pydantic_ai.durable_exec.temporal.TemporalAgent.run] works just like [`Agent.run()`][pydantic_ai.Agent.run], but it will automatically offload model requests, tool calls, and MCP server communication to Temporal activities. +5. We connect to the Temporal server which keeps track of workflow and activity execution. +6. This assumes the Temporal server is [running locally](https://github.com/temporalio/temporal#download-and-start-temporal-server-locally). +7. The [`PydanticAIPlugin`][pydantic_ai.durable_exec.temporal.PydanticAIPlugin] tells Temporal to use Pydantic for serialization and deserialization, treats [`UserError`][pydantic_ai.exceptions.UserError] exceptions as non-retryable, and automatically registers activities for agents listed in `__pydantic_ai_agents__`. +8. We start the worker that will listen on the specified task queue and run workflows and activities. In a real world application, this might be run in a separate service. +9. The agent's `name` is used to uniquely identify its activities. +10. We call on the server to execute the workflow on a worker that's listening on the specified task queue. _(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py index 48e4489ab9..dc27bd9409 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from collections.abc import Sequence from dataclasses import replace from typing import Any @@ -8,7 +9,7 @@ from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter from temporalio.converter import DataConverter, DefaultPayloadConverter from temporalio.plugin import SimplePlugin -from temporalio.worker import WorkflowRunner +from temporalio.worker import WorkerConfig, WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner from ...exceptions import UserError @@ -16,6 +17,7 @@ from ._logfire import LogfirePlugin from ._run_context import TemporalRunContext from ._toolset import TemporalWrapperToolset +from ._workflow import PydanticAIWorkflow __all__ = [ 'TemporalAgent', @@ -24,6 +26,7 @@ 'AgentPlugin', 'TemporalRunContext', 'TemporalWrapperToolset', + 'PydanticAIWorkflow', ] # We need eagerly import the anyio backends or it will happens inside workflow code and temporal has issues @@ -91,6 +94,31 @@ def __init__(self): workflow_failure_exception_types=[UserError, PydanticUserError], ) + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + config = super().configure_worker(config) + + workflows = list(config.get('workflows', [])) # type: ignore[reportUnknownMemberType] + activities = list(config.get('activities', [])) # type: ignore[reportUnknownMemberType] + + for workflow_class in workflows: # type: ignore[reportUnknownMemberType] + agents = getattr(workflow_class, '__pydantic_ai_agents__', None) # type: ignore[reportUnknownMemberType] + if agents is None: + continue + if not isinstance(agents, Sequence): + raise TypeError( # pragma: no cover + f'__pydantic_ai_agents__ must be a Sequence of TemporalAgent instances, got {type(agents)}' + ) + for agent in agents: # type: ignore[reportUnknownVariableType] + if not isinstance(agent, TemporalAgent): + raise TypeError( # pragma: no cover + f'__pydantic_ai_agents__ must be a Sequence of TemporalAgent, got {type(agent)}' # type: ignore[reportUnknownVariableType] + ) + activities.extend(agent.temporal_activities) # type: ignore[reportUnknownMemberType] + + config['activities'] = activities + + return config + class AgentPlugin(SimplePlugin): """Temporal worker plugin for a specific Pydantic AI agent.""" diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_workflow.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_workflow.py new file mode 100644 index 0000000000..fb07de4fb8 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_workflow.py @@ -0,0 +1,10 @@ +from collections.abc import Sequence +from typing import Any + +from pydantic_ai.durable_exec.temporal import TemporalAgent + + +class PydanticAIWorkflow: + """Temporal Workflow base class that provides `__pydantic_ai_agents__` for direct agent registration.""" + + __pydantic_ai_agents__: Sequence[TemporalAgent[Any, Any]] diff --git a/tests/cassettes/test_temporal/test_passing_agents_through_workflow.yaml b/tests/cassettes/test_temporal/test_passing_agents_through_workflow.yaml new file mode 100644 index 0000000000..9aa924f5f5 --- /dev/null +++ b/tests/cassettes/test_temporal/test_passing_agents_through_workflow.yaml @@ -0,0 +1,79 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '105' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the capital of Mexico? + role: user + model: gpt-4o + stream: false + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '838' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '403' + openai-project: + - proj_dKobscVY9YJxeEaDJen54e3d + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: The capital of Mexico is Mexico City. + refusal: null + role: assistant + created: 1754675179 + id: chatcmpl-C2LSVwAtcuMjKCHykKXgKphwTaQVB + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_ff25b2783a + usage: + completion_tokens: 8 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 14 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 22 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/cassettes/test_temporal/test_passing_agents_through_workflow_without_pydantic_ai_workflow.yaml b/tests/cassettes/test_temporal/test_passing_agents_through_workflow_without_pydantic_ai_workflow.yaml new file mode 100644 index 0000000000..9aa924f5f5 --- /dev/null +++ b/tests/cassettes/test_temporal/test_passing_agents_through_workflow_without_pydantic_ai_workflow.yaml @@ -0,0 +1,79 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '105' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the capital of Mexico? + role: user + model: gpt-4o + stream: false + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '838' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '403' + openai-project: + - proj_dKobscVY9YJxeEaDJen54e3d + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: The capital of Mexico is Mexico City. + refusal: null + role: assistant + created: 1754675179 + id: chatcmpl-C2LSVwAtcuMjKCHykKXgKphwTaQVB + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_ff25b2783a + usage: + completion_tokens: 8 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 14 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 22 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 98039d9078..13dad4ed9f 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -63,7 +63,13 @@ from temporalio.worker import Worker from temporalio.workflow import ActivityConfig - from pydantic_ai.durable_exec.temporal import AgentPlugin, LogfirePlugin, PydanticAIPlugin, TemporalAgent + from pydantic_ai.durable_exec.temporal import ( + AgentPlugin, + LogfirePlugin, + PydanticAIPlugin, + PydanticAIWorkflow, + TemporalAgent, + ) from pydantic_ai.durable_exec.temporal._function_toolset import TemporalFunctionToolset from pydantic_ai.durable_exec.temporal._mcp_server import TemporalMCPServer from pydantic_ai.durable_exec.temporal._model import TemporalModel @@ -2356,3 +2362,53 @@ async def test_beta_graph_parallel_execution_in_workflow(client: Client): # Results can be in any order due to parallel execution # 10 * 2 = 20, 10 * 3 = 30, 10 * 4 = 40 assert sorted(output) == [20, 30, 40] + + +@workflow.defn +class WorkflowWithAgents(PydanticAIWorkflow): + __pydantic_ai_agents__ = [simple_temporal_agent] + + @workflow.run + async def run(self, prompt: str) -> str: + result = await simple_temporal_agent.run(prompt) + return result.output + + +@workflow.defn +class WorkflowWithAgentsWithoutPydanticAIWorkflow: + __pydantic_ai_agents__ = [simple_temporal_agent] + + @workflow.run + async def run(self, prompt: str) -> str: + result = await simple_temporal_agent.run(prompt) + return result.output + + +async def test_passing_agents_through_workflow(allow_model_requests: None, client: Client): + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[WorkflowWithAgents], + ): + output = await client.execute_workflow( + WorkflowWithAgents.run, + args=['What is the capital of Mexico?'], + id=WorkflowWithAgents.__name__, + task_queue=TASK_QUEUE, + ) + assert output == snapshot('The capital of Mexico is Mexico City.') + + +async def test_passing_agents_through_workflow_without_pydantic_ai_workflow(allow_model_requests: None, client: Client): + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[WorkflowWithAgentsWithoutPydanticAIWorkflow], + ): + output = await client.execute_workflow( + WorkflowWithAgentsWithoutPydanticAIWorkflow.run, + args=['What is the capital of Mexico?'], + id=WorkflowWithAgentsWithoutPydanticAIWorkflow.__name__, + task_queue=TASK_QUEUE, + ) + assert output == snapshot('The capital of Mexico is Mexico City.')