diff --git a/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py b/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py index 04b991c..f4937c0 100644 --- a/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py +++ b/src/aws_durable_execution_sdk_python_testing/checkpoint/processor.py @@ -9,6 +9,7 @@ CheckpointUpdatedExecutionState, OperationUpdate, StateOutput, + Operation, ) from aws_durable_execution_sdk_python_testing.checkpoint.transformer import ( @@ -88,14 +89,37 @@ def process_checkpoint( def get_execution_state( self, checkpoint_token: str, - next_marker: str, # noqa: ARG002 - max_items: int = 1000, # noqa: ARG002 + next_marker: str | None = None, + max_items: int = 1000, ) -> StateOutput: - """Get current execution state.""" + """Get current execution state with batched checkpoint token validation and pagination.""" + if not checkpoint_token: + msg: str = "Checkpoint token is required" + raise InvalidParameterValueException(msg) + token: CheckpointToken = CheckpointToken.from_str(checkpoint_token) execution: Execution = self._store.load(token.execution_arn) + execution.validate_checkpoint_token(checkpoint_token) + + # Get all operations + all_operations: list[Operation] = execution.get_navigable_operations() + + # Apply pagination + start_index: int = 0 + if next_marker: + try: + start_index = int(next_marker) + except ValueError: + start_index = 0 + + end_index: int = start_index + max_items + paginated_operations: list[Operation] = all_operations[start_index:end_index] + + # Determine next marker + next_marker_result: str | None = ( + str(end_index) if end_index < len(all_operations) else None + ) - # TODO: paging when size or max return StateOutput( - operations=execution.get_navigable_operations(), next_marker=None + operations=paginated_operations, next_marker=next_marker_result ) diff --git a/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py b/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py index e8ad2ef..5f5f788 100644 --- a/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py +++ b/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py @@ -38,12 +38,8 @@ def process( ) case _: # intentional. actual service will fail any EXECUTION update that is not SUCCEED. - error = ( - update.error - if update.error - else ErrorObject.from_message( - "There is no error details but EXECUTION checkpoint action is not SUCCEED." - ) + error = update.error or ErrorObject.from_message( + "There is no error details but EXECUTION checkpoint action is not SUCCEED." ) # All EXECUTION failures go through normal fail path # Timeout/Stop status is set by executor based on the operation that caused it diff --git a/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py b/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py index cd37b8a..223f56f 100644 --- a/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py +++ b/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py @@ -55,7 +55,7 @@ def __init__( self, processors: MutableMapping[OperationType, OperationProcessor] | None = None, ): - self.processors = processors if processors else self._DEFAULT_PROCESSORS + self.processors = processors or self._DEFAULT_PROCESSORS def process_updates( self, diff --git a/src/aws_durable_execution_sdk_python_testing/execution.py b/src/aws_durable_execution_sdk_python_testing/execution.py index b651bf1..43040ba 100644 --- a/src/aws_durable_execution_sdk_python_testing/execution.py +++ b/src/aws_durable_execution_sdk_python_testing/execution.py @@ -60,7 +60,7 @@ def __init__( self.start_input: StartDurableExecutionInput = start_input self.operations: list[Operation] = operations self.updates: list[OperationUpdate] = [] - self.used_tokens: set[str] = set() + self.generated_tokens: set[str] = set() # TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store self._token_sequence: int = 0 self._state_lock: Lock = Lock() @@ -101,7 +101,7 @@ def to_dict(self) -> dict[str, Any]: "StartInput": self.start_input.to_dict(), "Operations": [op.to_dict() for op in self.operations], "Updates": [update.to_dict() for update in self.updates], - "UsedTokens": list(self.used_tokens), + "GeneratedTokens": list(self.generated_tokens), "TokenSequence": self._token_sequence, "IsComplete": self.is_complete, "Result": self.result.to_dict() if self.result else None, @@ -129,7 +129,7 @@ def from_dict(cls, data: dict[str, Any]) -> Execution: execution.updates = [ OperationUpdate.from_dict(update_data) for update_data in data["Updates"] ] - execution.used_tokens = set(data["UsedTokens"]) + execution.generated_tokens = set(data["GeneratedTokens"]) execution._token_sequence = data["TokenSequence"] # noqa: SLF001 execution.is_complete = data["IsComplete"] execution.result = ( @@ -184,13 +184,38 @@ def get_new_checkpoint_token(self) -> str: token_sequence=new_token_sequence, ) token_str = token.to_str() - self.used_tokens.add(token_str) + self.generated_tokens.add(token_str) return token_str def get_navigable_operations(self) -> list[Operation]: """Get list of operations, but exclude child operations where the parent has already completed.""" return self.operations + def validate_checkpoint_token( + self, + token: str | None, + checkpoint_required_msg: str | None = None, + ) -> None: + """Validate checkpoint token against this execution.""" + if not token: + msg: str = checkpoint_required_msg or "Checkpoint token is required" + raise InvalidParameterValueException(msg) + + checkpoint_token: CheckpointToken = CheckpointToken.from_str(token) + if checkpoint_token.execution_arn != self.durable_execution_arn: + msg = "Checkpoint token does not match execution ARN" + raise InvalidParameterValueException(msg) + + if self.is_complete or checkpoint_token.token_sequence > self.token_sequence: + msg = "Invalid or expired checkpoint token" + raise InvalidParameterValueException(msg) + + # Check if token has been generated + token_str: str = checkpoint_token.to_str() + if token_str not in self.generated_tokens: + msg = f"Invalid checkpoint token: {token_str}" + raise InvalidParameterValueException(msg) + def get_assertable_operations(self) -> list[Operation]: """Get list of operations, but exclude the EXECUTION operations""" # TODO: this excludes EXECUTION at start, but can there be an EXECUTION at the end if there was a checkpoint with large payload? diff --git a/src/aws_durable_execution_sdk_python_testing/executor.py b/src/aws_durable_execution_sdk_python_testing/executor.py index 3c398f6..0e78107 100644 --- a/src/aws_durable_execution_sdk_python_testing/executor.py +++ b/src/aws_durable_execution_sdk_python_testing/executor.py @@ -58,6 +58,8 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from aws_durable_execution_sdk_python.lambda_service import Operation + from aws_durable_execution_sdk_python_testing.checkpoint.processor import ( CheckpointProcessor, ) @@ -347,32 +349,33 @@ def get_execution_state( ResourceNotFoundException: If execution does not exist InvalidParameterValueException: If checkpoint token is invalid """ - execution = self.get_execution(execution_arn) + execution: Execution = self.get_execution(execution_arn) + is_checkpoint_required: bool = not execution.is_complete and marker is not None - # TODO: Validate checkpoint token if provided - if checkpoint_token and checkpoint_token not in execution.used_tokens: - msg: str = f"Invalid checkpoint token: {checkpoint_token}" - raise InvalidParameterValueException(msg) + if is_checkpoint_required or checkpoint_token: + checkpoint_required_msg: str = "Checkpoint token is required for paginated requests on active executions" + execution.validate_checkpoint_token( + checkpoint_token, checkpoint_required_msg + ) # Get operations (excluding the initial EXECUTION operation for state) - operations = execution.get_assertable_operations() + operations: list[Operation] = execution.get_assertable_operations() # Apply pagination if max_items is None: max_items = 100 - # Simple pagination - in real implementation would need proper marker handling - start_index = 0 + start_index: int = 0 if marker: try: start_index = int(marker) except ValueError: start_index = 0 - end_index = start_index + max_items - paginated_operations = operations[start_index:end_index] + end_index: int = start_index + max_items + paginated_operations: list[Operation] = operations[start_index:end_index] - next_marker = None + next_marker: str | None = None if end_index < len(operations): next_marker = str(end_index) @@ -541,11 +544,10 @@ def checkpoint_execution( InvalidParameterValueException: If checkpoint token is invalid """ execution = self.get_execution(execution_arn) - - # Validate checkpoint token - if checkpoint_token not in execution.used_tokens: - msg: str = f"Invalid checkpoint token: {checkpoint_token}" - raise InvalidParameterValueException(msg) + execution.validate_checkpoint_token( + checkpoint_token, + checkpoint_required_msg="Checkpoint token is required for checkpoint operations", + ) if updates: checkpoint_output = self._checkpoint_processor.process_checkpoint( diff --git a/tests/execution_test.py b/tests/execution_test.py index 602698b..8639c37 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -43,7 +43,7 @@ def test_execution_init(): assert execution.start_input == start_input assert execution.operations == operations assert execution.updates == [] - assert execution.used_tokens == set() + assert execution.generated_tokens == set() assert execution.token_sequence == 0 assert execution.is_complete is False assert execution.consecutive_failed_invocation_attempts == 0 @@ -154,8 +154,8 @@ def test_get_new_checkpoint_token(): token2 = execution.get_new_checkpoint_token() assert execution.token_sequence == 2 - assert token1 in execution.used_tokens - assert token2 in execution.used_tokens + assert token1 in execution.generated_tokens + assert token2 in execution.generated_tokens assert token1 != token2 @@ -801,7 +801,7 @@ def test_from_dict_with_none_result(): "StartInput": {"function_name": "test"}, "Operations": [], "Updates": [], - "UsedTokens": [], + "GeneratedTokens": [], "TokenSequence": 0, "IsComplete": False, "Result": None, # None result diff --git a/tests/executor_test.py b/tests/executor_test.py index 008a4a0..41cabef 100644 --- a/tests/executor_test.py +++ b/tests/executor_test.py @@ -47,6 +47,7 @@ ) from aws_durable_execution_sdk_python_testing.token import ( CallbackToken, + CheckpointToken, ) @@ -300,7 +301,6 @@ def test_should_complete_workflow_with_error_when_invocation_fails( handler = mock_scheduler.call_later.call_args[0][0] # Execute the handler to trigger the invocation logic - import asyncio asyncio.run(handler()) @@ -344,7 +344,6 @@ def test_should_complete_workflow_with_result_when_invocation_succeeds( handler = mock_scheduler.call_later.call_args[0][0] # Execute the handler to trigger the invocation logic - import asyncio asyncio.run(handler()) @@ -385,7 +384,6 @@ def test_should_handle_pending_status_when_operations_exist( handler = mock_scheduler.call_later.call_args[0][0] # Execute the handler to trigger the invocation logic - import asyncio asyncio.run(handler()) @@ -424,7 +422,6 @@ def test_should_ignore_response_when_execution_already_complete( handler = mock_scheduler.call_later.call_args[0][0] # Execute the handler to trigger the invocation logic - import asyncio asyncio.run(handler()) @@ -914,7 +911,6 @@ def test_should_fail_execution_when_function_not_found( handler = mock_scheduler.call_later.call_args[0][0] # Execute the handler to trigger the invocation logic - import asyncio asyncio.run(handler()) @@ -958,7 +954,6 @@ def test_should_fail_execution_when_retries_exhausted( handler = mock_scheduler.call_later.call_args[0][0] # Execute the handler to trigger the invocation logic - import asyncio asyncio.run(handler()) @@ -1045,7 +1040,6 @@ def test_should_retry_invocation_when_under_limit_through_public_api( # Simulate scheduler executing the initial invocation handler initial_handler = mock_scheduler.call_later.call_args[0][0] - import asyncio asyncio.run(initial_handler()) @@ -2037,9 +2031,15 @@ def test_get_execution_not_found(executor, mock_store): def test_get_execution_state(executor, mock_store): """Test get_execution_state method.""" - mock_execution = Mock() - mock_execution.used_tokens = {"token1", "token2"} + mock_execution.durable_execution_arn = "test-arn" + mock_execution.token_sequence = 5 + mock_execution.is_complete = False + + # Create valid token and add to used_tokens + token = CheckpointToken("test-arn", 3) + valid_token = token.to_str() + mock_execution.used_tokens = {valid_token} # Create mock operations operations = [ @@ -2064,7 +2064,7 @@ def test_get_execution_state(executor, mock_store): mock_store.load.return_value = mock_execution - result = executor.get_execution_state("test-arn", checkpoint_token="token1") # noqa: S106 + result = executor.get_execution_state("test-arn", checkpoint_token=valid_token) assert len(result.operations) == 2 assert result.next_marker is None @@ -2073,14 +2073,22 @@ def test_get_execution_state(executor, mock_store): def test_get_execution_state_invalid_token(executor, mock_store): """Test get_execution_state with invalid checkpoint token.""" - mock_execution = Mock() - mock_execution.used_tokens = {"token1", "token2"} - mock_store.load.return_value = mock_execution + # Use real Execution object so validation actually runs + real_execution = Execution("test-arn", Mock(), []) + real_execution._token_sequence = 10 # noqa: SLF001 + # Don't add the token to used_tokens so it will be invalid + real_execution.generated_tokens = {"other-token"} + + mock_store.load.return_value = real_execution + + token = CheckpointToken("invalid-arn", 3) # Different ARN + invalid_token = token.to_str() with pytest.raises( - InvalidParameterValueException, match="Invalid checkpoint token" + InvalidParameterValueException, + match="Checkpoint token does not match execution ARN", ): - executor.get_execution_state("test-arn", checkpoint_token="invalid-token") # noqa: S106 + executor.get_execution_state("test-arn", checkpoint_token=invalid_token) def test_get_execution_history(executor, mock_store): @@ -2232,13 +2240,22 @@ def test_get_execution_history_invalid_marker(executor, mock_store): def test_checkpoint_execution(executor, mock_store): """Test checkpoint_execution method.""" mock_execution = Mock() - mock_execution.used_tokens = {"token1", "token2"} - mock_execution.get_new_checkpoint_token.return_value = "new-token" + mock_execution.durable_execution_arn = "test-arn" + mock_execution.token_sequence = 5 + mock_execution.is_complete = False + new_token = "new-token" # noqa:S105 + mock_execution.get_new_checkpoint_token.return_value = new_token + + # Create valid token and add to used_tokens + token = CheckpointToken("test-arn", 3) + valid_token = token.to_str() + mock_execution.used_tokens = {valid_token} + mock_store.load.return_value = mock_execution - result = executor.checkpoint_execution("test-arn", "token1") + result = executor.checkpoint_execution("test-arn", valid_token) - assert result.checkpoint_token == "new-token" # noqa: S105 + assert result.checkpoint_token == new_token assert result.new_execution_state is None mock_store.load.assert_called_once_with("test-arn") mock_execution.get_new_checkpoint_token.assert_called_once() @@ -2246,14 +2263,19 @@ def test_checkpoint_execution(executor, mock_store): def test_checkpoint_execution_invalid_token(executor, mock_store): """Test checkpoint_execution with invalid checkpoint token.""" - mock_execution = Mock() - mock_execution.used_tokens = {"token1", "token2"} - mock_store.load.return_value = mock_execution - + execution_arn = "test_arn" + start_durable_execution_input = Mock() + execution = Execution(execution_arn, start_durable_execution_input, []) + execution.generated_tokens = {"token1", "token2"} + execution.durable_execution_arn = execution_arn + mock_store.load.return_value = execution + token = CheckpointToken("execution-arn", 3) + invalid_token = token.to_str() with pytest.raises( - InvalidParameterValueException, match="Invalid checkpoint token" + InvalidParameterValueException, + match="Checkpoint token does not match execution ARN", ): - executor.checkpoint_execution("test-arn", "invalid-token") + executor.checkpoint_execution(execution_arn, invalid_token) # Callback method tests @@ -2860,3 +2882,41 @@ def test_notify_stopped(): notifier.notify_stopped("test-arn", error) observer.on_stopped.assert_called_once_with(execution_arn="test-arn", error=error) + + +def test_get_execution_state_no_token_with_marker_active_execution( + mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor +): + """Test get_execution_state fails when no token provided with marker on active execution.""" + executor = Executor( + mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor + ) + execution_arn = "test-arn" + + # Create an active execution + execution = Execution(execution_arn, "test-function", {}) + execution.is_complete = False + mock_store.load.return_value = execution + + with pytest.raises( + InvalidParameterValueException, match="Checkpoint token is required" + ): + executor.get_execution_state(execution_arn, marker="some-marker") + + +def test_checkpoint_execution_no_token( + mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor +): + """Test checkpoint_execution fails when no token provided.""" + executor = Executor( + mock_store, mock_scheduler, mock_invoker, mock_checkpoint_processor + ) + execution_arn = "test-arn" + + execution = Execution(execution_arn, "test-function", {}) + mock_store.load.return_value = execution + + with pytest.raises( + InvalidParameterValueException, match="Checkpoint token is required" + ): + executor.checkpoint_execution(execution_arn, "", [], "client-token") diff --git a/tests/token_test.py b/tests/token_test.py index 714d8c9..1c6aae7 100644 --- a/tests/token_test.py +++ b/tests/token_test.py @@ -2,9 +2,14 @@ import base64 import json +from unittest.mock import Mock import pytest +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.execution import Execution from aws_durable_execution_sdk_python_testing.token import ( CallbackToken, CheckpointToken, @@ -130,3 +135,82 @@ def test_callback_token_frozen_dataclass(): with pytest.raises(AttributeError): token.operation_id = "new-op" + + +def test_checkpoint_token_validate_for_execution_success(): + """Test successful token validation.""" + token = CheckpointToken("test-arn", 5) + execution = Execution("test-arn", Mock(), []) + execution._token_sequence = 10 # noqa: SLF001 + execution.generated_tokens = {token.to_str()} + + execution.validate_checkpoint_token(token.to_str()) + + +def test_checkpoint_token_validate_for_execution_arn_mismatch(): + """Test token validation fails when ARN doesn't match.""" + token = CheckpointToken("test-arn", 5) + execution = Execution("different-arn", "test-name", "test-input") + execution._token_sequence = 10 # noqa: SLF001 + + with pytest.raises( + InvalidParameterValueException, match="does not match execution ARN" + ): + execution.validate_checkpoint_token(token.to_str()) + + +def test_checkpoint_token_validate_for_execution_completed(): + """Test token validation fails when execution is complete.""" + token = CheckpointToken("test-arn", 5) + start_input = Mock() + execution = Execution("test-arn", start_input, []) + execution.generated_tokens = {token.to_str()} # Add token to used_tokens + execution.is_complete = True + + with pytest.raises(InvalidParameterValueException, match="Invalid or expired"): + execution.validate_checkpoint_token(token.to_str()) + + +def test_checkpoint_token_validate_for_execution_future_sequence(): + """Test token validation fails when token sequence is from future.""" + token = CheckpointToken("test-arn", 15) + execution = Execution("test-arn", "test-name", "test-input") + execution._token_sequence = 10 # noqa: SLF001 + + with pytest.raises(InvalidParameterValueException, match="Invalid or expired"): + execution.validate_checkpoint_token(token.to_str()) + + +def test_checkpoint_token_validate_for_execution_equal_sequence(): + """Test token validation succeeds when sequences are equal.""" + token = CheckpointToken("test-arn", 10) + execution = Execution("test-arn", "test-name", "test-input") + execution._token_sequence = 10 # noqa: SLF001 + execution.generated_tokens = {token.to_str()} + + execution.validate_checkpoint_token(token.to_str()) + + +def test_checkpoint_token_validate_for_execution_not_in_used_tokens(): + """Test token validation fails when token not in used_tokens.""" + token = CheckpointToken("test-arn", 5) + execution = Execution("test-arn", "test-name", "test-input") + execution._token_sequence = 10 # noqa: SLF001 + execution.generated_tokens = {"other-token"} + + with pytest.raises( + InvalidParameterValueException, match="Invalid checkpoint token" + ): + execution.validate_checkpoint_token(token.to_str()) + + +def test_checkpoint_token_validate_for_execution_in_used_tokens(): + """Test token validation succeeds when token is in used_tokens.""" + token = CheckpointToken("test-arn", 5) + execution = Execution("test-arn", "test-name", "test-input") + execution._token_sequence = 10 # noqa: SLF001 + # Mock the token string that would be generated + token_str = token.to_str() + execution.generated_tokens = {token_str} + + execution.validate_checkpoint_token(token_str)