Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
OperationUpdate,
CallbackDetails,
OperationType,
CallbackOptions,
)
from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import (
OperationProcessor,
Expand Down Expand Up @@ -43,12 +44,6 @@ def process(
operation_id=update.operation_id,
)

notifier.notify_callback_created(
execution_arn=execution_arn,
operation_id=update.operation_id,
callback_token=callback_token,
)

callback_id: str = callback_token.to_str()

callback_details: CallbackDetails | None = (
Expand All @@ -60,11 +55,15 @@ def process(
if update.operation_type == OperationType.CALLBACK
else None
)

status: OperationStatus = OperationStatus.STARTED

start_time: datetime.datetime | None = self._get_start_time(current_op)

end_time: datetime.datetime | None = self._get_end_time(
current_op, status
)

operation: Operation = Operation(
operation_id=update.operation_id,
parent_id=update.parent_id,
Expand All @@ -76,7 +75,14 @@ def process(
sub_type=update.sub_type,
callback_details=callback_details,
)
callback_options: CallbackOptions | None = update.callback_options

notifier.notify_callback_created(
execution_arn=execution_arn,
operation_id=update.operation_id,
callback_options=callback_options,
callback_token=callback_token,
)
return operation
case _:
msg: str = "Invalid action for CALLBACK operation."
Expand Down
26 changes: 26 additions & 0 deletions src/aws_durable_execution_sdk_python_testing/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,32 @@ def complete_callback_failure(
)
return self.operations[index]

def complete_callback_timeout(
self, callback_id: str, error: ErrorObject
) -> Operation:
"""Complete CALLBACK operation with timeout."""
index, operation = self.find_callback_operation(callback_id)

if operation.status != OperationStatus.STARTED:
msg: str = f"Callback operation [{callback_id}] is not in STARTED state"
raise IllegalStateException(msg)

with self._state_lock:
self._token_sequence += 1
updated_callback_details = None
if operation.callback_details:
updated_callback_details = replace(
operation.callback_details, error=error
)

self.operations[index] = replace(
operation,
status=OperationStatus.TIMED_OUT,
end_timestamp=datetime.now(UTC),
callback_details=updated_callback_details,
)
return self.operations[index]

def _end_execution(self, status: OperationStatus) -> None:
"""Set the end_timestamp on the main EXECUTION operation when execution completes."""
execution_op: Operation = self.get_operation_execution_started()
Expand Down
76 changes: 30 additions & 46 deletions src/aws_durable_execution_sdk_python_testing/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
OperationUpdate,
OperationStatus,
OperationType,
CallbackOptions,
)

from aws_durable_execution_sdk_python_testing.exceptions import (
Expand Down Expand Up @@ -57,6 +58,7 @@

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from concurrent.futures import Future

from aws_durable_execution_sdk_python_testing.checkpoint.processor import (
CheckpointProcessor,
Expand Down Expand Up @@ -84,10 +86,8 @@ def __init__(
self._invoker = invoker
self._checkpoint_processor = checkpoint_processor
self._completion_events: dict[str, Event] = {}
self._callback_timeouts: dict[str, Event] = {} # callback_id -> timeout event
self._callback_heartbeats: dict[
str, Event
] = {} # callback_id -> heartbeat event
self._callback_timeouts: dict[str, Future] = {}
self._callback_heartbeats: dict[str, Future] = {}

def start_execution(
self,
Expand Down Expand Up @@ -1011,7 +1011,11 @@ def retry_handler() -> None:
)

def on_callback_created(
self, execution_arn: str, operation_id: str, callback_token: CallbackToken
self,
execution_arn: str,
operation_id: str,
callback_options: CallbackOptions | None,
callback_token: CallbackToken,
) -> None:
"""Handle callback creation. Observer method triggered by notifier."""
callback_id = callback_token.to_str()
Expand All @@ -1023,34 +1027,19 @@ def on_callback_created(
)

# Schedule callback timeouts if configured
self._schedule_callback_timeouts(execution_arn, operation_id, callback_id)
self._schedule_callback_timeouts(execution_arn, callback_options, callback_id)

# endregion ExecutionObserver

# region Callback Timeouts
def _schedule_callback_timeouts(
self, execution_arn: str, operation_id: str, callback_id: str
self,
execution_arn: str,
callback_options: CallbackOptions | None,
callback_id: str,
) -> None:
"""Schedule callback timeout and heartbeat timeout if configured."""
try:
execution = self.get_execution(execution_arn)
_, operation = execution.find_operation(operation_id)

if not operation.callback_details:
return

# Find the callback options from the operation update that created this callback
# We need to look at the checkpoint updates to find the original callback options
callback_options = None
for update in execution.updates:
if (
update.operation_id == operation_id
and update.callback_options
and update.action.value == "START"
):
callback_options = update.callback_options
break

if not callback_options:
return

Expand All @@ -1062,27 +1051,25 @@ def _schedule_callback_timeouts(
def timeout_handler():
self._on_callback_timeout(execution_arn, callback_id)

timeout_event = self._scheduler.create_event()
self._callback_timeouts[callback_id] = timeout_event
self._scheduler.call_later(
timeout_future = self._scheduler.call_later(
timeout_handler,
delay=callback_options.timeout_seconds,
completion_event=completion_event,
)
self._callback_timeouts[callback_id] = timeout_future

# Schedule heartbeat timeout if configured
if callback_options.heartbeat_timeout_seconds > 0:

def heartbeat_timeout_handler():
self._on_callback_heartbeat_timeout(execution_arn, callback_id)

heartbeat_event = self._scheduler.create_event()
self._callback_heartbeats[callback_id] = heartbeat_event
self._scheduler.call_later(
heartbeat_future = self._scheduler.call_later(
heartbeat_timeout_handler,
delay=callback_options.heartbeat_timeout_seconds,
completion_event=completion_event,
)
self._callback_heartbeats[callback_id] = heartbeat_future

except Exception:
logger.exception(
Expand All @@ -1096,16 +1083,14 @@ def _reset_callback_heartbeat_timeout(
) -> None:
"""Reset the heartbeat timeout for a callback."""
# Cancel existing heartbeat timeout
if heartbeat_event := self._callback_heartbeats.get(callback_id):
heartbeat_event.remove()
del self._callback_heartbeats[callback_id]
if heartbeat_future := self._callback_heartbeats.pop(callback_id, None):
heartbeat_future.cancel()

# Find callback options to reschedule heartbeat timeout
try:
callback_token = CallbackToken.from_str(callback_id)
execution = self.get_execution(callback_token.execution_arn)

# Find callback options from updates
callback_options = None
for update in execution.updates:
if (
Expand All @@ -1122,13 +1107,14 @@ def heartbeat_timeout_handler():
self._on_callback_heartbeat_timeout(execution_arn, callback_id)

completion_event = self._completion_events.get(execution_arn)
heartbeat_event = self._scheduler.create_event()
self._callback_heartbeats[callback_id] = heartbeat_event
self._scheduler.call_later(

heartbeat_future = self._scheduler.call_later(
heartbeat_timeout_handler,
delay=callback_options.heartbeat_timeout_seconds,
completion_event=completion_event,
)
self._callback_heartbeats[callback_id] = heartbeat_future

except Exception:
logger.exception(
"[%s] Error resetting callback heartbeat timeout for %s",
Expand All @@ -1139,14 +1125,12 @@ def heartbeat_timeout_handler():
def _cleanup_callback_timeouts(self, callback_id: str) -> None:
"""Clean up timeout events for a completed callback."""
# Clean up main timeout
if timeout_event := self._callback_timeouts.get(callback_id):
timeout_event.remove()
del self._callback_timeouts[callback_id]
if timeout_future := self._callback_timeouts.pop(callback_id, None):
timeout_future.cancel()

# Clean up heartbeat timeout
if heartbeat_event := self._callback_heartbeats.get(callback_id):
heartbeat_event.remove()
del self._callback_heartbeats[callback_id]
if heartbeat_future := self._callback_heartbeats.pop(callback_id, None):
heartbeat_future.cancel()

def _on_callback_timeout(self, execution_arn: str, callback_id: str) -> None:
"""Handle callback timeout."""
Expand All @@ -1161,7 +1145,7 @@ def _on_callback_timeout(self, execution_arn: str, callback_id: str) -> None:
timeout_error = ErrorObject.from_message(
f"Callback timed out: {CallbackTimeoutType.TIMEOUT.value}"
)
execution.complete_callback_failure(callback_id, timeout_error)
execution.complete_callback_timeout(callback_id, timeout_error)
execution.complete_fail(timeout_error)
self._store.update(execution)
logger.warning("[%s] Callback %s timed out", execution_arn, callback_id)
Expand All @@ -1188,7 +1172,7 @@ def _on_callback_heartbeat_timeout(
heartbeat_error = ErrorObject.from_message(
f"Callback heartbeat timed out: {CallbackTimeoutType.HEARTBEAT.value}"
)
execution.complete_callback_failure(callback_id, heartbeat_error)
execution.complete_callback_timeout(callback_id, heartbeat_error)
execution.complete_fail(heartbeat_error)
self._store.update(execution)
logger.warning(
Expand Down
18 changes: 15 additions & 3 deletions src/aws_durable_execution_sdk_python_testing/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
if TYPE_CHECKING:
from collections.abc import Callable

from aws_durable_execution_sdk_python.lambda_service import ErrorObject
from aws_durable_execution_sdk_python.lambda_service import (
ErrorObject,
CallbackOptions,
)


class ExecutionObserver(ABC):
Expand Down Expand Up @@ -47,7 +50,11 @@ def on_step_retry_scheduled(

@abstractmethod
def on_callback_created(
self, execution_arn: str, operation_id: str, callback_token: CallbackToken
self,
execution_arn: str,
operation_id: str,
callback_options: CallbackOptions | None,
callback_token: CallbackToken,
) -> None:
"""Called when callback is created."""

Expand Down Expand Up @@ -119,13 +126,18 @@ def notify_step_retry_scheduled(
)

def notify_callback_created(
self, execution_arn: str, operation_id: str, callback_token: CallbackToken
self,
execution_arn: str,
operation_id: str,
callback_options: CallbackOptions | None,
callback_token: CallbackToken,
) -> None:
"""Notify observers about callback creation."""
self._notify_observers(
ExecutionObserver.on_callback_created,
execution_arn=execution_arn,
operation_id=operation_id,
callback_options=callback_options,
callback_token=callback_token,
)

Expand Down
Loading
Loading