diff --git a/pyproject.toml b/pyproject.toml index bfd8bee67..2599a0a65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ license = "MIT" license-files = ["LICENSE"] keywords = ["temporal", "workflow"] dependencies = [ - "nexus-rpc==1.1.0", + "nexus-rpc @ git+https://github.com/nexus-rpc/sdk-python@task-cancellation", "protobuf>=3.20,<7.0.0", "python-dateutil>=2.8.2,<3 ; python_version < '3.11'", "types-protobuf>=3.20", @@ -28,10 +28,7 @@ classifiers = [ grpc = ["grpcio>=1.48.2,<2"] opentelemetry = ["opentelemetry-api>=1.11.1,<2", "opentelemetry-sdk>=1.11.1,<2"] pydantic = ["pydantic>=2.0.0,<3"] -openai-agents = [ - "openai-agents>=0.3,<0.5", - "mcp>=1.9.4, <2", -] +openai-agents = ["openai-agents>=0.3,<0.5", "mcp>=1.9.4, <2"] [project.urls] Homepage = "https://github.com/temporalio/sdk-python" diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index 28c625816..3c97482ff 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -117,7 +117,7 @@ async def _start( return WorkflowRunOperationHandler(_start) method_name = get_callable_name(start) - nexusrpc.set_operation_definition( + nexusrpc.set_operation( operation_handler_factory, nexusrpc.Operation( name=name or method_name, diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index a73c3eb50..aa5351353 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -10,13 +10,10 @@ HandlerError, HandlerErrorType, InputT, - OperationInfo, OutputT, ) from nexusrpc.handler import ( CancelOperationContext, - FetchOperationInfoContext, - FetchOperationResultContext, OperationHandler, StartOperationContext, StartOperationResultAsync, @@ -81,22 +78,6 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None: """Cancel the operation, by cancelling the workflow.""" await _cancel_workflow(token) - async def fetch_info( - self, ctx: FetchOperationInfoContext, token: str - ) -> OperationInfo: - """Fetch operation info (not supported for Temporal Nexus operations).""" - raise NotImplementedError( - "Temporal Nexus operation handlers do not support fetching operation info." - ) - - async def fetch_result( - self, ctx: FetchOperationResultContext, token: str - ) -> OutputT: - """Fetch operation result (not supported for Temporal Nexus operations).""" - raise NotImplementedError( - "Temporal Nexus operation handlers do not support fetching the operation result." - ) - async def _cancel_workflow( token: str, diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index ef005d0c4..de5bf94b2 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -129,12 +129,12 @@ def get_operation_factory( ``obj`` should be a decorated operation start method. """ - op_defn = nexusrpc.get_operation_definition(obj) + op_defn = nexusrpc.get_operation(obj) if op_defn: factory = obj else: if factory := getattr(obj, "__nexus_operation_factory__", None): - op_defn = nexusrpc.get_operation_definition(factory) + op_defn = nexusrpc.get_operation(factory) if not isinstance(op_defn, nexusrpc.Operation): return None, None return factory, op_defn diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index 72b47187f..ecedcbd97 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -5,6 +5,7 @@ import asyncio import concurrent.futures import json +import threading from dataclasses import dataclass from typing import ( Any, @@ -32,7 +33,10 @@ import temporalio.common import temporalio.converter import temporalio.nexus -from temporalio.exceptions import ApplicationError, WorkflowAlreadyStartedError +from temporalio.exceptions import ( + ApplicationError, + WorkflowAlreadyStartedError, +) from temporalio.nexus import Info, logger from temporalio.service import RPCError, RPCStatusCode @@ -41,6 +45,16 @@ _TEMPORAL_FAILURE_PROTO_TYPE = "temporal.api.failure.v1.Failure" +@dataclass +class _RunningNexusTask: + task: asyncio.Task[Any] + cancellation: _NexusTaskCancellation + + def cancel(self, reason: Optional[str] = None): + self.cancellation.cancel(reason) + self.task.cancel() + + class _NexusWorker: def __init__( self, @@ -65,7 +79,7 @@ def __init__( self._interceptors = interceptors # TODO(nexus-preview): metric_meter self._metric_meter = metric_meter - self._running_tasks: dict[bytes, asyncio.Task[Any]] = {} + self._running_tasks: dict[bytes, _RunningNexusTask] = {} self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue() async def run(self) -> None: @@ -90,21 +104,31 @@ async def raise_from_exception_queue() -> NoReturn: if nexus_task.HasField("task"): task = nexus_task.task if task.request.HasField("start_operation"): - self._running_tasks[task.task_token] = asyncio.create_task( + task_cancellation = _NexusTaskCancellation() + start_op_task = asyncio.create_task( self._handle_start_operation_task( task.task_token, task.request.start_operation, dict(task.request.header), + task_cancellation, ) ) + self._running_tasks[task.task_token] = _RunningNexusTask( + start_op_task, task_cancellation + ) elif task.request.HasField("cancel_operation"): - self._running_tasks[task.task_token] = asyncio.create_task( + task_cancellation = _NexusTaskCancellation() + cancel_op_task = asyncio.create_task( self._handle_cancel_operation_task( task.task_token, task.request.cancel_operation, dict(task.request.header), + task_cancellation, ) ) + self._running_tasks[task.task_token] = _RunningNexusTask( + cancel_op_task, task_cancellation + ) else: raise NotImplementedError( f"Invalid Nexus task request: {task.request}" @@ -113,8 +137,12 @@ async def raise_from_exception_queue() -> NoReturn: if running_task := self._running_tasks.get( nexus_task.cancel_task.task_token ): - # TODO(nexus-prerelease): when do we remove the entry from _running_operations? - running_task.cancel() + reason = ( + temporalio.bridge.proto.nexus.NexusTaskCancelReason.Name( + nexus_task.cancel_task.reason + ) + ) + running_task.cancel(reason) else: logger.debug( f"Received cancel_task but no running task exists for " @@ -147,7 +175,10 @@ async def drain_poll_queue(self) -> None: # Only call this after run()/drain_poll_queue() have returned. This will not # raise an exception. async def wait_all_completed(self) -> None: - await asyncio.gather(*self._running_tasks.values(), return_exceptions=True) + running_tasks = [ + running_task.task for running_task in self._running_tasks.values() + ] + await asyncio.gather(*running_tasks, return_exceptions=True) # TODO(nexus-preview): stack trace pruning. See sdk-typescript NexusHandler.execute # "Any call up to this function and including this one will be trimmed out of stack traces."" @@ -157,6 +188,7 @@ async def _handle_cancel_operation_task( task_token: bytes, request: temporalio.api.nexus.v1.CancelOperationRequest, headers: Mapping[str, str], + task_cancellation: nexusrpc.handler.OperationTaskCancellation, ) -> None: """Handle a cancel operation task. @@ -168,6 +200,7 @@ async def _handle_cancel_operation_task( service=request.service, operation=request.operation, headers=headers, + task_cancellation=task_cancellation, ) temporalio.nexus._operation_context._TemporalCancelOperationContext( info=lambda: Info(task_queue=self._task_queue), @@ -177,6 +210,11 @@ async def _handle_cancel_operation_task( try: try: await self._handler.cancel_operation(ctx, request.operation_token) + except asyncio.CancelledError: + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + ack_cancel=task_cancellation.is_cancelled(), + ) except BaseException as err: logger.warning("Failed to execute Nexus cancel operation method") completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( @@ -209,6 +247,7 @@ async def _handle_start_operation_task( task_token: bytes, start_request: temporalio.api.nexus.v1.StartOperationRequest, headers: Mapping[str, str], + task_cancellation: nexusrpc.handler.OperationTaskCancellation, ) -> None: """Handle a start operation task. @@ -217,7 +256,14 @@ async def _handle_start_operation_task( """ try: try: - start_response = await self._start_operation(start_request, headers) + start_response = await self._start_operation( + start_request, headers, task_cancellation + ) + except asyncio.CancelledError: + completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( + task_token=task_token, + ack_cancel=task_cancellation.is_cancelled(), + ) except BaseException as err: logger.warning("Failed to execute Nexus start operation method") completion = temporalio.bridge.proto.nexus.NexusTaskCompletion( @@ -226,6 +272,7 @@ async def _handle_start_operation_task( _exception_to_handler_error(err) ), ) + if isinstance(err, concurrent.futures.BrokenExecutor): self._fail_worker_exception_queue.put_nowait(err) else: @@ -235,6 +282,7 @@ async def _handle_start_operation_task( start_operation=start_response ), ) + await self._bridge_worker().complete_nexus_task(completion) except Exception: logger.exception("Failed to send Nexus task completion") @@ -250,6 +298,7 @@ async def _start_operation( self, start_request: temporalio.api.nexus.v1.StartOperationRequest, headers: Mapping[str, str], + cancellation: nexusrpc.handler.OperationTaskCancellation, ) -> temporalio.api.nexus.v1.StartOperationResponse: """Invoke the Nexus handler's start_operation method and construct the StartOperationResponse. @@ -268,6 +317,7 @@ async def _start_operation( for link in start_request.links ], callback_headers=dict(start_request.callback_header), + task_cancellation=cancellation, ) temporalio.nexus._operation_context._TemporalStartOperationContext( nexus_context=ctx, @@ -517,3 +567,36 @@ def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError: ) handler_err.__cause__ = err return handler_err + + +class _NexusTaskCancellation(nexusrpc.handler.OperationTaskCancellation): + def __init__(self): + self._thread_evt = threading.Event() + self._aysnc_evt = asyncio.Event() + self._lock = threading.Lock() + self._reason: Optional[str] = None + + def is_cancelled(self) -> bool: + return self._thread_evt.is_set() + + def cancellation_reason(self) -> Optional[str]: + with self._lock: + return self._reason + + def wait_until_cancelled_sync(self, timeout: float | None = None) -> bool: + return self._thread_evt.wait(timeout) + + async def wait_until_cancelled(self) -> None: + await self._aysnc_evt.wait() + + def cancel(self, reason: Optional[str] = None) -> bool: + if self._thread_evt.is_set(): + return False + + with self._lock: + if self._thread_evt.is_set(): + return False + self._reason = reason + self._thread_evt.set() + self._aysnc_evt.set() + return True diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 1dc70c8b4..e84bb40ee 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -5417,6 +5417,22 @@ async def start_operation( headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... + # Overload for operation_handler + @overload + @abstractmethod + async def start_operation( + self, + operation: Callable[ + [ServiceHandlerT], nexusrpc.handler.OperationHandler[InputT, OutputT] + ], + input: InputT, + *, + output_type: Optional[Type[OutputT]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, + headers: Optional[Mapping[str, str]] = None, + ) -> NexusOperationHandle[OutputT]: ... + @abstractmethod async def start_operation( self, @@ -5527,6 +5543,23 @@ async def execute_operation( headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... + # Overload for operation_handler + @overload + @abstractmethod + async def execute_operation( + self, + operation: Callable[ + [ServiceT], + nexusrpc.handler.OperationHandler[InputT, OutputT], + ], + input: InputT, + *, + output_type: Optional[Type[OutputT]] = None, + schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, + headers: Optional[Mapping[str, str]] = None, + ) -> OutputT: ... + @abstractmethod async def execute_operation( self, diff --git a/tests/contrib/openai_agents/research_agents/research_manager.py b/tests/contrib/openai_agents/research_agents/research_manager.py index de721f9b9..f37eb6293 100644 --- a/tests/contrib/openai_agents/research_agents/research_manager.py +++ b/tests/contrib/openai_agents/research_agents/research_manager.py @@ -4,6 +4,7 @@ from agents import Runner, custom_span, gen_trace_id, trace +import temporalio.workflow from tests.contrib.openai_agents.research_agents.planner_agent import ( WebSearchItem, WebSearchPlan, @@ -45,7 +46,7 @@ async def _perform_searches(self, search_plan: WebSearchPlan) -> list[str]: asyncio.create_task(self._search(item)) for item in search_plan.searches ] results = [] - for task in asyncio.as_completed(tasks): + for task in temporalio.workflow.as_completed(tasks): result = await task if result is not None: results.append(result) diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index 0eef14b84..bc4eacd27 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -44,20 +44,6 @@ async def cancel( ) -> None: raise NotImplementedError - async def fetch_info( - self, - ctx: nexusrpc.handler.FetchOperationInfoContext, - token: str, - ) -> nexusrpc.OperationInfo: - raise NotImplementedError - - async def fetch_result( - self, - ctx: nexusrpc.handler.FetchOperationResultContext, - token: str, - ) -> int: - raise NotImplementedError - @nexusrpc.handler.service_handler class MyServiceHandlerWithWorkflowRunOperation: @@ -78,8 +64,8 @@ async def test_run_nexus_service_from_programmatically_created_service_handler( service_handler = nexusrpc.handler._core.ServiceHandler( service=nexusrpc.ServiceDefinition( name="MyService", - operations={ - "increment": nexusrpc.Operation[int, int]( + operation_definitions={ + "increment": nexusrpc.OperationDefinition[int, int]( name="increment", method_name="increment", input_type=int, diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 1f3420da3..7ce0149f9 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -32,12 +32,9 @@ HandlerErrorType, OperationError, OperationErrorState, - OperationInfo, ) from nexusrpc.handler import ( CancelOperationContext, - FetchOperationInfoContext, - FetchOperationResultContext, OperationHandler, StartOperationContext, StartOperationResultSync, @@ -83,8 +80,6 @@ class NonSerializableOutput: # TODO(nexus-preview): type check nexus implementation under mypy # TODO(nexus-preview): test malformed inbound_links and outbound_links -# TODO(nexus-prerelease): 2025-07-02T23:29:20.000489Z WARN temporal_sdk_core::worker::nexus: Nexus task not found on completion. This may happen if the operation has already been cancelled but completed anyway. details=Status { code: NotFound, message: "Nexus task not found or already expired", details: b"\x08\x05\x12'Nexus task not found or already expired\x1aB\n@type.googleapis.com/temporal.api.errordetails.v1.NotFoundFailure", metadata: MetadataMap { headers: {"content-type": "application/grpc"} }, source: None } - @nexusrpc.service class MyService: @@ -261,16 +256,6 @@ async def start( # type: ignore[override] # intentional test error # or StartOperationResultAsync return Output(value="unwrapped result error") - async def fetch_info( - self, ctx: FetchOperationInfoContext, token: str - ) -> OperationInfo: - raise NotImplementedError - - async def fetch_result( - self, ctx: FetchOperationResultContext, token: str - ) -> Output: - raise NotImplementedError - async def cancel(self, ctx: CancelOperationContext, token: str) -> None: raise NotImplementedError @@ -779,7 +764,7 @@ async def test_start_operation_without_type_annotations( def test_operation_without_type_annotations_without_service_definition_raises_validation_error(): with pytest.raises( ValueError, - match=r"has no input type.+has no output type", + match=r"has no input type", ): service_handler(MyServiceHandlerWithOperationsWithoutTypeAnnotations) @@ -885,14 +870,6 @@ async def start( def cancel(self, ctx: CancelOperationContext, token: str) -> None: return None # type: ignore - def fetch_info( - self, ctx: FetchOperationInfoContext, token: str - ) -> OperationInfo: - raise NotImplementedError - - def fetch_result(self, ctx: FetchOperationResultContext, token: str) -> Output: - raise NotImplementedError - @operation_handler def echo(self) -> OperationHandler[Input, Output]: return SyncCancelHandler.SyncCancel() @@ -1084,7 +1061,7 @@ async def start_two_workflows_with_conflicting_workflow_ids( assert status_code == 201 op_info = resp.json() assert op_info["token"] - assert op_info["state"] == nexusrpc.OperationState.RUNNING.value + assert op_info["state"] == "running" else: assert status_code >= 400 failure = Failure(**resp.json()) diff --git a/tests/nexus/test_handler_async_operation.py b/tests/nexus/test_handler_async_operation.py index df245d0ff..bf8099491 100644 --- a/tests/nexus/test_handler_async_operation.py +++ b/tests/nexus/test_handler_async_operation.py @@ -12,14 +12,9 @@ from dataclasses import dataclass, field from typing import Any, Type, Union -import nexusrpc -import nexusrpc.handler import pytest -from nexusrpc import OperationInfo from nexusrpc.handler import ( CancelOperationContext, - FetchOperationInfoContext, - FetchOperationResultContext, OperationHandler, StartOperationContext, StartOperationResultAsync, @@ -57,23 +52,6 @@ async def task() -> Output: await self.executor.add_task(task_id, task()) return StartOperationResultAsync(token=task_id) - async def fetch_info( - self, ctx: FetchOperationInfoContext, token: str - ) -> OperationInfo: - # status = self.executor.get_task_status(task_id=token) - # return OperationInfo(token=token, status=status) - raise NotImplementedError( - "Not possible to test this currently since the server's Nexus implementation does not support fetch_info" - ) - - async def fetch_result( - self, ctx: FetchOperationResultContext, token: str - ) -> Output: - # return await self.executor.get_task_result(task_id=token) - raise NotImplementedError( - "Not possible to test this currently since the server's Nexus implementation does not support fetch_result" - ) - async def cancel(self, ctx: CancelOperationContext, token: str) -> None: self.executor.request_cancel_task(task_id=token) @@ -93,19 +71,6 @@ async def task() -> Output: self.executor.add_task_sync(task_id, task()) return StartOperationResultAsync(token=task_id) - def fetch_info(self, ctx: FetchOperationInfoContext, token: str) -> OperationInfo: - # status = self.executor.get_task_status(task_id=token) - # return OperationInfo(token=token, status=status) - raise NotImplementedError( - "Not possible to test this currently since the server's Nexus implementation does not support fetch_info" - ) - - def fetch_result(self, ctx: FetchOperationResultContext, token: str) -> Output: - # return self.executor.get_task_result_sync(task_id=token) - raise NotImplementedError( - "Not possible to test this currently since the server's Nexus implementation does not support fetch_result" - ) - def cancel(self, ctx: CancelOperationContext, token: str) -> None: self.executor.request_cancel_task(task_id=token) @@ -213,17 +178,6 @@ def add_task_sync(self, task_id: str, coro: Coroutine[Any, Any, Any]) -> None: self.add_task(task_id, coro), self.event_loop ).result() - def get_task_status(self, task_id: str) -> nexusrpc.OperationState: - task = self.tasks[task_id] - if not task.done(): - return nexusrpc.OperationState.RUNNING - elif task.cancelled(): - return nexusrpc.OperationState.CANCELED - elif task.exception(): - return nexusrpc.OperationState.FAILED - else: - return nexusrpc.OperationState.SUCCEEDED - async def get_task_result(self, task_id: str) -> Any: """ Get the result of a task from the task execution platform. diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 743f2b3e0..d53c178fa 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -11,8 +11,6 @@ import pytest from nexusrpc.handler import ( CancelOperationContext, - FetchOperationInfoContext, - FetchOperationResultContext, OperationHandler, StartOperationContext, StartOperationResultAsync, @@ -174,16 +172,6 @@ async def start( # type: ignore[override] async def cancel(self, ctx: CancelOperationContext, token: str) -> None: return await temporalio.nexus._operation_handlers._cancel_workflow(token) - async def fetch_info( - self, ctx: FetchOperationInfoContext, token: str - ) -> nexusrpc.OperationInfo: - raise NotImplementedError - - async def fetch_result( - self, ctx: FetchOperationResultContext, token: str - ) -> OpOutput: - raise NotImplementedError - @service_handler(service=ServiceInterface) class ServiceImpl: @@ -253,12 +241,14 @@ def __init__( request_cancel: bool, task_queue: str, ) -> None: - self.nexus_client = workflow.create_nexus_client( - service={ - CallerReference.IMPL_WITH_INTERFACE: ServiceImpl, - CallerReference.INTERFACE: ServiceInterface, - }[input.op_input.caller_reference], - endpoint=make_nexus_endpoint_name(task_queue), + self.nexus_client: workflow.NexusClient[ServiceInterface] = ( + workflow.create_nexus_client( + service={ + CallerReference.IMPL_WITH_INTERFACE: ServiceImpl, + CallerReference.INTERFACE: ServiceInterface, + }[input.op_input.caller_reference], + endpoint=make_nexus_endpoint_name(task_queue), + ) ) self._nexus_operation_start_resolved = False self._proceed = False diff --git a/tests/nexus/test_workflow_caller_errors.py b/tests/nexus/test_workflow_caller_errors.py index 2bff390da..0fc32ba5b 100644 --- a/tests/nexus/test_workflow_caller_errors.py +++ b/tests/nexus/test_workflow_caller_errors.py @@ -2,21 +2,22 @@ import asyncio import concurrent.futures +import threading import uuid from collections import Counter from dataclasses import dataclass from datetime import timedelta +from logging import getLogger import nexusrpc import nexusrpc.handler import pytest from nexusrpc.handler import ( CancelOperationContext, - FetchOperationInfoContext, - FetchOperationResultContext, OperationHandler, StartOperationContext, StartOperationResultAsync, + operation_handler, service_handler, sync_operation, ) @@ -33,11 +34,13 @@ ) from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers import assert_eq_eventually +from tests.helpers import LogCapturer, assert_eq_eventually from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name operation_invocation_counts = Counter[str]() +logger = getLogger(__name__) + @dataclass class ErrorTestInput: @@ -229,29 +232,48 @@ async def test_nexus_operation_fails_without_retry_as_handler_error( pytest.fail("Unreachable") +_start_operation_sync_complete = threading.Event() + + # Start timeout test @service_handler class StartTimeoutTestService: @sync_operation - async def op_handler_that_never_returns( + async def expect_timeout_cancellation_async( self, ctx: StartOperationContext, input: None ) -> None: - await asyncio.Future() + try: + await asyncio.wait_for(ctx.task_cancellation.wait_until_cancelled(), 1) + except asyncio.TimeoutError: + raise ApplicationError("expected cancel", non_retryable=True) + + @sync_operation + def expect_timeout_cancellation_sync( + self, ctx: StartOperationContext, input: None + ) -> None: + global _start_operation_sync_complete + cancelled = ctx.task_cancellation.wait_until_cancelled_sync(1) + if not cancelled: + raise ApplicationError("expected cancel", non_retryable=True) + reason = ctx.task_cancellation.cancellation_reason() + if reason != "TIMED_OUT": + logger.error("unexpected cancellation reason: %s", reason) + _start_operation_sync_complete.set() @workflow.defn class StartTimeoutTestCallerWorkflow: @workflow.init - def __init__(self): + def __init__(self, operation: str): self.nexus_client = workflow.create_nexus_client( service=StartTimeoutTestService, endpoint=make_nexus_endpoint_name(workflow.info().task_queue), ) @workflow.run - async def run(self) -> None: + async def run(self, operation: str) -> None: await self.nexus_client.execute_operation( - StartTimeoutTestService.op_handler_that_never_returns, # type: ignore[arg-type] # mypy can't infer OutputT=None in Union type + operation, None, output_type=None, schedule_to_close_timeout=timedelta(seconds=0.1), @@ -263,6 +285,7 @@ async def test_error_raised_by_timeout_of_nexus_start_operation( ): if env.supports_time_skipping: pytest.skip("Nexus tests don't work with time-skipping server") + global _start_operation_sync_complete task_queue = str(uuid.uuid4()) async with Worker( @@ -270,11 +293,13 @@ async def test_error_raised_by_timeout_of_nexus_start_operation( nexus_service_handlers=[StartTimeoutTestService()], workflows=[StartTimeoutTestCallerWorkflow], task_queue=task_queue, + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), ): await create_nexus_endpoint(task_queue, client) try: await client.execute_workflow( StartTimeoutTestCallerWorkflow.run, + "expect_timeout_cancellation_async", id=str(uuid.uuid4()), task_queue=task_queue, ) @@ -285,35 +310,51 @@ async def test_error_raised_by_timeout_of_nexus_start_operation( else: pytest.fail("Expected exception due to timeout of nexus start operation") + with LogCapturer().logs_captured(logger) as capturer: + try: + await client.execute_workflow( + StartTimeoutTestCallerWorkflow.run, + "expect_timeout_cancellation_sync", + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + except Exception as err: + assert isinstance(err, WorkflowFailureError) + assert isinstance(err.__cause__, NexusOperationError) + assert isinstance(err.__cause__.__cause__, TimeoutError) + else: + pytest.fail( + "Expected exception due to timeout of nexus start operation" + ) + + _start_operation_sync_complete.wait() + assert capturer.find_log("unexpected cancellation reason") is None + # Cancellation timeout test -class OperationWithCancelMethodThatNeverReturns(OperationHandler[None, None]): +class OperationWithCancelMethodThatExpectsCancel(OperationHandler[None, None]): async def start( self, ctx: StartOperationContext, input: None ) -> StartOperationResultAsync: return StartOperationResultAsync("fake-token") async def cancel(self, ctx: CancelOperationContext, token: str) -> None: - await asyncio.Future() - - async def fetch_info( - self, ctx: FetchOperationInfoContext, token: str - ) -> nexusrpc.OperationInfo: - raise NotImplementedError("Not implemented") - - async def fetch_result(self, ctx: FetchOperationResultContext, token: str) -> None: - raise NotImplementedError("Not implemented") + try: + await asyncio.wait_for(ctx.task_cancellation.wait_until_cancelled(), 1) + except asyncio.TimeoutError: + logger.error("expected cancellation") + raise ApplicationError("expected cancellation", non_retryable=True) @service_handler class CancellationTimeoutTestService: - @nexusrpc.handler._decorators.operation_handler - def op_with_cancel_method_that_never_returns( + @operation_handler + def op_with_cancel_method_that_expects_cancel( self, ) -> OperationHandler[None, None]: - return OperationWithCancelMethodThatNeverReturns() + return OperationWithCancelMethodThatExpectsCancel() @workflow.defn @@ -327,13 +368,8 @@ def __init__(self): @workflow.run async def run(self) -> None: - # TODO(nexus-prerelease) op_handle = await self.nexus_client.start_operation( - # Although the tests are making use of it, we are not exposing operation - # factory methods to users as a way to write nexus operations, and so the - # types on NexusClient start_operation/execute_operation do not currently - # permit it. - CancellationTimeoutTestService.op_with_cancel_method_that_never_returns, # type: ignore + CancellationTimeoutTestService.op_with_cancel_method_that_expects_cancel, None, schedule_to_close_timeout=timedelta(seconds=0.1), ) @@ -347,7 +383,6 @@ async def test_error_raised_by_timeout_of_nexus_cancel_operation( if env.supports_time_skipping: pytest.skip("Nexus tests don't work with time-skipping server") - pytest.skip("TODO(nexus-prerelease): finish writing this test") task_queue = str(uuid.uuid4()) async with Worker( client, @@ -355,16 +390,21 @@ async def test_error_raised_by_timeout_of_nexus_cancel_operation( workflows=[CancellationTimeoutTestCallerWorkflow], task_queue=task_queue, ): - await create_nexus_endpoint(task_queue, client) - try: - await client.execute_workflow( - CancellationTimeoutTestCallerWorkflow.run, - id=str(uuid.uuid4()), - task_queue=task_queue, - ) - except Exception as err: - assert isinstance(err, WorkflowFailureError) - assert isinstance(err.__cause__, NexusOperationError) - assert isinstance(err.__cause__.__cause__, TimeoutError) - else: - pytest.fail("Expected exception due to timeout of nexus cancel operation") + with LogCapturer().logs_captured(logger) as capturer: + await create_nexus_endpoint(task_queue, client) + try: + await client.execute_workflow( + CancellationTimeoutTestCallerWorkflow.run, + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + except Exception as err: + assert isinstance(err, WorkflowFailureError) + assert isinstance(err.__cause__, NexusOperationError) + assert isinstance(err.__cause__.__cause__, TimeoutError) + else: + pytest.fail( + "Expected exception due to timeout of nexus cancel operation" + ) + + assert capturer.find_log("expected cancellation") is None diff --git a/uv.lock b/uv.lock index 68f2f73bd..e11decd03 100644 --- a/uv.lock +++ b/uv.lock @@ -1761,14 +1761,10 @@ wheels = [ [[package]] name = "nexus-rpc" version = "1.1.0" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/nexus-rpc/sdk-python?rev=task-cancellation#847d9cb903d962e95d147de26b02fb1e412ee7cb" } dependencies = [ { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ef/66/540687556bd28cf1ec370cc6881456203dfddb9dab047b8979c6865b5984/nexus_rpc-1.1.0.tar.gz", hash = "sha256:d65ad6a2f54f14e53ebe39ee30555eaeb894102437125733fb13034a04a44553", size = 77383, upload-time = "2025-07-07T19:03:58.368Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/bf/2f/9e9d0dcaa4c6ffa22b7aa31069a8a264c753ff8027b36af602cce038c92f/nexus_rpc-1.1.0-py3-none-any.whl", hash = "sha256:d1b007af2aba186a27e736f8eaae39c03aed05b488084ff6c3d1785c9ba2ad38", size = 27743, upload-time = "2025-07-07T19:03:57.556Z" }, -] [[package]] name = "nh3" @@ -3007,7 +3003,7 @@ dev = [ requires-dist = [ { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, { name = "mcp", marker = "extra == 'openai-agents'", specifier = ">=1.9.4,<2" }, - { name = "nexus-rpc", specifier = "==1.1.0" }, + { name = "nexus-rpc", git = "https://github.com/nexus-rpc/sdk-python?rev=task-cancellation" }, { name = "openai-agents", marker = "extra == 'openai-agents'", specifier = ">=0.3,<0.5" }, { name = "opentelemetry-api", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, { name = "opentelemetry-sdk", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" },