Skip to content

Commit 2116ab4

Browse files
rarepolzRares Polenciuc
andauthored
fix: callback timeout handling (#140)
Co-authored-by: Rares Polenciuc <rarepolz@amazon.com>
1 parent 18eec29 commit 2116ab4

File tree

8 files changed

+255
-149
lines changed

8 files changed

+255
-149
lines changed

src/aws_durable_execution_sdk_python_testing/checkpoint/processors/callback.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
OperationUpdate,
1313
CallbackDetails,
1414
OperationType,
15+
CallbackOptions,
1516
)
1617
from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import (
1718
OperationProcessor,
@@ -43,12 +44,6 @@ def process(
4344
operation_id=update.operation_id,
4445
)
4546

46-
notifier.notify_callback_created(
47-
execution_arn=execution_arn,
48-
operation_id=update.operation_id,
49-
callback_token=callback_token,
50-
)
51-
5247
callback_id: str = callback_token.to_str()
5348

5449
callback_details: CallbackDetails | None = (
@@ -60,11 +55,15 @@ def process(
6055
if update.operation_type == OperationType.CALLBACK
6156
else None
6257
)
58+
6359
status: OperationStatus = OperationStatus.STARTED
60+
6461
start_time: datetime.datetime | None = self._get_start_time(current_op)
62+
6563
end_time: datetime.datetime | None = self._get_end_time(
6664
current_op, status
6765
)
66+
6867
operation: Operation = Operation(
6968
operation_id=update.operation_id,
7069
parent_id=update.parent_id,
@@ -76,7 +75,14 @@ def process(
7675
sub_type=update.sub_type,
7776
callback_details=callback_details,
7877
)
78+
callback_options: CallbackOptions | None = update.callback_options
7979

80+
notifier.notify_callback_created(
81+
execution_arn=execution_arn,
82+
operation_id=update.operation_id,
83+
callback_options=callback_options,
84+
callback_token=callback_token,
85+
)
8086
return operation
8187
case _:
8288
msg: str = "Invalid action for CALLBACK operation."

src/aws_durable_execution_sdk_python_testing/execution.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,32 @@ def complete_callback_failure(
381381
)
382382
return self.operations[index]
383383

384+
def complete_callback_timeout(
385+
self, callback_id: str, error: ErrorObject
386+
) -> Operation:
387+
"""Complete CALLBACK operation with timeout."""
388+
index, operation = self.find_callback_operation(callback_id)
389+
390+
if operation.status != OperationStatus.STARTED:
391+
msg: str = f"Callback operation [{callback_id}] is not in STARTED state"
392+
raise IllegalStateException(msg)
393+
394+
with self._state_lock:
395+
self._token_sequence += 1
396+
updated_callback_details = None
397+
if operation.callback_details:
398+
updated_callback_details = replace(
399+
operation.callback_details, error=error
400+
)
401+
402+
self.operations[index] = replace(
403+
operation,
404+
status=OperationStatus.TIMED_OUT,
405+
end_timestamp=datetime.now(UTC),
406+
callback_details=updated_callback_details,
407+
)
408+
return self.operations[index]
409+
384410
def _end_execution(self, status: OperationStatus) -> None:
385411
"""Set the end_timestamp on the main EXECUTION operation when execution completes."""
386412
execution_op: Operation = self.get_operation_execution_started()

src/aws_durable_execution_sdk_python_testing/executor.py

Lines changed: 30 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
OperationUpdate,
2020
OperationStatus,
2121
OperationType,
22+
CallbackOptions,
2223
)
2324

2425
from aws_durable_execution_sdk_python_testing.exceptions import (
@@ -57,6 +58,7 @@
5758

5859
if TYPE_CHECKING:
5960
from collections.abc import Awaitable, Callable
61+
from concurrent.futures import Future
6062

6163
from aws_durable_execution_sdk_python_testing.checkpoint.processor import (
6264
CheckpointProcessor,
@@ -84,10 +86,8 @@ def __init__(
8486
self._invoker = invoker
8587
self._checkpoint_processor = checkpoint_processor
8688
self._completion_events: dict[str, Event] = {}
87-
self._callback_timeouts: dict[str, Event] = {} # callback_id -> timeout event
88-
self._callback_heartbeats: dict[
89-
str, Event
90-
] = {} # callback_id -> heartbeat event
89+
self._callback_timeouts: dict[str, Future] = {}
90+
self._callback_heartbeats: dict[str, Future] = {}
9191

9292
def start_execution(
9393
self,
@@ -1011,7 +1011,11 @@ def retry_handler() -> None:
10111011
)
10121012

10131013
def on_callback_created(
1014-
self, execution_arn: str, operation_id: str, callback_token: CallbackToken
1014+
self,
1015+
execution_arn: str,
1016+
operation_id: str,
1017+
callback_options: CallbackOptions | None,
1018+
callback_token: CallbackToken,
10151019
) -> None:
10161020
"""Handle callback creation. Observer method triggered by notifier."""
10171021
callback_id = callback_token.to_str()
@@ -1023,34 +1027,19 @@ def on_callback_created(
10231027
)
10241028

10251029
# Schedule callback timeouts if configured
1026-
self._schedule_callback_timeouts(execution_arn, operation_id, callback_id)
1030+
self._schedule_callback_timeouts(execution_arn, callback_options, callback_id)
10271031

10281032
# endregion ExecutionObserver
10291033

10301034
# region Callback Timeouts
10311035
def _schedule_callback_timeouts(
1032-
self, execution_arn: str, operation_id: str, callback_id: str
1036+
self,
1037+
execution_arn: str,
1038+
callback_options: CallbackOptions | None,
1039+
callback_id: str,
10331040
) -> None:
10341041
"""Schedule callback timeout and heartbeat timeout if configured."""
10351042
try:
1036-
execution = self.get_execution(execution_arn)
1037-
_, operation = execution.find_operation(operation_id)
1038-
1039-
if not operation.callback_details:
1040-
return
1041-
1042-
# Find the callback options from the operation update that created this callback
1043-
# We need to look at the checkpoint updates to find the original callback options
1044-
callback_options = None
1045-
for update in execution.updates:
1046-
if (
1047-
update.operation_id == operation_id
1048-
and update.callback_options
1049-
and update.action.value == "START"
1050-
):
1051-
callback_options = update.callback_options
1052-
break
1053-
10541043
if not callback_options:
10551044
return
10561045

@@ -1062,27 +1051,25 @@ def _schedule_callback_timeouts(
10621051
def timeout_handler():
10631052
self._on_callback_timeout(execution_arn, callback_id)
10641053

1065-
timeout_event = self._scheduler.create_event()
1066-
self._callback_timeouts[callback_id] = timeout_event
1067-
self._scheduler.call_later(
1054+
timeout_future = self._scheduler.call_later(
10681055
timeout_handler,
10691056
delay=callback_options.timeout_seconds,
10701057
completion_event=completion_event,
10711058
)
1059+
self._callback_timeouts[callback_id] = timeout_future
10721060

10731061
# Schedule heartbeat timeout if configured
10741062
if callback_options.heartbeat_timeout_seconds > 0:
10751063

10761064
def heartbeat_timeout_handler():
10771065
self._on_callback_heartbeat_timeout(execution_arn, callback_id)
10781066

1079-
heartbeat_event = self._scheduler.create_event()
1080-
self._callback_heartbeats[callback_id] = heartbeat_event
1081-
self._scheduler.call_later(
1067+
heartbeat_future = self._scheduler.call_later(
10821068
heartbeat_timeout_handler,
10831069
delay=callback_options.heartbeat_timeout_seconds,
10841070
completion_event=completion_event,
10851071
)
1072+
self._callback_heartbeats[callback_id] = heartbeat_future
10861073

10871074
except Exception:
10881075
logger.exception(
@@ -1096,16 +1083,14 @@ def _reset_callback_heartbeat_timeout(
10961083
) -> None:
10971084
"""Reset the heartbeat timeout for a callback."""
10981085
# Cancel existing heartbeat timeout
1099-
if heartbeat_event := self._callback_heartbeats.get(callback_id):
1100-
heartbeat_event.remove()
1101-
del self._callback_heartbeats[callback_id]
1086+
if heartbeat_future := self._callback_heartbeats.pop(callback_id, None):
1087+
heartbeat_future.cancel()
11021088

11031089
# Find callback options to reschedule heartbeat timeout
11041090
try:
11051091
callback_token = CallbackToken.from_str(callback_id)
11061092
execution = self.get_execution(callback_token.execution_arn)
11071093

1108-
# Find callback options from updates
11091094
callback_options = None
11101095
for update in execution.updates:
11111096
if (
@@ -1122,13 +1107,14 @@ def heartbeat_timeout_handler():
11221107
self._on_callback_heartbeat_timeout(execution_arn, callback_id)
11231108

11241109
completion_event = self._completion_events.get(execution_arn)
1125-
heartbeat_event = self._scheduler.create_event()
1126-
self._callback_heartbeats[callback_id] = heartbeat_event
1127-
self._scheduler.call_later(
1110+
1111+
heartbeat_future = self._scheduler.call_later(
11281112
heartbeat_timeout_handler,
11291113
delay=callback_options.heartbeat_timeout_seconds,
11301114
completion_event=completion_event,
11311115
)
1116+
self._callback_heartbeats[callback_id] = heartbeat_future
1117+
11321118
except Exception:
11331119
logger.exception(
11341120
"[%s] Error resetting callback heartbeat timeout for %s",
@@ -1139,14 +1125,12 @@ def heartbeat_timeout_handler():
11391125
def _cleanup_callback_timeouts(self, callback_id: str) -> None:
11401126
"""Clean up timeout events for a completed callback."""
11411127
# Clean up main timeout
1142-
if timeout_event := self._callback_timeouts.get(callback_id):
1143-
timeout_event.remove()
1144-
del self._callback_timeouts[callback_id]
1128+
if timeout_future := self._callback_timeouts.pop(callback_id, None):
1129+
timeout_future.cancel()
11451130

11461131
# Clean up heartbeat timeout
1147-
if heartbeat_event := self._callback_heartbeats.get(callback_id):
1148-
heartbeat_event.remove()
1149-
del self._callback_heartbeats[callback_id]
1132+
if heartbeat_future := self._callback_heartbeats.pop(callback_id, None):
1133+
heartbeat_future.cancel()
11501134

11511135
def _on_callback_timeout(self, execution_arn: str, callback_id: str) -> None:
11521136
"""Handle callback timeout."""
@@ -1161,7 +1145,7 @@ def _on_callback_timeout(self, execution_arn: str, callback_id: str) -> None:
11611145
timeout_error = ErrorObject.from_message(
11621146
f"Callback timed out: {CallbackTimeoutType.TIMEOUT.value}"
11631147
)
1164-
execution.complete_callback_failure(callback_id, timeout_error)
1148+
execution.complete_callback_timeout(callback_id, timeout_error)
11651149
execution.complete_fail(timeout_error)
11661150
self._store.update(execution)
11671151
logger.warning("[%s] Callback %s timed out", execution_arn, callback_id)
@@ -1188,7 +1172,7 @@ def _on_callback_heartbeat_timeout(
11881172
heartbeat_error = ErrorObject.from_message(
11891173
f"Callback heartbeat timed out: {CallbackTimeoutType.HEARTBEAT.value}"
11901174
)
1191-
execution.complete_callback_failure(callback_id, heartbeat_error)
1175+
execution.complete_callback_timeout(callback_id, heartbeat_error)
11921176
execution.complete_fail(heartbeat_error)
11931177
self._store.update(execution)
11941178
logger.warning(

src/aws_durable_execution_sdk_python_testing/observer.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
if TYPE_CHECKING:
1212
from collections.abc import Callable
1313

14-
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
14+
from aws_durable_execution_sdk_python.lambda_service import (
15+
ErrorObject,
16+
CallbackOptions,
17+
)
1518

1619

1720
class ExecutionObserver(ABC):
@@ -47,7 +50,11 @@ def on_step_retry_scheduled(
4750

4851
@abstractmethod
4952
def on_callback_created(
50-
self, execution_arn: str, operation_id: str, callback_token: CallbackToken
53+
self,
54+
execution_arn: str,
55+
operation_id: str,
56+
callback_options: CallbackOptions | None,
57+
callback_token: CallbackToken,
5158
) -> None:
5259
"""Called when callback is created."""
5360

@@ -119,13 +126,18 @@ def notify_step_retry_scheduled(
119126
)
120127

121128
def notify_callback_created(
122-
self, execution_arn: str, operation_id: str, callback_token: CallbackToken
129+
self,
130+
execution_arn: str,
131+
operation_id: str,
132+
callback_options: CallbackOptions | None,
133+
callback_token: CallbackToken,
123134
) -> None:
124135
"""Notify observers about callback creation."""
125136
self._notify_observers(
126137
ExecutionObserver.on_callback_created,
127138
execution_arn=execution_arn,
128139
operation_id=operation_id,
140+
callback_options=callback_options,
129141
callback_token=callback_token,
130142
)
131143

0 commit comments

Comments
 (0)