From 2972d9b79b2d4f7d2b549e1dde677ba8e55a1106 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Thu, 30 Oct 2025 18:02:16 -0700 Subject: [PATCH 1/6] WIP: update to latest nexus patterns and try out task cancellation --- pyproject.toml | 7 +- temporalio/nexus/_decorators.py | 2 +- temporalio/nexus/_operation_handlers.py | 19 ---- temporalio/nexus/_util.py | 4 +- temporalio/worker/_nexus.py | 98 +++++++++++++++++-- ...ynamic_creation_of_user_handler_classes.py | 18 +--- tests/nexus/test_handler.py | 25 +---- tests/nexus/test_handler_async_operation.py | 46 --------- tests/nexus/test_workflow_caller.py | 31 +++--- tests/nexus/test_workflow_caller_errors.py | 69 ++++++++++--- uv.lock | 8 +- 11 files changed, 171 insertions(+), 156 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bfd8bee67..2599a0a65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ license = "MIT" license-files = ["LICENSE"] keywords = ["temporal", "workflow"] dependencies = [ - "nexus-rpc==1.1.0", + "nexus-rpc @ git+https://github.com/nexus-rpc/sdk-python@task-cancellation", "protobuf>=3.20,<7.0.0", "python-dateutil>=2.8.2,<3 ; python_version < '3.11'", "types-protobuf>=3.20", @@ -28,10 +28,7 @@ classifiers = [ grpc = ["grpcio>=1.48.2,<2"] opentelemetry = ["opentelemetry-api>=1.11.1,<2", "opentelemetry-sdk>=1.11.1,<2"] pydantic = ["pydantic>=2.0.0,<3"] -openai-agents = [ - "openai-agents>=0.3,<0.5", - "mcp>=1.9.4, <2", -] +openai-agents = ["openai-agents>=0.3,<0.5", "mcp>=1.9.4, <2"] [project.urls] Homepage = "https://github.com/temporalio/sdk-python" diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index 28c625816..3c97482ff 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -117,7 +117,7 @@ async def _start( return WorkflowRunOperationHandler(_start) method_name = get_callable_name(start) - nexusrpc.set_operation_definition( + nexusrpc.set_operation( operation_handler_factory, nexusrpc.Operation( name=name or method_name, diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index a73c3eb50..aa5351353 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -10,13 +10,10 @@ HandlerError, HandlerErrorType, InputT, - OperationInfo, OutputT, ) from nexusrpc.handler import ( CancelOperationContext, - FetchOperationInfoContext, - FetchOperationResultContext, OperationHandler, StartOperationContext, StartOperationResultAsync, @@ -81,22 +78,6 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None: """Cancel the operation, by cancelling the workflow.""" await _cancel_workflow(token) - async def fetch_info( - self, ctx: FetchOperationInfoContext, token: str - ) -> OperationInfo: - """Fetch operation info (not supported for Temporal Nexus operations).""" - raise NotImplementedError( - "Temporal Nexus operation handlers do not support fetching operation info." - ) - - async def fetch_result( - self, ctx: FetchOperationResultContext, token: str - ) -> OutputT: - """Fetch operation result (not supported for Temporal Nexus operations).""" - raise NotImplementedError( - "Temporal Nexus operation handlers do not support fetching the operation result." - ) - async def _cancel_workflow( token: str, diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index ef005d0c4..de5bf94b2 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -129,12 +129,12 @@ def get_operation_factory( ``obj`` should be a decorated operation start method. """ - op_defn = nexusrpc.get_operation_definition(obj) + op_defn = nexusrpc.get_operation(obj) if op_defn: factory = obj else: if factory := getattr(obj, "__nexus_operation_factory__", None): - op_defn = nexusrpc.get_operation_definition(factory) + op_defn = nexusrpc.get_operation(factory) if not isinstance(op_defn, nexusrpc.Operation): return None, None return factory, op_defn diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 72b47187f..20537bda5 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -6,6 +6,7 @@ import concurrent.futures import json from dataclasses import dataclass +import threading from typing import ( Any, Callable, @@ -32,7 +33,11 @@ import temporalio.common import temporalio.converter import temporalio.nexus -from temporalio.exceptions import ApplicationError, WorkflowAlreadyStartedError +from temporalio.exceptions import ( + ApplicationError, + WorkflowAlreadyStartedError, + CancelledError, +) from temporalio.nexus import Info, logger from temporalio.service import RPCError, RPCStatusCode @@ -41,6 +46,16 @@ _TEMPORAL_FAILURE_PROTO_TYPE = "temporal.api.failure.v1.Failure" +@dataclass +class _RunningNexusTask: + task: asyncio.Task[Any] + cancellation: _NexusTaskCancellation + + def cancel(self, reason: Optional[str] = None): + self.cancellation.cancel(reason) + self.task.cancel() + + class _NexusWorker: def __init__( self, @@ -65,7 +80,7 @@ def __init__( self._interceptors = interceptors # TODO(nexus-preview): metric_meter self._metric_meter = metric_meter - self._running_tasks: dict[bytes, asyncio.Task[Any]] = {} + self._running_tasks: dict[bytes, _RunningNexusTask] = {} self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue() async def run(self) -> None: @@ -90,21 +105,31 @@ async def raise_from_exception_queue() -> NoReturn: if nexus_task.HasField("task"): task = nexus_task.task if task.request.HasField("start_operation"): - self._running_tasks[task.task_token] = asyncio.create_task( + cancellation = _NexusTaskCancellation() + start_op_task = asyncio.create_task( self._handle_start_operation_task( task.task_token, task.request.start_operation, dict(task.request.header), + cancellation, ) ) + self._running_tasks[task.task_token] = _RunningNexusTask( + start_op_task, cancellation + ) elif task.request.HasField("cancel_operation"): - self._running_tasks[task.task_token] = asyncio.create_task( + cancellation = _NexusTaskCancellation() + cancel_op_task = asyncio.create_task( self._handle_cancel_operation_task( task.task_token, task.request.cancel_operation, dict(task.request.header), + cancellation, ) ) + self._running_tasks[task.task_token] = _RunningNexusTask( + cancel_op_task, cancellation + ) else: raise NotImplementedError( f"Invalid Nexus task request: {task.request}" @@ -114,7 +139,8 @@ async def raise_from_exception_queue() -> NoReturn: nexus_task.cancel_task.task_token ): # TODO(nexus-prerelease): when do we remove the entry from _running_operations? - running_task.cancel() + # TODO(amazzeo): put real reason here? + running_task.cancel("timeout") else: logger.debug( f"Received cancel_task but no running task exists for " @@ -147,7 +173,10 @@ async def drain_poll_queue(self) -> None: # Only call this after run()/drain_poll_queue() have returned. This will not # raise an exception. async def wait_all_completed(self) -> None: - await asyncio.gather(*self._running_tasks.values(), return_exceptions=True) + running_tasks = [ + running_task.task for running_task in self._running_tasks.values() + ] + await asyncio.gather(*running_tasks, return_exceptions=True) # TODO(nexus-preview): stack trace pruning. See sdk-typescript NexusHandler.execute # "Any call up to this function and including this one will be trimmed out of stack traces."" @@ -157,6 +186,7 @@ async def _handle_cancel_operation_task( task_token: bytes, request: temporalio.api.nexus.v1.CancelOperationRequest, headers: Mapping[str, str], + task_cancellation: nexusrpc.handler.OperationTaskCancellation, ) -> None: """Handle a cancel operation task. @@ -168,6 +198,7 @@ async def _handle_cancel_operation_task( service=request.service, operation=request.operation, headers=headers, + task_cancellation=task_cancellation, ) temporalio.nexus._operation_context._TemporalCancelOperationContext( info=lambda: Info(task_queue=self._task_queue), @@ -184,6 +215,7 @@ async def _handle_cancel_operation_task( error=await self._handler_error_to_proto( _exception_to_handler_error(err) ), + ack_cancel=task_cancellation.is_cancelled(), ) else: completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( @@ -209,6 +241,7 @@ async def _handle_start_operation_task( task_token: bytes, start_request: temporalio.api.nexus.v1.StartOperationRequest, headers: Mapping[str, str], + task_cancellation: nexusrpc.handler.OperationTaskCancellation, ) -> None: """Handle a start operation task. @@ -217,7 +250,14 @@ async def _handle_start_operation_task( """ try: try: - start_response = await self._start_operation(start_request, headers) + start_response = await self._start_operation( + start_request, headers, task_cancellation + ) + except asyncio.CancelledError: + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + ack_cancel=task_cancellation.is_cancelled(), + ) except BaseException as err: logger.warning("Failed to execute Nexus start operation method") completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( @@ -226,6 +266,7 @@ async def _handle_start_operation_task( _exception_to_handler_error(err) ), ) + if isinstance(err, concurrent.futures.BrokenExecutor): self._fail_worker_exception_queue.put_nowait(err) else: @@ -235,6 +276,7 @@ async def _handle_start_operation_task( start_operation=start_response ), ) + await self._bridge_worker().complete_nexus_task(completion) except Exception: logger.exception("Failed to send Nexus task completion") @@ -250,6 +292,7 @@ async def _start_operation( self, start_request: temporalio.api.nexus.v1.StartOperationRequest, headers: Mapping[str, str], + cancellation: nexusrpc.handler.OperationTaskCancellation, ) -> temporalio.api.nexus.v1.StartOperationResponse: """Invoke the Nexus handler's start_operation method and construct the StartOperationResponse. @@ -268,6 +311,7 @@ async def _start_operation( for link in start_request.links ], callback_headers=dict(start_request.callback_header), + task_cancellation=cancellation, ) temporalio.nexus._operation_context._TemporalStartOperationContext( nexus_context=ctx, @@ -511,9 +555,49 @@ def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError: f"Unhandled RPC error status: {err.status}", type=nexusrpc.HandlerErrorType.INTERNAL, ) + elif isinstance(err, asyncio.CancelledError): + # TODO(amazzeo): What type should we use? a new type? + handler_err = nexusrpc.HandlerError( + "Cancelled", + type=nexusrpc.HandlerErrorType.RESOURCE_EXHAUSTED, + ) else: handler_err = nexusrpc.HandlerError( str(err), type=nexusrpc.HandlerErrorType.INTERNAL ) handler_err.__cause__ = err return handler_err + + +class _NexusTaskCancellation(nexusrpc.handler.OperationTaskCancellation): + def __init__(self): + self._thread_evt = threading.Event() + self._aysnc_evt = asyncio.Event() + self._lock = threading.Lock() + self._reason: Optional[str] = None + self._thread_ident: Optional[int] = None + + def is_cancelled(self) -> bool: + return self._thread_evt.is_set() + + def cancellation_reason(self) -> Optional[str]: + with self._lock: + return self._reason + + def wait_until_cancelled_sync(self, timeout: float | None = None) -> bool: + return self._thread_evt.wait(timeout) + + async def wait_until_cancelled(self) -> None: + await self._aysnc_evt.wait() + + def cancel(self, reason: Optional[str] = None) -> bool: + if self._thread_evt.is_set(): + return False + + with self._lock: + if self._thread_evt.is_set(): + return False + self._reason = reason + self._thread_evt.set() + self._aysnc_evt.set() + return True diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index 0eef14b84..bc4eacd27 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -44,20 +44,6 @@ async def cancel( ) -> None: raise NotImplementedError - async def fetch_info( - self, - ctx: nexusrpc.handler.FetchOperationInfoContext, - token: str, - ) -> nexusrpc.OperationInfo: - raise NotImplementedError - - async def fetch_result( - self, - ctx: nexusrpc.handler.FetchOperationResultContext, - token: str, - ) -> int: - raise NotImplementedError - @nexusrpc.handler.service_handler class MyServiceHandlerWithWorkflowRunOperation: @@ -78,8 +64,8 @@ async def test_run_nexus_service_from_programmatically_created_service_handler( service_handler = nexusrpc.handler._core.ServiceHandler( service=nexusrpc.ServiceDefinition( name="MyService", - operations={ - "increment": nexusrpc.Operation[int, int]( + operation_definitions={ + "increment": nexusrpc.OperationDefinition[int, int]( name="increment", method_name="increment", input_type=int, diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 1f3420da3..5cbe8643c 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -32,12 +32,9 @@ HandlerErrorType, OperationError, OperationErrorState, - OperationInfo, ) from nexusrpc.handler import ( CancelOperationContext, - FetchOperationInfoContext, - FetchOperationResultContext, OperationHandler, StartOperationContext, StartOperationResultSync, @@ -261,16 +258,6 @@ async def start( # type: ignore[override] # intentional test error # or StartOperationResultAsync return Output(value="unwrapped result error") - async def fetch_info( - self, ctx: FetchOperationInfoContext, token: str - ) -> OperationInfo: - raise NotImplementedError - - async def fetch_result( - self, ctx: FetchOperationResultContext, token: str - ) -> Output: - raise NotImplementedError - async def cancel(self, ctx: CancelOperationContext, token: str) -> None: raise NotImplementedError @@ -885,14 +872,6 @@ async def start( def cancel(self, ctx: CancelOperationContext, token: str) -> None: return None # type: ignore - def fetch_info( - self, ctx: FetchOperationInfoContext, token: str - ) -> OperationInfo: - raise NotImplementedError - - def fetch_result(self, ctx: FetchOperationResultContext, token: str) -> Output: - raise NotImplementedError - @operation_handler def echo(self) -> OperationHandler[Input, Output]: return SyncCancelHandler.SyncCancel() @@ -1084,7 +1063,9 @@ async def start_two_workflows_with_conflicting_workflow_ids( assert status_code == 201 op_info = resp.json() assert op_info["token"] - assert op_info["state"] == nexusrpc.OperationState.RUNNING.value + assert ( + op_info["state"] == "running" + ) # nexusrpc.OperationState.RUNNING.value <--- this doesn't exist anymore else: assert status_code >= 400 failure = Failure(**resp.json()) diff --git a/tests/nexus/test_handler_async_operation.py b/tests/nexus/test_handler_async_operation.py index df245d0ff..bf8099491 100644 --- a/tests/nexus/test_handler_async_operation.py +++ b/tests/nexus/test_handler_async_operation.py @@ -12,14 +12,9 @@ from dataclasses import dataclass, field from typing import Any, Type, Union -import nexusrpc -import nexusrpc.handler import pytest -from nexusrpc import OperationInfo from nexusrpc.handler import ( CancelOperationContext, - FetchOperationInfoContext, - FetchOperationResultContext, OperationHandler, StartOperationContext, StartOperationResultAsync, @@ -57,23 +52,6 @@ async def task() -> Output: await self.executor.add_task(task_id, task()) return StartOperationResultAsync(token=task_id) - async def fetch_info( - self, ctx: FetchOperationInfoContext, token: str - ) -> OperationInfo: - # status = self.executor.get_task_status(task_id=token) - # return OperationInfo(token=token, status=status) - raise NotImplementedError( - "Not possible to test this currently since the server's Nexus implementation does not support fetch_info" - ) - - async def fetch_result( - self, ctx: FetchOperationResultContext, token: str - ) -> Output: - # return await self.executor.get_task_result(task_id=token) - raise NotImplementedError( - "Not possible to test this currently since the server's Nexus implementation does not support fetch_result" - ) - async def cancel(self, ctx: CancelOperationContext, token: str) -> None: self.executor.request_cancel_task(task_id=token) @@ -93,19 +71,6 @@ async def task() -> Output: self.executor.add_task_sync(task_id, task()) return StartOperationResultAsync(token=task_id) - def fetch_info(self, ctx: FetchOperationInfoContext, token: str) -> OperationInfo: - # status = self.executor.get_task_status(task_id=token) - # return OperationInfo(token=token, status=status) - raise NotImplementedError( - "Not possible to test this currently since the server's Nexus implementation does not support fetch_info" - ) - - def fetch_result(self, ctx: FetchOperationResultContext, token: str) -> Output: - # return self.executor.get_task_result_sync(task_id=token) - raise NotImplementedError( - "Not possible to test this currently since the server's Nexus implementation does not support fetch_result" - ) - def cancel(self, ctx: CancelOperationContext, token: str) -> None: self.executor.request_cancel_task(task_id=token) @@ -213,17 +178,6 @@ def add_task_sync(self, task_id: str, coro: Coroutine[Any, Any, Any]) -> None: self.add_task(task_id, coro), self.event_loop ).result() - def get_task_status(self, task_id: str) -> nexusrpc.OperationState: - task = self.tasks[task_id] - if not task.done(): - return nexusrpc.OperationState.RUNNING - elif task.cancelled(): - return nexusrpc.OperationState.CANCELED - elif task.exception(): - return nexusrpc.OperationState.FAILED - else: - return nexusrpc.OperationState.SUCCEEDED - async def get_task_result(self, task_id: str) -> Any: """ Get the result of a task from the task execution platform. diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 743f2b3e0..2e79b6625 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -11,8 +11,6 @@ import pytest from nexusrpc.handler import ( CancelOperationContext, - FetchOperationInfoContext, - FetchOperationResultContext, OperationHandler, StartOperationContext, StartOperationResultAsync, @@ -174,16 +172,6 @@ async def start( # type: ignore[override] async def cancel(self, ctx: CancelOperationContext, token: str) -> None: return await temporalio.nexus._operation_handlers._cancel_workflow(token) - async def fetch_info( - self, ctx: FetchOperationInfoContext, token: str - ) -> nexusrpc.OperationInfo: - raise NotImplementedError - - async def fetch_result( - self, ctx: FetchOperationResultContext, token: str - ) -> OpOutput: - raise NotImplementedError - @service_handler(service=ServiceInterface) class ServiceImpl: @@ -503,14 +491,21 @@ async def test_workflow_run_operation_happy_path( # TODO(nexus-preview): cross-namespace tests # TODO(nexus-preview): nexus endpoint pytest fixture? # TODO(nexus-prerelease): test headers -@pytest.mark.parametrize("exception_in_operation_start", [False, True]) -@pytest.mark.parametrize("request_cancel", [False, True]) -@pytest.mark.parametrize( - "op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND] -) +# @pytest.mark.parametrize("exception_in_operation_start", [False, True]) +@pytest.mark.parametrize("exception_in_operation_start", [True]) +# @pytest.mark.parametrize("request_cancel", [False, True]) +@pytest.mark.parametrize("request_cancel", [False]) +# @pytest.mark.parametrize( +# "op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND] +# ) +@pytest.mark.parametrize("op_definition_type", [OpDefinitionType.SHORTHAND]) +# @pytest.mark.parametrize( +# "caller_reference", +# [CallerReference.IMPL_WITH_INTERFACE, CallerReference.INTERFACE], +# ) @pytest.mark.parametrize( "caller_reference", - [CallerReference.IMPL_WITH_INTERFACE, CallerReference.INTERFACE], + [CallerReference.IMPL_WITH_INTERFACE], ) async def test_sync_response( client: Client, diff --git a/tests/nexus/test_workflow_caller_errors.py b/tests/nexus/test_workflow_caller_errors.py index 2bff390da..5956a3e60 100644 --- a/tests/nexus/test_workflow_caller_errors.py +++ b/tests/nexus/test_workflow_caller_errors.py @@ -12,8 +12,6 @@ import pytest from nexusrpc.handler import ( CancelOperationContext, - FetchOperationInfoContext, - FetchOperationResultContext, OperationHandler, StartOperationContext, StartOperationResultAsync, @@ -31,6 +29,7 @@ NexusOperationError, TimeoutError, ) +import temporalio.exceptions from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers import assert_eq_eventually @@ -236,7 +235,41 @@ class StartTimeoutTestService: async def op_handler_that_never_returns( self, ctx: StartOperationContext, input: None ) -> None: - await asyncio.Future() + try: + await asyncio.wait_for(ctx.task_cancellation.wait_until_cancelled(), 1) + except asyncio.TimeoutError: + print("timeout") + raise ApplicationError("expected cancel", non_retryable=True) + + @sync_operation + def op_handler_that_never_returns_but_sync( + self, ctx: StartOperationContext, input: None + ) -> None: + cancelled = ctx.task_cancellation.wait_until_cancelled_sync(1) + if not cancelled: + raise ApplicationError("expected cancel", non_retryable=True) + reason = ctx.task_cancellation.cancellation_reason() + if reason != "timeout": + raise ApplicationError("expected cancel details", non_retryable=True) + + +@workflow.defn +class StartTimeoutTestCallerWorkflowSync: + @workflow.init + def __init__(self): + self.nexus_client = workflow.create_nexus_client( + service=StartTimeoutTestService, + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + ) + + @workflow.run + async def run(self) -> None: + await self.nexus_client.execute_operation( + StartTimeoutTestService.op_handler_that_never_returns_but_sync, # type: ignore[arg-type] # mypy can't infer OutputT=None in Union type + None, + output_type=None, + schedule_to_close_timeout=timedelta(seconds=0.1), + ) @workflow.defn @@ -268,8 +301,9 @@ async def test_error_raised_by_timeout_of_nexus_start_operation( async with Worker( client, nexus_service_handlers=[StartTimeoutTestService()], - workflows=[StartTimeoutTestCallerWorkflow], + workflows=[StartTimeoutTestCallerWorkflow, StartTimeoutTestCallerWorkflowSync], task_queue=task_queue, + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), ): await create_nexus_endpoint(task_queue, client) try: @@ -285,6 +319,19 @@ async def test_error_raised_by_timeout_of_nexus_start_operation( else: pytest.fail("Expected exception due to timeout of nexus start operation") + try: + await client.execute_workflow( + StartTimeoutTestCallerWorkflowSync.run, + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + except Exception as err: + assert isinstance(err, WorkflowFailureError) + assert isinstance(err.__cause__, NexusOperationError) + assert isinstance(err.__cause__.__cause__, TimeoutError) + else: + pytest.fail("Expected exception due to timeout of nexus start operation") + # Cancellation timeout test @@ -296,15 +343,10 @@ async def start( return StartOperationResultAsync("fake-token") async def cancel(self, ctx: CancelOperationContext, token: str) -> None: - await asyncio.Future() - - async def fetch_info( - self, ctx: FetchOperationInfoContext, token: str - ) -> nexusrpc.OperationInfo: - raise NotImplementedError("Not implemented") - - async def fetch_result(self, ctx: FetchOperationResultContext, token: str) -> None: - raise NotImplementedError("Not implemented") + try: + await asyncio.wait_for(ctx.task_cancellation.wait_until_cancelled(), 10) + except asyncio.TimeoutError: + raise RuntimeError("expected cancellation") @service_handler @@ -347,7 +389,6 @@ async def test_error_raised_by_timeout_of_nexus_cancel_operation( if env.supports_time_skipping: pytest.skip("Nexus tests don't work with time-skipping server") - pytest.skip("TODO(nexus-prerelease): finish writing this test") task_queue = str(uuid.uuid4()) async with Worker( client, diff --git a/uv.lock b/uv.lock index 68f2f73bd..e11decd03 100644 --- a/uv.lock +++ b/uv.lock @@ -1761,14 +1761,10 @@ wheels = [ [[package]] name = "nexus-rpc" version = "1.1.0" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/nexus-rpc/sdk-python?rev=task-cancellation#847d9cb903d962e95d147de26b02fb1e412ee7cb" } dependencies = [ { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ef/66/540687556bd28cf1ec370cc6881456203dfddb9dab047b8979c6865b5984/nexus_rpc-1.1.0.tar.gz", hash = "sha256:d65ad6a2f54f14e53ebe39ee30555eaeb894102437125733fb13034a04a44553", size = 77383, upload-time = "2025-07-07T19:03:58.368Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/bf/2f/9e9d0dcaa4c6ffa22b7aa31069a8a264c753ff8027b36af602cce038c92f/nexus_rpc-1.1.0-py3-none-any.whl", hash = "sha256:d1b007af2aba186a27e736f8eaae39c03aed05b488084ff6c3d1785c9ba2ad38", size = 27743, upload-time = "2025-07-07T19:03:57.556Z" }, -] [[package]] name = "nh3" @@ -3007,7 +3003,7 @@ dev = [ requires-dist = [ { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, { name = "mcp", marker = "extra == 'openai-agents'", specifier = ">=1.9.4,<2" }, - { name = "nexus-rpc", specifier = "==1.1.0" }, + { name = "nexus-rpc", git = "https://github.com/nexus-rpc/sdk-python?rev=task-cancellation" }, { name = "openai-agents", marker = "extra == 'openai-agents'", specifier = ">=0.3,<0.5" }, { name = "opentelemetry-api", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, { name = "opentelemetry-sdk", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, From 134b473f825f920df26e2ba27380746cc6f4d350 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 3 Nov 2025 11:50:39 -0800 Subject: [PATCH 2/6] fix up failing nexus test. Use deterministic as_completed in research_manager. Finish implementing nexus task cancellation --- temporalio/worker/_nexus.py | 28 +++--- temporalio/workflow.py | 33 +++++++ .../research_agents/research_manager.py | 3 +- tests/nexus/test_handler.py | 4 +- tests/nexus/test_workflow_caller.py | 19 ++-- tests/nexus/test_workflow_caller_errors.py | 92 ++++++++----------- 6 files changed, 98 insertions(+), 81 deletions(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 20537bda5..6a5fffad1 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -36,7 +36,6 @@ from temporalio.exceptions import ( ApplicationError, WorkflowAlreadyStartedError, - CancelledError, ) from temporalio.nexus import Info, logger from temporalio.service import RPCError, RPCStatusCode @@ -105,30 +104,30 @@ async def raise_from_exception_queue() -> NoReturn: if nexus_task.HasField("task"): task = nexus_task.task if task.request.HasField("start_operation"): - cancellation = _NexusTaskCancellation() + task_cancellation = _NexusTaskCancellation() start_op_task = asyncio.create_task( self._handle_start_operation_task( task.task_token, task.request.start_operation, dict(task.request.header), - cancellation, + task_cancellation, ) ) self._running_tasks[task.task_token] = _RunningNexusTask( - start_op_task, cancellation + start_op_task, task_cancellation ) elif task.request.HasField("cancel_operation"): - cancellation = _NexusTaskCancellation() + task_cancellation = _NexusTaskCancellation() cancel_op_task = asyncio.create_task( self._handle_cancel_operation_task( task.task_token, task.request.cancel_operation, dict(task.request.header), - cancellation, + task_cancellation, ) ) self._running_tasks[task.task_token] = _RunningNexusTask( - cancel_op_task, cancellation + cancel_op_task, task_cancellation ) else: raise NotImplementedError( @@ -138,9 +137,12 @@ async def raise_from_exception_queue() -> NoReturn: if running_task := self._running_tasks.get( nexus_task.cancel_task.task_token ): - # TODO(nexus-prerelease): when do we remove the entry from _running_operations? - # TODO(amazzeo): put real reason here? - running_task.cancel("timeout") + reason = ( + temporalio.bridge.proto.nexus.NexusTaskCancelReason.Name( + nexus_task.cancel_task.reason + ) + ) + running_task.cancel(reason) else: logger.debug( f"Received cancel_task but no running task exists for " @@ -208,6 +210,11 @@ async def _handle_cancel_operation_task( try: try: await self._handler.cancel_operation(ctx, request.operation_token) + except asyncio.CancelledError: + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + ack_cancel=task_cancellation.is_cancelled(), + ) except BaseException as err: logger.warning("Failed to execute Nexus cancel operation method") completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( @@ -215,7 +222,6 @@ async def _handle_cancel_operation_task( error=await self._handler_error_to_proto( _exception_to_handler_error(err) ), - ack_cancel=task_cancellation.is_cancelled(), ) else: completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 1dc70c8b4..e84bb40ee 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -5417,6 +5417,22 @@ async def start_operation( headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... + # Overload for operation_handler + @overload + @abstractmethod + async def start_operation( + self, + operation: Callable[ + [ServiceHandlerT], nexusrpc.handler.OperationHandler[InputT, OutputT] + ], + input: InputT, + *, + output_type: Optional[Type[OutputT]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, + headers: Optional[Mapping[str, str]] = None, + ) -> NexusOperationHandle[OutputT]: ... + @abstractmethod async def start_operation( self, @@ -5527,6 +5543,23 @@ async def execute_operation( headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... + # Overload for operation_handler + @overload + @abstractmethod + async def execute_operation( + self, + operation: Callable[ + [ServiceT], + nexusrpc.handler.OperationHandler[InputT, OutputT], + ], + input: InputT, + *, + output_type: Optional[Type[OutputT]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, + headers: Optional[Mapping[str, str]] = None, + ) -> OutputT: ... + @abstractmethod async def execute_operation( self, diff --git a/tests/contrib/openai_agents/research_agents/research_manager.py b/tests/contrib/openai_agents/research_agents/research_manager.py index de721f9b9..febc56ec4 100644 --- a/tests/contrib/openai_agents/research_agents/research_manager.py +++ b/tests/contrib/openai_agents/research_agents/research_manager.py @@ -14,6 +14,7 @@ ReportData, new_writer_agent, ) +import temporalio.workflow class ResearchManager: @@ -45,7 +46,7 @@ async def _perform_searches(self, search_plan: WebSearchPlan) -> list[str]: asyncio.create_task(self._search(item)) for item in search_plan.searches ] results = [] - for task in asyncio.as_completed(tasks): + for task in temporalio.workflow.as_completed(tasks): result = await task if result is not None: results.append(result) diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 5cbe8643c..34ad33e37 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -80,8 +80,6 @@ class NonSerializableOutput: # TODO(nexus-preview): type check nexus implementation under mypy # TODO(nexus-preview): test malformed inbound_links and outbound_links -# TODO(nexus-prerelease): 2025-07-02T23:29:20.000489Z WARN temporal_sdk_core::worker::nexus: Nexus task not found on completion. This may happen if the operation has already been cancelled but completed anyway. details=Status { code: NotFound, message: "Nexus task not found or already expired", details: b"\x08\x05\x12'Nexus task not found or already expired\x1aB\n@type.googleapis.com/temporal.api.errordetails.v1.NotFoundFailure", metadata: MetadataMap { headers: {"content-type": "application/grpc"} }, source: None } - @nexusrpc.service class MyService: @@ -766,7 +764,7 @@ async def test_start_operation_without_type_annotations( def test_operation_without_type_annotations_without_service_definition_raises_validation_error(): with pytest.raises( ValueError, - match=r"has no input type.+has no output type", + match=r"has no input type", # TODO(amazzeo): the previous implementation gathered validation errors. Are we okay dropping that? ".+has no output type", ): service_handler(MyServiceHandlerWithOperationsWithoutTypeAnnotations) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 2e79b6625..06638ef89 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -491,21 +491,14 @@ async def test_workflow_run_operation_happy_path( # TODO(nexus-preview): cross-namespace tests # TODO(nexus-preview): nexus endpoint pytest fixture? # TODO(nexus-prerelease): test headers -# @pytest.mark.parametrize("exception_in_operation_start", [False, True]) -@pytest.mark.parametrize("exception_in_operation_start", [True]) -# @pytest.mark.parametrize("request_cancel", [False, True]) -@pytest.mark.parametrize("request_cancel", [False]) -# @pytest.mark.parametrize( -# "op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND] -# ) -@pytest.mark.parametrize("op_definition_type", [OpDefinitionType.SHORTHAND]) -# @pytest.mark.parametrize( -# "caller_reference", -# [CallerReference.IMPL_WITH_INTERFACE, CallerReference.INTERFACE], -# ) +@pytest.mark.parametrize("exception_in_operation_start", [False, True]) +@pytest.mark.parametrize("request_cancel", [False, True]) +@pytest.mark.parametrize( + "op_definition_type", [OpDefinitionType.SHORTHAND, OpDefinitionType.LONGHAND] +) @pytest.mark.parametrize( "caller_reference", - [CallerReference.IMPL_WITH_INTERFACE], + [CallerReference.IMPL_WITH_INTERFACE, CallerReference.INTERFACE], ) async def test_sync_response( client: Client, diff --git a/tests/nexus/test_workflow_caller_errors.py b/tests/nexus/test_workflow_caller_errors.py index 5956a3e60..190e9b026 100644 --- a/tests/nexus/test_workflow_caller_errors.py +++ b/tests/nexus/test_workflow_caller_errors.py @@ -6,6 +6,7 @@ from collections import Counter from dataclasses import dataclass from datetime import timedelta +from logging import getLogger import nexusrpc import nexusrpc.handler @@ -17,6 +18,7 @@ StartOperationResultAsync, service_handler, sync_operation, + operation_handler, ) from temporalio import nexus, workflow @@ -29,14 +31,15 @@ NexusOperationError, TimeoutError, ) -import temporalio.exceptions from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers import assert_eq_eventually +from tests.helpers import LogCapturer, assert_eq_eventually from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name operation_invocation_counts = Counter[str]() +logger = getLogger(__name__) + @dataclass class ErrorTestInput: @@ -232,17 +235,16 @@ async def test_nexus_operation_fails_without_retry_as_handler_error( @service_handler class StartTimeoutTestService: @sync_operation - async def op_handler_that_never_returns( + async def expect_timeout_cancellation_async( self, ctx: StartOperationContext, input: None ) -> None: try: await asyncio.wait_for(ctx.task_cancellation.wait_until_cancelled(), 1) except asyncio.TimeoutError: - print("timeout") raise ApplicationError("expected cancel", non_retryable=True) @sync_operation - def op_handler_that_never_returns_but_sync( + def expect_timeout_cancellation_sync( self, ctx: StartOperationContext, input: None ) -> None: cancelled = ctx.task_cancellation.wait_until_cancelled_sync(1) @@ -253,38 +255,19 @@ def op_handler_that_never_returns_but_sync( raise ApplicationError("expected cancel details", non_retryable=True) -@workflow.defn -class StartTimeoutTestCallerWorkflowSync: - @workflow.init - def __init__(self): - self.nexus_client = workflow.create_nexus_client( - service=StartTimeoutTestService, - endpoint=make_nexus_endpoint_name(workflow.info().task_queue), - ) - - @workflow.run - async def run(self) -> None: - await self.nexus_client.execute_operation( - StartTimeoutTestService.op_handler_that_never_returns_but_sync, # type: ignore[arg-type] # mypy can't infer OutputT=None in Union type - None, - output_type=None, - schedule_to_close_timeout=timedelta(seconds=0.1), - ) - - @workflow.defn class StartTimeoutTestCallerWorkflow: @workflow.init - def __init__(self): + def __init__(self, operation: str): self.nexus_client = workflow.create_nexus_client( service=StartTimeoutTestService, endpoint=make_nexus_endpoint_name(workflow.info().task_queue), ) @workflow.run - async def run(self) -> None: + async def run(self, operation: str) -> None: await self.nexus_client.execute_operation( - StartTimeoutTestService.op_handler_that_never_returns, # type: ignore[arg-type] # mypy can't infer OutputT=None in Union type + operation, None, output_type=None, schedule_to_close_timeout=timedelta(seconds=0.1), @@ -301,7 +284,7 @@ async def test_error_raised_by_timeout_of_nexus_start_operation( async with Worker( client, nexus_service_handlers=[StartTimeoutTestService()], - workflows=[StartTimeoutTestCallerWorkflow, StartTimeoutTestCallerWorkflowSync], + workflows=[StartTimeoutTestCallerWorkflow], task_queue=task_queue, nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), ): @@ -309,6 +292,7 @@ async def test_error_raised_by_timeout_of_nexus_start_operation( try: await client.execute_workflow( StartTimeoutTestCallerWorkflow.run, + "expect_timeout_cancellation_async", id=str(uuid.uuid4()), task_queue=task_queue, ) @@ -321,7 +305,8 @@ async def test_error_raised_by_timeout_of_nexus_start_operation( try: await client.execute_workflow( - StartTimeoutTestCallerWorkflowSync.run, + StartTimeoutTestCallerWorkflow.run, + "expect_timeout_cancellation_sync", id=str(uuid.uuid4()), task_queue=task_queue, ) @@ -336,7 +321,7 @@ async def test_error_raised_by_timeout_of_nexus_start_operation( # Cancellation timeout test -class OperationWithCancelMethodThatNeverReturns(OperationHandler[None, None]): +class OperationWithCancelMethodThatExpectsCancel(OperationHandler[None, None]): async def start( self, ctx: StartOperationContext, input: None ) -> StartOperationResultAsync: @@ -346,16 +331,17 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None: try: await asyncio.wait_for(ctx.task_cancellation.wait_until_cancelled(), 10) except asyncio.TimeoutError: - raise RuntimeError("expected cancellation") + logger.error("expected cancellation") + raise ApplicationError("expected cancellation", non_retryable=True) @service_handler class CancellationTimeoutTestService: - @nexusrpc.handler._decorators.operation_handler - def op_with_cancel_method_that_never_returns( + @operation_handler + def op_with_cancel_method_that_expects_cancel( self, ) -> OperationHandler[None, None]: - return OperationWithCancelMethodThatNeverReturns() + return OperationWithCancelMethodThatExpectsCancel() @workflow.defn @@ -369,13 +355,8 @@ def __init__(self): @workflow.run async def run(self) -> None: - # TODO(nexus-prerelease) op_handle = await self.nexus_client.start_operation( - # Although the tests are making use of it, we are not exposing operation - # factory methods to users as a way to write nexus operations, and so the - # types on NexusClient start_operation/execute_operation do not currently - # permit it. - CancellationTimeoutTestService.op_with_cancel_method_that_never_returns, # type: ignore + CancellationTimeoutTestService.op_with_cancel_method_that_expects_cancel, None, schedule_to_close_timeout=timedelta(seconds=0.1), ) @@ -396,16 +377,21 @@ async def test_error_raised_by_timeout_of_nexus_cancel_operation( workflows=[CancellationTimeoutTestCallerWorkflow], task_queue=task_queue, ): - await create_nexus_endpoint(task_queue, client) - try: - await client.execute_workflow( - CancellationTimeoutTestCallerWorkflow.run, - id=str(uuid.uuid4()), - task_queue=task_queue, - ) - except Exception as err: - assert isinstance(err, WorkflowFailureError) - assert isinstance(err.__cause__, NexusOperationError) - assert isinstance(err.__cause__.__cause__, TimeoutError) - else: - pytest.fail("Expected exception due to timeout of nexus cancel operation") + with LogCapturer().logs_captured(logger) as capturer: + await create_nexus_endpoint(task_queue, client) + try: + await client.execute_workflow( + CancellationTimeoutTestCallerWorkflow.run, + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + except Exception as err: + assert isinstance(err, WorkflowFailureError) + assert isinstance(err.__cause__, NexusOperationError) + assert isinstance(err.__cause__.__cause__, TimeoutError) + else: + pytest.fail( + "Expected exception due to timeout of nexus cancel operation" + ) + + assert capturer.find_log("expected cancellation") is None From 883ba6005efd9946115483c5e976f3e7e410eea4 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 3 Nov 2025 12:10:00 -0800 Subject: [PATCH 3/6] Apply formatting. Fix linter typing error. Remove some WIP elements that weren't relevant --- temporalio/worker/_nexus.py | 8 +------- .../research_agents/research_manager.py | 2 +- tests/nexus/test_handler.py | 6 ++---- tests/nexus/test_workflow_caller.py | 14 ++++++++------ tests/nexus/test_workflow_caller_errors.py | 2 +- 5 files changed, 13 insertions(+), 19 deletions(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 6a5fffad1..38c1f926e 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -5,8 +5,8 @@ import asyncio import concurrent.futures import json -from dataclasses import dataclass import threading +from dataclasses import dataclass from typing import ( Any, Callable, @@ -561,12 +561,6 @@ def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError: f"Unhandled RPC error status: {err.status}", type=nexusrpc.HandlerErrorType.INTERNAL, ) - elif isinstance(err, asyncio.CancelledError): - # TODO(amazzeo): What type should we use? a new type? - handler_err = nexusrpc.HandlerError( - "Cancelled", - type=nexusrpc.HandlerErrorType.RESOURCE_EXHAUSTED, - ) else: handler_err = nexusrpc.HandlerError( str(err), type=nexusrpc.HandlerErrorType.INTERNAL diff --git a/tests/contrib/openai_agents/research_agents/research_manager.py b/tests/contrib/openai_agents/research_agents/research_manager.py index febc56ec4..f37eb6293 100644 --- a/tests/contrib/openai_agents/research_agents/research_manager.py +++ b/tests/contrib/openai_agents/research_agents/research_manager.py @@ -4,6 +4,7 @@ from agents import Runner, custom_span, gen_trace_id, trace +import temporalio.workflow from tests.contrib.openai_agents.research_agents.planner_agent import ( WebSearchItem, WebSearchPlan, @@ -14,7 +15,6 @@ ReportData, new_writer_agent, ) -import temporalio.workflow class ResearchManager: diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 34ad33e37..7ce0149f9 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -764,7 +764,7 @@ async def test_start_operation_without_type_annotations( def test_operation_without_type_annotations_without_service_definition_raises_validation_error(): with pytest.raises( ValueError, - match=r"has no input type", # TODO(amazzeo): the previous implementation gathered validation errors. Are we okay dropping that? ".+has no output type", + match=r"has no input type", ): service_handler(MyServiceHandlerWithOperationsWithoutTypeAnnotations) @@ -1061,9 +1061,7 @@ async def start_two_workflows_with_conflicting_workflow_ids( assert status_code == 201 op_info = resp.json() assert op_info["token"] - assert ( - op_info["state"] == "running" - ) # nexusrpc.OperationState.RUNNING.value <--- this doesn't exist anymore + assert op_info["state"] == "running" else: assert status_code >= 400 failure = Failure(**resp.json()) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 06638ef89..d53c178fa 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -241,12 +241,14 @@ def __init__( request_cancel: bool, task_queue: str, ) -> None: - self.nexus_client = workflow.create_nexus_client( - service={ - CallerReference.IMPL_WITH_INTERFACE: ServiceImpl, - CallerReference.INTERFACE: ServiceInterface, - }[input.op_input.caller_reference], - endpoint=make_nexus_endpoint_name(task_queue), + self.nexus_client: workflow.NexusClient[ServiceInterface] = ( + workflow.create_nexus_client( + service={ + CallerReference.IMPL_WITH_INTERFACE: ServiceImpl, + CallerReference.INTERFACE: ServiceInterface, + }[input.op_input.caller_reference], + endpoint=make_nexus_endpoint_name(task_queue), + ) ) self._nexus_operation_start_resolved = False self._proceed = False diff --git a/tests/nexus/test_workflow_caller_errors.py b/tests/nexus/test_workflow_caller_errors.py index 190e9b026..f97eec1a1 100644 --- a/tests/nexus/test_workflow_caller_errors.py +++ b/tests/nexus/test_workflow_caller_errors.py @@ -16,9 +16,9 @@ OperationHandler, StartOperationContext, StartOperationResultAsync, + operation_handler, service_handler, sync_operation, - operation_handler, ) from temporalio import nexus, workflow From fa971e079f46fe36605dbaaaa7fd4fccfa216cba Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 3 Nov 2025 13:26:32 -0800 Subject: [PATCH 4/6] Fix test to properly reference new cancellation details string --- tests/nexus/test_workflow_caller_errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/nexus/test_workflow_caller_errors.py b/tests/nexus/test_workflow_caller_errors.py index f97eec1a1..02cff8d52 100644 --- a/tests/nexus/test_workflow_caller_errors.py +++ b/tests/nexus/test_workflow_caller_errors.py @@ -251,7 +251,7 @@ def expect_timeout_cancellation_sync( if not cancelled: raise ApplicationError("expected cancel", non_retryable=True) reason = ctx.task_cancellation.cancellation_reason() - if reason != "timeout": + if reason != "TIMED_OUT": raise ApplicationError("expected cancel details", non_retryable=True) From 5a5d031f1aa61827358830495c957f2781d225f4 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 3 Nov 2025 13:41:13 -0800 Subject: [PATCH 5/6] use threading.Event and logs to make sure test covers cancellation reason accuracy --- tests/nexus/test_workflow_caller_errors.py | 43 ++++++++++++++-------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/tests/nexus/test_workflow_caller_errors.py b/tests/nexus/test_workflow_caller_errors.py index 02cff8d52..0fc32ba5b 100644 --- a/tests/nexus/test_workflow_caller_errors.py +++ b/tests/nexus/test_workflow_caller_errors.py @@ -2,6 +2,7 @@ import asyncio import concurrent.futures +import threading import uuid from collections import Counter from dataclasses import dataclass @@ -231,6 +232,9 @@ async def test_nexus_operation_fails_without_retry_as_handler_error( pytest.fail("Unreachable") +_start_operation_sync_complete = threading.Event() + + # Start timeout test @service_handler class StartTimeoutTestService: @@ -247,12 +251,14 @@ async def expect_timeout_cancellation_async( def expect_timeout_cancellation_sync( self, ctx: StartOperationContext, input: None ) -> None: + global _start_operation_sync_complete cancelled = ctx.task_cancellation.wait_until_cancelled_sync(1) if not cancelled: raise ApplicationError("expected cancel", non_retryable=True) reason = ctx.task_cancellation.cancellation_reason() if reason != "TIMED_OUT": - raise ApplicationError("expected cancel details", non_retryable=True) + logger.error("unexpected cancellation reason: %s", reason) + _start_operation_sync_complete.set() @workflow.defn @@ -279,6 +285,7 @@ async def test_error_raised_by_timeout_of_nexus_start_operation( ): if env.supports_time_skipping: pytest.skip("Nexus tests don't work with time-skipping server") + global _start_operation_sync_complete task_queue = str(uuid.uuid4()) async with Worker( @@ -303,19 +310,25 @@ async def test_error_raised_by_timeout_of_nexus_start_operation( else: pytest.fail("Expected exception due to timeout of nexus start operation") - try: - await client.execute_workflow( - StartTimeoutTestCallerWorkflow.run, - "expect_timeout_cancellation_sync", - id=str(uuid.uuid4()), - task_queue=task_queue, - ) - except Exception as err: - assert isinstance(err, WorkflowFailureError) - assert isinstance(err.__cause__, NexusOperationError) - assert isinstance(err.__cause__.__cause__, TimeoutError) - else: - pytest.fail("Expected exception due to timeout of nexus start operation") + with LogCapturer().logs_captured(logger) as capturer: + try: + await client.execute_workflow( + StartTimeoutTestCallerWorkflow.run, + "expect_timeout_cancellation_sync", + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + except Exception as err: + assert isinstance(err, WorkflowFailureError) + assert isinstance(err.__cause__, NexusOperationError) + assert isinstance(err.__cause__.__cause__, TimeoutError) + else: + pytest.fail( + "Expected exception due to timeout of nexus start operation" + ) + + _start_operation_sync_complete.wait() + assert capturer.find_log("unexpected cancellation reason") is None # Cancellation timeout test @@ -329,7 +342,7 @@ async def start( async def cancel(self, ctx: CancelOperationContext, token: str) -> None: try: - await asyncio.wait_for(ctx.task_cancellation.wait_until_cancelled(), 10) + await asyncio.wait_for(ctx.task_cancellation.wait_until_cancelled(), 1) except asyncio.TimeoutError: logger.error("expected cancellation") raise ApplicationError("expected cancellation", non_retryable=True) From 4956b3e564cf43349b969c15f768b35ea89ec5a4 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 3 Nov 2025 14:07:58 -0800 Subject: [PATCH 6/6] remove unsued field in _NexusTaskCancellation --- temporalio/worker/_nexus.py | 1 - 1 file changed, 1 deletion(-) diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 38c1f926e..ecedcbd97 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -575,7 +575,6 @@ def __init__(self): self._aysnc_evt = asyncio.Event() self._lock = threading.Lock() self._reason: Optional[str] = None - self._thread_ident: Optional[int] = None def is_cancelled(self) -> bool: return self._thread_evt.is_set()