Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 55 additions & 106 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
from __future__ import annotations

import warnings
from collections.abc import AsyncIterator, Callable, Sequence
from contextlib import AbstractAsyncContextManager
from dataclasses import replace
from typing import Any

from pydantic.errors import PydanticUserError
from temporalio.client import ClientConfig, Plugin as ClientPlugin, WorkflowHistory
from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter
from temporalio.converter import DataConverter, DefaultPayloadConverter
from temporalio.service import ConnectConfig, ServiceClient
from temporalio.worker import (
Plugin as WorkerPlugin,
Replayer,
ReplayerConfig,
Worker,
WorkerConfig,
WorkflowReplayResult,
)
from temporalio.plugin import SimplePlugin
from temporalio.worker import WorkflowRunner
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner

from ...exceptions import UserError
Expand All @@ -37,102 +27,61 @@
]


class PydanticAIPlugin(ClientPlugin, WorkerPlugin):
def _data_converter(converter: DataConverter | None) -> DataConverter:
if converter and converter.payload_converter_class not in (
DefaultPayloadConverter,
PydanticPayloadConverter,
):
warnings.warn( # pragma: no cover
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
)

return pydantic_data_converter


def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner:
if not runner:
raise ValueError('No WorkflowRunner provided to the Pydantic AI plugin.')

if isinstance(runner, SandboxedWorkflowRunner):
return replace(
runner,
restrictions=runner.restrictions.with_passthrough_modules(
'pydantic_ai',
'pydantic',
'pydantic_core',
'logfire',
'rich',
'httpx',
'anyio',
'httpcore',
# Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize
'attrs',
# Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize
'numpy',
'pandas',
),
)
return runner


class PydanticAIPlugin(SimplePlugin):
"""Temporal client and worker plugin for Pydantic AI."""

def init_client_plugin(self, next: ClientPlugin) -> None:
self.next_client_plugin = next

def init_worker_plugin(self, next: WorkerPlugin) -> None:
self.next_worker_plugin = next

def configure_client(self, config: ClientConfig) -> ClientConfig:
config['data_converter'] = self._get_new_data_converter(config.get('data_converter'))
return self.next_client_plugin.configure_client(config)

def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType]
if isinstance(runner, SandboxedWorkflowRunner): # pragma: no branch
config['workflow_runner'] = replace(
runner,
restrictions=runner.restrictions.with_passthrough_modules(
'pydantic_ai',
'pydantic',
'pydantic_core',
'logfire',
'rich',
'httpx',
'anyio',
'httpcore',
# Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize
'attrs',
# Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize
'numpy',
'pandas',
),
)

config['workflow_failure_exception_types'] = [
*config.get('workflow_failure_exception_types', []), # pyright: ignore[reportUnknownMemberType]
UserError,
PydanticUserError,
]

return self.next_worker_plugin.configure_worker(config)

async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
return await self.next_client_plugin.connect_service_client(config)

async def run_worker(self, worker: Worker) -> None:
await self.next_worker_plugin.run_worker(worker)

def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover
config['data_converter'] = self._get_new_data_converter(config.get('data_converter')) # pyright: ignore[reportUnknownMemberType]
return self.next_worker_plugin.configure_replayer(config)

def run_replayer(
self,
replayer: Replayer,
histories: AsyncIterator[WorkflowHistory],
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover
return self.next_worker_plugin.run_replayer(replayer, histories)

def _get_new_data_converter(self, converter: DataConverter | None) -> DataConverter:
if converter and converter.payload_converter_class not in (
DefaultPayloadConverter,
PydanticPayloadConverter,
):
warnings.warn( # pragma: no cover
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
)

return pydantic_data_converter


class AgentPlugin(WorkerPlugin):
"""Temporal worker plugin for a specific Pydantic AI agent."""

def __init__(self, agent: TemporalAgent[Any, Any]):
self.agent = agent

def init_worker_plugin(self, next: WorkerPlugin) -> None:
self.next_worker_plugin = next
def __init__(self):
super().__init__( # type: ignore[reportUnknownMemberType]
name='PydanticAIPlugin',
data_converter=_data_converter,
workflow_runner=_workflow_runner,
workflow_failure_exception_types=[UserError, PydanticUserError],
)

def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType]
# Activities are checked for name conflicts by Temporal.
config['activities'] = [*activities, *self.agent.temporal_activities]
return self.next_worker_plugin.configure_worker(config)

async def run_worker(self, worker: Worker) -> None:
await self.next_worker_plugin.run_worker(worker)

def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover
return self.next_worker_plugin.configure_replayer(config)
class AgentPlugin(SimplePlugin):
"""Temporal worker plugin for a specific Pydantic AI agent."""

def run_replayer(
self,
replayer: Replayer,
histories: AsyncIterator[WorkflowHistory],
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover
return self.next_worker_plugin.run_replayer(replayer, histories)
def __init__(self, agent: TemporalAgent[Any, Any]):
super().__init__( # type: ignore[reportUnknownMemberType]
name='AgentPlugin',
activities=agent.temporal_activities,
)
28 changes: 13 additions & 15 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_logfire.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from collections.abc import Callable
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING

from temporalio.client import ClientConfig, Plugin as ClientPlugin
from temporalio.plugin import SimplePlugin
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
from temporalio.service import ConnectConfig, ServiceClient

Expand All @@ -19,12 +19,14 @@ def _default_setup_logfire() -> Logfire:
return instance


class LogfirePlugin(ClientPlugin):
class LogfirePlugin(SimplePlugin):
"""Temporal client plugin for Logfire."""

def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire, *, metrics: bool = True):
try:
import logfire # noqa: F401 # pyright: ignore[reportUnusedImport]
from opentelemetry.trace import get_tracer
from temporalio.contrib.opentelemetry import TracingInterceptor
except ImportError as _import_error:
raise ImportError(
'Please install the `logfire` package to use the Logfire plugin, '
Expand All @@ -34,18 +36,14 @@ def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire
self.setup_logfire = setup_logfire
self.metrics = metrics

def init_client_plugin(self, next: ClientPlugin) -> None:
self.next_client_plugin = next
super().__init__( # type: ignore[reportUnknownMemberType]
name='LogfirePlugin',
client_interceptors=[TracingInterceptor(get_tracer('temporalio'))],
)

def configure_client(self, config: ClientConfig) -> ClientConfig:
from opentelemetry.trace import get_tracer
from temporalio.contrib.opentelemetry import TracingInterceptor

interceptors = config.get('interceptors', [])
config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporalio'))]
return self.next_client_plugin.configure_client(config)

async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
async def connect_service_client(
self, config: ConnectConfig, next: Callable[[ConnectConfig], Awaitable[ServiceClient]]
) -> ServiceClient:
logfire = self.setup_logfire()

if self.metrics:
Expand All @@ -60,4 +58,4 @@ async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url, headers=headers))
)

return await self.next_client_plugin.connect_service_client(config)
return await next(config)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pydantic-ai-slim = { workspace = true }
pydantic-evals = { workspace = true }
pydantic-graph = { workspace = true }
pydantic-ai-examples = { workspace = true }
temporalio = { git = "https://github.com/temporalio/sdk-python.git", rev = "main" }
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove and rebuild uv.lock once temporalio/sdk-python#1139 is merged


[tool.uv.workspace]
members = [
Expand Down
14 changes: 3 additions & 11 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading