-
Notifications
You must be signed in to change notification settings - Fork 140
Nexus task cancellation #1204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Nexus task cancellation #1204
Changes from all commits
2972d9b
134b473
883ba60
fa971e0
5a5d031
4956b3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| import asyncio | ||
| import concurrent.futures | ||
| import json | ||
| import threading | ||
| from dataclasses import dataclass | ||
| from typing import ( | ||
| Any, | ||
|
|
@@ -32,7 +33,10 @@ | |
| import temporalio.common | ||
| import temporalio.converter | ||
| import temporalio.nexus | ||
| from temporalio.exceptions import ApplicationError, WorkflowAlreadyStartedError | ||
| from temporalio.exceptions import ( | ||
| ApplicationError, | ||
| WorkflowAlreadyStartedError, | ||
| ) | ||
| from temporalio.nexus import Info, logger | ||
| from temporalio.service import RPCError, RPCStatusCode | ||
|
|
||
|
|
@@ -41,6 +45,16 @@ | |
| _TEMPORAL_FAILURE_PROTO_TYPE = "temporal.api.failure.v1.Failure" | ||
|
|
||
|
|
||
| @dataclass | ||
| class _RunningNexusTask: | ||
| task: asyncio.Task[Any] | ||
| cancellation: _NexusTaskCancellation | ||
|
|
||
| def cancel(self, reason: Optional[str] = None): | ||
| self.cancellation.cancel(reason) | ||
| self.task.cancel() | ||
|
|
||
|
|
||
| class _NexusWorker: | ||
| def __init__( | ||
| self, | ||
|
|
@@ -65,7 +79,7 @@ def __init__( | |
| self._interceptors = interceptors | ||
| # TODO(nexus-preview): metric_meter | ||
| self._metric_meter = metric_meter | ||
| self._running_tasks: dict[bytes, asyncio.Task[Any]] = {} | ||
| self._running_tasks: dict[bytes, _RunningNexusTask] = {} | ||
| self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue() | ||
|
|
||
| async def run(self) -> None: | ||
|
|
@@ -90,21 +104,31 @@ async def raise_from_exception_queue() -> NoReturn: | |
| if nexus_task.HasField("task"): | ||
| task = nexus_task.task | ||
| if task.request.HasField("start_operation"): | ||
| self._running_tasks[task.task_token] = asyncio.create_task( | ||
| task_cancellation = _NexusTaskCancellation() | ||
| start_op_task = asyncio.create_task( | ||
| self._handle_start_operation_task( | ||
| task.task_token, | ||
| task.request.start_operation, | ||
| dict(task.request.header), | ||
| task_cancellation, | ||
| ) | ||
| ) | ||
| self._running_tasks[task.task_token] = _RunningNexusTask( | ||
| start_op_task, task_cancellation | ||
| ) | ||
| elif task.request.HasField("cancel_operation"): | ||
| self._running_tasks[task.task_token] = asyncio.create_task( | ||
| task_cancellation = _NexusTaskCancellation() | ||
| cancel_op_task = asyncio.create_task( | ||
| self._handle_cancel_operation_task( | ||
| task.task_token, | ||
| task.request.cancel_operation, | ||
| dict(task.request.header), | ||
| task_cancellation, | ||
| ) | ||
| ) | ||
| self._running_tasks[task.task_token] = _RunningNexusTask( | ||
| cancel_op_task, task_cancellation | ||
| ) | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Invalid Nexus task request: {task.request}" | ||
|
|
@@ -113,8 +137,12 @@ async def raise_from_exception_queue() -> NoReturn: | |
| if running_task := self._running_tasks.get( | ||
| nexus_task.cancel_task.task_token | ||
| ): | ||
| # TODO(nexus-prerelease): when do we remove the entry from _running_operations? | ||
| running_task.cancel() | ||
| reason = ( | ||
| temporalio.bridge.proto.nexus.NexusTaskCancelReason.Name( | ||
| nexus_task.cancel_task.reason | ||
| ) | ||
| ) | ||
| running_task.cancel(reason) | ||
| else: | ||
| logger.debug( | ||
| f"Received cancel_task but no running task exists for " | ||
|
|
@@ -147,7 +175,10 @@ async def drain_poll_queue(self) -> None: | |
| # Only call this after run()/drain_poll_queue() have returned. This will not | ||
| # raise an exception. | ||
| async def wait_all_completed(self) -> None: | ||
| await asyncio.gather(*self._running_tasks.values(), return_exceptions=True) | ||
| running_tasks = [ | ||
| running_task.task for running_task in self._running_tasks.values() | ||
| ] | ||
| await asyncio.gather(*running_tasks, return_exceptions=True) | ||
|
|
||
| # TODO(nexus-preview): stack trace pruning. See sdk-typescript NexusHandler.execute | ||
| # "Any call up to this function and including this one will be trimmed out of stack traces."" | ||
|
|
@@ -157,6 +188,7 @@ async def _handle_cancel_operation_task( | |
| task_token: bytes, | ||
| request: temporalio.api.nexus.v1.CancelOperationRequest, | ||
| headers: Mapping[str, str], | ||
| task_cancellation: nexusrpc.handler.OperationTaskCancellation, | ||
| ) -> None: | ||
| """Handle a cancel operation task. | ||
|
|
||
|
|
@@ -168,6 +200,7 @@ async def _handle_cancel_operation_task( | |
| service=request.service, | ||
| operation=request.operation, | ||
| headers=headers, | ||
| task_cancellation=task_cancellation, | ||
| ) | ||
| temporalio.nexus._operation_context._TemporalCancelOperationContext( | ||
| info=lambda: Info(task_queue=self._task_queue), | ||
|
|
@@ -177,6 +210,11 @@ async def _handle_cancel_operation_task( | |
| try: | ||
| try: | ||
| await self._handler.cancel_operation(ctx, request.operation_token) | ||
| except asyncio.CancelledError: | ||
| completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( | ||
| task_token=task_token, | ||
| ack_cancel=task_cancellation.is_cancelled(), | ||
| ) | ||
| except BaseException as err: | ||
| logger.warning("Failed to execute Nexus cancel operation method") | ||
| completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( | ||
|
|
@@ -209,6 +247,7 @@ async def _handle_start_operation_task( | |
| task_token: bytes, | ||
| start_request: temporalio.api.nexus.v1.StartOperationRequest, | ||
| headers: Mapping[str, str], | ||
| task_cancellation: nexusrpc.handler.OperationTaskCancellation, | ||
| ) -> None: | ||
| """Handle a start operation task. | ||
|
|
||
|
|
@@ -217,7 +256,14 @@ async def _handle_start_operation_task( | |
| """ | ||
| try: | ||
| try: | ||
| start_response = await self._start_operation(start_request, headers) | ||
| start_response = await self._start_operation( | ||
| start_request, headers, task_cancellation | ||
| ) | ||
| except asyncio.CancelledError: | ||
| completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( | ||
| task_token=task_token, | ||
| ack_cancel=task_cancellation.is_cancelled(), | ||
| ) | ||
| except BaseException as err: | ||
| logger.warning("Failed to execute Nexus start operation method") | ||
| completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( | ||
|
|
@@ -226,6 +272,7 @@ async def _handle_start_operation_task( | |
| _exception_to_handler_error(err) | ||
| ), | ||
| ) | ||
|
|
||
| if isinstance(err, concurrent.futures.BrokenExecutor): | ||
| self._fail_worker_exception_queue.put_nowait(err) | ||
| else: | ||
|
|
@@ -235,6 +282,7 @@ async def _handle_start_operation_task( | |
| start_operation=start_response | ||
| ), | ||
| ) | ||
|
|
||
| await self._bridge_worker().complete_nexus_task(completion) | ||
| except Exception: | ||
| logger.exception("Failed to send Nexus task completion") | ||
|
|
@@ -250,6 +298,7 @@ async def _start_operation( | |
| self, | ||
| start_request: temporalio.api.nexus.v1.StartOperationRequest, | ||
| headers: Mapping[str, str], | ||
| cancellation: nexusrpc.handler.OperationTaskCancellation, | ||
| ) -> temporalio.api.nexus.v1.StartOperationResponse: | ||
| """Invoke the Nexus handler's start_operation method and construct the StartOperationResponse. | ||
|
|
||
|
|
@@ -268,6 +317,7 @@ async def _start_operation( | |
| for link in start_request.links | ||
| ], | ||
| callback_headers=dict(start_request.callback_header), | ||
| task_cancellation=cancellation, | ||
| ) | ||
| temporalio.nexus._operation_context._TemporalStartOperationContext( | ||
| nexus_context=ctx, | ||
|
|
@@ -517,3 +567,36 @@ def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError: | |
| ) | ||
| handler_err.__cause__ = err | ||
| return handler_err | ||
|
|
||
|
|
||
| class _NexusTaskCancellation(nexusrpc.handler.OperationTaskCancellation): | ||
| def __init__(self): | ||
| self._thread_evt = threading.Event() | ||
| self._aysnc_evt = asyncio.Event() | ||
| self._lock = threading.Lock() | ||
| self._reason: Optional[str] = None | ||
|
|
||
| def is_cancelled(self) -> bool: | ||
| return self._thread_evt.is_set() | ||
|
|
||
| def cancellation_reason(self) -> Optional[str]: | ||
| with self._lock: | ||
| return self._reason | ||
|
|
||
| def wait_until_cancelled_sync(self, timeout: float | None = None) -> bool: | ||
| return self._thread_evt.wait(timeout) | ||
|
|
||
| async def wait_until_cancelled(self) -> None: | ||
| await self._aysnc_evt.wait() | ||
|
|
||
| def cancel(self, reason: Optional[str] = None) -> bool: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably no need to return a never-read value from an effectively module-private method, but meh
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a possibility of reason being none? Even if there is (some newer proto value on the server that we don't have an entry for yet), is there any value in defaulting this parameter (or the one on the |
||
| if self._thread_evt.is_set(): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No real value in Python in double-checked in and out of lock, can just rely on the in-lock check (arguably lock not as valuable in Python as it is in other langs, could just change top of |
||
| return False | ||
|
|
||
| with self._lock: | ||
| if self._thread_evt.is_set(): | ||
| return False | ||
| self._reason = reason | ||
| self._thread_evt.set() | ||
| self._aysnc_evt.set() | ||
| return True | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.