Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ def process(
"""Process an operation update and return the transformed operation."""
raise NotImplementedError

def _get_start_time(
self, current_operation: Operation | None
) -> datetime.datetime | None:
start_time: datetime.datetime | None = (
current_operation.start_timestamp
if current_operation
else datetime.datetime.now(tz=datetime.UTC)
)
return start_time

def _get_end_time(
self, current_operation: Operation | None, status: OperationStatus
) -> datetime.datetime | None:
Expand Down Expand Up @@ -116,35 +126,17 @@ def _create_invoke_details(
return ChainedInvokeDetails(result=update.payload, error=update.error)
return None

def _create_wait_details(
self, update: OperationUpdate, current_operation: Operation | None
) -> WaitDetails | None:
"""Create WaitDetails from OperationUpdate."""
if update.operation_type == OperationType.WAIT and update.wait_options:
if current_operation and current_operation.wait_details:
scheduled_end_timestamp = (
current_operation.wait_details.scheduled_end_timestamp
)
else:
scheduled_end_timestamp = datetime.datetime.now(
tz=datetime.UTC
) + timedelta(seconds=update.wait_options.wait_seconds)
return WaitDetails(scheduled_end_timestamp=scheduled_end_timestamp)
return None

def _translate_update_to_operation(
self,
update: OperationUpdate,
current_operation: Operation | None,
status: OperationStatus,
) -> Operation:
"""Transform OperationUpdate to Operation, always creating new Operation."""
start_time = (
current_operation.start_timestamp
if current_operation
else datetime.datetime.now(tz=datetime.UTC)
start_time: datetime.datetime | None = self._get_start_time(current_operation)
end_time: datetime.datetime | None = self._get_end_time(
current_operation, status
)
end_time = self._get_end_time(current_operation, status)

execution_details = self._create_execution_details(update)
context_details = self._create_context_details(update)
Expand All @@ -169,3 +161,19 @@ def _translate_update_to_operation(
chained_invoke_details=invoke_details,
wait_details=wait_details,
)

def _create_wait_details(
self, update: OperationUpdate, current_operation: Operation | None
) -> WaitDetails | None:
"""Create WaitDetails from OperationUpdate."""
if update.operation_type == OperationType.WAIT and update.wait_options:
if current_operation and current_operation.wait_details:
scheduled_end_timestamp = (
current_operation.wait_details.scheduled_end_timestamp
)
else:
scheduled_end_timestamp = datetime.datetime.now(
tz=datetime.UTC
) + timedelta(seconds=update.wait_options.wait_seconds)
return WaitDetails(scheduled_end_timestamp=scheduled_end_timestamp)
return None
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@

from __future__ import annotations

import datetime
from typing import TYPE_CHECKING

from aws_durable_execution_sdk_python.lambda_service import (
Operation,
OperationAction,
OperationStatus,
OperationUpdate,
CallbackDetails,
OperationType,
)

from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import (
OperationProcessor,
)
from aws_durable_execution_sdk_python_testing.exceptions import (
InvalidParameterValueException,
)

from aws_durable_execution_sdk_python_testing.token import CallbackToken

if TYPE_CHECKING:
from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
Expand All @@ -36,14 +38,46 @@ def process(
"""Process CALLBACK operation update with scheduler integration for activities."""
match update.action:
case OperationAction.START:
# TODO: create CallbackToken (see token module). Add Observer/Notifier for on_callback_created possibly,
# but token might well have enough so don't need to maintain token list on execution itself
return self._translate_update_to_operation(
update=update,
current_operation=current_op,
status=OperationStatus.STARTED,
callback_token: CallbackToken = CallbackToken(
execution_arn=execution_arn,
operation_id=update.operation_id,
)

notifier.notify_callback_created(
execution_arn=execution_arn,
operation_id=update.operation_id,
callback_token=callback_token,
)

callback_id: str = callback_token.to_str()

callback_details: CallbackDetails | None = (
CallbackDetails(
callback_id=callback_id,
result=update.payload,
error=update.error,
)
if update.operation_type == OperationType.CALLBACK
else None
)
status: OperationStatus = OperationStatus.STARTED
start_time: datetime.datetime | None = self._get_start_time(current_op)
end_time: datetime.datetime | None = self._get_end_time(
current_op, status
)
operation: Operation = Operation(
operation_id=update.operation_id,
parent_id=update.parent_id,
name=update.name,
start_timestamp=start_time,
end_timestamp=end_time,
operation_type=update.operation_type,
status=status,
sub_type=update.sub_type,
callback_details=callback_details,
)

return operation
case _:
msg: str = "Invalid action for CALLBACK operation."

raise InvalidParameterValueException(msg)
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
VALID_ACTIONS_FOR_CALLBACK = frozenset(
[
OperationAction.START,
OperationAction.CANCEL,
]
)

Expand All @@ -41,14 +40,6 @@ def validate(current_state: Operation | None, update: OperationUpdate) -> None:
"Cannot start a CALLBACK that already exist."
)
raise InvalidParameterValueException(msg_callback_exists)
case OperationAction.CANCEL:
if (
current_state is None
or current_state.status
not in CallbackOperationValidator._ALLOWED_STATUS_TO_CANCEL
):
msg_callback_cancel: str = "Cannot cancel a CALLBACK that does not exist or has already completed."
raise InvalidParameterValueException(msg_callback_cancel)
case _:
msg_callback_invalid: str = "Invalid CALLBACK action."
raise InvalidParameterValueException(msg_callback_invalid)
msg: str = "Invalid action for CALLBACK operation."
raise InvalidParameterValueException(msg)
69 changes: 68 additions & 1 deletion src/aws_durable_execution_sdk_python_testing/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from aws_durable_execution_sdk_python_testing.model import (
StartDurableExecutionInput,
)
from aws_durable_execution_sdk_python_testing.token import CheckpointToken
from aws_durable_execution_sdk_python_testing.token import (
CheckpointToken,
CallbackToken,
)


class Execution:
Expand Down Expand Up @@ -203,6 +206,18 @@ def find_operation(self, operation_id: str) -> tuple[int, Operation]:
msg: str = f"Attempting to update state of an Operation [{operation_id}] that doesn't exist"
raise IllegalStateException(msg)

def find_callback_operation(self, callback_id: str) -> tuple[int, Operation]:
"""Find callback operation by callback_id, return index and operation."""
for i, operation in enumerate(self.operations):
if (
operation.operation_type == OperationType.CALLBACK
and operation.callback_details
and operation.callback_details.callback_id == callback_id
):
return i, operation
msg: str = f"Callback operation with callback_id [{callback_id}] not found"
raise IllegalStateException(msg)

def complete_wait(self, operation_id: str) -> Operation:
"""Complete WAIT operation when timer fires."""
index, operation = self.find_operation(operation_id)
Expand Down Expand Up @@ -260,3 +275,55 @@ def complete_retry(self, operation_id: str) -> Operation:
# Assign
self.operations[index] = updated_operation
return updated_operation

def complete_callback_success(
self, callback_id: str, result: bytes | None = None
) -> Operation:
"""Complete CALLBACK operation with success."""
index, operation = self.find_callback_operation(callback_id)
if operation.status != OperationStatus.STARTED:
msg: str = f"Callback operation [{callback_id}] is not in STARTED state"
raise IllegalStateException(msg)

with self._state_lock:
self._token_sequence += 1
updated_callback_details = None
if operation.callback_details:
updated_callback_details = replace(
operation.callback_details,
result=result.decode() if result else None,
)

self.operations[index] = replace(
operation,
status=OperationStatus.SUCCEEDED,
end_timestamp=datetime.now(UTC),
callback_details=updated_callback_details,
)
return self.operations[index]

def complete_callback_failure(
self, callback_id: str, error: ErrorObject
) -> Operation:
"""Complete CALLBACK operation with failure."""
index, operation = self.find_callback_operation(callback_id)

if operation.status != OperationStatus.STARTED:
msg: str = f"Callback operation [{callback_id}] is not in STARTED state"
raise IllegalStateException(msg)

with self._state_lock:
self._token_sequence += 1
updated_callback_details = None
if operation.callback_details:
updated_callback_details = replace(
operation.callback_details, error=error
)

self.operations[index] = replace(
operation,
status=OperationStatus.FAILED,
end_timestamp=datetime.now(UTC),
callback_details=updated_callback_details,
)
return self.operations[index]
Loading