Skip to content

Commit ea2d9d3

Browse files
authored
Let TemporalAgents be registered to a Temporal workflow using __pydantic_ai_agents__ field (#3676)
1 parent 87ad3d5 commit ea2d9d3

File tree

6 files changed

+273
-20
lines changed

6 files changed

+273
-20
lines changed

docs/durable_execution/temporal.md

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ from temporalio.worker import Worker
8686

8787
from pydantic_ai import Agent
8888
from pydantic_ai.durable_exec.temporal import (
89-
AgentPlugin,
9089
PydanticAIPlugin,
90+
PydanticAIWorkflow,
9191
TemporalAgent,
9292
)
9393

@@ -101,26 +101,27 @@ temporal_agent = TemporalAgent(agent) # (1)!
101101

102102

103103
@workflow.defn
104-
class GeographyWorkflow: # (2)!
104+
class GeographyWorkflow(PydanticAIWorkflow): # (2)!
105+
__pydantic_ai_agents__ = [temporal_agent] # (3)!
106+
105107
@workflow.run
106108
async def run(self, prompt: str) -> str:
107-
result = await temporal_agent.run(prompt) # (3)!
109+
result = await temporal_agent.run(prompt) # (4)!
108110
return result.output
109111

110112

111113
async def main():
112-
client = await Client.connect( # (4)!
113-
'localhost:7233', # (5)!
114-
plugins=[PydanticAIPlugin()], # (6)!
114+
client = await Client.connect( # (5)!
115+
'localhost:7233', # (6)!
116+
plugins=[PydanticAIPlugin()], # (7)!
115117
)
116118

117-
async with Worker( # (7)!
119+
async with Worker( # (8)!
118120
client,
119121
task_queue='geography',
120122
workflows=[GeographyWorkflow],
121-
plugins=[AgentPlugin(temporal_agent)], # (8)!
122123
):
123-
output = await client.execute_workflow( # (9)!
124+
output = await client.execute_workflow( # (10)!
124125
GeographyWorkflow.run,
125126
args=['What is the capital of Mexico?'],
126127
id=f'geography-{uuid.uuid4()}',
@@ -131,15 +132,15 @@ async def main():
131132
```
132133

133134
1. The original `Agent` cannot be used inside a deterministic Temporal workflow, but the `TemporalAgent` can.
134-
2. As explained above, the workflow represents a deterministic piece of code that can use non-deterministic activities for operations that require I/O.
135-
3. [`TemporalAgent.run()`][pydantic_ai.durable_exec.temporal.TemporalAgent.run] works just like [`Agent.run()`][pydantic_ai.Agent.run], but it will automatically offload model requests, tool calls, and MCP server communication to Temporal activities.
136-
4. We connect to the Temporal server which keeps track of workflow and activity execution.
137-
5. This assumes the Temporal server is [running locally](https://github.com/temporalio/temporal#download-and-start-temporal-server-locally).
138-
6. The [`PydanticAIPlugin`][pydantic_ai.durable_exec.temporal.PydanticAIPlugin] tells Temporal to use Pydantic for serialization and deserialization, and to treat [`UserError`][pydantic_ai.exceptions.UserError] exceptions as non-retryable.
139-
7. We start the worker that will listen on the specified task queue and run workflows and activities. In a real world application, this might be run in a separate service.
140-
8. The [`AgentPlugin`][pydantic_ai.durable_exec.temporal.AgentPlugin] registers the `TemporalAgent`'s activities with the worker.
141-
9. We call on the server to execute the workflow on a worker that's listening on the specified task queue.
142-
10. The agent's `name` is used to uniquely identify its activities.
135+
2. As explained above, the workflow represents a deterministic piece of code that can use non-deterministic activities for operations that require I/O. Subclassing [`PydanticAIWorkflow`][pydantic_ai.durable_exec.temporal.PydanticAIWorkflow] is optional but provides proper typing for the `__pydantic_ai_agents__` class variable.
136+
3. List the `TemporalAgent`s used by this workflow. The [`PydanticAIPlugin`][pydantic_ai.durable_exec.temporal.PydanticAIPlugin] will automatically register their activities with the worker. Alternatively, if modifying the worker initialization is easier than the workflow class, you can use [`AgentPlugin`][pydantic_ai.durable_exec.temporal.AgentPlugin] to register agents directly on the worker.
137+
4. [`TemporalAgent.run()`][pydantic_ai.durable_exec.temporal.TemporalAgent.run] works just like [`Agent.run()`][pydantic_ai.Agent.run], but it will automatically offload model requests, tool calls, and MCP server communication to Temporal activities.
138+
5. We connect to the Temporal server which keeps track of workflow and activity execution.
139+
6. This assumes the Temporal server is [running locally](https://github.com/temporalio/temporal#download-and-start-temporal-server-locally).
140+
7. The [`PydanticAIPlugin`][pydantic_ai.durable_exec.temporal.PydanticAIPlugin] tells Temporal to use Pydantic for serialization and deserialization, treats [`UserError`][pydantic_ai.exceptions.UserError] exceptions as non-retryable, and automatically registers activities for agents listed in `__pydantic_ai_agents__`.
141+
8. We start the worker that will listen on the specified task queue and run workflows and activities. In a real world application, this might be run in a separate service.
142+
9. The agent's `name` is used to uniquely identify its activities.
143+
10. We call on the server to execute the workflow on a worker that's listening on the specified task queue.
143144

144145
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_
145146

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
from __future__ import annotations
22

33
import warnings
4+
from collections.abc import Sequence
45
from dataclasses import replace
56
from typing import Any
67

78
from pydantic.errors import PydanticUserError
89
from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter
910
from temporalio.converter import DataConverter, DefaultPayloadConverter
1011
from temporalio.plugin import SimplePlugin
11-
from temporalio.worker import WorkflowRunner
12+
from temporalio.worker import WorkerConfig, WorkflowRunner
1213
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
1314

1415
from ...exceptions import UserError
1516
from ._agent import TemporalAgent
1617
from ._logfire import LogfirePlugin
1718
from ._run_context import TemporalRunContext
1819
from ._toolset import TemporalWrapperToolset
20+
from ._workflow import PydanticAIWorkflow
1921

2022
__all__ = [
2123
'TemporalAgent',
@@ -24,6 +26,7 @@
2426
'AgentPlugin',
2527
'TemporalRunContext',
2628
'TemporalWrapperToolset',
29+
'PydanticAIWorkflow',
2730
]
2831

2932
# We need eagerly import the anyio backends or it will happens inside workflow code and temporal has issues
@@ -91,6 +94,31 @@ def __init__(self):
9194
workflow_failure_exception_types=[UserError, PydanticUserError],
9295
)
9396

97+
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
98+
config = super().configure_worker(config)
99+
100+
workflows = list(config.get('workflows', [])) # type: ignore[reportUnknownMemberType]
101+
activities = list(config.get('activities', [])) # type: ignore[reportUnknownMemberType]
102+
103+
for workflow_class in workflows: # type: ignore[reportUnknownMemberType]
104+
agents = getattr(workflow_class, '__pydantic_ai_agents__', None) # type: ignore[reportUnknownMemberType]
105+
if agents is None:
106+
continue
107+
if not isinstance(agents, Sequence):
108+
raise TypeError( # pragma: no cover
109+
f'__pydantic_ai_agents__ must be a Sequence of TemporalAgent instances, got {type(agents)}'
110+
)
111+
for agent in agents: # type: ignore[reportUnknownVariableType]
112+
if not isinstance(agent, TemporalAgent):
113+
raise TypeError( # pragma: no cover
114+
f'__pydantic_ai_agents__ must be a Sequence of TemporalAgent, got {type(agent)}' # type: ignore[reportUnknownVariableType]
115+
)
116+
activities.extend(agent.temporal_activities) # type: ignore[reportUnknownMemberType]
117+
118+
config['activities'] = activities
119+
120+
return config
121+
94122

95123
class AgentPlugin(SimplePlugin):
96124
"""Temporal worker plugin for a specific Pydantic AI agent."""
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from collections.abc import Sequence
2+
from typing import Any
3+
4+
from pydantic_ai.durable_exec.temporal import TemporalAgent
5+
6+
7+
class PydanticAIWorkflow:
8+
"""Temporal Workflow base class that provides `__pydantic_ai_agents__` for direct agent registration."""
9+
10+
__pydantic_ai_agents__: Sequence[TemporalAgent[Any, Any]]
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- application/json
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '105'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.openai.com
16+
method: POST
17+
parsed_body:
18+
messages:
19+
- content: What is the capital of Mexico?
20+
role: user
21+
model: gpt-4o
22+
stream: false
23+
uri: https://api.openai.com/v1/chat/completions
24+
response:
25+
headers:
26+
access-control-expose-headers:
27+
- X-Request-ID
28+
alt-svc:
29+
- h3=":443"; ma=86400
30+
connection:
31+
- keep-alive
32+
content-length:
33+
- '838'
34+
content-type:
35+
- application/json
36+
openai-organization:
37+
- pydantic-28gund
38+
openai-processing-ms:
39+
- '403'
40+
openai-project:
41+
- proj_dKobscVY9YJxeEaDJen54e3d
42+
openai-version:
43+
- '2020-10-01'
44+
strict-transport-security:
45+
- max-age=31536000; includeSubDomains; preload
46+
transfer-encoding:
47+
- chunked
48+
parsed_body:
49+
choices:
50+
- finish_reason: stop
51+
index: 0
52+
logprobs: null
53+
message:
54+
annotations: []
55+
content: The capital of Mexico is Mexico City.
56+
refusal: null
57+
role: assistant
58+
created: 1754675179
59+
id: chatcmpl-C2LSVwAtcuMjKCHykKXgKphwTaQVB
60+
model: gpt-4o-2024-08-06
61+
object: chat.completion
62+
service_tier: default
63+
system_fingerprint: fp_ff25b2783a
64+
usage:
65+
completion_tokens: 8
66+
completion_tokens_details:
67+
accepted_prediction_tokens: 0
68+
audio_tokens: 0
69+
reasoning_tokens: 0
70+
rejected_prediction_tokens: 0
71+
prompt_tokens: 14
72+
prompt_tokens_details:
73+
audio_tokens: 0
74+
cached_tokens: 0
75+
total_tokens: 22
76+
status:
77+
code: 200
78+
message: OK
79+
version: 1
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- application/json
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '105'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.openai.com
16+
method: POST
17+
parsed_body:
18+
messages:
19+
- content: What is the capital of Mexico?
20+
role: user
21+
model: gpt-4o
22+
stream: false
23+
uri: https://api.openai.com/v1/chat/completions
24+
response:
25+
headers:
26+
access-control-expose-headers:
27+
- X-Request-ID
28+
alt-svc:
29+
- h3=":443"; ma=86400
30+
connection:
31+
- keep-alive
32+
content-length:
33+
- '838'
34+
content-type:
35+
- application/json
36+
openai-organization:
37+
- pydantic-28gund
38+
openai-processing-ms:
39+
- '403'
40+
openai-project:
41+
- proj_dKobscVY9YJxeEaDJen54e3d
42+
openai-version:
43+
- '2020-10-01'
44+
strict-transport-security:
45+
- max-age=31536000; includeSubDomains; preload
46+
transfer-encoding:
47+
- chunked
48+
parsed_body:
49+
choices:
50+
- finish_reason: stop
51+
index: 0
52+
logprobs: null
53+
message:
54+
annotations: []
55+
content: The capital of Mexico is Mexico City.
56+
refusal: null
57+
role: assistant
58+
created: 1754675179
59+
id: chatcmpl-C2LSVwAtcuMjKCHykKXgKphwTaQVB
60+
model: gpt-4o-2024-08-06
61+
object: chat.completion
62+
service_tier: default
63+
system_fingerprint: fp_ff25b2783a
64+
usage:
65+
completion_tokens: 8
66+
completion_tokens_details:
67+
accepted_prediction_tokens: 0
68+
audio_tokens: 0
69+
reasoning_tokens: 0
70+
rejected_prediction_tokens: 0
71+
prompt_tokens: 14
72+
prompt_tokens_details:
73+
audio_tokens: 0
74+
cached_tokens: 0
75+
total_tokens: 22
76+
status:
77+
code: 200
78+
message: OK
79+
version: 1

tests/test_temporal.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,13 @@
6363
from temporalio.worker import Worker
6464
from temporalio.workflow import ActivityConfig
6565

66-
from pydantic_ai.durable_exec.temporal import AgentPlugin, LogfirePlugin, PydanticAIPlugin, TemporalAgent
66+
from pydantic_ai.durable_exec.temporal import (
67+
AgentPlugin,
68+
LogfirePlugin,
69+
PydanticAIPlugin,
70+
PydanticAIWorkflow,
71+
TemporalAgent,
72+
)
6773
from pydantic_ai.durable_exec.temporal._function_toolset import TemporalFunctionToolset
6874
from pydantic_ai.durable_exec.temporal._mcp_server import TemporalMCPServer
6975
from pydantic_ai.durable_exec.temporal._model import TemporalModel
@@ -2361,3 +2367,53 @@ async def test_beta_graph_parallel_execution_in_workflow(client: Client):
23612367
# Results can be in any order due to parallel execution
23622368
# 10 * 2 = 20, 10 * 3 = 30, 10 * 4 = 40
23632369
assert sorted(output) == [20, 30, 40]
2370+
2371+
2372+
@workflow.defn
2373+
class WorkflowWithAgents(PydanticAIWorkflow):
2374+
__pydantic_ai_agents__ = [simple_temporal_agent]
2375+
2376+
@workflow.run
2377+
async def run(self, prompt: str) -> str:
2378+
result = await simple_temporal_agent.run(prompt)
2379+
return result.output
2380+
2381+
2382+
@workflow.defn
2383+
class WorkflowWithAgentsWithoutPydanticAIWorkflow:
2384+
__pydantic_ai_agents__ = [simple_temporal_agent]
2385+
2386+
@workflow.run
2387+
async def run(self, prompt: str) -> str:
2388+
result = await simple_temporal_agent.run(prompt)
2389+
return result.output
2390+
2391+
2392+
async def test_passing_agents_through_workflow(allow_model_requests: None, client: Client):
2393+
async with Worker(
2394+
client,
2395+
task_queue=TASK_QUEUE,
2396+
workflows=[WorkflowWithAgents],
2397+
):
2398+
output = await client.execute_workflow(
2399+
WorkflowWithAgents.run,
2400+
args=['What is the capital of Mexico?'],
2401+
id=WorkflowWithAgents.__name__,
2402+
task_queue=TASK_QUEUE,
2403+
)
2404+
assert output == snapshot('The capital of Mexico is Mexico City.')
2405+
2406+
2407+
async def test_passing_agents_through_workflow_without_pydantic_ai_workflow(allow_model_requests: None, client: Client):
2408+
async with Worker(
2409+
client,
2410+
task_queue=TASK_QUEUE,
2411+
workflows=[WorkflowWithAgentsWithoutPydanticAIWorkflow],
2412+
):
2413+
output = await client.execute_workflow(
2414+
WorkflowWithAgentsWithoutPydanticAIWorkflow.run,
2415+
args=['What is the capital of Mexico?'],
2416+
id=WorkflowWithAgentsWithoutPydanticAIWorkflow.__name__,
2417+
task_queue=TASK_QUEUE,
2418+
)
2419+
assert output == snapshot('The capital of Mexico is Mexico City.')

0 commit comments

Comments
 (0)