Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ license = "MIT"
license-files = ["LICENSE"]
keywords = ["temporal", "workflow"]
dependencies = [
"nexus-rpc==1.1.0",
"nexus-rpc @ git+https://github.com/nexus-rpc/sdk-python@task-cancellation",
"protobuf>=3.20,<7.0.0",
"python-dateutil>=2.8.2,<3 ; python_version < '3.11'",
"types-protobuf>=3.20",
Expand All @@ -28,10 +28,7 @@ classifiers = [
grpc = ["grpcio>=1.48.2,<2"]
opentelemetry = ["opentelemetry-api>=1.11.1,<2", "opentelemetry-sdk>=1.11.1,<2"]
pydantic = ["pydantic>=2.0.0,<3"]
openai-agents = [
"openai-agents>=0.3,<0.5",
"mcp>=1.9.4, <2",
]
openai-agents = ["openai-agents>=0.3,<0.5", "mcp>=1.9.4, <2"]

[project.urls]
Homepage = "https://github.com/temporalio/sdk-python"
Expand Down
2 changes: 1 addition & 1 deletion temporalio/nexus/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def _start(
return WorkflowRunOperationHandler(_start)

method_name = get_callable_name(start)
nexusrpc.set_operation_definition(
nexusrpc.set_operation(
operation_handler_factory,
nexusrpc.Operation(
name=name or method_name,
Expand Down
19 changes: 0 additions & 19 deletions temporalio/nexus/_operation_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@
HandlerError,
HandlerErrorType,
InputT,
OperationInfo,
OutputT,
)
from nexusrpc.handler import (
CancelOperationContext,
FetchOperationInfoContext,
FetchOperationResultContext,
OperationHandler,
StartOperationContext,
StartOperationResultAsync,
Expand Down Expand Up @@ -81,22 +78,6 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
"""Cancel the operation, by cancelling the workflow."""
await _cancel_workflow(token)

async def fetch_info(
self, ctx: FetchOperationInfoContext, token: str
) -> OperationInfo:
"""Fetch operation info (not supported for Temporal Nexus operations)."""
raise NotImplementedError(
"Temporal Nexus operation handlers do not support fetching operation info."
)

async def fetch_result(
self, ctx: FetchOperationResultContext, token: str
) -> OutputT:
"""Fetch operation result (not supported for Temporal Nexus operations)."""
raise NotImplementedError(
"Temporal Nexus operation handlers do not support fetching the operation result."
)


async def _cancel_workflow(
token: str,
Expand Down
4 changes: 2 additions & 2 deletions temporalio/nexus/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ def get_operation_factory(

``obj`` should be a decorated operation start method.
"""
op_defn = nexusrpc.get_operation_definition(obj)
op_defn = nexusrpc.get_operation(obj)
if op_defn:
factory = obj
else:
if factory := getattr(obj, "__nexus_operation_factory__", None):
op_defn = nexusrpc.get_operation_definition(factory)
op_defn = nexusrpc.get_operation(factory)
if not isinstance(op_defn, nexusrpc.Operation):
return None, None
return factory, op_defn
Expand Down
99 changes: 91 additions & 8 deletions temporalio/worker/_nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
import concurrent.futures
import json
import threading
from dataclasses import dataclass
from typing import (
Any,
Expand Down Expand Up @@ -32,7 +33,10 @@
import temporalio.common
import temporalio.converter
import temporalio.nexus
from temporalio.exceptions import ApplicationError, WorkflowAlreadyStartedError
from temporalio.exceptions import (
ApplicationError,
WorkflowAlreadyStartedError,
)
from temporalio.nexus import Info, logger
from temporalio.service import RPCError, RPCStatusCode

Expand All @@ -41,6 +45,16 @@
_TEMPORAL_FAILURE_PROTO_TYPE = "temporal.api.failure.v1.Failure"


@dataclass
class _RunningNexusTask:
task: asyncio.Task[Any]
cancellation: _NexusTaskCancellation

def cancel(self, reason: Optional[str] = None):
self.cancellation.cancel(reason)
self.task.cancel()


class _NexusWorker:
def __init__(
self,
Expand All @@ -65,7 +79,7 @@ def __init__(
self._interceptors = interceptors
# TODO(nexus-preview): metric_meter
self._metric_meter = metric_meter
self._running_tasks: dict[bytes, asyncio.Task[Any]] = {}
self._running_tasks: dict[bytes, _RunningNexusTask] = {}
self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue()

async def run(self) -> None:
Expand All @@ -90,21 +104,31 @@ async def raise_from_exception_queue() -> NoReturn:
if nexus_task.HasField("task"):
task = nexus_task.task
if task.request.HasField("start_operation"):
self._running_tasks[task.task_token] = asyncio.create_task(
task_cancellation = _NexusTaskCancellation()
start_op_task = asyncio.create_task(
self._handle_start_operation_task(
task.task_token,
task.request.start_operation,
dict(task.request.header),
task_cancellation,
)
)
self._running_tasks[task.task_token] = _RunningNexusTask(
start_op_task, task_cancellation
)
elif task.request.HasField("cancel_operation"):
self._running_tasks[task.task_token] = asyncio.create_task(
task_cancellation = _NexusTaskCancellation()
cancel_op_task = asyncio.create_task(
self._handle_cancel_operation_task(
task.task_token,
task.request.cancel_operation,
dict(task.request.header),
task_cancellation,
)
)
self._running_tasks[task.task_token] = _RunningNexusTask(
cancel_op_task, task_cancellation
)
else:
raise NotImplementedError(
f"Invalid Nexus task request: {task.request}"
Expand All @@ -113,8 +137,12 @@ async def raise_from_exception_queue() -> NoReturn:
if running_task := self._running_tasks.get(
nexus_task.cancel_task.task_token
):
# TODO(nexus-prerelease): when do we remove the entry from _running_operations?
running_task.cancel()
reason = (
temporalio.bridge.proto.nexus.NexusTaskCancelReason.Name(
nexus_task.cancel_task.reason
)
)
running_task.cancel(reason)
else:
logger.debug(
f"Received cancel_task but no running task exists for "
Expand Down Expand Up @@ -147,7 +175,10 @@ async def drain_poll_queue(self) -> None:
# Only call this after run()/drain_poll_queue() have returned. This will not
# raise an exception.
async def wait_all_completed(self) -> None:
await asyncio.gather(*self._running_tasks.values(), return_exceptions=True)
running_tasks = [
running_task.task for running_task in self._running_tasks.values()
]
await asyncio.gather(*running_tasks, return_exceptions=True)

# TODO(nexus-preview): stack trace pruning. See sdk-typescript NexusHandler.execute
# "Any call up to this function and including this one will be trimmed out of stack traces.""
Expand All @@ -157,6 +188,7 @@ async def _handle_cancel_operation_task(
task_token: bytes,
request: temporalio.api.nexus.v1.CancelOperationRequest,
headers: Mapping[str, str],
task_cancellation: nexusrpc.handler.OperationTaskCancellation,
) -> None:
"""Handle a cancel operation task.

Expand All @@ -168,6 +200,7 @@ async def _handle_cancel_operation_task(
service=request.service,
operation=request.operation,
headers=headers,
task_cancellation=task_cancellation,
)
temporalio.nexus._operation_context._TemporalCancelOperationContext(
info=lambda: Info(task_queue=self._task_queue),
Expand All @@ -177,6 +210,11 @@ async def _handle_cancel_operation_task(
try:
try:
await self._handler.cancel_operation(ctx, request.operation_token)
except asyncio.CancelledError:
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
task_token=task_token,
ack_cancel=task_cancellation.is_cancelled(),
)
except BaseException as err:
logger.warning("Failed to execute Nexus cancel operation method")
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
Expand Down Expand Up @@ -209,6 +247,7 @@ async def _handle_start_operation_task(
task_token: bytes,
start_request: temporalio.api.nexus.v1.StartOperationRequest,
headers: Mapping[str, str],
task_cancellation: nexusrpc.handler.OperationTaskCancellation,
) -> None:
"""Handle a start operation task.

Expand All @@ -217,7 +256,14 @@ async def _handle_start_operation_task(
"""
try:
try:
start_response = await self._start_operation(start_request, headers)
start_response = await self._start_operation(
start_request, headers, task_cancellation
)
except asyncio.CancelledError:
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
task_token=task_token,
ack_cancel=task_cancellation.is_cancelled(),
)
except BaseException as err:
logger.warning("Failed to execute Nexus start operation method")
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
Expand All @@ -226,6 +272,7 @@ async def _handle_start_operation_task(
_exception_to_handler_error(err)
),
)

if isinstance(err, concurrent.futures.BrokenExecutor):
self._fail_worker_exception_queue.put_nowait(err)
else:
Expand All @@ -235,6 +282,7 @@ async def _handle_start_operation_task(
start_operation=start_response
),
)

await self._bridge_worker().complete_nexus_task(completion)
except Exception:
logger.exception("Failed to send Nexus task completion")
Expand All @@ -250,6 +298,7 @@ async def _start_operation(
self,
start_request: temporalio.api.nexus.v1.StartOperationRequest,
headers: Mapping[str, str],
cancellation: nexusrpc.handler.OperationTaskCancellation,
) -> temporalio.api.nexus.v1.StartOperationResponse:
"""Invoke the Nexus handler's start_operation method and construct the StartOperationResponse.

Expand All @@ -268,6 +317,7 @@ async def _start_operation(
for link in start_request.links
],
callback_headers=dict(start_request.callback_header),
task_cancellation=cancellation,
)
temporalio.nexus._operation_context._TemporalStartOperationContext(
nexus_context=ctx,
Expand Down Expand Up @@ -517,3 +567,36 @@ def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError:
)
handler_err.__cause__ = err
return handler_err


class _NexusTaskCancellation(nexusrpc.handler.OperationTaskCancellation):
def __init__(self):
self._thread_evt = threading.Event()
self._aysnc_evt = asyncio.Event()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self._aysnc_evt = asyncio.Event()
self._async_evt = asyncio.Event()

self._lock = threading.Lock()
self._reason: Optional[str] = None

def is_cancelled(self) -> bool:
return self._thread_evt.is_set()

def cancellation_reason(self) -> Optional[str]:
with self._lock:
return self._reason

def wait_until_cancelled_sync(self, timeout: float | None = None) -> bool:
return self._thread_evt.wait(timeout)

async def wait_until_cancelled(self) -> None:
await self._aysnc_evt.wait()

def cancel(self, reason: Optional[str] = None) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably no need to return a never-read value from an effectively module-private method, but meh

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a possibility of reason being none? Even if there is (some newer proto value on the server that we don't have an entry for yet), is there any value in defaulting this parameter (or the one on the cancel call that calls this)?

if self._thread_evt.is_set():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No real value in Python in double-checked in and out of lock, can just rely on the in-lock check (arguably lock not as valuable in Python as it is in other langs, could just change top of cancellation_reason to have an if self._thread_evt.is_set() instead of the lock, but it's fine)

return False

with self._lock:
if self._thread_evt.is_set():
return False
self._reason = reason
self._thread_evt.set()
self._aysnc_evt.set()
return True
33 changes: 33 additions & 0 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5417,6 +5417,22 @@ async def start_operation(
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...

# Overload for operation_handler
@overload
@abstractmethod
async def start_operation(
self,
operation: Callable[
[ServiceHandlerT], nexusrpc.handler.OperationHandler[InputT, OutputT]
],
input: InputT,
*,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...

@abstractmethod
async def start_operation(
self,
Expand Down Expand Up @@ -5527,6 +5543,23 @@ async def execute_operation(
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...

# Overload for operation_handler
@overload
@abstractmethod
async def execute_operation(
self,
operation: Callable[
[ServiceT],
nexusrpc.handler.OperationHandler[InputT, OutputT],
],
input: InputT,
*,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...

@abstractmethod
async def execute_operation(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from agents import Runner, custom_span, gen_trace_id, trace

import temporalio.workflow
from tests.contrib.openai_agents.research_agents.planner_agent import (
WebSearchItem,
WebSearchPlan,
Expand Down Expand Up @@ -45,7 +46,7 @@ async def _perform_searches(self, search_plan: WebSearchPlan) -> list[str]:
asyncio.create_task(self._search(item)) for item in search_plan.searches
]
results = []
for task in asyncio.as_completed(tasks):
for task in temporalio.workflow.as_completed(tasks):
result = await task
if result is not None:
results.append(result)
Expand Down
Loading