diff --git a/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py b/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py index f7991a6..d9c34ba 100644 --- a/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py +++ b/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py @@ -37,6 +37,16 @@ def process( """Process an operation update and return the transformed operation.""" raise NotImplementedError + def _get_start_time( + self, current_operation: Operation | None + ) -> datetime.datetime | None: + start_time: datetime.datetime | None = ( + current_operation.start_timestamp + if current_operation + else datetime.datetime.now(tz=datetime.UTC) + ) + return start_time + def _get_end_time( self, current_operation: Operation | None, status: OperationStatus ) -> datetime.datetime | None: @@ -116,22 +126,6 @@ def _create_invoke_details( return ChainedInvokeDetails(result=update.payload, error=update.error) return None - def _create_wait_details( - self, update: OperationUpdate, current_operation: Operation | None - ) -> WaitDetails | None: - """Create WaitDetails from OperationUpdate.""" - if update.operation_type == OperationType.WAIT and update.wait_options: - if current_operation and current_operation.wait_details: - scheduled_end_timestamp = ( - current_operation.wait_details.scheduled_end_timestamp - ) - else: - scheduled_end_timestamp = datetime.datetime.now( - tz=datetime.UTC - ) + timedelta(seconds=update.wait_options.wait_seconds) - return WaitDetails(scheduled_end_timestamp=scheduled_end_timestamp) - return None - def _translate_update_to_operation( self, update: OperationUpdate, @@ -139,12 +133,10 @@ def _translate_update_to_operation( status: OperationStatus, ) -> Operation: """Transform OperationUpdate to Operation, always creating new Operation.""" - start_time = ( - current_operation.start_timestamp - if current_operation - else datetime.datetime.now(tz=datetime.UTC) + start_time: datetime.datetime | None = self._get_start_time(current_operation) + end_time: datetime.datetime | None = self._get_end_time( + current_operation, status ) - end_time = self._get_end_time(current_operation, status) execution_details = self._create_execution_details(update) context_details = self._create_context_details(update) @@ -169,3 +161,19 @@ def _translate_update_to_operation( chained_invoke_details=invoke_details, wait_details=wait_details, ) + + def _create_wait_details( + self, update: OperationUpdate, current_operation: Operation | None + ) -> WaitDetails | None: + """Create WaitDetails from OperationUpdate.""" + if update.operation_type == OperationType.WAIT and update.wait_options: + if current_operation and current_operation.wait_details: + scheduled_end_timestamp = ( + current_operation.wait_details.scheduled_end_timestamp + ) + else: + scheduled_end_timestamp = datetime.datetime.now( + tz=datetime.UTC + ) + timedelta(seconds=update.wait_options.wait_seconds) + return WaitDetails(scheduled_end_timestamp=scheduled_end_timestamp) + return None diff --git a/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/callback.py b/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/callback.py index c47b5ec..c1a0ec7 100644 --- a/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/callback.py +++ b/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/callback.py @@ -2,6 +2,7 @@ from __future__ import annotations +import datetime from typing import TYPE_CHECKING from aws_durable_execution_sdk_python.lambda_service import ( @@ -9,15 +10,16 @@ OperationAction, OperationStatus, OperationUpdate, + CallbackDetails, + OperationType, ) - from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import ( OperationProcessor, ) from aws_durable_execution_sdk_python_testing.exceptions import ( InvalidParameterValueException, ) - +from aws_durable_execution_sdk_python_testing.token import CallbackToken if TYPE_CHECKING: from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier @@ -36,14 +38,46 @@ def process( """Process CALLBACK operation update with scheduler integration for activities.""" match update.action: case OperationAction.START: - # TODO: create CallbackToken (see token module). Add Observer/Notifier for on_callback_created possibly, - # but token might well have enough so don't need to maintain token list on execution itself - return self._translate_update_to_operation( - update=update, - current_operation=current_op, - status=OperationStatus.STARTED, + callback_token: CallbackToken = CallbackToken( + execution_arn=execution_arn, + operation_id=update.operation_id, + ) + + notifier.notify_callback_created( + execution_arn=execution_arn, + operation_id=update.operation_id, + callback_token=callback_token, ) + + callback_id: str = callback_token.to_str() + + callback_details: CallbackDetails | None = ( + CallbackDetails( + callback_id=callback_id, + result=update.payload, + error=update.error, + ) + if update.operation_type == OperationType.CALLBACK + else None + ) + status: OperationStatus = OperationStatus.STARTED + start_time: datetime.datetime | None = self._get_start_time(current_op) + end_time: datetime.datetime | None = self._get_end_time( + current_op, status + ) + operation: Operation = Operation( + operation_id=update.operation_id, + parent_id=update.parent_id, + name=update.name, + start_timestamp=start_time, + end_timestamp=end_time, + operation_type=update.operation_type, + status=status, + sub_type=update.sub_type, + callback_details=callback_details, + ) + + return operation case _: msg: str = "Invalid action for CALLBACK operation." - raise InvalidParameterValueException(msg) diff --git a/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/callback.py b/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/callback.py index fb5317b..575db81 100644 --- a/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/callback.py +++ b/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/callback.py @@ -17,7 +17,6 @@ VALID_ACTIONS_FOR_CALLBACK = frozenset( [ OperationAction.START, - OperationAction.CANCEL, ] ) @@ -41,14 +40,6 @@ def validate(current_state: Operation | None, update: OperationUpdate) -> None: "Cannot start a CALLBACK that already exist." ) raise InvalidParameterValueException(msg_callback_exists) - case OperationAction.CANCEL: - if ( - current_state is None - or current_state.status - not in CallbackOperationValidator._ALLOWED_STATUS_TO_CANCEL - ): - msg_callback_cancel: str = "Cannot cancel a CALLBACK that does not exist or has already completed." - raise InvalidParameterValueException(msg_callback_cancel) case _: - msg_callback_invalid: str = "Invalid CALLBACK action." - raise InvalidParameterValueException(msg_callback_invalid) + msg: str = "Invalid action for CALLBACK operation." + raise InvalidParameterValueException(msg) diff --git a/src/aws_durable_execution_sdk_python_testing/execution.py b/src/aws_durable_execution_sdk_python_testing/execution.py index 17f99ef..24e3f81 100644 --- a/src/aws_durable_execution_sdk_python_testing/execution.py +++ b/src/aws_durable_execution_sdk_python_testing/execution.py @@ -28,7 +28,10 @@ from aws_durable_execution_sdk_python_testing.model import ( StartDurableExecutionInput, ) -from aws_durable_execution_sdk_python_testing.token import CheckpointToken +from aws_durable_execution_sdk_python_testing.token import ( + CheckpointToken, + CallbackToken, +) class Execution: @@ -203,6 +206,18 @@ def find_operation(self, operation_id: str) -> tuple[int, Operation]: msg: str = f"Attempting to update state of an Operation [{operation_id}] that doesn't exist" raise IllegalStateException(msg) + def find_callback_operation(self, callback_id: str) -> tuple[int, Operation]: + """Find callback operation by callback_id, return index and operation.""" + for i, operation in enumerate(self.operations): + if ( + operation.operation_type == OperationType.CALLBACK + and operation.callback_details + and operation.callback_details.callback_id == callback_id + ): + return i, operation + msg: str = f"Callback operation with callback_id [{callback_id}] not found" + raise IllegalStateException(msg) + def complete_wait(self, operation_id: str) -> Operation: """Complete WAIT operation when timer fires.""" index, operation = self.find_operation(operation_id) @@ -260,3 +275,55 @@ def complete_retry(self, operation_id: str) -> Operation: # Assign self.operations[index] = updated_operation return updated_operation + + def complete_callback_success( + self, callback_id: str, result: bytes | None = None + ) -> Operation: + """Complete CALLBACK operation with success.""" + index, operation = self.find_callback_operation(callback_id) + if operation.status != OperationStatus.STARTED: + msg: str = f"Callback operation [{callback_id}] is not in STARTED state" + raise IllegalStateException(msg) + + with self._state_lock: + self._token_sequence += 1 + updated_callback_details = None + if operation.callback_details: + updated_callback_details = replace( + operation.callback_details, + result=result.decode() if result else None, + ) + + self.operations[index] = replace( + operation, + status=OperationStatus.SUCCEEDED, + end_timestamp=datetime.now(UTC), + callback_details=updated_callback_details, + ) + return self.operations[index] + + def complete_callback_failure( + self, callback_id: str, error: ErrorObject + ) -> Operation: + """Complete CALLBACK operation with failure.""" + index, operation = self.find_callback_operation(callback_id) + + if operation.status != OperationStatus.STARTED: + msg: str = f"Callback operation [{callback_id}] is not in STARTED state" + raise IllegalStateException(msg) + + with self._state_lock: + self._token_sequence += 1 + updated_callback_details = None + if operation.callback_details: + updated_callback_details = replace( + operation.callback_details, error=error + ) + + self.operations[index] = replace( + operation, + status=OperationStatus.FAILED, + end_timestamp=datetime.now(UTC), + callback_details=updated_callback_details, + ) + return self.operations[index] diff --git a/src/aws_durable_execution_sdk_python_testing/executor.py b/src/aws_durable_execution_sdk_python_testing/executor.py index 6f07ebe..12b4352 100644 --- a/src/aws_durable_execution_sdk_python_testing/executor.py +++ b/src/aws_durable_execution_sdk_python_testing/executor.py @@ -13,6 +13,7 @@ InvocationStatus, ) from aws_durable_execution_sdk_python.lambda_service import ( + CallbackTimeoutType, ErrorObject, Operation, OperationUpdate, @@ -51,6 +52,7 @@ Execution as ExecutionSummary, ) from aws_durable_execution_sdk_python_testing.observer import ExecutionObserver +from aws_durable_execution_sdk_python_testing.token import CallbackToken if TYPE_CHECKING: @@ -82,6 +84,10 @@ def __init__( self._invoker = invoker self._checkpoint_processor = checkpoint_processor self._completion_events: dict[str, Event] = {} + self._callback_timeouts: dict[str, Event] = {} # callback_id -> timeout event + self._callback_heartbeats: dict[ + str, Event + ] = {} # callback_id -> heartbeat event def start_execution( self, @@ -613,7 +619,7 @@ def checkpoint_execution( def send_callback_success( self, callback_id: str, - result: bytes | None = None, # noqa: ARG002 + result: bytes | None = None, ) -> SendDurableExecutionCallbackSuccessResponse: """Send callback success response. @@ -632,16 +638,23 @@ def send_callback_success( msg: str = "callback_id is required" raise InvalidParameterValueException(msg) - # TODO: Implement actual callback success logic - # This would involve finding the callback operation and completing it - logger.info("Callback success sent for callback_id: %s", callback_id) + try: + callback_token = CallbackToken.from_str(callback_id) + execution = self.get_execution(callback_token.execution_arn) + execution.complete_callback_success(callback_id, result) + self._store.update(execution) + self._cleanup_callback_timeouts(callback_id) + logger.info("Callback success completed for callback_id: %s", callback_id) + except Exception as e: + msg = f"Failed to process callback success: {e}" + raise ResourceNotFoundException(msg) from e return SendDurableExecutionCallbackSuccessResponse() def send_callback_failure( self, callback_id: str, - error: ErrorObject | None = None, # noqa: ARG002 + error: ErrorObject | None = None, ) -> SendDurableExecutionCallbackFailureResponse: """Send callback failure response. @@ -660,9 +673,18 @@ def send_callback_failure( msg: str = "callback_id is required" raise InvalidParameterValueException(msg) - # TODO: Implement actual callback failure logic - # This would involve finding the callback operation and failing it - logger.info("Callback failure sent for callback_id: %s", callback_id) + callback_error: ErrorObject = error or ErrorObject.from_message("") + + try: + callback_token: CallbackToken = CallbackToken.from_str(callback_id) + execution: Execution = self.get_execution(callback_token.execution_arn) + execution.complete_callback_failure(callback_id, callback_error) + self._store.update(execution) + self._cleanup_callback_timeouts(callback_id) + logger.info("Callback failure completed for callback_id: %s", callback_id) + except Exception as e: + msg = f"Failed to process callback failure: {e}" + raise ResourceNotFoundException(msg) from e return SendDurableExecutionCallbackFailureResponse() @@ -685,9 +707,24 @@ def send_callback_heartbeat( msg: str = "callback_id is required" raise InvalidParameterValueException(msg) - # TODO: Implement actual callback heartbeat logic - # This would involve updating the callback timeout - logger.info("Callback heartbeat sent for callback_id: %s", callback_id) + try: + callback_token: CallbackToken = CallbackToken.from_str(callback_id) + execution: Execution = self.get_execution(callback_token.execution_arn) + + # Find callback operation to verify it exists and is active + _, operation = execution.find_callback_operation(callback_id) + if operation.status != OperationStatus.STARTED: + msg = f"Callback {callback_id} is not active" + raise ResourceNotFoundException(msg) + + # Reset heartbeat timeout if configured + self._reset_callback_heartbeat_timeout( + callback_id, execution.durable_execution_arn + ) + logger.info("Callback heartbeat processed for callback_id: %s", callback_id) + except Exception as e: + msg = f"Failed to process callback heartbeat: {e}" + raise ResourceNotFoundException(msg) from e return SendDurableExecutionCallbackHeartbeatResponse() @@ -1001,4 +1038,197 @@ def retry_handler() -> None: retry_handler, delay=delay, completion_event=completion_event ) + def on_callback_created( + self, execution_arn: str, operation_id: str, callback_token: CallbackToken + ) -> None: + """Handle callback creation. Observer method triggered by notifier.""" + callback_id = callback_token.to_str() + logger.debug( + "[%s] Callback created for operation %s with callback_id: %s", + execution_arn, + operation_id, + callback_id, + ) + + # Schedule callback timeouts if configured + self._schedule_callback_timeouts(execution_arn, operation_id, callback_id) + # endregion ExecutionObserver + + # region Callback Timeouts + def _schedule_callback_timeouts( + self, execution_arn: str, operation_id: str, callback_id: str + ) -> None: + """Schedule callback timeout and heartbeat timeout if configured.""" + try: + execution = self.get_execution(execution_arn) + _, operation = execution.find_operation(operation_id) + + if not operation.callback_details: + return + + # Find the callback options from the operation update that created this callback + # We need to look at the checkpoint updates to find the original callback options + callback_options = None + for update in execution.updates: + if ( + update.operation_id == operation_id + and update.callback_options + and update.action.value == "START" + ): + callback_options = update.callback_options + break + + if not callback_options: + return + + completion_event = self._completion_events.get(execution_arn) + + # Schedule main timeout if configured + if callback_options.timeout_seconds > 0: + + def timeout_handler(): + self._on_callback_timeout(execution_arn, callback_id) + + timeout_event = self._scheduler.create_event() + self._callback_timeouts[callback_id] = timeout_event + self._scheduler.call_later( + timeout_handler, + delay=callback_options.timeout_seconds, + completion_event=completion_event, + ) + + # Schedule heartbeat timeout if configured + if callback_options.heartbeat_timeout_seconds > 0: + + def heartbeat_timeout_handler(): + self._on_callback_heartbeat_timeout(execution_arn, callback_id) + + heartbeat_event = self._scheduler.create_event() + self._callback_heartbeats[callback_id] = heartbeat_event + self._scheduler.call_later( + heartbeat_timeout_handler, + delay=callback_options.heartbeat_timeout_seconds, + completion_event=completion_event, + ) + + except Exception: + logger.exception( + "[%s] Error scheduling callback timeouts for %s", + execution_arn, + callback_id, + ) + + def _reset_callback_heartbeat_timeout( + self, callback_id: str, execution_arn: str + ) -> None: + """Reset the heartbeat timeout for a callback.""" + # Cancel existing heartbeat timeout + if heartbeat_event := self._callback_heartbeats.get(callback_id): + heartbeat_event.remove() + del self._callback_heartbeats[callback_id] + + # Find callback options to reschedule heartbeat timeout + try: + callback_token = CallbackToken.from_str(callback_id) + execution = self.get_execution(callback_token.execution_arn) + + # Find callback options from updates + callback_options = None + for update in execution.updates: + if ( + update.operation_id == callback_token.operation_id + and update.callback_options + and update.action.value == "START" + ): + callback_options = update.callback_options + break + + if callback_options and callback_options.heartbeat_timeout_seconds > 0: + + def heartbeat_timeout_handler(): + self._on_callback_heartbeat_timeout(execution_arn, callback_id) + + completion_event = self._completion_events.get(execution_arn) + heartbeat_event = self._scheduler.create_event() + self._callback_heartbeats[callback_id] = heartbeat_event + self._scheduler.call_later( + heartbeat_timeout_handler, + delay=callback_options.heartbeat_timeout_seconds, + completion_event=completion_event, + ) + except Exception: + logger.exception( + "[%s] Error resetting callback heartbeat timeout for %s", + execution_arn, + callback_id, + ) + + def _cleanup_callback_timeouts(self, callback_id: str) -> None: + """Clean up timeout events for a completed callback.""" + # Clean up main timeout + if timeout_event := self._callback_timeouts.get(callback_id): + timeout_event.remove() + del self._callback_timeouts[callback_id] + + # Clean up heartbeat timeout + if heartbeat_event := self._callback_heartbeats.get(callback_id): + heartbeat_event.remove() + del self._callback_heartbeats[callback_id] + + def _on_callback_timeout(self, execution_arn: str, callback_id: str) -> None: + """Handle callback timeout.""" + try: + callback_token = CallbackToken.from_str(callback_id) + execution = self.get_execution(callback_token.execution_arn) + + if execution.is_complete: + return + + # Fail the callback with timeout error + timeout_error = ErrorObject.from_message( + f"Callback timed out: {CallbackTimeoutType.TIMEOUT.value}" + ) + execution.complete_callback_failure(callback_id, timeout_error) + self._store.update(execution) + self._invoke_execution(execution_arn) + + logger.warning("[%s] Callback %s timed out", execution_arn, callback_id) + except Exception: + logger.exception( + "[%s] Error processing callback timeout for %s", + execution_arn, + callback_id, + ) + + def _on_callback_heartbeat_timeout( + self, execution_arn: str, callback_id: str + ) -> None: + """Handle callback heartbeat timeout.""" + try: + callback_token = CallbackToken.from_str(callback_id) + execution = self.get_execution(callback_token.execution_arn) + + if execution.is_complete: + return + + # Fail the callback with heartbeat timeout error + + heartbeat_error = ErrorObject.from_message( + f"Callback heartbeat timed out: {CallbackTimeoutType.HEARTBEAT.value}" + ) + execution.complete_callback_failure(callback_id, heartbeat_error) + self._store.update(execution) + self._invoke_execution(execution_arn) + + logger.warning( + "[%s] Callback %s heartbeat timed out", execution_arn, callback_id + ) + except Exception: + logger.exception( + "[%s] Error processing callback heartbeat timeout for %s", + execution_arn, + callback_id, + ) + + # endregion Callback Timeouts diff --git a/src/aws_durable_execution_sdk_python_testing/observer.py b/src/aws_durable_execution_sdk_python_testing/observer.py index eb4aa5a..2473896 100644 --- a/src/aws_durable_execution_sdk_python_testing/observer.py +++ b/src/aws_durable_execution_sdk_python_testing/observer.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING +from aws_durable_execution_sdk_python_testing.token import CallbackToken if TYPE_CHECKING: from collections.abc import Callable @@ -36,6 +37,12 @@ def on_step_retry_scheduled( ) -> None: """Called when step retry scheduled.""" + @abstractmethod + def on_callback_created( + self, execution_arn: str, operation_id: str, callback_token: CallbackToken + ) -> None: + """Called when callback is created.""" + class ExecutionNotifier: """Notifies observers about execution events. Thread-safe.""" @@ -91,4 +98,15 @@ def notify_step_retry_scheduled( delay=delay, ) + def notify_callback_created( + self, execution_arn: str, operation_id: str, callback_token: CallbackToken + ) -> None: + """Notify observers about callback creation.""" + self._notify_observers( + ExecutionObserver.on_callback_created, + execution_arn=execution_arn, + operation_id=operation_id, + callback_token=callback_token, + ) + # endregion event emitters diff --git a/tests/checkpoint/validators/checkpoint_test.py b/tests/checkpoint/validators/checkpoint_test.py index 6777d74..548e988 100644 --- a/tests/checkpoint/validators/checkpoint_test.py +++ b/tests/checkpoint/validators/checkpoint_test.py @@ -353,27 +353,6 @@ def test_validate_operation_status_transition_wait(): CheckpointValidator.validate_input(updates, execution) -def test_validate_operation_status_transition_callback(): - """Test validation calls callback validator for CALLBACK operations.""" - execution = _create_test_execution() - - callback_op = Operation( - operation_id="callback-1", - operation_type=OperationType.CALLBACK, - status=OperationStatus.STARTED, - ) - execution.operations.append(callback_op) - - updates = [ - OperationUpdate( - operation_id="callback-1", - operation_type=OperationType.CALLBACK, - action=OperationAction.CANCEL, - ) - ] - CheckpointValidator.validate_input(updates, execution) - - def test_validate_operation_status_transition_invoke(): """Test validation calls invoke validator for INVOKE operations.""" execution = _create_test_execution() diff --git a/tests/checkpoint/validators/operations/callback_test.py b/tests/checkpoint/validators/operations/callback_test.py index f497f51..93f6dd3 100644 --- a/tests/checkpoint/validators/operations/callback_test.py +++ b/tests/checkpoint/validators/operations/callback_test.py @@ -47,21 +47,6 @@ def test_validate_start_action_with_existing_state(): CallbackOperationValidator.validate(current_state, update) -def test_validate_cancel_action_with_started_state(): - """Test CANCEL action with STARTED state.""" - current_state = Operation( - operation_id="test-id", - operation_type=OperationType.CALLBACK, - status=OperationStatus.STARTED, - ) - update = OperationUpdate( - operation_id="test-id", - operation_type=OperationType.CALLBACK, - action=OperationAction.CANCEL, - ) - CallbackOperationValidator.validate(current_state, update) - - def test_validate_cancel_action_with_no_current_state(): """Test CANCEL action with no current state raises error.""" update = OperationUpdate( @@ -72,7 +57,7 @@ def test_validate_cancel_action_with_no_current_state(): with pytest.raises( InvalidParameterValueException, - match="Cannot cancel a CALLBACK that does not exist or has already completed", + match="Invalid action for CALLBACK operation.", ): CallbackOperationValidator.validate(None, update) @@ -92,7 +77,7 @@ def test_validate_cancel_action_with_completed_state(): with pytest.raises( InvalidParameterValueException, - match="Cannot cancel a CALLBACK that does not exist or has already completed", + match="Invalid action for CALLBACK operation.", ): CallbackOperationValidator.validate(current_state, update) @@ -105,5 +90,7 @@ def test_validate_invalid_action(): action=OperationAction.SUCCEED, ) - with pytest.raises(InvalidParameterValueException, match="Invalid CALLBACK action"): + with pytest.raises( + InvalidParameterValueException, match="Invalid action for CALLBACK operation." + ): CallbackOperationValidator.validate(None, update) diff --git a/tests/checkpoint/validators/transitions_test.py b/tests/checkpoint/validators/transitions_test.py index edf10a9..901db3c 100644 --- a/tests/checkpoint/validators/transitions_test.py +++ b/tests/checkpoint/validators/transitions_test.py @@ -51,7 +51,6 @@ def test_validate_callback_valid_actions(): """Test valid actions for CALLBACK operations.""" valid_actions = [ OperationAction.START, - OperationAction.CANCEL, ] for action in valid_actions: ValidActionsByOperationTypeValidator.validate(OperationType.CALLBACK, action) diff --git a/tests/execution_test.py b/tests/execution_test.py index 26d7469..0a82e26 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -1,7 +1,7 @@ """Unit tests for execution module.""" -from datetime import UTC, datetime -from unittest.mock import patch +from datetime import datetime, timezone +from unittest.mock import patch, Mock import pytest from aws_durable_execution_sdk_python.execution import InvocationStatus @@ -11,9 +11,13 @@ OperationStatus, OperationType, StepDetails, + CallbackDetails, ) -from aws_durable_execution_sdk_python_testing.exceptions import IllegalStateException +from aws_durable_execution_sdk_python_testing.exceptions import ( + IllegalStateException, + InvalidParameterValueException, +) from aws_durable_execution_sdk_python_testing.execution import Execution from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput @@ -68,7 +72,7 @@ def test_execution_new(mock_uuid4): @patch("aws_durable_execution_sdk_python_testing.execution.datetime") def test_execution_start(mock_datetime): """Test Execution.start method.""" - mock_now = datetime(2023, 1, 1, 12, 0, 0, tzinfo=UTC) + mock_now = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc) mock_datetime.now.return_value = mock_now start_input = StartDurableExecutionInput( @@ -168,7 +172,7 @@ def test_get_navigable_operations(): operation_id="op1", parent_id=None, name="test", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.EXECUTION, status=OperationStatus.STARTED, ) @@ -194,7 +198,7 @@ def test_get_assertable_operations(): operation_id="exec-op", parent_id=None, name="execution", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.EXECUTION, status=OperationStatus.STARTED, ) @@ -202,7 +206,7 @@ def test_get_assertable_operations(): operation_id="step-op", parent_id=None, name="step", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.STEP, status=OperationStatus.STARTED, ) @@ -230,7 +234,7 @@ def test_has_pending_operations_with_pending_step(): operation_id="op1", parent_id=None, name="test", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.STEP, status=OperationStatus.PENDING, ) @@ -257,7 +261,7 @@ def test_has_pending_operations_with_started_wait(): operation_id="op1", parent_id=None, name="test", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.WAIT, status=OperationStatus.STARTED, ) @@ -284,7 +288,7 @@ def test_has_pending_operations_with_started_callback(): operation_id="op1", parent_id=None, name="test", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.CALLBACK, status=OperationStatus.STARTED, ) @@ -311,7 +315,7 @@ def test_has_pending_operations_with_started_invoke(): operation_id="op1", parent_id=None, name="test", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.CHAINED_INVOKE, status=OperationStatus.STARTED, ) @@ -338,7 +342,7 @@ def test_has_pending_operations_no_pending(): operation_id="op1", parent_id=None, name="test", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.STEP, status=OperationStatus.SUCCEEDED, ) @@ -422,7 +426,7 @@ def test_find_operation_exists(): operation_id="test-op-id", parent_id=None, name="test", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.STEP, status=OperationStatus.STARTED, ) @@ -455,7 +459,7 @@ def test_find_operation_not_exists(): @patch("aws_durable_execution_sdk_python_testing.execution.datetime") def test_complete_wait_success(mock_datetime): """Test complete_wait method successful completion.""" - mock_now = datetime(2023, 1, 1, 12, 0, 0, tzinfo=UTC) + mock_now = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc) mock_datetime.now.return_value = mock_now start_input = StartDurableExecutionInput( @@ -470,7 +474,7 @@ def test_complete_wait_success(mock_datetime): operation_id="wait-op-id", parent_id=None, name="test-wait", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.WAIT, status=OperationStatus.STARTED, ) @@ -498,7 +502,7 @@ def test_complete_wait_wrong_status(): operation_id="wait-op-id", parent_id=None, name="test-wait", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.WAIT, status=OperationStatus.SUCCEEDED, ) @@ -524,7 +528,7 @@ def test_complete_wait_wrong_type(): operation_id="step-op-id", parent_id=None, name="test-step", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.STEP, status=OperationStatus.STARTED, ) @@ -545,14 +549,14 @@ def test_complete_retry_success(): execution_retention_period_days=7, ) step_details = StepDetails( - next_attempt_timestamp=str(datetime.now(UTC)), + next_attempt_timestamp=str(datetime.now(timezone.utc)), attempt=1, ) operation = Operation( operation_id="step-op-id", parent_id=None, name="test-step", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.STEP, status=OperationStatus.PENDING, step_details=step_details, @@ -581,7 +585,7 @@ def test_complete_retry_no_step_details(): operation_id="step-op-id", parent_id=None, name="test-step", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.STEP, status=OperationStatus.PENDING, ) @@ -608,7 +612,7 @@ def test_complete_retry_wrong_status(): operation_id="step-op-id", parent_id=None, name="test-step", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.STEP, status=OperationStatus.STARTED, ) @@ -634,7 +638,7 @@ def test_complete_retry_wrong_type(): operation_id="wait-op-id", parent_id=None, name="test-wait", - start_timestamp=datetime.now(UTC), + start_timestamp=datetime.now(timezone.utc), operation_type=OperationType.WAIT, status=OperationStatus.PENDING, ) @@ -642,3 +646,231 @@ def test_complete_retry_wrong_type(): with pytest.raises(IllegalStateException, match="Expected STEP operation"): execution.complete_retry("wait-op-id") + + +def test_complete_retry_with_step_details(): + """Test complete_retry with operation that has step_details.""" + step_details = StepDetails( + attempt=1, next_attempt_timestamp=datetime.now(timezone.utc) + ) + step_op = Operation( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + step_details=step_details, + ) + + execution = Execution("test-arn", Mock(), [step_op]) + + result = execution.complete_retry("op-1") + assert result.status == OperationStatus.READY + assert result.step_details.next_attempt_timestamp is None + + +def test_complete_retry_without_step_details(): + """Test complete_retry with operation that has no step_details.""" + step_op = Operation( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + step_details=None, # No step details + ) + + execution = Execution("test-arn", Mock(), [step_op]) + + result = execution.complete_retry("op-1") + assert result.status == OperationStatus.READY + assert result.step_details is None + + +# endregion retry + + +def test_from_dict_with_none_result(): + """Test from_dict with None result.""" + data = { + "DurableExecutionArn": "test-arn", + "StartInput": {"function_name": "test"}, + "Operations": [], + "Updates": [], + "UsedTokens": [], + "TokenSequence": 0, + "IsComplete": False, + "Result": None, # None result + "ConsecutiveFailedInvocationAttempts": 0, + "CloseStatus": None, + } + + with patch( + "aws_durable_execution_sdk_python_testing.model.StartDurableExecutionInput.from_dict" + ) as mock_from_dict: + mock_from_dict.return_value = Mock() + execution = Execution.from_dict(data) + assert execution.result is None + + +# region callback +def test_find_callback_operation_not_found(): + """Test find_callback_operation raises exception when callback not found.""" + execution = Execution("test-arn", Mock(), []) + + with pytest.raises( + IllegalStateException, + match="Callback operation with callback_id \\[nonexistent\\] not found", + ): + execution.find_callback_operation("nonexistent") + + +def test_complete_callback_success_not_started(): + """Test complete_callback_success raises exception when callback not in STARTED state.""" + # Create callback operation in wrong state + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.SUCCEEDED, # Wrong state + callback_details=CallbackDetails(callback_id="test-id"), + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + + with pytest.raises( + IllegalStateException, + match="Callback operation \\[test-id\\] is not in STARTED state", + ): + execution.complete_callback_success("test-id") + + +def test_complete_callback_failure_not_started(): + """Test complete_callback_failure raises exception when callback not in STARTED state.""" + # Create callback operation in wrong state + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.FAILED, # Wrong state + callback_details=CallbackDetails(callback_id="test-id"), + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + error = ErrorObject.from_message("test error") + + with pytest.raises( + IllegalStateException, + match="Callback operation \\[test-id\\] is not in STARTED state", + ): + execution.complete_callback_failure("test-id", error) + + +def test_complete_callback_success_no_callback_details(): + """Test complete_callback_success with operation that has no callback_details.""" + callback_details = CallbackDetails(callback_id="test-id") + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=callback_details, + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + + # Test with None result + result = execution.complete_callback_success("test-id", None) + assert result.status == OperationStatus.SUCCEEDED + + +def test_complete_callback_failure_no_callback_details(): + """Test complete_callback_failure with operation that has no callback_details.""" + callback_details = CallbackDetails(callback_id="test-id") + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=callback_details, + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + error = ErrorObject.from_message("test error") + + # Test with actual callback details + result = execution.complete_callback_failure("test-id", error) + assert result.status == OperationStatus.FAILED + + +# region callback - details + + +def test_complete_callback_success_with_none_callback_details(): + """Test complete_callback_success when operation has None callback_details.""" + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=None, # None callback details + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + + # Mock find_callback_operation to return this operation + execution.find_callback_operation = Mock(return_value=(0, callback_op)) + + result = execution.complete_callback_success("test-id", b"result") + assert result.status == OperationStatus.SUCCEEDED + assert result.callback_details is None + + +def test_complete_callback_failure_with_none_callback_details(): + """Test complete_callback_failure when operation has None callback_details.""" + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=None, # None callback details + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + error = ErrorObject.from_message("test error") + + # Mock find_callback_operation to return this operation + execution.find_callback_operation = Mock(return_value=(0, callback_op)) + + result = execution.complete_callback_failure("test-id", error) + assert result.status == OperationStatus.FAILED + assert result.callback_details is None + + +def test_complete_callback_success_with_bytes_result(): + """Test complete_callback_success with bytes result that gets decoded.""" + callback_details = CallbackDetails(callback_id="test-id") + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=callback_details, + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + + result = execution.complete_callback_success("test-id", b"test result") + assert result.status == OperationStatus.SUCCEEDED + assert result.callback_details.result == "test result" + + +def test_complete_callback_success_with_none_result(): + """Test complete_callback_success with None result.""" + callback_details = CallbackDetails(callback_id="test-id") + callback_op = Operation( + operation_id="op-1", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=callback_details, + ) + + execution = Execution("test-arn", Mock(), [callback_op]) + + result = execution.complete_callback_success("test-id", None) + assert result.status == OperationStatus.SUCCEEDED + assert result.callback_details.result is None + + +# endregion callback -details + +# endregion callback diff --git a/tests/executor_test.py b/tests/executor_test.py index 2d8ab06..7ae5164 100644 --- a/tests/executor_test.py +++ b/tests/executor_test.py @@ -1,23 +1,28 @@ """Unit tests for executor module.""" import asyncio -import uuid from datetime import UTC, datetime from unittest.mock import Mock, patch import pytest + from aws_durable_execution_sdk_python.execution import ( DurableExecutionInvocationOutput, InvocationStatus, ) from aws_durable_execution_sdk_python.lambda_service import ( - ErrorObject, - ExecutionDetails, + CallbackOptions, + OperationUpdate, + OperationAction, + OperationType, Operation, OperationStatus, - OperationType, + CallbackDetails, +) +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + ExecutionDetails, ) - from aws_durable_execution_sdk_python_testing.exceptions import ( ExecutionAlreadyStartedException, IllegalStateException, @@ -34,6 +39,9 @@ StartDurableExecutionInput, ) from aws_durable_execution_sdk_python_testing.observer import ExecutionObserver +from aws_durable_execution_sdk_python_testing.token import ( + CallbackToken, +) class MockExecutionObserver(ExecutionObserver): @@ -44,6 +52,7 @@ def __init__(self): self.failed_executions = {} self.wait_timers = {} self.retry_schedules = {} + self.callback_creations = {} def on_completed(self, execution_arn: str, result: str | None = None) -> None: """Capture completion events.""" @@ -68,6 +77,29 @@ def on_step_retry_scheduled( "delay": delay, } + def on_callback_created( + self, execution_arn: str, operation_id: str, callback_token: CallbackToken + ) -> None: + """Capture callback creation events.""" + self.callback_creations[execution_arn] = { + "operation_id": operation_id, + "callback_id": callback_token.to_str(), + } + + def on_callback_completed( + self, execution_arn: str, operation_id: str, callback_id: str + ) -> None: + """Capture callback completion events.""" + pass # Not needed for current tests + + def on_timed_out(self, execution_arn: str, error: ErrorObject) -> None: + """Capture timeout events.""" + pass # Not needed for current tests + + def on_stopped(self, execution_arn: str, error: ErrorObject) -> None: + """Capture stop events.""" + pass # Not needed for current tests + @pytest.fixture def test_observer(): @@ -2170,12 +2202,29 @@ def test_checkpoint_execution_invalid_token(executor, mock_store): # Callback method tests -def test_send_callback_success(executor): +def test_send_callback_success(executor, mock_store): """Test send_callback_success method.""" + from aws_durable_execution_sdk_python_testing.token import CallbackToken - result = executor.send_callback_success("test-callback-id", "success-result") + # Create valid callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with callback operation + mock_execution = Mock() + mock_execution.find_callback_operation.return_value = (0, Mock()) + mock_execution.complete_callback_success.return_value = Mock() + mock_store.load.return_value = mock_execution + + with patch.object(executor, "_invoke_execution"): + result = executor.send_callback_success(callback_id, b"success-result") assert isinstance(result, SendDurableExecutionCallbackSuccessResponse) + mock_store.load.assert_called_once_with("test-arn") + mock_execution.complete_callback_success.assert_called_once_with( + callback_id, b"success-result" + ) + mock_store.update.assert_called_once_with(mock_execution) def test_send_callback_success_empty_callback_id(executor): @@ -2190,19 +2239,46 @@ def test_send_callback_success_none_callback_id(executor): executor.send_callback_success(None) -def test_send_callback_success_with_result(executor): +def test_send_callback_success_with_result(executor, mock_store): """Test send_callback_success with result data.""" - result = executor.send_callback_success("test-callback-id", "test-result") + from aws_durable_execution_sdk_python_testing.token import CallbackToken + + # Create valid callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with callback operation + mock_execution = Mock() + mock_execution.find_callback_operation.return_value = (0, Mock()) + mock_execution.complete_callback_success.return_value = Mock() + mock_store.load.return_value = mock_execution + + with patch.object(executor, "_invoke_execution"): + result = executor.send_callback_success(callback_id, b"test-result") assert isinstance(result, SendDurableExecutionCallbackSuccessResponse) -def test_send_callback_failure(executor): +def test_send_callback_failure(executor, mock_store): """Test send_callback_failure method.""" + from aws_durable_execution_sdk_python_testing.token import CallbackToken + + # Create valid callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with callback operation + mock_execution = Mock() + mock_execution.find_callback_operation.return_value = (0, Mock()) + mock_execution.complete_callback_failure.return_value = Mock() + mock_store.load.return_value = mock_execution - result = executor.send_callback_failure("test-callback-id") + with patch.object(executor, "_invoke_execution"): + result = executor.send_callback_failure(callback_id) assert isinstance(result, SendDurableExecutionCallbackFailureResponse) + mock_store.load.assert_called_once_with("test-arn") + mock_store.update.assert_called_once_with(mock_execution) def test_send_callback_failure_empty_callback_id(executor): @@ -2217,20 +2293,46 @@ def test_send_callback_failure_none_callback_id(executor): executor.send_callback_failure(None) -def test_send_callback_failure_with_error(executor): +def test_send_callback_failure_with_error(executor, mock_store): """Test send_callback_failure with error object.""" + # Create valid callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with callback operation + mock_execution = Mock() + mock_execution.find_callback_operation.return_value = (0, Mock()) + mock_execution.complete_callback_failure.return_value = Mock() + mock_store.load.return_value = mock_execution + error = ErrorObject.from_message("Test callback error") - result = executor.send_callback_failure("test-callback-id", error) + with patch.object(executor, "_invoke_execution"): + result = executor.send_callback_failure(callback_id, error) assert isinstance(result, SendDurableExecutionCallbackFailureResponse) + mock_execution.complete_callback_failure.assert_called_once_with(callback_id, error) -def test_send_callback_heartbeat(executor): +def test_send_callback_heartbeat(executor, mock_store): """Test send_callback_heartbeat method.""" + # Create valid callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with callback operation + mock_execution = Mock() + mock_operation = Mock() + mock_operation.status = OperationStatus.STARTED + mock_execution.find_callback_operation.return_value = (0, mock_operation) + mock_execution.updates = [] # No callback options to reset timeout + mock_store.load.return_value = mock_execution - result = executor.send_callback_heartbeat("test-callback-id") + result = executor.send_callback_heartbeat(callback_id) assert isinstance(result, SendDurableExecutionCallbackHeartbeatResponse) + # Called twice: once in get_execution, once in _reset_callback_heartbeat_timeout + assert mock_store.load.call_count == 2 + mock_execution.find_callback_operation.assert_called_once_with(callback_id) def test_send_callback_heartbeat_empty_callback_id(executor): @@ -2243,3 +2345,388 @@ def test_send_callback_heartbeat_none_callback_id(executor): """Test send_callback_heartbeat with None callback_id.""" with pytest.raises(InvalidParameterValueException, match="callback_id is required"): executor.send_callback_heartbeat(None) + + +def test_complete_execution_no_result(mock_store, executor): + """Test complete_execution when execution has no result after completion.""" + mock_execution = Mock() + mock_execution.result = None # No result after completion + mock_store.load.return_value = mock_execution + + with patch.object(executor, "_complete_events"): + with pytest.raises(IllegalStateException, match="Execution result is required"): + executor.complete_execution("test-arn", "result") + + +def test_fail_execution_no_result(mock_store, executor): + """Test fail_execution when execution has no result after failure.""" + mock_execution = Mock() + mock_execution.result = None # No result after failure + mock_store.load.return_value = mock_execution + error = ErrorObject.from_message("test error") + + with patch.object(executor, "_complete_events"): + with pytest.raises(IllegalStateException, match="Execution result is required"): + executor.fail_execution("test-arn", error) + + +def test_send_callback_heartbeat_inactive_callback(mock_store, executor): + """Test send_callback_heartbeat with inactive callback.""" + + # Create valid callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with inactive callback operation + mock_execution = Mock() + mock_operation = Mock() + mock_operation.status = OperationStatus.SUCCEEDED # Not STARTED + mock_execution.find_callback_operation.return_value = (0, mock_operation) + mock_store.load.return_value = mock_execution + + with pytest.raises(ResourceNotFoundException, match="Callback .* is not active"): + executor.send_callback_heartbeat(callback_id) + + +def test_send_callback_success_invalid_token(executor): + """Test send_callback_success with invalid token format.""" + with pytest.raises( + ResourceNotFoundException, match="Failed to process callback success" + ): + executor.send_callback_success("invalid-token") + + +def test_send_callback_failure_invalid_token(executor): + """Test send_callback_failure with invalid token format.""" + with pytest.raises( + ResourceNotFoundException, match="Failed to process callback failure" + ): + executor.send_callback_failure("invalid-token") + + +def test_send_callback_heartbeat_invalid_token(executor): + """Test send_callback_heartbeat with invalid token format.""" + with pytest.raises( + ResourceNotFoundException, match="Failed to process callback heartbeat" + ): + executor.send_callback_heartbeat("invalid-token") + + +def test_complete_events_no_event(executor): + """Test _complete_events when no event exists.""" + # Should not raise exception when event doesn't exist + executor._complete_events("nonexistent-arn") # Should handle gracefully + + +# Tests for callback timeout functionality + + +def test_callback_timeout_scheduling(executor, mock_store, mock_scheduler): + """Test that callback timeouts are scheduled when callback is created.""" + # Create mock execution with callback operation and updates + mock_execution = Mock() + mock_execution.durable_execution_arn = "test-arn" + + # Create callback operation with details + callback_operation = Operation( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=CallbackDetails(callback_id="callback-id"), + ) + mock_execution.find_operation.return_value = (0, callback_operation) + + # Create callback update with timeout options + callback_options = CallbackOptions(timeout_seconds=60, heartbeat_timeout_seconds=30) + update = OperationUpdate( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + callback_options=callback_options, + ) + mock_execution.updates = [update] + + mock_store.load.return_value = mock_execution + mock_scheduler.create_event.return_value = Mock() + + # Set up completion event + executor._completion_events["test-arn"] = Mock() + + # Test the timeout scheduling directly + executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id") + + # Verify scheduler was called for both timeouts + assert mock_scheduler.call_later.call_count == 2 # main timeout + heartbeat timeout + assert mock_scheduler.create_event.call_count == 2 # events for both timeouts + + +def test_callback_timeout_cleanup(executor, mock_store): + """Test that callback timeouts are cleaned up when callback completes.""" + # Create mock timeout events + timeout_event = Mock() + heartbeat_event = Mock() + + executor._callback_timeouts["callback-id"] = timeout_event + executor._callback_heartbeats["callback-id"] = heartbeat_event + + # Trigger cleanup + executor._cleanup_callback_timeouts("callback-id") + + # Verify events were removed and cleaned up + timeout_event.remove.assert_called_once() + heartbeat_event.remove.assert_called_once() + assert "callback-id" not in executor._callback_timeouts + assert "callback-id" not in executor._callback_heartbeats + + +def test_callback_heartbeat_timeout_reset(executor, mock_store, mock_scheduler): + """Test that heartbeat timeout is reset when heartbeat is received.""" + + # Create callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution with callback options + mock_execution = Mock() + callback_options = CallbackOptions(heartbeat_timeout_seconds=30) + update = OperationUpdate( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + callback_options=callback_options, + ) + mock_execution.updates = [update] + + mock_store.load.return_value = mock_execution + mock_scheduler.create_event.return_value = Mock() + + # Set up existing heartbeat event + old_event = Mock() + executor._callback_heartbeats[callback_id] = old_event + + # Reset heartbeat timeout + executor._reset_callback_heartbeat_timeout(callback_id, "test-arn") + + # Verify old event was removed and new one scheduled + old_event.remove.assert_called_once() + mock_scheduler.call_later.assert_called() + + +def test_callback_timeout_handlers(executor, mock_store): + """Test callback timeout and heartbeat timeout handlers.""" + # Create callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create mock execution + mock_execution = Mock() + mock_execution.is_complete = False + mock_store.load.return_value = mock_execution + + with patch.object(executor, "_invoke_execution"): + # Test main timeout handler + executor._on_callback_timeout("test-arn", callback_id) + + # Verify callback was failed with timeout error + mock_execution.complete_callback_failure.assert_called() + timeout_error = mock_execution.complete_callback_failure.call_args[0][1] + assert "Callback.Timeout" in str(timeout_error.message) + + # Test heartbeat timeout handler + executor._on_callback_heartbeat_timeout("test-arn", callback_id) + + # Verify callback was failed with heartbeat timeout error + assert mock_execution.complete_callback_failure.call_count == 2 + heartbeat_error = mock_execution.complete_callback_failure.call_args[0][1] + assert "Callback.Heartbeat" in str(heartbeat_error.message) + + +def test_callback_timeout_completed_execution(executor, mock_store): + """Test that timeout handlers ignore completed executions.""" + + # Create callback token + callback_token = CallbackToken(execution_arn="test-arn", operation_id="op-123") + callback_id = callback_token.to_str() + + # Create completed execution + mock_execution = Mock() + mock_execution.is_complete = True + mock_store.load.return_value = mock_execution + + # Test timeout handlers with completed execution + executor._on_callback_timeout("test-arn", callback_id) + executor._on_callback_heartbeat_timeout("test-arn", callback_id) + + # Verify no callback operations were performed + mock_execution.complete_callback_failure.assert_not_called() + mock_store.update.assert_not_called() + + +def test_schedule_callback_timeouts_no_callback_details(executor, mock_store): + """Test _schedule_callback_timeouts when operation has no callback details.""" + + # Create operation without callback details + operation = Operation( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=None, + ) + + mock_execution = Mock() + mock_execution.find_operation.return_value = (0, operation) + mock_store.load.return_value = mock_execution + + # Should return early without scheduling + executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id") + + # No scheduler calls should be made + assert len(executor._callback_timeouts) == 0 + assert len(executor._callback_heartbeats) == 0 + + +def test_schedule_callback_timeouts_no_callback_options(executor, mock_store): + """Test _schedule_callback_timeouts when no callback options are found.""" + + # Create operation with callback details but no matching updates + operation = Operation( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=CallbackDetails(callback_id="callback-id"), + ) + + mock_execution = Mock() + mock_execution.find_operation.return_value = (0, operation) + mock_execution.updates = [] # No updates with callback options + mock_store.load.return_value = mock_execution + + # Should return early without scheduling + executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id") + + # No scheduler calls should be made + assert len(executor._callback_timeouts) == 0 + assert len(executor._callback_heartbeats) == 0 + + +def test_schedule_callback_timeouts_zero_timeouts(executor, mock_store, mock_scheduler): + """Test _schedule_callback_timeouts with zero timeout values.""" + # Create operation with callback details + operation = Operation( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=CallbackDetails(callback_id="callback-id"), + ) + + mock_execution = Mock() + mock_execution.find_operation.return_value = (0, operation) + + # Create update with zero timeouts (disabled) + callback_options = CallbackOptions(timeout_seconds=0, heartbeat_timeout_seconds=0) + update = OperationUpdate( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + callback_options=callback_options, + ) + mock_execution.updates = [update] + + mock_store.load.return_value = mock_execution + executor._completion_events["test-arn"] = Mock() + + # Should not schedule any timeouts + executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id") + + # No scheduler calls should be made + mock_scheduler.call_later.assert_not_called() + assert len(executor._callback_timeouts) == 0 + assert len(executor._callback_heartbeats) == 0 + + +def test_schedule_callback_timeouts_only_main_timeout( + executor, mock_store, mock_scheduler +): + """Test _schedule_callback_timeouts with only main timeout configured.""" + + # Create operation with callback details + operation = Operation( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=CallbackDetails(callback_id="callback-id"), + ) + + mock_execution = Mock() + mock_execution.find_operation.return_value = (0, operation) + + # Create update with only main timeout + callback_options = CallbackOptions(timeout_seconds=60, heartbeat_timeout_seconds=0) + update = OperationUpdate( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + callback_options=callback_options, + ) + mock_execution.updates = [update] + + mock_store.load.return_value = mock_execution + mock_scheduler.create_event.return_value = Mock() + executor._completion_events["test-arn"] = Mock() + + executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id") + + # Only main timeout should be scheduled + assert mock_scheduler.call_later.call_count == 1 + assert len(executor._callback_timeouts) == 1 + assert len(executor._callback_heartbeats) == 0 + + +def test_schedule_callback_timeouts_only_heartbeat_timeout( + executor, mock_store, mock_scheduler +): + """Test _schedule_callback_timeouts with only heartbeat timeout configured.""" + # Create operation with callback details + operation = Operation( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + status=OperationStatus.STARTED, + callback_details=CallbackDetails(callback_id="callback-id"), + ) + + mock_execution = Mock() + mock_execution.find_operation.return_value = (0, operation) + + # Create update with only heartbeat timeout + callback_options = CallbackOptions(timeout_seconds=0, heartbeat_timeout_seconds=30) + update = OperationUpdate( + operation_id="op-123", + operation_type=OperationType.CALLBACK, + action=OperationAction.START, + callback_options=callback_options, + ) + mock_execution.updates = [update] + + mock_store.load.return_value = mock_execution + mock_scheduler.create_event.return_value = Mock() + executor._completion_events["test-arn"] = Mock() + + executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id") + + # Only heartbeat timeout should be scheduled + assert mock_scheduler.call_later.call_count == 1 + assert len(executor._callback_timeouts) == 0 + assert len(executor._callback_heartbeats) == 1 + + +def test_schedule_callback_timeouts_exception_handling(executor, mock_store): + """Test _schedule_callback_timeouts handles exceptions gracefully.""" + # Make get_execution raise an exception + mock_store.load.side_effect = Exception("Test error") + + # Should not raise exception + executor._schedule_callback_timeouts("test-arn", "op-123", "callback-id") + + # No timeouts should be scheduled + assert len(executor._callback_timeouts) == 0 + assert len(executor._callback_heartbeats) == 0 diff --git a/tests/observer_test.py b/tests/observer_test.py index 2944a23..9464452 100644 --- a/tests/observer_test.py +++ b/tests/observer_test.py @@ -11,6 +11,7 @@ ExecutionNotifier, ExecutionObserver, ) +from aws_durable_execution_sdk_python_testing.token import CallbackToken class MockExecutionObserver(ExecutionObserver): @@ -21,6 +22,7 @@ def __init__(self): self.on_failed_calls = [] self.on_wait_timer_scheduled_calls = [] self.on_step_retry_scheduled_calls = [] + self.on_callback_created_calls = [] def on_completed(self, execution_arn: str, result: str | None = None) -> None: self.on_completed_calls.append((execution_arn, result)) @@ -38,6 +40,13 @@ def on_step_retry_scheduled( ) -> None: self.on_step_retry_scheduled_calls.append((execution_arn, operation_id, delay)) + def on_callback_created( + self, execution_arn: str, operation_id: str, callback_token: CallbackToken + ) -> None: + self.on_callback_created_calls.append( + (execution_arn, operation_id, callback_token) + ) + def test_execution_notifier_init(): """Test ExecutionNotifier initialization."""