From 77590d1fbad881c3b23453e87465585e009159e7 Mon Sep 17 00:00:00 2001 From: Rares Polenciuc Date: Mon, 20 Oct 2025 15:17:00 +0100 Subject: [PATCH] feat: complete sqlite store and function handler implementation - Add SQLiteExecutionStore with database persistence and indexing - Implement query system with pagination support - Add BaseExecutionStore with shared query processing logic - Update Executor to use new query system for efficient operations - Complete ListDurableExecutionsByFunctionHandler with proper filtering - Add function name validation and error handling - Add comprehensive test coverage for all implementations - Support concurrent access patterns with proper database handling --- .../checkpoint/processors/execution.py | 2 + .../execution.py | 66 +- .../executor.py | 171 ++-- .../model.py | 201 +++- .../observer.py | 20 + .../runner.py | 2 +- .../stores/base.py | 122 ++- .../stores/filesystem.py | 5 +- .../stores/memory.py | 6 +- .../stores/sqlite.py | 274 ++++++ .../web/handlers.py | 132 +-- tests/execution_test.py | 116 ++- tests/executor_test.py | 263 ++++-- tests/model_test.py | 22 +- tests/observer_test.py | 17 +- tests/stores/concurrent_test.py | 171 +++- tests/stores/filesystem_store_test.py | 152 ++++ tests/stores/memory_store_test.py | 344 +++++++ tests/stores/sqlite_store_test.py | 860 ++++++++++++++++++ tests/web/handlers_test.py | 36 +- 20 files changed, 2647 insertions(+), 335 deletions(-) create mode 100644 src/aws_durable_execution_sdk_python_testing/stores/sqlite.py create mode 100644 tests/stores/sqlite_store_test.py diff --git a/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py b/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py index b81117b..e8ad2ef 100644 --- a/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py +++ b/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/execution.py @@ -45,6 +45,8 @@ def process( "There is no error details but EXECUTION checkpoint action is not SUCCEED." ) ) + # All EXECUTION failures go through normal fail path + # Timeout/Stop status is set by executor based on the operation that caused it notifier.notify_failed(execution_arn=execution_arn, error=error) # TODO: Svc doesn't actually create checkpoint for EXECUTION. might have to for localrunner though. return None diff --git a/src/aws_durable_execution_sdk_python_testing/execution.py b/src/aws_durable_execution_sdk_python_testing/execution.py index 24e3f81..b651bf1 100644 --- a/src/aws_durable_execution_sdk_python_testing/execution.py +++ b/src/aws_durable_execution_sdk_python_testing/execution.py @@ -3,6 +3,7 @@ import json from dataclasses import replace from datetime import UTC, datetime +from enum import Enum from threading import Lock from typing import Any from uuid import uuid4 @@ -20,11 +21,12 @@ OperationUpdate, ) -# Import AWS exceptions from aws_durable_execution_sdk_python_testing.exceptions import ( IllegalStateException, InvalidParameterValueException, ) + +# Import AWS exceptions from aws_durable_execution_sdk_python_testing.model import ( StartDurableExecutionInput, ) @@ -34,6 +36,16 @@ ) +class ExecutionStatus(Enum): + """Execution status for API responses.""" + + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + STOPPED = "STOPPED" + TIMED_OUT = "TIMED_OUT" + + class Execution: """Execution state.""" @@ -55,12 +67,24 @@ def __init__( self.is_complete: bool = False self.result: DurableExecutionInvocationOutput | None = None self.consecutive_failed_invocation_attempts: int = 0 + self.close_status: ExecutionStatus | None = None @property def token_sequence(self) -> int: """Get current token sequence value.""" return self._token_sequence + def current_status(self) -> ExecutionStatus: + """Get execution status.""" + if not self.is_complete: + return ExecutionStatus.RUNNING + + if not self.close_status: + msg: str = "close_status must be set when execution is complete" + raise IllegalStateException(msg) + + return self.close_status + @staticmethod def new(input: StartDurableExecutionInput) -> Execution: # noqa: A002 # make a nicer arn @@ -82,6 +106,7 @@ def to_dict(self) -> dict[str, Any]: "IsComplete": self.is_complete, "Result": self.result.to_dict() if self.result else None, "ConsecutiveFailedInvocationAttempts": self.consecutive_failed_invocation_attempts, + "CloseStatus": self.close_status.value if self.close_status else None, } @classmethod @@ -115,6 +140,10 @@ def from_dict(cls, data: dict[str, Any]) -> Execution: execution.consecutive_failed_invocation_attempts = data[ "ConsecutiveFailedInvocationAttempts" ] + close_status_str = data.get("CloseStatus") + execution.close_status = ( + ExecutionStatus(close_status_str) if close_status_str else None + ) return execution @@ -187,16 +216,40 @@ def has_pending_operations(self, execution: Execution) -> bool: return False def complete_success(self, result: str | None) -> None: + """Complete execution successfully (DecisionType.COMPLETE_WORKFLOW_EXECUTION).""" self.result = DurableExecutionInvocationOutput( status=InvocationStatus.SUCCEEDED, result=result ) self.is_complete = True + self.close_status = ExecutionStatus.SUCCEEDED + self._end_execution(OperationStatus.SUCCEEDED) def complete_fail(self, error: ErrorObject) -> None: + """Complete execution with failure (DecisionType.FAIL_WORKFLOW_EXECUTION).""" self.result = DurableExecutionInvocationOutput( status=InvocationStatus.FAILED, error=error ) self.is_complete = True + self.close_status = ExecutionStatus.FAILED + self._end_execution(OperationStatus.FAILED) + + def complete_timeout(self, error: ErrorObject) -> None: + """Complete execution with timeout.""" + self.result = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, error=error + ) + self.is_complete = True + self.close_status = ExecutionStatus.TIMED_OUT + self._end_execution(OperationStatus.TIMED_OUT) + + def complete_stopped(self, error: ErrorObject) -> None: + """Complete execution as terminated (TerminateWorkflowExecutionV2Request).""" + self.result = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, error=error + ) + self.is_complete = True + self.close_status = ExecutionStatus.STOPPED + self._end_execution(OperationStatus.STOPPED) def find_operation(self, operation_id: str) -> tuple[int, Operation]: """Find operation by ID, return index and operation.""" @@ -327,3 +380,14 @@ def complete_callback_failure( callback_details=updated_callback_details, ) return self.operations[index] + + def _end_execution(self, status: OperationStatus) -> None: + """Set the end_timestamp on the main EXECUTION operation when execution completes.""" + execution_op: Operation = self.get_operation_execution_started() + if execution_op.operation_type == OperationType.EXECUTION: + with self._state_lock: + self.operations[0] = replace( + execution_op, + status=status, + end_timestamp=datetime.now(UTC), + ) diff --git a/src/aws_durable_execution_sdk_python_testing/executor.py b/src/aws_durable_execution_sdk_python_testing/executor.py index 518bc9e..3c398f6 100644 --- a/src/aws_durable_execution_sdk_python_testing/executor.py +++ b/src/aws_durable_execution_sdk_python_testing/executor.py @@ -28,7 +28,6 @@ ResourceNotFoundException, ) from aws_durable_execution_sdk_python_testing.execution import Execution -from aws_durable_execution_sdk_python_testing.exceptions import IllegalStateException from aws_durable_execution_sdk_python_testing.model import ( CheckpointDurableExecutionResponse, CheckpointUpdatedExecutionState, @@ -157,18 +156,7 @@ def get_execution_details(self, execution_arn: str) -> GetDurableExecutionRespon # Extract execution details from the first operation (EXECUTION type) execution_op = execution.get_operation_execution_started() - - # Determine status based on execution state - if execution.is_complete: - if ( - execution.result - and execution.result.status == InvocationStatus.SUCCEEDED - ): - status = "SUCCEEDED" - else: - status = "FAILED" - else: - status = "RUNNING" + status = execution.current_status().value # Extract result and error from execution result result = None @@ -204,8 +192,8 @@ def list_executions( function_version: str | None = None, # noqa: ARG002 execution_name: str | None = None, status_filter: str | None = None, - time_after: str | None = None, # noqa: ARG002 - time_before: str | None = None, # noqa: ARG002 + started_after: str | None = None, + started_before: str | None = None, marker: str | None = None, max_items: int | None = None, reverse_order: bool = False, # noqa: FBT001, FBT002 @@ -217,8 +205,8 @@ def list_executions( function_version: Filter by function version execution_name: Filter by execution name status_filter: Filter by status (RUNNING, SUCCEEDED, FAILED) - time_after: Filter executions started after this time - time_before: Filter executions started before this time + started_after: Filter executions started after this time + started_before: Filter executions started before this time marker: Pagination marker max_items: Maximum items to return (default 50) reverse_order: Return results in reverse chronological order @@ -226,77 +214,34 @@ def list_executions( Returns: ListDurableExecutionsResponse: List of executions with pagination """ - # Get all executions from store - all_executions = self._store.list_all() - - # Apply filters - filtered_executions = [] - for execution in all_executions: - # Filter by function name - if function_name and execution.start_input.function_name != function_name: - continue - - # Filter by execution name - if ( - execution_name - and execution.start_input.execution_name != execution_name - ): - continue - - # Determine execution status - execution_status = "RUNNING" - if execution.is_complete: - if ( - execution.result - and execution.result.status == InvocationStatus.SUCCEEDED - ): - execution_status = "SUCCEEDED" - else: - execution_status = "FAILED" - - # Filter by status - if status_filter and execution_status != status_filter: - continue - - # Convert to ExecutionSummary - execution_op = execution.get_operation_execution_started() - execution_summary = ExecutionSummary( - durable_execution_arn=execution.durable_execution_arn, - durable_execution_name=execution.start_input.execution_name, - function_arn=f"arn:aws:lambda:us-east-1:123456789012:function:{execution.start_input.function_name}", - status=execution_status, - start_timestamp=execution_op.start_timestamp - if execution_op.start_timestamp - else datetime.now(UTC), - end_timestamp=execution_op.end_timestamp - if execution_op.end_timestamp - else None, - ) - filtered_executions.append(execution_summary) - - # Sort by start date - filtered_executions.sort(key=lambda e: e.start_timestamp, reverse=reverse_order) - - # Apply pagination - if max_items is None: - max_items = 50 - - start_index = 0 + # Convert marker to offset + offset: int = 0 if marker: try: - start_index = int(marker) + offset = int(marker) except ValueError: - start_index = 0 + offset = 0 - end_index = start_index + max_items - paginated_executions = filtered_executions[start_index:end_index] + # Query store directly with parameters + executions, next_marker = self._store.query( + function_name=function_name, + execution_name=execution_name, + status_filter=status_filter, + started_after=started_after, + started_before=started_before, + limit=max_items or 50, + offset=offset, + reverse_order=reverse_order, + ) - next_marker = None - if end_index < len(filtered_executions): - next_marker = str(end_index) + # Convert to ExecutionSummary objects + execution_summaries: list[ExecutionSummary] = [ + ExecutionSummary.from_execution(execution, execution.current_status().value) + for execution in executions + ] return ListDurableExecutionsResponse( - durable_executions=paginated_executions, next_marker=next_marker + durable_executions=execution_summaries, next_marker=next_marker ) def list_executions_by_function( @@ -305,8 +250,8 @@ def list_executions_by_function( qualifier: str | None = None, # noqa: ARG002 execution_name: str | None = None, status_filter: str | None = None, - time_after: str | None = None, - time_before: str | None = None, + started_after: str | None = None, + started_before: str | None = None, marker: str | None = None, max_items: int | None = None, reverse_order: bool = False, # noqa: FBT001, FBT002 @@ -318,8 +263,8 @@ def list_executions_by_function( qualifier: Function qualifier/version execution_name: Filter by execution name status_filter: Filter by status (RUNNING, SUCCEEDED, FAILED) - time_after: Filter executions started after this time - time_before: Filter executions started before this time + started_after: Filter executions started after this time + started_before: Filter executions started before this time marker: Pagination marker max_items: Maximum items to return (default 50) reverse_order: Return results in reverse chronological order @@ -332,8 +277,8 @@ def list_executions_by_function( function_name=function_name, execution_name=execution_name, status_filter=status_filter, - time_after=time_after, - time_before=time_before, + started_after=started_after, + started_before=started_before, marker=marker, max_items=max_items, reverse_order=reverse_order, @@ -372,8 +317,11 @@ def stop_execution( "Execution stopped by user request" ) - # Stop the execution - self.fail_execution(execution_arn, stop_error) + # Stop sets TERMINATED close status (different from fail) + logger.exception("[%s] Stopping execution.", execution_arn) + execution.complete_stopped(error=stop_error) # Sets CloseStatus.TERMINATED + self._store.update(execution) + self._complete_events(execution_arn=execution_arn) return StopDurableExecutionResponse(stop_timestamp=datetime.now(UTC)) @@ -459,13 +407,13 @@ def get_execution_history( # Generate events all_events: list[HistoryEvent] = [] - event_id: int = 1 ops: list[Operation] = execution.operations updates: list[OperationUpdate] = execution.updates updates_dict: dict[str, OperationUpdate] = {u.operation_id: u for u in updates} durable_execution_arn: str = execution.durable_execution_arn + + # Generate all events first (without final event IDs) for op in ops: - # Step Operation can have PENDING status -> not included in History operation_update: OperationUpdate | None = updates_dict.get( op.operation_id, None ) @@ -478,7 +426,7 @@ def get_execution_history( continue context: EventCreationContext = EventCreationContext( op, - event_id, + 0, # Temporary event_id, will be reassigned after sorting durable_execution_arn, execution.start_input, execution.result, @@ -487,11 +435,10 @@ def get_execution_history( ) pending = HistoryEvent.create_chained_invoke_event_pending(context) all_events.append(pending) - event_id += 1 if op.start_timestamp is not None: context = EventCreationContext( op, - event_id, + 0, # Temporary event_id, will be reassigned after sorting durable_execution_arn, execution.start_input, execution.result, @@ -500,11 +447,10 @@ def get_execution_history( ) started = HistoryEvent.create_event_started(context) all_events.append(started) - event_id += 1 if op.end_timestamp is not None and op.status in TERMINAL_STATUSES: context = EventCreationContext( op, - event_id, + 0, # Temporary event_id, will be reassigned after sorting durable_execution_arn, execution.start_input, execution.result, @@ -513,7 +459,15 @@ def get_execution_history( ) finished = HistoryEvent.create_event_terminated(context) all_events.append(finished) - event_id += 1 + + # Sort events by timestamp to get correct chronological order + all_events.sort(key=lambda event: event.event_timestamp) + + # Reassign event IDs based on chronological order + all_events = [ + HistoryEvent.from_event_with_id(event, i) + for i, event in enumerate(all_events, 1) + ] # Apply cursor-based pagination if max_items is None: @@ -938,27 +892,25 @@ def wait_until_complete( raise ResourceNotFoundException(msg) def complete_execution(self, execution_arn: str, result: str | None = None) -> None: - """Complete execution successfully.""" + """Complete execution successfully (COMPLETE_WORKFLOW_EXECUTION decision).""" logger.debug("[%s] Completing execution with result: %s", execution_arn, result) execution: Execution = self._store.load(execution_arn=execution_arn) - execution.complete_success(result=result) + execution.complete_success(result=result) # Sets CloseStatus.COMPLETED self._store.update(execution) if execution.result is None: msg: str = "Execution result is required" - raise IllegalStateException(msg) self._complete_events(execution_arn=execution_arn) def fail_execution(self, execution_arn: str, error: ErrorObject) -> None: - """Fail execution with error.""" + """Fail execution with error (FAIL_WORKFLOW_EXECUTION decision).""" logger.error("[%s] Completing execution with error: %s", execution_arn, error) execution: Execution = self._store.load(execution_arn=execution_arn) - execution.complete_fail(error=error) + execution.complete_fail(error=error) # Sets CloseStatus.FAILED self._store.update(execution) # set by complete_fail if execution.result is None: msg: str = "Execution result is required" - raise IllegalStateException(msg) self._complete_events(execution_arn=execution_arn) @@ -1010,6 +962,19 @@ def on_failed(self, execution_arn: str, error: ErrorObject) -> None: """Fail execution. Observer method triggered by notifier.""" self.fail_execution(execution_arn, error) + def on_timed_out(self, execution_arn: str, error: ErrorObject) -> None: + """Handle execution timeout (workflow timeout). Observer method triggered by notifier.""" + logger.exception("[%s] Execution timed out.", execution_arn) + execution: Execution = self._store.load(execution_arn=execution_arn) + execution.complete_timeout(error=error) # Sets CloseStatus.TIMED_OUT + self._store.update(execution) + self._complete_events(execution_arn=execution_arn) + + def on_stopped(self, execution_arn: str, error: ErrorObject) -> None: + """Handle execution stop. Observer method triggered by notifier.""" + # This should not be called directly - stop_execution handles termination + self.fail_execution(execution_arn, error) + def on_wait_timer_scheduled( self, execution_arn: str, operation_id: str, delay: float ) -> None: diff --git a/src/aws_durable_execution_sdk_python_testing/model.py b/src/aws_durable_execution_sdk_python_testing/model.py index 469a0d3..27da2d8 100644 --- a/src/aws_durable_execution_sdk_python_testing/model.py +++ b/src/aws_durable_execution_sdk_python_testing/model.py @@ -297,6 +297,24 @@ def to_dict(self) -> dict[str, Any]: result["EndTimestamp"] = self.end_timestamp return result + @classmethod + def from_execution(cls, execution, status: str) -> Execution: + """Create ExecutionSummary from Execution object.""" + + execution_op = execution.get_operation_execution_started() + return cls( + durable_execution_arn=execution.durable_execution_arn, + durable_execution_name=execution.start_input.execution_name, + function_arn=f"arn:aws:lambda:us-east-1:123456789012:function:{execution.start_input.function_name}", + status=status, + start_timestamp=execution_op.start_timestamp + if execution_op.start_timestamp + else datetime.datetime.now(datetime.UTC), + end_timestamp=execution_op.end_timestamp + if execution_op.end_timestamp + else None, + ) + @dataclass(frozen=True) class ListDurableExecutionsRequest: @@ -306,24 +324,71 @@ class ListDurableExecutionsRequest: function_version: str | None = None durable_execution_name: str | None = None status_filter: list[str] | None = None - time_after: str | None = None - time_before: str | None = None + started_after: str | None = None + started_before: str | None = None marker: str | None = None max_items: int = 0 reverse_order: bool | None = None @classmethod def from_dict(cls, data: dict) -> ListDurableExecutionsRequest: + # Handle query parameters that may be lists + function_name = data.get("FunctionName") + if isinstance(function_name, list): + function_name = function_name[0] if function_name else None + + function_version = data.get("FunctionVersion") + if isinstance(function_version, list): + function_version = function_version[0] if function_version else None + + durable_execution_name = data.get("DurableExecutionName") + if isinstance(durable_execution_name, list): + durable_execution_name = ( + durable_execution_name[0] if durable_execution_name else None + ) + + status_filter = data.get("StatusFilter") + if isinstance(status_filter, list): + status_filter = status_filter if status_filter else None + elif status_filter: + status_filter = [status_filter] + + started_after = data.get("StartedAfter") + if isinstance(started_after, list): + started_after = started_after[0] if started_after else None + + started_before = data.get("StartedBefore") + if isinstance(started_before, list): + started_before = started_before[0] if started_before else None + + marker = data.get("Marker") + if isinstance(marker, list): + marker = marker[0] if marker else None + + max_items = data.get("MaxItems", 0) + if isinstance(max_items, list): + max_items = int(max_items[0]) if max_items else 0 + + reverse_order = data.get("ReverseOrder") + if isinstance(reverse_order, list): + reverse_order = ( + reverse_order[0].lower() in ("true", "1", "yes") + if reverse_order + else None + ) + elif isinstance(reverse_order, str): + reverse_order = reverse_order.lower() in ("true", "1", "yes") + return cls( - function_name=data.get("FunctionName"), - function_version=data.get("FunctionVersion"), - durable_execution_name=data.get("DurableExecutionName"), - status_filter=data.get("StatusFilter"), - time_after=data.get("TimeAfter"), - time_before=data.get("TimeBefore"), - marker=data.get("Marker"), - max_items=data.get("MaxItems", 0), - reverse_order=data.get("ReverseOrder"), + function_name=function_name, + function_version=function_version, + durable_execution_name=durable_execution_name, + status_filter=status_filter, + started_after=started_after, + started_before=started_before, + marker=marker, + max_items=max_items, + reverse_order=reverse_order, ) def to_dict(self) -> dict[str, Any]: @@ -336,10 +401,10 @@ def to_dict(self) -> dict[str, Any]: result["DurableExecutionName"] = self.durable_execution_name if self.status_filter is not None: result["StatusFilter"] = self.status_filter - if self.time_after is not None: - result["TimeAfter"] = self.time_after - if self.time_before is not None: - result["TimeBefore"] = self.time_before + if self.started_after is not None: + result["StartedAfter"] = self.started_after + if self.started_before is not None: + result["StartedBefore"] = self.started_before if self.marker is not None: result["Marker"] = self.marker if self.max_items is not None: @@ -2144,6 +2209,43 @@ def create_event_started(cls, context: EventCreationContext) -> Event: msg = f"Unknown operation type: {context.operation.operation_type}" raise InvalidParameterValueException(msg) + @classmethod + def from_event_with_id(cls, event: Event, event_id: int) -> Event: + """Create a new Event from an existing event with updated event_id.""" + return cls( + event_type=event.event_type, + event_timestamp=event.event_timestamp, + sub_type=event.sub_type, + event_id=event_id, + operation_id=event.operation_id, + name=event.name, + parent_id=event.parent_id, + execution_started_details=event.execution_started_details, + execution_succeeded_details=event.execution_succeeded_details, + execution_failed_details=event.execution_failed_details, + execution_timed_out_details=event.execution_timed_out_details, + execution_stopped_details=event.execution_stopped_details, + context_started_details=event.context_started_details, + context_succeeded_details=event.context_succeeded_details, + context_failed_details=event.context_failed_details, + wait_started_details=event.wait_started_details, + wait_succeeded_details=event.wait_succeeded_details, + wait_cancelled_details=event.wait_cancelled_details, + step_started_details=event.step_started_details, + step_succeeded_details=event.step_succeeded_details, + step_failed_details=event.step_failed_details, + chained_invoke_pending_details=event.chained_invoke_pending_details, + chained_invoke_started_details=event.chained_invoke_started_details, + chained_invoke_succeeded_details=event.chained_invoke_succeeded_details, + chained_invoke_failed_details=event.chained_invoke_failed_details, + chained_invoke_timed_out_details=event.chained_invoke_timed_out_details, + chained_invoke_stopped_details=event.chained_invoke_stopped_details, + callback_started_details=event.callback_started_details, + callback_succeeded_details=event.callback_succeeded_details, + callback_failed_details=event.callback_failed_details, + callback_timed_out_details=event.callback_timed_out_details, + ) + @classmethod def create_event_terminated(cls, context: EventCreationContext) -> Event: """Convert operation to finished event.""" @@ -2696,16 +2798,67 @@ class ListDurableExecutionsByFunctionRequest: @classmethod def from_dict(cls, data: dict) -> ListDurableExecutionsByFunctionRequest: + # Handle query parameters that may be lists + function_name = data.get("FunctionName") + if isinstance(function_name, list): + function_name = function_name[0] if function_name else "" + elif not function_name: + function_name = "" + + qualifier = data.get("Qualifier") or data.get("functionVersion") + if isinstance(qualifier, list): + qualifier = qualifier[0] if qualifier else None + + durable_execution_name = data.get("DurableExecutionName") or data.get( + "executionName" + ) + if isinstance(durable_execution_name, list): + durable_execution_name = ( + durable_execution_name[0] if durable_execution_name else None + ) + + status_filter = data.get("StatusFilter") or data.get("statusFilter") + if isinstance(status_filter, list): + status_filter = status_filter if status_filter else None + elif status_filter: + status_filter = [status_filter] + + started_after = data.get("StartedAfter") or data.get("startedAfter") + if isinstance(started_after, list): + started_after = started_after[0] if started_after else None + + started_before = data.get("StartedBefore") or data.get("startedBefore") + if isinstance(started_before, list): + started_before = started_before[0] if started_before else None + + marker = data.get("Marker") or data.get("marker") + if isinstance(marker, list): + marker = marker[0] if marker else None + + max_items = data.get("MaxItems") or data.get("maxItems", 0) + if isinstance(max_items, list): + max_items = int(max_items[0]) if max_items else 0 + + reverse_order = data.get("ReverseOrder") or data.get("reverseOrder") + if isinstance(reverse_order, list): + reverse_order = ( + reverse_order[0].lower() in ("true", "1", "yes") + if reverse_order + else None + ) + elif isinstance(reverse_order, str): + reverse_order = reverse_order.lower() in ("true", "1", "yes") + return cls( - function_name=data["FunctionName"], - qualifier=data.get("Qualifier"), - durable_execution_name=data.get("DurableExecutionName"), - status_filter=data.get("StatusFilter"), - started_after=data.get("StartedAfter"), - started_before=data.get("StartedBefore"), - marker=data.get("Marker"), - max_items=data.get("MaxItems", 0), - reverse_order=data.get("ReverseOrder"), + function_name=function_name, + qualifier=qualifier, + durable_execution_name=durable_execution_name, + status_filter=status_filter, + started_after=started_after, + started_before=started_before, + marker=marker, + max_items=max_items, + reverse_order=reverse_order, ) def to_dict(self) -> dict[str, Any]: diff --git a/src/aws_durable_execution_sdk_python_testing/observer.py b/src/aws_durable_execution_sdk_python_testing/observer.py index 2473896..3a21fd5 100644 --- a/src/aws_durable_execution_sdk_python_testing/observer.py +++ b/src/aws_durable_execution_sdk_python_testing/observer.py @@ -25,6 +25,14 @@ def on_completed(self, execution_arn: str, result: str | None = None) -> None: def on_failed(self, execution_arn: str, error: ErrorObject) -> None: """Called when execution fails.""" + @abstractmethod + def on_timed_out(self, execution_arn: str, error: ErrorObject) -> None: + """Called when execution times out.""" + + @abstractmethod + def on_stopped(self, execution_arn: str, error: ErrorObject) -> None: + """Called when execution is stopped.""" + @abstractmethod def on_wait_timer_scheduled( self, execution_arn: str, operation_id: str, delay: float @@ -76,6 +84,18 @@ def notify_failed(self, execution_arn: str, error: ErrorObject) -> None: ExecutionObserver.on_failed, execution_arn=execution_arn, error=error ) + def notify_timed_out(self, execution_arn: str, error: ErrorObject) -> None: + """Notify observers about execution timeout.""" + self._notify_observers( + ExecutionObserver.on_timed_out, execution_arn=execution_arn, error=error + ) + + def notify_stopped(self, execution_arn: str, error: ErrorObject) -> None: + """Notify observers about execution being stopped.""" + self._notify_observers( + ExecutionObserver.on_stopped, execution_arn=execution_arn, error=error + ) + def notify_wait_timer_scheduled( self, execution_arn: str, operation_id: str, delay: float ) -> None: diff --git a/src/aws_durable_execution_sdk_python_testing/runner.py b/src/aws_durable_execution_sdk_python_testing/runner.py index 866f50d..37a48a9 100644 --- a/src/aws_durable_execution_sdk_python_testing/runner.py +++ b/src/aws_durable_execution_sdk_python_testing/runner.py @@ -742,7 +742,7 @@ class DurableFunctionCloudTestRunner: ... ) >>> with runner: ... result = runner.run(input={"name": "World"}, timeout=60) - >>> assert result.status == InvocationStatus.SUCCEEDED + >>> assert result.current_status == InvocationStatus.SUCCEEDED """ def __init__( diff --git a/src/aws_durable_execution_sdk_python_testing/stores/base.py b/src/aws_durable_execution_sdk_python_testing/stores/base.py index f4943e9..ca87e28 100644 --- a/src/aws_durable_execution_sdk_python_testing/stores/base.py +++ b/src/aws_durable_execution_sdk_python_testing/stores/base.py @@ -2,11 +2,14 @@ from __future__ import annotations +from datetime import UTC from enum import Enum from typing import TYPE_CHECKING, Protocol if TYPE_CHECKING: + from aws_durable_execution_sdk_python.lambda_service import Operation + from aws_durable_execution_sdk_python_testing.execution import Execution @@ -15,6 +18,7 @@ class StoreType(Enum): MEMORY = "memory" FILESYSTEM = "filesystem" + SQLITE = "sqlite" class ExecutionStore(Protocol): @@ -24,4 +28,120 @@ class ExecutionStore(Protocol): def save(self, execution: Execution) -> None: ... # pragma: no cover def load(self, execution_arn: str) -> Execution: ... # pragma: no cover def update(self, execution: Execution) -> None: ... # pragma: no cover - def list_all(self) -> list[Execution]: ... # pragma: no cover + def query( + self, + function_name: str | None = None, + execution_name: str | None = None, + status_filter: str | None = None, + started_after: str | None = None, + started_before: str | None = None, + limit: int | None = None, + offset: int = 0, + reverse_order: bool = False, # noqa: FBT001, FBT002 + ) -> tuple[list[Execution], str | None]: ... # pragma: no cover + def list_all( + self, + ) -> list[Execution]: ... # pragma: no cover # Keep for backward compatibility + + +class BaseExecutionStore(ExecutionStore): + """Base implementation for execution stores with shared query logic.""" + + @staticmethod + def process_query( + executions: list[Execution], + function_name: str | None = None, + execution_name: str | None = None, + status_filter: str | None = None, + started_after: str | None = None, + started_before: str | None = None, + limit: int | None = None, + offset: int = 0, + reverse_order: bool = False, # noqa: FBT001, FBT002 + ) -> tuple[list[Execution], str | None]: + """Apply filtering, sorting, and pagination to executions.""" + # Apply filters + filtered: list[Execution] = [] + for execution in executions: + if function_name and execution.start_input.function_name != function_name: + continue + if ( + execution_name + and execution.start_input.execution_name != execution_name + ): + continue + + # Status filtering + if status_filter and execution.current_status().value != status_filter: + continue + + # Time filtering + if started_after or started_before: + try: + operation: Operation = execution.get_operation_execution_started() + if operation.start_timestamp: + timestamp: float = ( + operation.start_timestamp.timestamp() + if hasattr(operation.start_timestamp, "timestamp") + else operation.start_timestamp.replace( + tzinfo=UTC + ).timestamp() + ) + if started_after and timestamp < float(started_after): + continue + if started_before and timestamp > float(started_before): + continue + except (ValueError, AttributeError): + continue + + filtered.append(execution) + + # Sort by start timestamp + def get_sort_key(exe: Execution): + try: + op: Operation = exe.get_operation_execution_started() + if op.start_timestamp: + return ( + op.start_timestamp.timestamp() + if hasattr(op.start_timestamp, "timestamp") + else op.start_timestamp.replace(tzinfo=UTC).timestamp() + ) + except Exception: # noqa: BLE001, S110 + pass + return 0 + + filtered.sort(key=get_sort_key, reverse=reverse_order) + + # Apply pagination + if limit is not None and limit > 0: + end_idx: int = offset + limit + paginated: list[Execution] = filtered[offset:end_idx] + has_more: bool = end_idx < len(filtered) + next_marker: str | None = str(end_idx) if has_more else None + return paginated, next_marker + return filtered[offset:], None + + def query( + self, + function_name: str | None = None, + execution_name: str | None = None, + status_filter: str | None = None, + started_after: str | None = None, + started_before: str | None = None, + limit: int | None = None, + offset: int = 0, + reverse_order: bool = False, # noqa: FBT001, FBT002 + ) -> tuple[list[Execution], str | None]: + """Apply filtering, sorting, and pagination to executions.""" + executions: list[Execution] = self.list_all() + return self.process_query( + executions, + function_name=function_name, + execution_name=execution_name, + status_filter=status_filter, + started_after=started_after, + started_before=started_before, + limit=limit, + offset=offset, + reverse_order=reverse_order, + ) diff --git a/src/aws_durable_execution_sdk_python_testing/stores/filesystem.py b/src/aws_durable_execution_sdk_python_testing/stores/filesystem.py index 6ccd4b1..9306532 100644 --- a/src/aws_durable_execution_sdk_python_testing/stores/filesystem.py +++ b/src/aws_durable_execution_sdk_python_testing/stores/filesystem.py @@ -11,6 +11,9 @@ ResourceNotFoundException, ) from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.stores.base import ( + BaseExecutionStore, +) class DateTimeEncoder(json.JSONEncoder): @@ -37,7 +40,7 @@ def datetime_object_hook(obj): return obj -class FileSystemExecutionStore: +class FileSystemExecutionStore(BaseExecutionStore): """File system-based execution store for persistence.""" def __init__(self, storage_dir: Path) -> None: diff --git a/src/aws_durable_execution_sdk_python_testing/stores/memory.py b/src/aws_durable_execution_sdk_python_testing/stores/memory.py index 9dfc91d..5e6e083 100644 --- a/src/aws_durable_execution_sdk_python_testing/stores/memory.py +++ b/src/aws_durable_execution_sdk_python_testing/stores/memory.py @@ -5,12 +5,16 @@ from threading import Lock from typing import TYPE_CHECKING +from aws_durable_execution_sdk_python_testing.stores.base import ( + BaseExecutionStore, +) + if TYPE_CHECKING: from aws_durable_execution_sdk_python_testing.execution import Execution -class InMemoryExecutionStore: +class InMemoryExecutionStore(BaseExecutionStore): """Dict-based storage for testing.""" def __init__(self) -> None: diff --git a/src/aws_durable_execution_sdk_python_testing/stores/sqlite.py b/src/aws_durable_execution_sdk_python_testing/stores/sqlite.py new file mode 100644 index 0000000..4eb4222 --- /dev/null +++ b/src/aws_durable_execution_sdk_python_testing/stores/sqlite.py @@ -0,0 +1,274 @@ +"""SQLite-based execution store implementation.""" + +from __future__ import annotations + +import json +import sqlite3 +from datetime import datetime +from pathlib import Path +from typing import Any, cast + +from aws_durable_execution_sdk_python_testing.exceptions import ( + ResourceNotFoundException, + InvalidParameterValueException, + RuntimeException, +) +from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.stores.base import ( + ExecutionStore, +) +from aws_durable_execution_sdk_python_testing.stores.filesystem import DateTimeEncoder + + +class SQLiteExecutionStore(ExecutionStore): + """SQLite-based execution store for efficient querying.""" + + def __init__(self, db_path: Path) -> None: + self.db_path: Path = db_path + + @classmethod + def create_and_initialize( + cls, db_path: Path | str | None = None + ) -> SQLiteExecutionStore: + """Create SQLite store with default path.""" + path: Path = Path(db_path) if db_path else Path("durable-executions.db") + path.parent.mkdir(exist_ok=True) + store: SQLiteExecutionStore = cls(path) + store._init_db() + return store + + def _get_connection(self) -> sqlite3.Connection: + """Get SQLite connection with optimizations.""" + conn: sqlite3.Connection = sqlite3.connect(self.db_path, timeout=30.0) + conn.execute("PRAGMA journal_mode=WAL;") + conn.execute("PRAGMA synchronous=NORMAL;") + return conn + + def _init_db(self) -> None: + """Initialize database schema.""" + try: + with self._get_connection() as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS executions ( + durable_execution_arn TEXT PRIMARY KEY, + function_name TEXT NOT NULL, + execution_name TEXT, + status TEXT NOT NULL, + start_timestamp REAL, + end_timestamp REAL, + data TEXT NOT NULL + ) + """) + # Create indexes for better query performance + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_function_name ON executions(function_name)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_status ON executions(status)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_start_timestamp ON executions(start_timestamp)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_composite ON executions(function_name, status, start_timestamp)" + ) + except sqlite3.Error as e: + raise RuntimeError(f"Failed to initialize database: {e}") from e + + def save(self, execution: Execution) -> None: + """Save execution to SQLite.""" + try: + execution_op = execution.get_operation_execution_started() + status: str = execution.current_status().value + + with self._get_connection() as conn: + conn.execute( + """ + INSERT OR REPLACE INTO executions + (durable_execution_arn, function_name, execution_name, status, start_timestamp, end_timestamp, data) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + execution.durable_execution_arn, + execution.start_input.function_name, + execution.start_input.execution_name, + status, + execution_op.start_timestamp.timestamp() + if execution_op.start_timestamp + else None, + execution_op.end_timestamp.timestamp() + if execution_op.end_timestamp + else None, + json.dumps(execution.to_dict(), cls=DateTimeEncoder), + ), + ) + except sqlite3.Error as e: + raise RuntimeError( + f"Failed to save execution {execution.durable_execution_arn}: {e}" + ) from e + except (AttributeError, TypeError) as e: + raise ValueError(f"Invalid execution data: {e}") from e + + def load(self, execution_arn: str) -> Execution: + """Load execution from SQLite.""" + try: + with self._get_connection() as conn: + cursor: sqlite3.Cursor = conn.execute( + "SELECT data FROM executions WHERE durable_execution_arn = ?", + (execution_arn,), + ) + row: tuple[str] | None = cursor.fetchone() + + if not row: + raise ResourceNotFoundException(f"Execution {execution_arn} not found") + + return Execution.from_dict(json.loads(row[0])) + except sqlite3.Error as e: + raise RuntimeError(f"Failed to load execution {execution_arn}: {e}") from e + except json.JSONDecodeError as e: + raise ValueError( + f"Corrupted execution data for {execution_arn}: {e}" + ) from e + + def update(self, execution: Execution) -> None: + """Update execution (same as save).""" + self.save(execution) + + def query( + self, + function_name: str | None = None, + execution_name: str | None = None, + status_filter: str | None = None, + started_after: str | None = None, + started_before: str | None = None, + limit: int | None = None, + offset: int = 0, + reverse_order: bool = False, + ) -> tuple[list[Execution], str | None]: + """Query executions with efficient SQL filtering.""" + try: + # Build query safely with parameterized conditions + conditions: list[str] = [] + params: list[str | float | int] = [] + + if function_name: + conditions.append("function_name = ?") + params.append(function_name) + + if execution_name: + conditions.append("execution_name = ?") + params.append(execution_name) + + if status_filter: + conditions.append("status = ?") + params.append(status_filter) + + if started_after: + started_after_float: float = datetime.fromisoformat( + started_after + ).timestamp() + conditions.append("start_timestamp >= ?") + params.append(started_after_float) + + if started_before: + started_before_float: float = datetime.fromisoformat( + started_before + ).timestamp() + conditions.append("start_timestamp <= ?") + params.append(started_before_float) + + # Build WHERE clause safely + where_clause: str = "" + if conditions: + where_clause = "WHERE " + " AND ".join(conditions) + + # Build ORDER BY clause + order_direction: str = "DESC" if reverse_order else "ASC" + order_clause: str = f"ORDER BY start_timestamp {order_direction}" + + # For better performance, only get metadata for counting and pagination + base_query: str = f"FROM executions {where_clause}" + count_query: str = f"SELECT COUNT(*) {base_query}" + + limit_exists: bool = limit is not None and limit > 0 + + # Only fetch data we need + if limit_exists: + data_query: str = f"SELECT durable_execution_arn, data {base_query} {order_clause} LIMIT ? OFFSET ?" + params_with_limit: list[str | float | int] = params + [ + cast(int, limit), + offset, + ] + else: + data_query = ( + f"SELECT durable_execution_arn, data {base_query} {order_clause}" + ) + params_with_limit = params + + with self._get_connection() as conn: + # Get total count for pagination + total_count: int = int(conn.execute(count_query, params).fetchone()[0]) + + # Get actual data + cursor: sqlite3.Cursor = conn.execute(data_query, params_with_limit) + rows: list[tuple[str, str]] = cursor.fetchall() + + # Only deserialize the executions we actually need + executions: list[Execution] = [] + for durable_execution_arn, data in rows: + try: + executions.append(Execution.from_dict(json.loads(data))) + except (json.JSONDecodeError, ValueError) as e: + # Log corrupted data but continue with other records + print( + f"Warning: Skipping corrupted execution {durable_execution_arn}: {e}" + ) + continue + + # Calculate pagination + has_more: bool = limit_exists and (offset + len(executions) < total_count) + next_marker: str | None = ( + str(offset + len(executions)) if has_more else None + ) + + return executions, next_marker + + except sqlite3.Error as e: + raise RuntimeException(f"Query failed: {e}") from e + except ValueError as e: + raise InvalidParameterValueException( + f"Invalid query parameters: {e}" + ) from e + + def list_all(self) -> list[Execution]: + """List all executions (for backward compatibility).""" + executions, _ = self.query() + return executions + + def get_execution_metadata(self, execution_arn: str) -> dict[str, Any] | None: + """Get just the metadata without full deserialization for performance.""" + try: + with self._get_connection() as conn: + cursor: sqlite3.Cursor = conn.execute( + "SELECT function_name, execution_name, status, start_timestamp, end_timestamp FROM executions WHERE durable_execution_arn = ?", + (execution_arn,), + ) + row: tuple[str, str | None, str, float | None, float | None] | None = ( + cursor.fetchone() + ) + + if not row: + return None + + return { + "durable_execution_arn": execution_arn, + "function_name": row[0], + "execution_name": row[1], + "status": row[2], + "start_timestamp": row[3], + "end_timestamp": row[4], + } + except sqlite3.Error as e: + raise RuntimeError( + f"Failed to get metadata for {execution_arn}: {e}" + ) from e diff --git a/src/aws_durable_execution_sdk_python_testing/web/handlers.py b/src/aws_durable_execution_sdk_python_testing/web/handlers.py index 2fee38e..465731b 100644 --- a/src/aws_durable_execution_sdk_python_testing/web/handlers.py +++ b/src/aws_durable_execution_sdk_python_testing/web/handlers.py @@ -21,7 +21,6 @@ GetDurableExecutionHistoryResponse, GetDurableExecutionStateResponse, ListDurableExecutionsByFunctionRequest, - ListDurableExecutionsByFunctionResponse, ListDurableExecutionsRequest, ListDurableExecutionsResponse, SendDurableExecutionCallbackFailureRequest, @@ -495,48 +494,8 @@ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # HTTPResponse: The HTTP response to send to the client """ try: - query_params: dict[str, Any] = {} - - # TODO: encapsulate this better. Also, it is a GET, to confirm AWS SDK does - # pass args in querystring rather than body (per spec body should be ignored) - if function_name := self._parse_query_param(request, "FunctionName"): - query_params["FunctionName"] = function_name - if function_version := self._parse_query_param(request, "FunctionVersion"): - query_params["FunctionVersion"] = function_version - if durable_execution_name := self._parse_query_param( - request, "DurableExecutionName" - ): - query_params["DurableExecutionName"] = durable_execution_name - if status_filter := self._parse_query_param(request, "StatusFilter"): - query_params["StatusFilter"] = [ - status_filter - ] # Convert to list for model - if time_after := self._parse_query_param(request, "TimeAfter"): - query_params["TimeAfter"] = time_after - if time_before := self._parse_query_param(request, "TimeBefore"): - query_params["TimeBefore"] = time_before - if marker := self._parse_query_param(request, "Marker"): - query_params["Marker"] = marker - - # Parse integer parameters - if max_items_str := self._parse_query_param(request, "MaxItems"): - try: - query_params["MaxItems"] = int(max_items_str) - except ValueError as e: - error_msg: str = f"Invalid MaxItems value: {max_items_str}" - raise InvalidParameterValueException(error_msg) from e - - # Parse boolean parameters - if reverse_order_str := self._parse_query_param(request, "ReverseOrder"): - query_params["ReverseOrder"] = reverse_order_str.lower() in ( - "true", - "1", - "yes", - ) - - # Create request object from query parameters list_request: ListDurableExecutionsRequest = ( - ListDurableExecutionsRequest.from_dict(query_params) + ListDurableExecutionsRequest.from_dict(request.query_params) ) # Call executor method with correct attribute mapping @@ -547,8 +506,8 @@ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # status_filter=list_request.status_filter[0] if list_request.status_filter else None, # Executor expects single string - time_after=list_request.time_after, - time_before=list_request.time_before, + started_after=list_request.started_after, + started_before=list_request.started_before, marker=list_request.marker, max_items=list_request.max_items if list_request.max_items > 0 @@ -570,6 +529,13 @@ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # class ListDurableExecutionsByFunctionHandler(EndpointHandler): """Handler for GET /2025-12-01/functions/{function_name}/durable-executions.""" + @staticmethod + def _validate_function_name(function_name: str) -> None: + """Validate function name parameter.""" + if not function_name or not function_name.strip(): + msg = "Function name is required" + raise InvalidParameterValueException(msg) + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: """Handle list durable executions by function request. @@ -580,63 +546,37 @@ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: Returns: HTTPResponse: The HTTP response to send to the client """ - try: - function_route = cast(ListDurableExecutionsByFunctionRoute, parsed_route) - function_name: str = function_route.function_name - - # Parse query parameters and map to dataclass field names - query_params: dict[str, Any] = {"FunctionName": function_name} - - if qualifier := self._parse_query_param(request, "functionVersion"): - query_params["Qualifier"] = qualifier - if execution_name := self._parse_query_param(request, "executionName"): - query_params["DurableExecutionName"] = execution_name - if status_filter := self._parse_query_param(request, "statusFilter"): - query_params["StatusFilter"] = [status_filter] # Convert to list - if time_after := self._parse_query_param(request, "timeAfter"): - query_params["StartedAfter"] = time_after - if time_before := self._parse_query_param(request, "timeBefore"): - query_params["StartedBefore"] = time_before - if marker := self._parse_query_param(request, "marker"): - query_params["Marker"] = marker - if max_items_str := self._parse_query_param(request, "maxItems"): - try: - query_params["MaxItems"] = int(max_items_str) - except ValueError as ve: - error_msg: str = f"Invalid MaxItems value: {max_items_str}" - raise InvalidParameterValueException(error_msg) from ve - if reverse_order_str := self._parse_query_param(request, "reverseOrder"): - query_params["ReverseOrder"] = reverse_order_str.lower() in ( - "true", - "1", - "yes", - ) + function_route = cast(ListDurableExecutionsByFunctionRoute, parsed_route) + function_name: str = function_route.function_name - list_request: ListDurableExecutionsByFunctionRequest = ( - ListDurableExecutionsByFunctionRequest.from_dict(query_params) - ) + # Validate function name before processing + self._validate_function_name(function_name) - list_response: ListDurableExecutionsByFunctionResponse = ( - self.executor.list_executions_by_function( - function_name=list_request.function_name, - qualifier=list_request.qualifier, - execution_name=list_request.durable_execution_name, - status_filter=list_request.status_filter[0] - if list_request.status_filter - else None, - time_after=list_request.started_after, - time_before=list_request.started_before, - marker=list_request.marker, - max_items=list_request.max_items - if list_request.max_items > 0 - else None, - reverse_order=list_request.reverse_order or False, - ) + try: + # Add function name from route to query params + query_params = dict(request.query_params) + query_params["FunctionName"] = [function_name] + list_request = ListDurableExecutionsByFunctionRequest.from_dict( + query_params ) - response_data: dict[str, Any] = list_response.to_dict() + list_response = self.executor.list_executions_by_function( + function_name=list_request.function_name, + qualifier=list_request.qualifier, + execution_name=list_request.durable_execution_name, + status_filter=list_request.status_filter[0] + if list_request.status_filter + else None, + started_after=list_request.started_after, + started_before=list_request.started_before, + marker=list_request.marker, + max_items=list_request.max_items + if list_request.max_items > 0 + else None, + reverse_order=list_request.reverse_order or False, + ) - return self._success_response(response_data) + return self._success_response(list_response.to_dict()) except AwsApiException as e: return self._handle_aws_exception(e) diff --git a/tests/execution_test.py b/tests/execution_test.py index 0a82e26..602698b 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -4,7 +4,9 @@ from unittest.mock import patch, Mock import pytest -from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, +) from aws_durable_execution_sdk_python.lambda_service import ( ErrorObject, Operation, @@ -364,7 +366,7 @@ def test_complete_success_with_string_result(): execution_timeout_seconds=300, execution_retention_period_days=7, ) - execution = Execution("test-arn", start_input, []) + execution = Execution("test-arn", start_input, [Mock()]) execution.complete_success("success result") @@ -383,7 +385,7 @@ def test_complete_success_with_none_result(): execution_timeout_seconds=300, execution_retention_period_days=7, ) - execution = Execution("test-arn", start_input, []) + execution = Execution("test-arn", start_input, [Mock()]) execution.complete_success(None) @@ -402,7 +404,7 @@ def test_complete_fail(): execution_timeout_seconds=300, execution_retention_period_days=7, ) - execution = Execution("test-arn", start_input, []) + execution = Execution("test-arn", start_input, [Mock()]) error = ErrorObject.from_message("Test error message") execution.complete_fail(error) @@ -648,6 +650,112 @@ def test_complete_retry_wrong_type(): execution.complete_retry("wait-op-id") +def test_status_running(): + """Test status property returns RUNNING for incomplete execution.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + execution = Execution("test-arn", start_input, []) + + assert execution.current_status().value == "RUNNING" + + +def test_status_succeeded(): + """Test status property returns SUCCEEDED for successful execution.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + execution = Execution("test-arn", start_input, [Mock()]) + execution.complete_success("success result") + + assert execution.current_status().value == "SUCCEEDED" + + +def test_status_failed(): + """Test status property returns FAILED for failed execution.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + execution = Execution("test-arn", start_input, [Mock()]) + error = ErrorObject.from_message("Test error") + execution.complete_fail(error) + + assert execution.current_status().value == "FAILED" + + +def test_status_timed_out(): + """Test status property returns TIMED_OUT for timeout errors.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + execution = Execution("test-arn", start_input, [Mock()]) + error = ErrorObject( + message="Execution timed out", type="TimeoutError", data=None, stack_trace=None + ) + execution.complete_timeout(error) + + assert execution.current_status().value == "TIMED_OUT" + + +def test_status_stopped(): + """Test status property returns STOPPED for stop errors.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + execution = Execution("test-arn", start_input, [Mock()]) + error = ErrorObject( + message="Execution stopped", type="StopError", data=None, stack_trace=None + ) + execution.complete_stopped(error) + + assert execution.current_status().value == "STOPPED" + + +def test_status_no_result(): + """Test status property returns FAILED for completed execution with no result.""" + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + ) + execution = Execution("test-arn", start_input, []) + execution.is_complete = True + execution.result = None + with pytest.raises( + IllegalStateException, + match="close_status must be set when execution is complete", + ): + execution.current_status() + + def test_complete_retry_with_step_details(): """Test complete_retry with operation that has step_details.""" step_details = StepDetails( diff --git a/tests/executor_test.py b/tests/executor_test.py index a598950..008a4a0 100644 --- a/tests/executor_test.py +++ b/tests/executor_test.py @@ -29,7 +29,10 @@ InvalidParameterValueException, ResourceNotFoundException, ) -from aws_durable_execution_sdk_python_testing.execution import Execution +from aws_durable_execution_sdk_python_testing.execution import ( + ExecutionStatus, + Execution, +) from aws_durable_execution_sdk_python_testing.executor import Executor from aws_durable_execution_sdk_python_testing.model import ( ListDurableExecutionsResponse, @@ -38,7 +41,10 @@ SendDurableExecutionCallbackSuccessResponse, StartDurableExecutionInput, ) -from aws_durable_execution_sdk_python_testing.observer import ExecutionObserver +from aws_durable_execution_sdk_python_testing.observer import ( + ExecutionNotifier, + ExecutionObserver, +) from aws_durable_execution_sdk_python_testing.token import ( CallbackToken, ) @@ -1727,20 +1733,24 @@ def test_retry_handler_execution(executor, mock_scheduler): def test_get_execution_details(executor, mock_store): """Test get_execution_details method.""" - # Create mock execution with operation - mock_execution = Mock() - mock_execution.durable_execution_arn = "test-arn" - mock_execution.start_input.execution_name = "test-execution" - mock_execution.start_input.function_name = "test-function" - mock_execution.is_complete = True + # Create real execution instance with mocked start_input + mock_start_input = Mock() + mock_start_input.execution_name = "test-execution" + mock_start_input.function_name = "test-function" + + execution = Execution( + durable_execution_arn="test-arn", start_input=mock_start_input, operations=[] + ) + execution.is_complete = True # Create mock result mock_result = DurableExecutionInvocationOutput( status=InvocationStatus.SUCCEEDED, result="test-result" ) - mock_execution.result = mock_result + execution.result = mock_result + execution.close_status = ExecutionStatus.SUCCEEDED - # Create mock operation + # Create mock operation and add to execution mock_operation = Operation( operation_id="op-1", parent_id=None, @@ -1751,9 +1761,9 @@ def test_get_execution_details(executor, mock_store): status=OperationStatus.SUCCEEDED, execution_details=ExecutionDetails(input_payload='{"test": "data"}'), ) - mock_execution.get_operation_execution_started.return_value = mock_operation + execution.operations = [mock_operation] - mock_store.load.return_value = mock_execution + mock_store.load.return_value = execution result = executor.get_execution_details("test-arn") @@ -1776,20 +1786,23 @@ def test_get_execution_details_not_found(executor, mock_store): def test_get_execution_details_failed_execution(executor, mock_store): """Test get_execution_details with failed execution.""" - # Create mock execution with failed result - mock_execution = Mock() - mock_execution.durable_execution_arn = "test-arn" - mock_execution.start_input.execution_name = "test-execution" - mock_execution.start_input.function_name = "test-function" - mock_execution.is_complete = True + # Create real execution instance with mocked start_input + mock_start_input = Mock() + mock_start_input.execution_name = "test-execution" + mock_start_input.function_name = "test-function" + + execution = Execution( + durable_execution_arn="test-arn", start_input=mock_start_input, operations=[] + ) + execution.is_complete = True error = ErrorObject.from_message("Test error") mock_result = DurableExecutionInvocationOutput( status=InvocationStatus.FAILED, error=error ) - mock_execution.result = mock_result + execution.result = mock_result - # Create mock operation + # Create mock operation and add to execution mock_operation = Operation( operation_id="op-1", parent_id=None, @@ -1799,12 +1812,16 @@ def test_get_execution_details_failed_execution(executor, mock_store): status=OperationStatus.FAILED, execution_details=ExecutionDetails(input_payload='{"test": "data"}'), ) - mock_execution.get_operation_execution_started.return_value = mock_operation - - mock_store.load.return_value = mock_execution + execution.operations = [mock_operation] + mock_store.load.return_value = execution + with pytest.raises( + IllegalStateException, + match="close_status must be set when execution is complete", + ): + executor.get_execution_details("test-arn") + execution.close_status = ExecutionStatus.FAILED result = executor.get_execution_details("test-arn") - assert result.status == "FAILED" assert result.result is None assert result.error == error @@ -1812,35 +1829,29 @@ def test_get_execution_details_failed_execution(executor, mock_store): def test_list_executions_empty(executor, mock_store): """Test list_executions with no executions.""" - mock_store.list_all.return_value = [] + query_result = ([], None) + mock_store.query.return_value = query_result result = executor.list_executions() assert result.durable_executions == [] assert result.next_marker is None - mock_store.list_all.assert_called_once() + mock_store.query.assert_called_once() def test_list_executions_with_filtering(executor, mock_store): """Test list_executions with function name filtering.""" + # Create real execution instance + mock_start_input = Mock() + mock_start_input.execution_name = "exec1" + mock_start_input.function_name = "function1" - # Create mock executions - execution1 = Mock() - execution1.durable_execution_arn = "arn1" - execution1.start_input.execution_name = "exec1" - execution1.start_input.function_name = "function1" + execution1 = Execution( + durable_execution_arn="arn1", start_input=mock_start_input, operations=[] + ) execution1.is_complete = False execution1.result = None - execution2 = Mock() - execution2.durable_execution_arn = "arn2" - execution2.start_input.execution_name = "exec2" - execution2.start_input.function_name = "function2" - execution2.is_complete = True - execution2.result = DurableExecutionInvocationOutput( - status=InvocationStatus.SUCCEEDED, result="result" - ) - # Create mock operations op1 = Operation( operation_id="op-1", @@ -1851,20 +1862,11 @@ def test_list_executions_with_filtering(executor, mock_store): status=OperationStatus.STARTED, execution_details=ExecutionDetails(input_payload="{}"), ) - op2 = Operation( - operation_id="op-2", - parent_id=None, - name="exec2", - start_timestamp=datetime.now(UTC), - operation_type=OperationType.EXECUTION, - status=OperationStatus.SUCCEEDED, - execution_details=ExecutionDetails(input_payload="{}"), - ) - - execution1.get_operation_execution_started.return_value = op1 - execution2.get_operation_execution_started.return_value = op2 + execution1.operations = [op1] - mock_store.list_all.return_value = [execution1, execution2] + # Mock the query method to return filtered results + query_result = ([execution1], "1") + mock_store.query.return_value = query_result # Test filtering by function name result = executor.list_executions(function_name="function1") @@ -1876,10 +1878,31 @@ def test_list_executions_with_filtering(executor, mock_store): def test_list_executions_with_pagination(executor, mock_store): """Test list_executions with pagination.""" + # Create multiple mock executions for first page + executions_page1 = [] + for i in range(2): + execution = Mock() + execution.durable_execution_arn = f"arn{i}" + execution.start_input.execution_name = f"exec{i}" + execution.start_input.function_name = "test-function" + execution.is_complete = False + execution.result = None - # Create multiple mock executions - executions = [] - for i in range(5): + op = Operation( + operation_id=f"op-{i}", + parent_id=None, + name=f"exec{i}", + start_timestamp=datetime.now(UTC), + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + execution_details=ExecutionDetails(input_payload="{}"), + ) + execution.get_operation_execution_started.return_value = op + executions_page1.append(execution) + + # Create executions for second page + executions_page2 = [] + for i in range(2, 4): execution = Mock() execution.durable_execution_arn = f"arn{i}" execution.start_input.execution_name = f"exec{i}" @@ -1897,9 +1920,14 @@ def test_list_executions_with_pagination(executor, mock_store): execution_details=ExecutionDetails(input_payload="{}"), ) execution.get_operation_execution_started.return_value = op - executions.append(execution) + executions_page2.append(execution) + + # Mock query responses for pagination + query_result1 = (executions_page1, "2") + + query_result2 = (executions_page2, "4") - mock_store.list_all.return_value = executions + mock_store.query.side_effect = [query_result1, query_result2] # Test pagination with max_items=2 result = executor.list_executions(max_items=2) @@ -1930,8 +1958,8 @@ def test_list_executions_by_function(executor): function_name="test-function", execution_name=None, status_filter="RUNNING", - time_after=None, - time_before=None, + started_after=None, + started_before=None, marker=None, max_items=None, reverse_order=False, @@ -1942,16 +1970,26 @@ def test_list_executions_by_function(executor): def test_stop_execution(executor, mock_store): """Test stop_execution method.""" - mock_execution = Mock() - mock_execution.is_complete = False - mock_store.load.return_value = mock_execution + # Create real execution instance with mocked start_input + mock_start_input = Mock() + mock_start_input.execution_name = "test-execution" + mock_start_input.function_name = "test-function" + + execution = Execution( + durable_execution_arn="test-arn", + start_input=mock_start_input, + operations=[Mock()], + ) + execution.is_complete = False + mock_store.load.return_value = execution - with patch.object(executor, "fail_execution") as mock_fail: - result = executor.stop_execution("test-arn") + result = executor.stop_execution("test-arn") mock_store.load.assert_called_once_with("test-arn") - mock_fail.assert_called_once() + mock_store.update.assert_called_once_with(execution) assert result.stop_timestamp is not None + assert execution.is_complete is True + assert execution.close_status == ExecutionStatus.STOPPED def test_stop_execution_already_complete(executor, mock_store): @@ -1966,16 +2004,35 @@ def test_stop_execution_already_complete(executor, mock_store): def test_stop_execution_with_custom_error(executor, mock_store): """Test stop_execution with custom error.""" - mock_execution = Mock() - mock_execution.is_complete = False - mock_store.load.return_value = mock_execution + # Create real execution instance with mocked start_input + mock_start_input = Mock() + mock_start_input.execution_name = "test-execution" + mock_start_input.function_name = "test-function" + + execution = Execution( + durable_execution_arn="test-arn", + start_input=mock_start_input, + operations=[Mock()], + ) + execution.is_complete = False + mock_store.load.return_value = execution custom_error = ErrorObject.from_message("Custom stop error") - with patch.object(executor, "fail_execution") as mock_fail: - executor.stop_execution("test-arn", error=custom_error) + executor.stop_execution("test-arn", error=custom_error) + + mock_store.load.assert_called_once_with("test-arn") + mock_store.update.assert_called_once_with(execution) + assert execution.is_complete is True + assert execution.close_status == ExecutionStatus.STOPPED + assert execution.result.error == custom_error - mock_fail.assert_called_once_with("test-arn", custom_error) + +def test_get_execution_not_found(executor, mock_store): + mock_store.load.side_effect = KeyError("not found") + + with pytest.raises(ResourceNotFoundException): + executor.get_execution("test-arn") def test_get_execution_state(executor, mock_store): @@ -2741,3 +2798,65 @@ def test_schedule_callback_timeouts_exception_handling(executor, mock_store): # No timeouts should be scheduled assert len(executor._callback_timeouts) == 0 assert len(executor._callback_heartbeats) == 0 + + +def test_on_timed_out(executor, mock_store): + """Test on_timed_out method.""" + # Create real execution instance + mock_start_input = Mock() + mock_start_input.execution_name = "test-execution" + mock_start_input.function_name = "test-function" + + execution = Execution( + durable_execution_arn="test-arn", + start_input=mock_start_input, + operations=[Mock()], + ) + execution.is_complete = False + mock_store.load.return_value = execution + + error = ErrorObject.from_message("Execution timeout") + + with patch.object(executor, "_complete_events") as mock_complete_events: + executor.on_timed_out("test-arn", error) + + mock_store.load.assert_called_once_with(execution_arn="test-arn") + mock_store.update.assert_called_once_with(execution) + mock_complete_events.assert_called_once_with(execution_arn="test-arn") + assert execution.is_complete is True + assert execution.close_status == ExecutionStatus.TIMED_OUT + assert execution.result.error == error + + +def test_on_stopped(executor): + """Test on_stopped method.""" + error = ErrorObject.from_message("Execution stopped") + + with patch.object(executor, "fail_execution") as mock_fail: + executor.on_stopped("test-arn", error) + + mock_fail.assert_called_once_with("test-arn", error) + + +def test_notify_timed_out(): + """Test notify_timed_out method.""" + notifier = ExecutionNotifier() + observer = Mock() + notifier.add_observer(observer) + + error = ErrorObject.from_message("Timeout error") + notifier.notify_timed_out("test-arn", error) + + observer.on_timed_out.assert_called_once_with(execution_arn="test-arn", error=error) + + +def test_notify_stopped(): + """Test notify_stopped method.""" + notifier = ExecutionNotifier() + observer = Mock() + notifier.add_observer(observer) + + error = ErrorObject.from_message("Stop error") + notifier.notify_stopped("test-arn", error) + + observer.on_stopped.assert_called_once_with(execution_arn="test-arn", error=error) diff --git a/tests/model_test.py b/tests/model_test.py index 015ac9d..5605c3f 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -256,8 +256,8 @@ def test_list_durable_executions_request_serialization(): "FunctionVersion": "$LATEST", "DurableExecutionName": "test-execution", "StatusFilter": ["RUNNING", "SUCCEEDED"], - "TimeAfter": TIMESTAMP_2023_01_01_00_00, - "TimeBefore": TIMESTAMP_2023_01_02_00_00, + "StartedAfter": TIMESTAMP_2023_01_01_00_00, + "StartedBefore": TIMESTAMP_2023_01_02_00_00, "Marker": "marker-123", "MaxItems": 10, "ReverseOrder": True, @@ -268,8 +268,8 @@ def test_list_durable_executions_request_serialization(): assert request_obj.function_version == "$LATEST" assert request_obj.durable_execution_name == "test-execution" assert request_obj.status_filter == ["RUNNING", "SUCCEEDED"] - assert request_obj.time_after == TIMESTAMP_2023_01_01_00_00 - assert request_obj.time_before == TIMESTAMP_2023_01_02_00_00 + assert request_obj.started_after == TIMESTAMP_2023_01_01_00_00 + assert request_obj.started_before == TIMESTAMP_2023_01_02_00_00 assert request_obj.marker == "marker-123" assert request_obj.max_items == 10 assert request_obj.reverse_order is True @@ -291,8 +291,8 @@ def test_list_durable_executions_request_empty(): assert request_obj.function_version is None assert request_obj.durable_execution_name is None assert request_obj.status_filter is None - assert request_obj.time_after is None - assert request_obj.time_before is None + assert request_obj.started_after is None + assert request_obj.started_before is None assert request_obj.marker is None assert request_obj.max_items == 0 # Default value from Smithy assert request_obj.reverse_order is None @@ -1253,8 +1253,8 @@ def test_list_durable_executions_request_all_optional_fields(): function_version=None, durable_execution_name=None, status_filter=None, - time_after=None, - time_before=None, + started_after=None, + started_before=None, marker=None, max_items=None, reverse_order=None, @@ -1273,8 +1273,8 @@ def test_list_durable_executions_request_partial_fields(): function_version=None, durable_execution_name="test-execution", status_filter=None, - time_after=TIMESTAMP_2023_01_01_00_00, - time_before=None, + started_after=TIMESTAMP_2023_01_01_00_00, + started_before=None, marker="marker-123", max_items=10, reverse_order=None, @@ -1284,7 +1284,7 @@ def test_list_durable_executions_request_partial_fields(): expected_data = { "FunctionName": "my-function", "DurableExecutionName": "test-execution", - "TimeAfter": TIMESTAMP_2023_01_01_00_00, + "StartedAfter": TIMESTAMP_2023_01_01_00_00, "Marker": "marker-123", "MaxItems": 10, } diff --git a/tests/observer_test.py b/tests/observer_test.py index 9464452..4847eee 100644 --- a/tests/observer_test.py +++ b/tests/observer_test.py @@ -20,6 +20,8 @@ class MockExecutionObserver(ExecutionObserver): def __init__(self): self.on_completed_calls = [] self.on_failed_calls = [] + self.on_timed_out_calls = [] + self.on_stopped_calls = [] self.on_wait_timer_scheduled_calls = [] self.on_step_retry_scheduled_calls = [] self.on_callback_created_calls = [] @@ -30,6 +32,12 @@ def on_completed(self, execution_arn: str, result: str | None = None) -> None: def on_failed(self, execution_arn: str, error: ErrorObject) -> None: self.on_failed_calls.append((execution_arn, error)) + def on_timed_out(self, execution_arn: str, error: ErrorObject) -> None: + self.on_timed_out_calls.append((execution_arn, error)) + + def on_stopped(self, execution_arn: str, error: ErrorObject) -> None: + self.on_stopped_calls.append((execution_arn, error)) + def on_wait_timer_scheduled( self, execution_arn: str, operation_id: str, delay: float ) -> None: @@ -243,14 +251,19 @@ def test_mock_execution_observer_implementation(): observer = MockExecutionObserver() # Test all methods can be called + error = ErrorObject("Error", "Message", "data", ["trace"]) observer.on_completed("arn", "result") - observer.on_failed("arn", ErrorObject("Error", "Message", "data", ["trace"])) + observer.on_failed("arn", error) + observer.on_timed_out("arn", error) + observer.on_stopped("arn", error) observer.on_wait_timer_scheduled("arn", "op", 1.0) observer.on_step_retry_scheduled("arn", "op", 2.0) # Verify calls were recorded assert len(observer.on_completed_calls) == 1 assert len(observer.on_failed_calls) == 1 + assert len(observer.on_timed_out_calls) == 1 + assert len(observer.on_stopped_calls) == 1 assert len(observer.on_wait_timer_scheduled_calls) == 1 assert len(observer.on_step_retry_scheduled_calls) == 1 @@ -289,6 +302,8 @@ def test_execution_observer_abstract_method_coverage(): assert "on_completed" in method_names assert "on_failed" in method_names + assert "on_timed_out" in method_names + assert "on_stopped" in method_names assert "on_wait_timer_scheduled" in method_names assert "on_step_retry_scheduled" in method_names diff --git a/tests/stores/concurrent_test.py b/tests/stores/concurrent_test.py index bb06e77..8703d4f 100644 --- a/tests/stores/concurrent_test.py +++ b/tests/stores/concurrent_test.py @@ -1,13 +1,21 @@ -"""Concurrent access tests for InMemoryExecutionStore.""" +"""Concurrent access tests for execution stores.""" +import tempfile import threading from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +import pytest from aws_durable_execution_sdk_python_testing.execution import Execution from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput +from aws_durable_execution_sdk_python_testing.stores.filesystem import ( + FileSystemExecutionStore, +) from aws_durable_execution_sdk_python_testing.stores.memory import ( InMemoryExecutionStore, ) +from aws_durable_execution_sdk_python_testing.stores.sqlite import SQLiteExecutionStore def test_concurrent_save_load(): @@ -107,3 +115,164 @@ def list_executions(): assert len(results) == 6 final_list = store.list_all() assert len(final_list) == 3 + + +@pytest.fixture +def temp_storage_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +@pytest.fixture +def temp_db_path(): + """Create a temporary database file for testing.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as temp_file: + temp_path = Path(temp_file.name) + yield temp_path + if temp_path.exists(): + temp_path.unlink() + + +def test_concurrent_filesystem_save_load(temp_storage_dir): + """Test concurrent save and load operations with filesystem store.""" + store = FileSystemExecutionStore.create(temp_storage_dir) + results = [] + results_lock = threading.Lock() + + def save_execution(i: int): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"test-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"inv-{i}", + input=f'{{"test": {i}}}', + ) + execution = Execution.new(input_data) + execution.durable_execution_arn = f"arn-{i}" + execution.start() + store.save(execution) + with results_lock: + results.append(f"saved-{i}") + + def load_execution(i: int): + try: + execution = store.load(f"arn-{i}") + with results_lock: + results.append(f"loaded-{execution.start_input.execution_name}") + except KeyError: + with results_lock: + results.append(f"not-found-{i}") + + with ThreadPoolExecutor(max_workers=8) as executor: + # Submit save operations first + futures = [executor.submit(save_execution, i) for i in range(4)] + for future in as_completed(futures): + future.result() + + # Then submit load operations + futures = [executor.submit(load_execution, i) for i in range(4)] + for future in as_completed(futures): + future.result() + + assert len(results) == 8 + + +def test_concurrent_sqlite_save_load(temp_db_path): + """Test concurrent save and load operations with SQLite store.""" + store = SQLiteExecutionStore.create_and_initialize(temp_db_path) + results = [] + results_lock = threading.Lock() + + def save_execution(i: int): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"test-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"inv-{i}", + input=f'{{"test": {i}}}', + ) + execution = Execution.new(input_data) + execution.durable_execution_arn = f"arn-{i}" + execution.start() + store.save(execution) + with results_lock: + results.append(f"saved-{i}") + + def load_execution(i: int): + try: + execution = store.load(f"arn-{i}") + with results_lock: + results.append(f"loaded-{execution.start_input.execution_name}") + except KeyError: + with results_lock: + results.append(f"not-found-{i}") + + with ThreadPoolExecutor(max_workers=8) as executor: + # Submit save operations first + futures = [executor.submit(save_execution, i) for i in range(4)] + for future in as_completed(futures): + future.result() + + # Then submit load operations + futures = [executor.submit(load_execution, i) for i in range(4)] + for future in as_completed(futures): + future.result() + + assert len(results) == 8 + + +def test_concurrent_query_operations(): + """Test concurrent query operations on memory store.""" + store = InMemoryExecutionStore() + results = [] + results_lock = threading.Lock() + + # Pre-populate store with test data + for i in range(10): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name=f"function-{i % 3}", # 3 different functions + function_qualifier="$LATEST", + execution_name=f"exec-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"inv-{i}", + ) + execution = Execution.new(input_data) + execution.start() + # Complete some executions + if i % 4 == 0: + execution.complete_success("success") + store.save(execution) + + def query_store(query_type: str): + if query_type == "function": + executions, next_marker = store.query(function_name="function-1") + elif query_type == "status": + executions, next_marker = store.query(status_filter="SUCCEEDED") + elif query_type == "pagination": + executions, next_marker = store.query(limit=3, offset=2) + else: + executions, next_marker = store.query() + + with results_lock: + results.append(f"{query_type}-{len(executions)}") + + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [ + executor.submit(query_store, "function"), + executor.submit(query_store, "status"), + executor.submit(query_store, "pagination"), + executor.submit(query_store, "all"), + ] + for future in as_completed(futures): + future.result() + + assert len(results) == 4 diff --git a/tests/stores/filesystem_store_test.py b/tests/stores/filesystem_store_test.py index 6b613c8..7a0c803 100644 --- a/tests/stores/filesystem_store_test.py +++ b/tests/stores/filesystem_store_test.py @@ -280,3 +280,155 @@ def test_datetime_object_hook_converts_timestamp_fields(): expected_datetime = datetime.fromtimestamp(timestamp, tz=timezone.utc) assert result["start_timestamp"] == expected_datetime + + +def test_filesystem_execution_store_query_empty(store): + """Test query method with empty store.""" + executions, next_marker = store.query() + + assert executions == [] + assert next_marker is None + + +def test_filesystem_execution_store_query_by_function_name(store): + """Test query filtering by function name.""" + # Create executions with different function names + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="exec-1", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="function-b", + function_qualifier="$LATEST", + execution_name="exec-2", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + + exec1 = Execution.new(input1) + exec1.start() + exec2 = Execution.new(input2) + exec2.start() + store.save(exec1) + store.save(exec2) + + # Query for function-a only + executions, next_marker = store.query(function_name="function-a") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec1.durable_execution_arn + assert next_marker is None + + +def test_filesystem_execution_store_query_by_status(store): + """Test query filtering by status.""" + # Create running execution + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="running-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + exec1 = Execution.new(input1) + exec1.start() + + # Create completed execution + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="completed-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + exec2 = Execution.new(input2) + exec2.start() + exec2.complete_success("success result") + + store.save(exec1) + store.save(exec2) + + # Query for running executions + executions, next_marker = store.query(status_filter="RUNNING") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec1.durable_execution_arn + + # Query for succeeded executions + executions, next_marker = store.query(status_filter="SUCCEEDED") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec2.durable_execution_arn + + +def test_filesystem_execution_store_query_pagination(store): + """Test query pagination.""" + # Create multiple executions + executions = [] + for i in range(5): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"exec-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"invocation-{i}", + ) + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Test first page + executions, next_marker = store.query(limit=2, offset=0) + + assert len(executions) == 2 + assert next_marker is not None + + # Test last page + executions, next_marker = store.query(limit=2, offset=4) + + assert len(executions) == 1 + assert next_marker is None + + +def test_filesystem_execution_store_query_corrupted_file_handling( + store, temp_storage_dir +): + """Test that corrupted files are skipped during query.""" + # Create a valid execution + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + store.save(execution) + + # Create a corrupted file + corrupted_file = temp_storage_dir / "corrupted.json" + with open(corrupted_file, "w") as f: + f.write("invalid json content") + + # Query should skip the corrupted file and return only valid executions + executions, next_marker = store.query() + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == execution.durable_execution_arn diff --git a/tests/stores/memory_store_test.py b/tests/stores/memory_store_test.py index a58cf54..b4d5b3e 100644 --- a/tests/stores/memory_store_test.py +++ b/tests/stores/memory_store_test.py @@ -1,5 +1,6 @@ """Tests for InMemoryExecutionStore.""" +from datetime import UTC from unittest.mock import Mock import pytest @@ -148,3 +149,346 @@ def test_in_memory_execution_store_list_all_with_executions(): assert execution1 in result assert execution2 in result assert execution3 in result + + +def test_in_memory_execution_store_query_empty(): + """Test query method with empty store.""" + store = InMemoryExecutionStore() + + executions, next_marker = store.query() + + assert executions == [] + assert next_marker is None + + +def test_in_memory_execution_store_query_by_function_name(): + """Test query filtering by function name.""" + store = InMemoryExecutionStore() + + # Create executions with different function names + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="exec-1", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="function-b", + function_qualifier="$LATEST", + execution_name="exec-2", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + + exec1 = Execution.new(input1) + exec1.start() + exec2 = Execution.new(input2) + exec2.start() + store.save(exec1) + store.save(exec2) + + # Query for function-a only + executions, next_marker = store.query(function_name="function-a") + + assert len(executions) == 1 + assert executions[0] is exec1 + assert next_marker is None + + +def test_in_memory_execution_store_query_by_execution_name(): + """Test query filtering by execution name.""" + store = InMemoryExecutionStore() + + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="exec-alpha", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="exec-beta", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + + exec1 = Execution.new(input1) + exec1.start() + exec2 = Execution.new(input2) + exec2.start() + store.save(exec1) + store.save(exec2) + + executions, next_marker = store.query(execution_name="exec-beta") + + assert len(executions) == 1 + assert executions[0] is exec2 + + +def test_in_memory_execution_store_query_by_status(): + """Test query filtering by status.""" + store = InMemoryExecutionStore() + + # Create running execution + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="running-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + exec1 = Execution.new(input1) + exec1.start() + + # Create completed execution + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="completed-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + exec2 = Execution.new(input2) + exec2.start() + exec2.complete_success("success result") + + store.save(exec1) + store.save(exec2) + + # Query for running executions + executions, next_marker = store.query(status_filter="RUNNING") + + assert len(executions) == 1 + assert executions[0] is exec1 + + # Query for succeeded executions + executions, next_marker = store.query(status_filter="SUCCEEDED") + + assert len(executions) == 1 + assert executions[0] is exec2 + + +def test_in_memory_execution_store_query_pagination(): + """Test query pagination.""" + store = InMemoryExecutionStore() + + # Create multiple executions + executions = [] + for i in range(5): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"exec-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"invocation-{i}", + ) + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Test first page + executions, next_marker = store.query(limit=2, offset=0) + + assert len(executions) == 2 + assert next_marker is not None + + # Test second page + executions, next_marker = store.query(limit=2, offset=2) + + assert len(executions) == 2 + assert next_marker is not None + + # Test last page + executions, next_marker = store.query(limit=2, offset=4) + + assert len(executions) == 1 + assert next_marker is None + + +def test_in_memory_execution_store_query_sorting(): + """Test query sorting by timestamp.""" + store = InMemoryExecutionStore() + + # Create executions - they will be sorted by creation order + executions = [] + for i in range(3): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"exec-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"invocation-{i}", + ) + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Test ascending order (default) + executions, next_marker = store.query(reverse_order=False) + + assert len(executions) == 3 + + # Test descending order + executions, next_marker = store.query(reverse_order=True) + + assert len(executions) == 3 + + +def test_in_memory_execution_store_query_combined_filters(): + """Test query with multiple filters combined.""" + store = InMemoryExecutionStore() + + # Create various executions + inputs = [ + StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="target-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ), + StartDurableExecutionInput( + account_id="123456789012", + function_name="function-b", + function_qualifier="$LATEST", + execution_name="target-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ), + StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="other-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-3", + ), + ] + + executions = [] + for input_data in inputs: + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Query with both function_name and execution_name filters + filtered_executions, next_marker = store.query( + function_name="function-a", execution_name="target-exec" + ) + + assert len(filtered_executions) == 1 + assert filtered_executions[0] is executions[0] + + +def test_time_filtering_logic(): + """Test time filtering logic in process_query method.""" + from datetime import datetime + from unittest.mock import Mock + + store = InMemoryExecutionStore() + + # Create mock executions with different timestamps + exec1 = Mock() + exec1.start_input.function_name = "test-function" + exec1.start_input.execution_name = "exec1" + exec1.status = "RUNNING" + + exec2 = Mock() + exec2.start_input.function_name = "test-function" + exec2.start_input.execution_name = "exec2" + exec2.status = "RUNNING" + + exec3 = Mock() + exec3.start_input.function_name = "test-function" + exec3.start_input.execution_name = "exec3" + exec3.status = "RUNNING" + + # Use real datetime objects for timestamps + op1 = Mock() + op1.start_timestamp = datetime(2023, 1, 1, 12, 0, 0, tzinfo=UTC) + + op2 = Mock() + op2.start_timestamp = datetime(2023, 1, 2, 12, 0, 0, tzinfo=UTC) + + op3 = Mock() + op3.start_timestamp = datetime(2023, 1, 3, 12, 0, 0) # noqa: DTZ001 + + exec1.get_operation_execution_started.return_value = op1 + exec2.get_operation_execution_started.return_value = op2 + exec3.get_operation_execution_started.return_value = op3 + + executions = [exec1, exec2, exec3] + + # Test time_after filtering + filtered, _ = store.process_query( + executions, + started_after="1672617600.0", # 2023-01-01 24:00:00 UTC (between exec1 and exec2) + ) + assert len(filtered) == 2 + assert exec2 in filtered + assert exec3 in filtered + assert exec1 not in filtered + + # Test time_before filtering + filtered, _ = store.process_query( + executions, + started_before="1672617600.0", # 2023-01-01 24:00:00 UTC + ) + assert len(filtered) == 1 + assert exec1 in filtered + assert exec2 not in filtered + assert exec3 not in filtered + + # Test both time_after and time_before + filtered, _ = store.process_query( + executions, + started_after="1672617600.0", # 2023-01-02 00:00:00 UTC (between exec1 and exec2) + started_before="1672704000.0", # 2023-01-03 00:00:00 UTC (between exec2 and exec3) + ) + assert len(filtered) == 1 + assert exec2 in filtered + + # Test exception handling - exec with AttributeError + exec_error = Mock() + exec_error.start_input.function_name = "test-function" + exec_error.start_input.execution_name = "exec_error" + exec_error.status = "RUNNING" + exec_error.get_operation_execution_started.side_effect = AttributeError( + "No operation" + ) + + executions_with_error = [exec1, exec_error, exec2] + filtered, _ = store.process_query( + executions_with_error, + started_after="1672617600.0", # After exec1, before exec2 + ) + # exec_error should be filtered out due to exception, only exec2 should remain + assert len(filtered) == 1 + assert exec2 in filtered + assert exec_error not in filtered diff --git a/tests/stores/sqlite_store_test.py b/tests/stores/sqlite_store_test.py new file mode 100644 index 0000000..7c7feb4 --- /dev/null +++ b/tests/stores/sqlite_store_test.py @@ -0,0 +1,860 @@ +"""Tests for SQLiteExecutionStore.""" + +import tempfile +import time +from datetime import datetime, UTC +from pathlib import Path + +import pytest + +from aws_durable_execution_sdk_python_testing.exceptions import ( + ResourceNotFoundException, + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.execution import ( + ExecutionStatus, + Execution, +) +from aws_durable_execution_sdk_python_testing.model import StartDurableExecutionInput +from aws_durable_execution_sdk_python_testing.stores.sqlite import SQLiteExecutionStore + + +@pytest.fixture +def temp_db_path(): + """Create a temporary database file for testing.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as temp_file: + temp_path = Path(temp_file.name) + yield temp_path + # Cleanup + if temp_path.exists(): + temp_path.unlink() + + +@pytest.fixture +def store(temp_db_path): + """Create a SQLiteExecutionStore with temporary database.""" + return SQLiteExecutionStore.create_and_initialize(temp_db_path) + + +@pytest.fixture +def sample_execution(): + """Create a sample execution for testing.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + return Execution.new(input_data) + + +def test_sqlite_execution_store_save_and_load(store, sample_execution): + """Test saving and loading an execution.""" + sample_execution.start() + store.save(sample_execution) + loaded_execution = store.load(sample_execution.durable_execution_arn) + + assert ( + loaded_execution.durable_execution_arn == sample_execution.durable_execution_arn + ) + assert ( + loaded_execution.start_input.function_name + == sample_execution.start_input.function_name + ) + assert ( + loaded_execution.start_input.execution_name + == sample_execution.start_input.execution_name + ) + assert loaded_execution.token_sequence == sample_execution.token_sequence + assert loaded_execution.is_complete == sample_execution.is_complete + + +def test_sqlite_execution_store_load_nonexistent(store): + """Test loading a nonexistent execution raises KeyError.""" + with pytest.raises( + ResourceNotFoundException, match="Execution nonexistent-arn not found" + ): + store.load("nonexistent-arn") + + +def test_sqlite_execution_store_update(store, sample_execution): + """Test updating an execution.""" + sample_execution.start() + store.save(sample_execution) + + sample_execution.is_complete = True + sample_execution.close_status = ExecutionStatus.SUCCEEDED + for _ in range(5): + sample_execution.get_new_checkpoint_token() + store.update(sample_execution) + + loaded_execution = store.load(sample_execution.durable_execution_arn) + assert loaded_execution.is_complete is True + assert loaded_execution.token_sequence == 5 + + +def test_sqlite_execution_store_update_overwrites(store): + """Test that update overwrites existing execution.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution1 = Execution.new(input_data) + execution1.start() + execution2 = Execution.new(input_data) + execution2.start() + execution2.durable_execution_arn = execution1.durable_execution_arn + for _ in range(10): + execution2.get_new_checkpoint_token() + + store.save(execution1) + store.update(execution2) + + loaded_execution = store.load(execution1.durable_execution_arn) + assert loaded_execution.token_sequence == 10 + + +def test_sqlite_execution_store_multiple_executions(store): + """Test storing multiple executions.""" + input_data1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function-1", + function_qualifier="$LATEST", + execution_name="test-execution-1", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id-1", + ) + input_data2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function-2", + function_qualifier="$LATEST", + execution_name="test-execution-2", + execution_timeout_seconds=600, + execution_retention_period_days=14, + invocation_id="test-invocation-id-2", + ) + + execution1 = Execution.new(input_data1) + execution1.start() + execution2 = Execution.new(input_data2) + execution2.start() + + store.save(execution1) + store.save(execution2) + + loaded_execution1 = store.load(execution1.durable_execution_arn) + loaded_execution2 = store.load(execution2.durable_execution_arn) + + assert loaded_execution1.durable_execution_arn == execution1.durable_execution_arn + assert loaded_execution2.durable_execution_arn == execution2.durable_execution_arn + assert loaded_execution1.start_input.function_name == "test-function-1" + assert loaded_execution2.start_input.function_name == "test-function-2" + + +def test_sqlite_execution_store_list_all_empty(store): + """Test list_all method with empty store.""" + result = store.list_all() + assert result == [] + + +def test_sqlite_execution_store_list_all_with_executions(store): + """Test list_all method with multiple executions.""" + # Create test executions + executions = [] + for i in range(3): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name=f"test-function-{i}", + function_qualifier="$LATEST", + execution_name=f"test-execution-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"test-invocation-id-{i}", + ) + execution = Execution.new(input_data) + execution.start() + executions.append(execution) + store.save(execution) + + # Test list_all + result = store.list_all() + + assert len(result) == 3 + arns = {execution.durable_execution_arn for execution in result} + for execution in executions: + assert execution.durable_execution_arn in arns + + +def test_sqlite_execution_store_query_empty(store): + """Test query method with empty store.""" + executions, next_marker = store.query() + + assert executions == [] + assert next_marker is None + + +def test_sqlite_execution_store_query_by_function_name(store): + """Test query filtering by function name.""" + # Create executions with different function names + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="exec-1", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="function-b", + function_qualifier="$LATEST", + execution_name="exec-2", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + + exec1 = Execution.new(input1) + exec1.start() + exec2 = Execution.new(input2) + exec2.start() + store.save(exec1) + store.save(exec2) + + # Query for function-a only + executions, next_marker = store.query(function_name="function-a") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec1.durable_execution_arn + assert next_marker is None + + +def test_sqlite_execution_store_query_by_execution_name(store): + """Test query filtering by execution name.""" + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="exec-alpha", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="exec-beta", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + + exec1 = Execution.new(input1) + exec1.start() + exec2 = Execution.new(input2) + exec2.start() + store.save(exec1) + store.save(exec2) + + executions, next_marker = store.query(execution_name="exec-beta") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec2.durable_execution_arn + + +def test_sqlite_execution_store_query_by_status(store): + """Test query filtering by status.""" + # Create running execution + input1 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="running-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ) + exec1 = Execution.new(input1) + exec1.start() + + # Create completed execution + input2 = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="completed-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ) + exec2 = Execution.new(input2) + exec2.start() + exec2.complete_success("success result") + + store.save(exec1) + store.save(exec2) + + # Query for running executions + executions, next_marker = store.query(status_filter="RUNNING") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec1.durable_execution_arn + + # Query for succeeded executions + executions, next_marker = store.query(status_filter="SUCCEEDED") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == exec2.durable_execution_arn + + +def test_sqlite_execution_store_query_pagination(store): + """Test query pagination.""" + # Create multiple executions + executions = [] + for i in range(5): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"exec-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"invocation-{i}", + ) + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Test first page + executions, next_marker = store.query(limit=2, offset=0) + + assert len(executions) == 2 + assert next_marker is not None + + # Test second page + executions, next_marker = store.query(limit=2, offset=2) + + assert len(executions) == 2 + assert next_marker is not None + + # Test last page + executions, next_marker = store.query(limit=2, offset=4) + + assert len(executions) == 1 + assert next_marker is None + + +def test_sqlite_execution_store_query_sorting(store): + """Test query sorting by timestamp.""" + # Create executions - they will be sorted by creation order + executions = [] + for i in range(3): + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name=f"exec-{i}", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id=f"invocation-{i}", + ) + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Test ascending order (default) + executions, next_marker = store.query(reverse_order=False) + + assert len(executions) == 3 + + # Test descending order + executions, next_marker = store.query(reverse_order=True) + + assert len(executions) == 3 + + +def test_sqlite_execution_store_query_combined_filters(store): + """Test query with multiple filters combined.""" + # Create various executions + inputs = [ + StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="target-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-1", + ), + StartDurableExecutionInput( + account_id="123456789012", + function_name="function-b", + function_qualifier="$LATEST", + execution_name="target-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-2", + ), + StartDurableExecutionInput( + account_id="123456789012", + function_name="function-a", + function_qualifier="$LATEST", + execution_name="other-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="invocation-3", + ), + ] + + executions = [] + for input_data in inputs: + exec_obj = Execution.new(input_data) + exec_obj.start() + executions.append(exec_obj) + store.save(exec_obj) + + # Query with both function_name and execution_name filters + filtered_executions, next_marker = store.query( + function_name="function-a", execution_name="target-exec" + ) + + assert len(filtered_executions) == 1 + assert ( + filtered_executions[0].durable_execution_arn + == executions[0].durable_execution_arn + ) + + +def test_sqlite_execution_store_database_initialization(temp_db_path): + """Test that database is properly initialized with schema.""" + store = SQLiteExecutionStore.create_and_initialize(temp_db_path) + + # Verify database file exists + assert temp_db_path.exists() + + # Verify we can perform basic operations + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + + store.save(execution) + loaded = store.load(execution.durable_execution_arn) + assert loaded.durable_execution_arn == execution.durable_execution_arn + + +def test_sqlite_execution_store_custom_db_path(): + """Test creating store with custom database path.""" + with tempfile.TemporaryDirectory() as temp_dir: + custom_path = Path(temp_dir) / "custom" / "executions.db" + store = SQLiteExecutionStore.create_and_initialize(custom_path) + + # Directory should be created + assert custom_path.parent.exists() + assert custom_path.exists() + + # Verify functionality + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + + store.save(execution) + loaded = store.load(execution.durable_execution_arn) + assert loaded.durable_execution_arn == execution.durable_execution_arn + + +def test_sqlite_execution_store_failed_execution_status(store): + """Test that failed executions are properly stored and queried.""" + from aws_durable_execution_sdk_python.lambda_service import ErrorObject + + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="failed-exec", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + + # Complete with failure + error = ErrorObject( + type="TestError", message="Test failure", data=None, stack_trace=None + ) + execution.complete_fail(error) + + store.save(execution) + + # Query for failed executions + executions, next_marker = store.query(status_filter="FAILED") + + assert len(executions) == 1 + assert executions[0].durable_execution_arn == execution.durable_execution_arn + assert executions[0].is_complete is True + + +def test_sqlite_execution_store_error_handling(temp_db_path): + """Test error handling for database operations.""" + store = SQLiteExecutionStore.create_and_initialize(temp_db_path) + + # Test with corrupted database by removing the file after creation + temp_db_path.unlink() + + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + + # Should raise RuntimeError for database operations + with pytest.raises(RuntimeError, match="Failed to save execution"): + store.save(execution) + + +def test_sqlite_execution_store_invalid_execution_data(store): + """Test handling of invalid execution data.""" + # Create execution and start it + execution = Execution.new( + StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + ) + execution.start() + + # Corrupt the execution object to trigger serialization error + execution.start_input = None + + with pytest.raises(ValueError, match="Invalid execution data"): + store.save(execution) + + +def test_sqlite_execution_store_sql_injection_protection(store): + """Test SQL injection protection in query parameters.""" + # Create test execution + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + store.save(execution) + + # Try SQL injection attempts - should be safely parameterized + malicious_inputs = [ + "'; DROP TABLE executions; --", + "test' OR '1'='1", + "test'; DELETE FROM executions; --", + "test' UNION SELECT * FROM executions --", + ] + + for malicious_input in malicious_inputs: + # These should return empty results, not cause SQL errors + executions, _ = store.query(function_name=malicious_input) + assert executions == [] + + executions, _ = store.query(execution_name=malicious_input) + assert executions == [] + + executions, _ = store.query(status_filter=malicious_input) + assert executions == [] + + +def test_sqlite_execution_store_time_filtering(store): + """Test time-based filtering with edge cases.""" + + # Create executions at different times + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + + execution1 = Execution.new(input_data) + execution1.start() + store.save(execution1) + + # Small delay to ensure different timestamps + time.sleep(0.01) + + execution2 = Execution.new(input_data) + execution2.start() + store.save(execution2) + + # Get timestamps as ISO strings + start_time_iso = ( + execution1.get_operation_execution_started().start_timestamp.isoformat() + ) + mid_time = ( + execution1.get_operation_execution_started().start_timestamp.timestamp() + 0.005 + ) + mid_time_iso = datetime.fromtimestamp(mid_time, tz=UTC).isoformat() + end_time_iso = datetime.fromtimestamp( + execution2.get_operation_execution_started().start_timestamp.timestamp() + 1, + tz=UTC, + ).isoformat() + + # Test started_after filter + executions, _ = store.query(started_after=mid_time_iso) + assert len(executions) == 1 + + # Test started_before filter + executions, _ = store.query(started_before=mid_time_iso) + assert len(executions) == 1 + + # Test both filters + executions, _ = store.query( + started_after=start_time_iso, started_before=end_time_iso + ) + assert len(executions) == 2 + + +def test_sqlite_execution_store_corrupted_data_handling(store, temp_db_path): + """Test handling of corrupted JSON data in database.""" + import sqlite3 + + # Insert corrupted JSON data directly + with sqlite3.connect(temp_db_path) as conn: + conn.execute( + """ + INSERT INTO executions + (durable_execution_arn, function_name, execution_name, status, start_timestamp, end_timestamp, data) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + "corrupted-arn", + "test-function", + "test-execution", + "RUNNING", + 1234567890.0, + None, + "invalid json data {{{", + ), + ) + + # Loading corrupted data should raise ValueError + with pytest.raises(ValueError, match="Corrupted execution data"): + store.load("corrupted-arn") + + # Query should skip corrupted records and continue + executions, _ = store.query() + # Should not include the corrupted record + assert all(exec.durable_execution_arn != "corrupted-arn" for exec in executions) + + +def test_sqlite_execution_store_get_execution_metadata(store): + """Test get_execution_metadata method.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + store.save(execution) + + # Test existing execution + metadata = store.get_execution_metadata(execution.durable_execution_arn) + assert metadata is not None + assert metadata["durable_execution_arn"] == execution.durable_execution_arn + assert metadata["function_name"] == "test-function" + assert metadata["execution_name"] == "test-execution" + assert metadata["status"] == "RUNNING" + assert metadata["start_timestamp"] is not None + + # Test nonexistent execution + metadata = store.get_execution_metadata("nonexistent-arn") + assert metadata is None + + +def test_sqlite_execution_store_database_init_error(): + """Test database initialization error handling.""" + # Try to create database in non-existent directory without permission + invalid_path = Path("/invalid/path/that/does/not/exist/test.db") + + with pytest.raises(RuntimeError, match="Failed to initialize database"): + store = SQLiteExecutionStore(invalid_path) + store._init_db() + + +def test_sqlite_execution_store_query_invalid_parameters(store): + """Test query with invalid parameters.""" + # Test with invalid time parameters + with pytest.raises( + InvalidParameterValueException, match="Invalid query parameters" + ): + store.query(started_after="invalid_timestamp") + + with pytest.raises( + InvalidParameterValueException, match="Invalid query parameters" + ): + store.query(started_before="not_a_number") + + +def test_sqlite_execution_store_query_no_limit_no_offset(store): + """Test query without limit and offset parameters.""" + # Create test execution + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + store.save(execution) + + # Query without limit should use different code path + executions, next_marker = store.query() + assert len(executions) == 1 + assert next_marker is None + + +def test_sqlite_execution_store_query_with_end_timestamp(store): + """Test execution with end timestamp.""" + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + execution.complete_success("result") # This should set end_timestamp + store.save(execution) + + loaded = store.load(execution.durable_execution_arn) + assert loaded.is_complete is True + + +def test_sqlite_execution_store_metadata_error_handling(temp_db_path): + """Test metadata retrieval error handling.""" + store = SQLiteExecutionStore.create_and_initialize(temp_db_path) + + # Remove database file to trigger error + temp_db_path.unlink() + + with pytest.raises(RuntimeError, match="Failed to get metadata"): + store.get_execution_metadata("test-arn") + + +def test_sqlite_execution_store_load_error_handling(temp_db_path): + """Test load error handling.""" + store = SQLiteExecutionStore.create_and_initialize(temp_db_path) + + # Remove database file to trigger error + temp_db_path.unlink() + + with pytest.raises(RuntimeError, match="Failed to load execution"): + store.load("test-arn") + + +def test_sqlite_execution_store_query_with_corrupted_data_warning( + store, temp_db_path, capsys +): + """Test that corrupted data in query results prints warning and continues.""" + import sqlite3 + + # Create a valid execution first + input_data = StartDurableExecutionInput( + account_id="123456789012", + function_name="test-function", + function_qualifier="$LATEST", + execution_name="test-execution", + execution_timeout_seconds=300, + execution_retention_period_days=7, + invocation_id="test-invocation-id", + ) + execution = Execution.new(input_data) + execution.start() + store.save(execution) + + # Insert corrupted JSON data directly + with sqlite3.connect(temp_db_path) as conn: + conn.execute( + """ + INSERT INTO executions + (durable_execution_arn, function_name, execution_name, status, start_timestamp, end_timestamp, data) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + "corrupted-arn-2", + "test-function", + "test-execution", + "RUNNING", + 1234567890.0, + None, + "invalid json data {{{", + ), + ) + + # Query should skip corrupted records and print warning + executions, _ = store.query() + + # Should get the valid execution, skip the corrupted one + assert len(executions) == 1 + assert executions[0].durable_execution_arn == execution.durable_execution_arn + + # Check that warning was printed + captured = capsys.readouterr() + assert "Warning: Skipping corrupted execution corrupted-arn-2" in captured.out diff --git a/tests/web/handlers_test.py b/tests/web/handlers_test.py index bd8b5d4..fbf016d 100644 --- a/tests/web/handlers_test.py +++ b/tests/web/handlers_test.py @@ -1297,8 +1297,8 @@ def test_list_durable_executions_handler_success(): function_version=None, execution_name=None, status_filter=None, - time_after=None, - time_before=None, + started_after=None, + started_before=None, marker=None, max_items=None, reverse_order=False, @@ -1342,8 +1342,8 @@ def test_list_durable_executions_handler_with_filters(): "FunctionVersion": ["$LATEST"], "DurableExecutionName": ["filtered-execution"], "StatusFilter": ["SUCCEEDED"], - "TimeAfter": ["2023-01-01T00:00:00Z"], - "TimeBefore": ["2023-01-01T23:59:59Z"], + "StartedAfter": ["2023-01-01T00:00:00Z"], + "StartedBefore": ["2023-01-01T23:59:59Z"], "Marker": ["start-token"], "MaxItems": ["10"], "ReverseOrder": ["true"], @@ -1376,8 +1376,8 @@ def test_list_durable_executions_handler_with_filters(): function_version="$LATEST", execution_name="filtered-execution", status_filter="SUCCEEDED", - time_after="2023-01-01T00:00:00Z", - time_before="2023-01-01T23:59:59Z", + started_after="2023-01-01T00:00:00Z", + started_before="2023-01-01T23:59:59Z", marker="start-token", max_items=10, reverse_order=True, @@ -1437,8 +1437,8 @@ def test_list_durable_executions_handler_pagination(): function_version=None, execution_name=None, status_filter=None, - time_after=None, - time_before=None, + started_after=None, + started_before=None, marker="current-page-marker", max_items=3, reverse_order=False, @@ -1538,8 +1538,8 @@ def test_list_durable_executions_handler_dataclass_serialization(): function_version=None, execution_name=None, status_filter=None, - time_after=None, - time_before=None, + started_after=None, + started_before=None, marker=None, max_items=5, reverse_order=False, @@ -1716,8 +1716,8 @@ def test_list_durable_executions_by_function_handler_success(): qualifier=None, execution_name=None, status_filter=None, - time_after=None, - time_before=None, + started_after=None, + started_before=None, marker=None, max_items=None, reverse_order=False, @@ -1763,8 +1763,8 @@ def test_list_durable_executions_by_function_handler_with_filters(): "functionVersion": ["$LATEST"], "executionName": ["filtered-execution"], "statusFilter": ["SUCCEEDED"], - "timeAfter": ["2023-01-01T00:00:00Z"], - "timeBefore": ["2023-01-01T23:59:59Z"], + "startedAfter": ["2023-01-01T00:00:00Z"], + "startedBefore": ["2023-01-01T23:59:59Z"], "marker": ["start-token"], "maxItems": ["5"], "reverseOrder": ["true"], @@ -1797,8 +1797,8 @@ def test_list_durable_executions_by_function_handler_with_filters(): qualifier="$LATEST", execution_name="filtered-execution", status_filter="SUCCEEDED", - time_after="2023-01-01T00:00:00Z", - time_before="2023-01-01T23:59:59Z", + started_after="2023-01-01T00:00:00Z", + started_before="2023-01-01T23:59:59Z", marker="start-token", max_items=5, reverse_order=True, @@ -1866,8 +1866,8 @@ def test_list_durable_executions_by_function_handler_dataclass_serialization(): qualifier="$LATEST", execution_name=None, status_filter=None, - time_after=None, - time_before=None, + started_after=None, + started_before=None, marker=None, max_items=10, reverse_order=False,