diff --git a/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py b/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py index 9189706..04b991c 100644 --- a/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py +++ b/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py @@ -71,15 +71,15 @@ def process_checkpoint( execution_arn=token.execution_arn, ) - # 5. Save update + # 5. Generate a new checkpoint token and save updated operations + new_checkpoint_token = execution.get_new_checkpoint_token() execution.operations = updated_operations execution.updates.extend(all_updates) - self._store.update(execution) # 6. Return checkpoint result return CheckpointOutput( - checkpoint_token=execution.get_new_checkpoint_token(), + checkpoint_token=new_checkpoint_token, new_execution_state=CheckpointUpdatedExecutionState( operations=execution.get_navigable_operations(), next_marker=None ), diff --git a/src/aws_durable_execution_sdk_python_testing/executor.py b/src/aws_durable_execution_sdk_python_testing/executor.py index 12b4352..518bc9e 100644 --- a/src/aws_durable_execution_sdk_python_testing/executor.py +++ b/src/aws_durable_execution_sdk_python_testing/executor.py @@ -28,6 +28,7 @@ ResourceNotFoundException, ) from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.exceptions import IllegalStateException from aws_durable_execution_sdk_python_testing.model import ( CheckpointDurableExecutionResponse, CheckpointUpdatedExecutionState, @@ -611,8 +612,12 @@ def checkpoint_execution( new_execution_state=new_execution_state, ) + # Save execution state after generating new token + new_checkpoint_token = execution.get_new_checkpoint_token() + self._store.update(execution) + return CheckpointDurableExecutionResponse( - checkpoint_token=execution.get_new_checkpoint_token(), + checkpoint_token=new_checkpoint_token, new_execution_state=None, ) @@ -644,6 +649,7 @@ def send_callback_success( execution.complete_callback_success(callback_id, result) self._store.update(execution) self._cleanup_callback_timeouts(callback_id) + self._invoke_execution(callback_token.execution_arn) logger.info("Callback success completed for callback_id: %s", callback_id) except Exception as e: msg = f"Failed to process callback success: {e}" @@ -681,6 +687,7 @@ def send_callback_failure( execution.complete_callback_failure(callback_id, callback_error) self._store.update(execution) self._cleanup_callback_timeouts(callback_id) + self._invoke_execution(callback_token.execution_arn) logger.info("Callback failure completed for callback_id: %s", callback_id) except Exception as e: msg = f"Failed to process callback failure: {e}" @@ -944,7 +951,7 @@ def complete_execution(self, execution_arn: str, result: str | None = None) -> N def fail_execution(self, execution_arn: str, error: ErrorObject) -> None: """Fail execution with error.""" - logger.exception("[%s] Completing execution with error.", execution_arn) + logger.error("[%s] Completing execution with error: %s", execution_arn, error) execution: Execution = self._store.load(execution_arn=execution_arn) execution.complete_fail(error=error) self._store.update(execution) @@ -1190,9 +1197,8 @@ def _on_callback_timeout(self, execution_arn: str, callback_id: str) -> None: f"Callback timed out: {CallbackTimeoutType.TIMEOUT.value}" ) execution.complete_callback_failure(callback_id, timeout_error) + execution.complete_fail(timeout_error) self._store.update(execution) - self._invoke_execution(execution_arn) - logger.warning("[%s] Callback %s timed out", execution_arn, callback_id) except Exception: logger.exception( @@ -1218,9 +1224,8 @@ def _on_callback_heartbeat_timeout( f"Callback heartbeat timed out: {CallbackTimeoutType.HEARTBEAT.value}" ) execution.complete_callback_failure(callback_id, heartbeat_error) + execution.complete_fail(heartbeat_error) self._store.update(execution) - self._invoke_execution(execution_arn) - logger.warning( "[%s] Callback %s heartbeat timed out", execution_arn, callback_id ) diff --git a/tests/executor_test.py b/tests/executor_test.py index 7ae5164..a598950 100644 --- a/tests/executor_test.py +++ b/tests/executor_test.py @@ -2216,7 +2216,7 @@ def test_send_callback_success(executor, mock_store): mock_execution.complete_callback_success.return_value = Mock() mock_store.load.return_value = mock_execution - with patch.object(executor, "_invoke_execution"): + with patch.object(executor, "_invoke_execution") as mock_invoke: result = executor.send_callback_success(callback_id, b"success-result") assert isinstance(result, SendDurableExecutionCallbackSuccessResponse) @@ -2225,6 +2225,8 @@ def test_send_callback_success(executor, mock_store): callback_id, b"success-result" ) mock_store.update.assert_called_once_with(mock_execution) + # Verify execution is invoked after callback success + mock_invoke.assert_called_once_with("test-arn") def test_send_callback_success_empty_callback_id(executor): @@ -2253,10 +2255,15 @@ def test_send_callback_success_with_result(executor, mock_store): mock_execution.complete_callback_success.return_value = Mock() mock_store.load.return_value = mock_execution - with patch.object(executor, "_invoke_execution"): + with patch.object(executor, "_invoke_execution") as mock_invoke: result = executor.send_callback_success(callback_id, b"test-result") assert isinstance(result, SendDurableExecutionCallbackSuccessResponse) + mock_execution.complete_callback_success.assert_called_once_with( + callback_id, b"test-result" + ) + # Verify execution is invoked after callback success + mock_invoke.assert_called_once_with("test-arn") def test_send_callback_failure(executor, mock_store): @@ -2273,12 +2280,14 @@ def test_send_callback_failure(executor, mock_store): mock_execution.complete_callback_failure.return_value = Mock() mock_store.load.return_value = mock_execution - with patch.object(executor, "_invoke_execution"): + with patch.object(executor, "_invoke_execution") as mock_invoke: result = executor.send_callback_failure(callback_id) assert isinstance(result, SendDurableExecutionCallbackFailureResponse) mock_store.load.assert_called_once_with("test-arn") mock_store.update.assert_called_once_with(mock_execution) + # Verify execution is invoked after callback failure + mock_invoke.assert_called_once_with("test-arn") def test_send_callback_failure_empty_callback_id(executor): @@ -2306,11 +2315,13 @@ def test_send_callback_failure_with_error(executor, mock_store): mock_store.load.return_value = mock_execution error = ErrorObject.from_message("Test callback error") - with patch.object(executor, "_invoke_execution"): + with patch.object(executor, "_invoke_execution") as mock_invoke: result = executor.send_callback_failure(callback_id, error) assert isinstance(result, SendDurableExecutionCallbackFailureResponse) mock_execution.complete_callback_failure.assert_called_once_with(callback_id, error) + # Verify execution is invoked after callback failure + mock_invoke.assert_called_once_with("test-arn") def test_send_callback_heartbeat(executor, mock_store):