Skip to content

Let Agent be run in a Temporal workflow by moving model requests, tool calls, and MCP to Temporal activities #2225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6263510
Add optional `id` field to toolsets
DouweM Jul 21, 2025
d02361a
WIP: temporalize_agent
DouweM Jul 21, 2025
f657709
Add Agent event_stream_handler
DouweM Jul 22, 2025
2f04894
Pass run_context to Model.request_stream for Temporal
DouweM Jul 22, 2025
5f6cfa7
Streaming with Temporal
DouweM Jul 22, 2025
a1e96e6
Fix google types issues by importing only google.genai.Client
DouweM Jul 24, 2025
966e7f8
Import TypeAlias from typing_extensions for Python 3.9
DouweM Jul 24, 2025
d3811c9
Merge branch 'main' into temporal-agent
DouweM Jul 24, 2025
090ec23
Start cleaning up temporal integration
DouweM Jul 24, 2025
4c87691
with_passthrough_modules doesn't import itself
DouweM Jul 25, 2025
694fa6b
Use Temporal plugins
DouweM Jul 25, 2025
5e858d3
Polish
DouweM Jul 25, 2025
2474f1a
Use latest temporalio version with plugins
DouweM Jul 30, 2025
7c62e35
Temporalize MCPServer.get_tools and call_tool instead of list_tools a…
DouweM Jul 31, 2025
8eb677b
Add ID to model activity names
DouweM Jul 31, 2025
6682e97
Use temporal wrapper classes instead of monkeypatching
DouweM Jul 31, 2025
fb5259c
Let running a tool in a Temporal activity be disabled
DouweM Jul 31, 2025
00002d3
Use temporal ActivityConfig instead of our TemporalSettings
DouweM Jul 31, 2025
f91547d
Warn when non-default Temporal data converter was swapped out
DouweM Jul 31, 2025
6aeb078
Use agent.override inside temporalize_agent instead of directly setti…
DouweM Jul 31, 2025
2f3965b
Remove duplication between AgentStream and TemporalModel get_final_re…
DouweM Jul 31, 2025
529f4c8
Some polish
DouweM Jul 31, 2025
295a69b
Add AbstractAgent, WrapperAgent, TemporalAgent instead of temporalize…
DouweM Aug 1, 2025
8d3e04d
Add a todo
DouweM Aug 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/models/google.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ You can supply a custom `GoogleProvider` instance using the `provider` argument
This is useful if you're using a custom-compatible endpoint with the Google Generative Language API.

```python
from google import genai
from google.genai import Client
from google.genai.types import HttpOptions

from pydantic_ai import Agent
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.providers.google import GoogleProvider

client = genai.Client(
client = Client(
api_key='gemini-custom-api-key',
http_options=HttpOptions(base_url='gemini-custom-base-url'),
)
Expand Down
3 changes: 2 additions & 1 deletion docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ from pydantic_ai.ext.langchain import LangChainToolset


toolkit = SlackToolkit()
toolset = LangChainToolset(toolkit.get_tools())
toolset = LangChainToolset(toolkit.get_tools(), id='slack')

agent = Agent('openai:gpt-4o', toolsets=[toolset])
# ...
Expand Down Expand Up @@ -823,6 +823,7 @@ toolset = ACIToolset(
'OPEN_WEATHER_MAP__FORECAST',
],
linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'),
id='open_weather_map',
)

agent = Agent('openai:gpt-4o', toolsets=[toolset])
Expand Down
15 changes: 10 additions & 5 deletions docs/toolsets.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ def temperature_fahrenheit(city: str) -> float:
return 69.8


weather_toolset = FunctionToolset(tools=[temperature_celsius, temperature_fahrenheit])
weather_toolset = FunctionToolset(
tools=[temperature_celsius, temperature_fahrenheit],
id='weather', # (1)!
)


@weather_toolset.tool
Expand All @@ -95,10 +98,10 @@ def conditions(ctx: RunContext, city: str) -> str:
return "It's raining"


datetime_toolset = FunctionToolset()
datetime_toolset = FunctionToolset(id='datetime')
datetime_toolset.add_function(lambda: datetime.now(), name='now')

test_model = TestModel() # (1)!
test_model = TestModel() # (2)!
agent = Agent(test_model)

result = agent.run_sync('What tools are available?', toolsets=[weather_toolset])
Expand All @@ -110,7 +113,8 @@ print([t.name for t in test_model.last_model_request_parameters.function_tools])
#> ['now']
```

1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run.
1. `FunctionToolset` supports an optional `id` argument that can help to identify the toolset in error messages. A toolset also needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the toolset's activities within the workflow.
2. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run.

_(This example is complete, it can be run "as is")_

Expand Down Expand Up @@ -609,7 +613,7 @@ from pydantic_ai.ext.langchain import LangChainToolset


toolkit = SlackToolkit()
toolset = LangChainToolset(toolkit.get_tools())
toolset = LangChainToolset(toolkit.get_tools(), id='slack')

agent = Agent('openai:gpt-4o', toolsets=[toolset])
# ...
Expand All @@ -634,6 +638,7 @@ toolset = ACIToolset(
'OPEN_WEATHER_MAP__FORECAST',
],
linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'),
id='open_weather_map',
)

agent = Agent('openai:gpt-4o', toolsets=[toolset])
Expand Down
10 changes: 6 additions & 4 deletions pydantic_ai_slim/pydantic_ai/_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
VideoUrl,
)

from .agent import Agent, AgentDepsT, OutputDataT
from .agent import AbstractAgent, AgentDepsT, OutputDataT

# AgentWorker output type needs to be invariant for use in both parameter and return positions
WorkerOutputT = TypeVar('WorkerOutputT')
Expand Down Expand Up @@ -59,7 +59,9 @@


@asynccontextmanager
async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT, OutputDataT]) -> AsyncIterator[None]:
async def worker_lifespan(
app: FastA2A, worker: Worker, agent: AbstractAgent[AgentDepsT, OutputDataT]
) -> AsyncIterator[None]:
"""Custom lifespan that runs the worker during application startup.

This ensures the worker is started and ready to process tasks as soon as the application starts.
Expand All @@ -70,7 +72,7 @@ async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT,


def agent_to_a2a(
agent: Agent[AgentDepsT, OutputDataT],
agent: AbstractAgent[AgentDepsT, OutputDataT],
*,
storage: Storage | None = None,
broker: Broker | None = None,
Expand Down Expand Up @@ -116,7 +118,7 @@ def agent_to_a2a(
class AgentWorker(Worker[list[ModelMessage]], Generic[WorkerOutputT, AgentDepsT]):
"""A worker that uses an agent to execute tasks."""

agent: Agent[AgentDepsT, WorkerOutputT]
agent: AbstractAgent[AgentDepsT, WorkerOutputT]

async def run_task(self, params: TaskSendParams) -> None:
task = await self.storage.load_task(params['id'])
Expand Down
51 changes: 21 additions & 30 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,18 @@ async def stream(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
) -> AsyncIterator[result.AgentStream[DepsT, T]]:
async with self._stream(ctx) as streamed_response:
assert not self._did_stream, 'stream() should only be called once per node'

model_settings, model_request_parameters, message_history, run_context = await self._prepare_request(ctx)
async with ctx.deps.model.request_stream(
message_history, model_settings, model_request_parameters, run_context
) as streamed_response:
self._did_stream = True
ctx.state.usage.requests += 1
agent_stream = result.AgentStream[DepsT, T](
streamed_response,
ctx.deps.output_schema,
model_request_parameters,
ctx.deps.output_validators,
build_run_context(ctx),
ctx.deps.usage_limits,
Expand All @@ -318,28 +326,6 @@ async def stream(
async for _ in agent_stream:
pass

@asynccontextmanager
async def _stream(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
) -> AsyncIterator[models.StreamedResponse]:
assert not self._did_stream, 'stream() should only be called once per node'

model_settings, model_request_parameters = await self._prepare_request(ctx)
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
message_history = await _process_message_history(
ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx)
)
async with ctx.deps.model.request_stream(
message_history, model_settings, model_request_parameters
) as streamed_response:
self._did_stream = True
ctx.state.usage.requests += 1
yield streamed_response
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
# otherwise usage won't be properly counted:
async for _ in streamed_response:
pass
model_response = streamed_response.get()

self._finish_handling(ctx, model_response)
Expand All @@ -351,19 +337,15 @@ async def _make_request(
if self._result is not None:
return self._result # pragma: no cover

model_settings, model_request_parameters = await self._prepare_request(ctx)
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
message_history = await _process_message_history(
ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx)
)
model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx)
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
ctx.state.usage.incr(_usage.Usage())

return self._finish_handling(ctx, model_response)

async def _prepare_request(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> tuple[ModelSettings | None, models.ModelRequestParameters]:
) -> tuple[ModelSettings | None, models.ModelRequestParameters, list[_messages.ModelMessage], RunContext[DepsT]]:
ctx.state.message_history.append(self.request)

# Check usage
Expand All @@ -373,9 +355,18 @@ async def _prepare_request(
# Increment run_step
ctx.state.run_step += 1

run_context = build_run_context(ctx)

model_settings = merge_model_settings(ctx.deps.model_settings, None)

model_request_parameters = await _prepare_request_parameters(ctx)
return model_settings, model_request_parameters
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)

message_history = await _process_message_history(
ctx.state.message_history, ctx.deps.history_processors, run_context
)

return model_settings, model_request_parameters, message_history, run_context

def _finish_handling(
self,
Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from . import __version__
from ._run_context import AgentDepsT
from .agent import Agent
from .agent import AbstractAgent, Agent
from .exceptions import UserError
from .messages import ModelMessage
from .models import KnownModelName, infer_model
Expand Down Expand Up @@ -220,7 +220,7 @@ def cli( # noqa: C901

async def run_chat(
stream: bool,
agent: Agent[AgentDepsT, OutputDataT],
agent: AbstractAgent[AgentDepsT, OutputDataT],
console: Console,
code_theme: str,
prog_name: str,
Expand Down Expand Up @@ -263,7 +263,7 @@ async def run_chat(


async def ask_agent(
agent: Agent[AgentDepsT, OutputDataT],
agent: AbstractAgent[AgentDepsT, OutputDataT],
prompt: str,
stream: bool,
console: Console,
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,10 @@ def __init__(
self.max_retries = max_retries
self.output_validators = output_validators or []

@property
def id(self) -> str | None:
return 'output'

async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
return {
tool_def.name: ToolsetTool(
Expand Down
9 changes: 5 additions & 4 deletions pydantic_ai_slim/pydantic_ai/ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from pydantic import BaseModel, ValidationError

from ._agent_graph import CallToolsNode, ModelRequestNode
from .agent import Agent, AgentRun, RunOutputDataT
from .agent import AbstractAgent, AgentRun, RunOutputDataT
from .messages import (
AgentStreamEvent,
FunctionToolResultEvent,
Expand Down Expand Up @@ -115,7 +115,7 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):

def __init__(
self,
agent: Agent[AgentDepsT, OutputDataT],
agent: AbstractAgent[AgentDepsT, OutputDataT],
*,
# Agent.iter parameters.
output_type: OutputSpec[OutputDataT] | None = None,
Expand Down Expand Up @@ -223,7 +223,7 @@ class _Adapter(Generic[AgentDepsT, OutputDataT]):
agent: The Pydantic AI `Agent` to adapt.
"""

agent: Agent[AgentDepsT, OutputDataT] = field(repr=False)
agent: AbstractAgent[AgentDepsT, OutputDataT] = field(repr=False)

async def run(
self,
Expand Down Expand Up @@ -273,7 +273,8 @@ async def run(
parameters_json_schema=tool.parameters,
)
for tool in run_input.tools
]
],
id='ag_ui_frontend',
)
toolsets = [*toolsets, toolset] if toolsets else [toolset]

Expand Down
Loading