Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions docs/durable_execution/temporal.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ from temporalio.worker import Worker

from pydantic_ai import Agent
from pydantic_ai.durable_exec.temporal import (
AgentPlugin,
PydanticAIPlugin,
PydanticAIWorkflow,
TemporalAgent,
)

Expand All @@ -101,26 +101,27 @@ temporal_agent = TemporalAgent(agent) # (1)!


@workflow.defn
class GeographyWorkflow: # (2)!
class GeographyWorkflow(PydanticAIWorkflow): # (2)!
__pydantic_ai_agents__ = [temporal_agent] # (3)!

@workflow.run
async def run(self, prompt: str) -> str:
result = await temporal_agent.run(prompt) # (3)!
result = await temporal_agent.run(prompt) # (4)!
return result.output


async def main():
client = await Client.connect( # (4)!
'localhost:7233', # (5)!
plugins=[PydanticAIPlugin()], # (6)!
client = await Client.connect( # (5)!
'localhost:7233', # (6)!
plugins=[PydanticAIPlugin()], # (7)!
)

async with Worker( # (7)!
async with Worker( # (8)!
client,
task_queue='geography',
workflows=[GeographyWorkflow],
plugins=[AgentPlugin(temporal_agent)], # (8)!
):
output = await client.execute_workflow( # (9)!
output = await client.execute_workflow( # (10)!
GeographyWorkflow.run,
args=['What is the capital of Mexico?'],
id=f'geography-{uuid.uuid4()}',
Expand All @@ -131,15 +132,15 @@ async def main():
```

1. The original `Agent` cannot be used inside a deterministic Temporal workflow, but the `TemporalAgent` can.
2. As explained above, the workflow represents a deterministic piece of code that can use non-deterministic activities for operations that require I/O.
3. [`TemporalAgent.run()`][pydantic_ai.durable_exec.temporal.TemporalAgent.run] works just like [`Agent.run()`][pydantic_ai.Agent.run], but it will automatically offload model requests, tool calls, and MCP server communication to Temporal activities.
4. We connect to the Temporal server which keeps track of workflow and activity execution.
5. This assumes the Temporal server is [running locally](https://github.com/temporalio/temporal#download-and-start-temporal-server-locally).
6. The [`PydanticAIPlugin`][pydantic_ai.durable_exec.temporal.PydanticAIPlugin] tells Temporal to use Pydantic for serialization and deserialization, and to treat [`UserError`][pydantic_ai.exceptions.UserError] exceptions as non-retryable.
7. We start the worker that will listen on the specified task queue and run workflows and activities. In a real world application, this might be run in a separate service.
8. The [`AgentPlugin`][pydantic_ai.durable_exec.temporal.AgentPlugin] registers the `TemporalAgent`'s activities with the worker.
9. We call on the server to execute the workflow on a worker that's listening on the specified task queue.
10. The agent's `name` is used to uniquely identify its activities.
2. As explained above, the workflow represents a deterministic piece of code that can use non-deterministic activities for operations that require I/O. Subclassing [`PydanticAIWorkflow`][pydantic_ai.durable_exec.temporal.PydanticAIWorkflow] is optional but provides proper typing for the `__pydantic_ai_agents__` class variable.
3. List the `TemporalAgent`s used by this workflow. The [`PydanticAIPlugin`][pydantic_ai.durable_exec.temporal.PydanticAIPlugin] will automatically register their activities with the worker. Alternatively, if modifying the worker initialization is easier than the workflow class, you can use [`AgentPlugin`][pydantic_ai.durable_exec.temporal.AgentPlugin] to register agents directly on the worker.
4. [`TemporalAgent.run()`][pydantic_ai.durable_exec.temporal.TemporalAgent.run] works just like [`Agent.run()`][pydantic_ai.Agent.run], but it will automatically offload model requests, tool calls, and MCP server communication to Temporal activities.
5. We connect to the Temporal server which keeps track of workflow and activity execution.
6. This assumes the Temporal server is [running locally](https://github.com/temporalio/temporal#download-and-start-temporal-server-locally).
7. The [`PydanticAIPlugin`][pydantic_ai.durable_exec.temporal.PydanticAIPlugin] tells Temporal to use Pydantic for serialization and deserialization, treats [`UserError`][pydantic_ai.exceptions.UserError] exceptions as non-retryable, and automatically registers activities for agents listed in `__pydantic_ai_agents__`.
8. We start the worker that will listen on the specified task queue and run workflows and activities. In a real world application, this might be run in a separate service.
9. The agent's `name` is used to uniquely identify its activities.
10. We call on the server to execute the workflow on a worker that's listening on the specified task queue.

_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_

Expand Down
30 changes: 29 additions & 1 deletion pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from __future__ import annotations

import warnings
from collections.abc import Sequence
from dataclasses import replace
from typing import Any

from pydantic.errors import PydanticUserError
from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter
from temporalio.converter import DataConverter, DefaultPayloadConverter
from temporalio.plugin import SimplePlugin
from temporalio.worker import WorkflowRunner
from temporalio.worker import WorkerConfig, WorkflowRunner
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner

from ...exceptions import UserError
from ._agent import TemporalAgent
from ._logfire import LogfirePlugin
from ._run_context import TemporalRunContext
from ._toolset import TemporalWrapperToolset
from ._workflow import PydanticAIWorkflow

__all__ = [
'TemporalAgent',
Expand All @@ -24,6 +26,7 @@
'AgentPlugin',
'TemporalRunContext',
'TemporalWrapperToolset',
'PydanticAIWorkflow',
]

# We need eagerly import the anyio backends or it will happens inside workflow code and temporal has issues
Expand Down Expand Up @@ -91,6 +94,31 @@ def __init__(self):
workflow_failure_exception_types=[UserError, PydanticUserError],
)

def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
config = super().configure_worker(config)

workflows = list(config.get('workflows', [])) # type: ignore[reportUnknownMemberType]
activities = list(config.get('activities', [])) # type: ignore[reportUnknownMemberType]

for workflow_class in workflows: # type: ignore[reportUnknownMemberType]
agents = getattr(workflow_class, '__pydantic_ai_agents__', None) # type: ignore[reportUnknownMemberType]
if agents is None:
continue
if not isinstance(agents, Sequence):
raise TypeError( # pragma: no cover
f'__pydantic_ai_agents__ must be a Sequence of TemporalAgent instances, got {type(agents)}'
)
for agent in agents: # type: ignore[reportUnknownVariableType]
if not isinstance(agent, TemporalAgent):
raise TypeError( # pragma: no cover
f'__pydantic_ai_agents__ must be a Sequence of TemporalAgent, got {type(agent)}' # type: ignore[reportUnknownVariableType]
)
activities.extend(agent.temporal_activities) # type: ignore[reportUnknownMemberType]

config['activities'] = activities

return config


class AgentPlugin(SimplePlugin):
"""Temporal worker plugin for a specific Pydantic AI agent."""
Expand Down
10 changes: 10 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from collections.abc import Sequence
from typing import Any

from pydantic_ai.durable_exec.temporal import TemporalAgent


class PydanticAIWorkflow:
"""Temporal Workflow base class that provides `__pydantic_ai_agents__` for direct agent registration."""

__pydantic_ai_agents__: Sequence[TemporalAgent[Any, Any]]
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
interactions:
- request:
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '105'
content-type:
- application/json
host:
- api.openai.com
method: POST
parsed_body:
messages:
- content: What is the capital of Mexico?
role: user
model: gpt-4o
stream: false
uri: https://api.openai.com/v1/chat/completions
response:
headers:
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
connection:
- keep-alive
content-length:
- '838'
content-type:
- application/json
openai-organization:
- pydantic-28gund
openai-processing-ms:
- '403'
openai-project:
- proj_dKobscVY9YJxeEaDJen54e3d
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
transfer-encoding:
- chunked
parsed_body:
choices:
- finish_reason: stop
index: 0
logprobs: null
message:
annotations: []
content: The capital of Mexico is Mexico City.
refusal: null
role: assistant
created: 1754675179
id: chatcmpl-C2LSVwAtcuMjKCHykKXgKphwTaQVB
model: gpt-4o-2024-08-06
object: chat.completion
service_tier: default
system_fingerprint: fp_ff25b2783a
usage:
completion_tokens: 8
completion_tokens_details:
accepted_prediction_tokens: 0
audio_tokens: 0
reasoning_tokens: 0
rejected_prediction_tokens: 0
prompt_tokens: 14
prompt_tokens_details:
audio_tokens: 0
cached_tokens: 0
total_tokens: 22
status:
code: 200
message: OK
version: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
interactions:
- request:
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '105'
content-type:
- application/json
host:
- api.openai.com
method: POST
parsed_body:
messages:
- content: What is the capital of Mexico?
role: user
model: gpt-4o
stream: false
uri: https://api.openai.com/v1/chat/completions
response:
headers:
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
connection:
- keep-alive
content-length:
- '838'
content-type:
- application/json
openai-organization:
- pydantic-28gund
openai-processing-ms:
- '403'
openai-project:
- proj_dKobscVY9YJxeEaDJen54e3d
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
transfer-encoding:
- chunked
parsed_body:
choices:
- finish_reason: stop
index: 0
logprobs: null
message:
annotations: []
content: The capital of Mexico is Mexico City.
refusal: null
role: assistant
created: 1754675179
id: chatcmpl-C2LSVwAtcuMjKCHykKXgKphwTaQVB
model: gpt-4o-2024-08-06
object: chat.completion
service_tier: default
system_fingerprint: fp_ff25b2783a
usage:
completion_tokens: 8
completion_tokens_details:
accepted_prediction_tokens: 0
audio_tokens: 0
reasoning_tokens: 0
rejected_prediction_tokens: 0
prompt_tokens: 14
prompt_tokens_details:
audio_tokens: 0
cached_tokens: 0
total_tokens: 22
status:
code: 200
message: OK
version: 1
58 changes: 57 additions & 1 deletion tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@
from temporalio.worker import Worker
from temporalio.workflow import ActivityConfig

from pydantic_ai.durable_exec.temporal import AgentPlugin, LogfirePlugin, PydanticAIPlugin, TemporalAgent
from pydantic_ai.durable_exec.temporal import (
AgentPlugin,
LogfirePlugin,
PydanticAIPlugin,
PydanticAIWorkflow,
TemporalAgent,
)
from pydantic_ai.durable_exec.temporal._function_toolset import TemporalFunctionToolset
from pydantic_ai.durable_exec.temporal._mcp_server import TemporalMCPServer
from pydantic_ai.durable_exec.temporal._model import TemporalModel
Expand Down Expand Up @@ -2356,3 +2362,53 @@ async def test_beta_graph_parallel_execution_in_workflow(client: Client):
# Results can be in any order due to parallel execution
# 10 * 2 = 20, 10 * 3 = 30, 10 * 4 = 40
assert sorted(output) == [20, 30, 40]


@workflow.defn
class WorkflowWithAgents(PydanticAIWorkflow):
__pydantic_ai_agents__ = [simple_temporal_agent]

@workflow.run
async def run(self, prompt: str) -> str:
result = await simple_temporal_agent.run(prompt)
return result.output


@workflow.defn
class WorkflowWithAgentsWithoutPydanticAIWorkflow:
__pydantic_ai_agents__ = [simple_temporal_agent]

@workflow.run
async def run(self, prompt: str) -> str:
result = await simple_temporal_agent.run(prompt)
return result.output


async def test_passing_agents_through_workflow(allow_model_requests: None, client: Client):
async with Worker(
client,
task_queue=TASK_QUEUE,
workflows=[WorkflowWithAgents],
):
output = await client.execute_workflow(
WorkflowWithAgents.run,
args=['What is the capital of Mexico?'],
id=WorkflowWithAgents.__name__,
task_queue=TASK_QUEUE,
)
assert output == snapshot('The capital of Mexico is Mexico City.')


async def test_passing_agents_through_workflow_without_pydantic_ai_workflow(allow_model_requests: None, client: Client):
async with Worker(
client,
task_queue=TASK_QUEUE,
workflows=[WorkflowWithAgentsWithoutPydanticAIWorkflow],
):
output = await client.execute_workflow(
WorkflowWithAgentsWithoutPydanticAIWorkflow.run,
args=['What is the capital of Mexico?'],
id=WorkflowWithAgentsWithoutPydanticAIWorkflow.__name__,
task_queue=TASK_QUEUE,
)
assert output == snapshot('The capital of Mexico is Mexico City.')