Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def process(
"There is no error details but EXECUTION checkpoint action is not SUCCEED."
)
)
# All EXECUTION failures go through normal fail path
# Timeout/Stop status is set by executor based on the operation that caused it
notifier.notify_failed(execution_arn=execution_arn, error=error)
# TODO: Svc doesn't actually create checkpoint for EXECUTION. might have to for localrunner though.
return None
66 changes: 65 additions & 1 deletion src/aws_durable_execution_sdk_python_testing/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from dataclasses import replace
from datetime import UTC, datetime
from enum import Enum
from threading import Lock
from typing import Any
from uuid import uuid4
Expand All @@ -20,11 +21,12 @@
OperationUpdate,
)

# Import AWS exceptions
from aws_durable_execution_sdk_python_testing.exceptions import (
IllegalStateException,
InvalidParameterValueException,
)

# Import AWS exceptions
from aws_durable_execution_sdk_python_testing.model import (
StartDurableExecutionInput,
)
Expand All @@ -34,6 +36,16 @@
)


class ExecutionStatus(Enum):
"""Execution status for API responses."""

RUNNING = "RUNNING"
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
STOPPED = "STOPPED"
TIMED_OUT = "TIMED_OUT"


class Execution:
"""Execution state."""

Expand All @@ -55,12 +67,24 @@ def __init__(
self.is_complete: bool = False
self.result: DurableExecutionInvocationOutput | None = None
self.consecutive_failed_invocation_attempts: int = 0
self.close_status: ExecutionStatus | None = None

@property
def token_sequence(self) -> int:
"""Get current token sequence value."""
return self._token_sequence

def current_status(self) -> ExecutionStatus:
"""Get execution status."""
if not self.is_complete:
return ExecutionStatus.RUNNING

if not self.close_status:
msg: str = "close_status must be set when execution is complete"
raise IllegalStateException(msg)

return self.close_status

@staticmethod
def new(input: StartDurableExecutionInput) -> Execution: # noqa: A002
# make a nicer arn
Expand All @@ -82,6 +106,7 @@ def to_dict(self) -> dict[str, Any]:
"IsComplete": self.is_complete,
"Result": self.result.to_dict() if self.result else None,
"ConsecutiveFailedInvocationAttempts": self.consecutive_failed_invocation_attempts,
"CloseStatus": self.close_status.value if self.close_status else None,
}

@classmethod
Expand Down Expand Up @@ -115,6 +140,10 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
execution.consecutive_failed_invocation_attempts = data[
"ConsecutiveFailedInvocationAttempts"
]
close_status_str = data.get("CloseStatus")
execution.close_status = (
ExecutionStatus(close_status_str) if close_status_str else None
)

return execution

Expand Down Expand Up @@ -187,16 +216,40 @@ def has_pending_operations(self, execution: Execution) -> bool:
return False

def complete_success(self, result: str | None) -> None:
"""Complete execution successfully (DecisionType.COMPLETE_WORKFLOW_EXECUTION)."""
self.result = DurableExecutionInvocationOutput(
status=InvocationStatus.SUCCEEDED, result=result
)
self.is_complete = True
self.close_status = ExecutionStatus.SUCCEEDED
self._end_execution(OperationStatus.SUCCEEDED)

def complete_fail(self, error: ErrorObject) -> None:
"""Complete execution with failure (DecisionType.FAIL_WORKFLOW_EXECUTION)."""
self.result = DurableExecutionInvocationOutput(
status=InvocationStatus.FAILED, error=error
)
self.is_complete = True
self.close_status = ExecutionStatus.FAILED
self._end_execution(OperationStatus.FAILED)

def complete_timeout(self, error: ErrorObject) -> None:
"""Complete execution with timeout."""
self.result = DurableExecutionInvocationOutput(
status=InvocationStatus.FAILED, error=error
)
self.is_complete = True
self.close_status = ExecutionStatus.TIMED_OUT
self._end_execution(OperationStatus.TIMED_OUT)

def complete_stopped(self, error: ErrorObject) -> None:
"""Complete execution as terminated (TerminateWorkflowExecutionV2Request)."""
self.result = DurableExecutionInvocationOutput(
status=InvocationStatus.FAILED, error=error
)
self.is_complete = True
self.close_status = ExecutionStatus.STOPPED
self._end_execution(OperationStatus.STOPPED)

def find_operation(self, operation_id: str) -> tuple[int, Operation]:
"""Find operation by ID, return index and operation."""
Expand Down Expand Up @@ -327,3 +380,14 @@ def complete_callback_failure(
callback_details=updated_callback_details,
)
return self.operations[index]

def _end_execution(self, status: OperationStatus) -> None:
"""Set the end_timestamp on the main EXECUTION operation when execution completes."""
execution_op: Operation = self.get_operation_execution_started()
if execution_op.operation_type == OperationType.EXECUTION:
with self._state_lock:
self.operations[0] = replace(
execution_op,
status=status,
end_timestamp=datetime.now(UTC),
)
Loading
Loading