Skip to content

Commit 4e8ecb8

Browse files
rarepolzRares Polenciucbchampp
authored
feat(testing-sdk): implement callback token generation and processing (#95)
- feat: implement callback token generation and processing - Add CallbackToken generation in callback processor with observer integration - Implement SendCallbackSuccess, SendCallbackFailure, and SendCallbackHeartbeat - Add callback operation lookup and completion methods to execution - Ensure unique token generation across executions - fix: type check for on_callback_created in executor test --------- Co-authored-by: Rares Polenciuc <rarepolz@amazon.com> Co-authored-by: Brent Champion <brchamp@amazon.com>
1 parent a3cda91 commit 4e8ecb8

File tree

12 files changed

+1171
-130
lines changed

12 files changed

+1171
-130
lines changed

src/aws_durable_execution_sdk_python_testing/checkpoint/processors/base.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ def process(
3838
"""Process an operation update and return the transformed operation."""
3939
raise NotImplementedError
4040

41+
def _get_start_time(
42+
self, current_operation: Operation | None
43+
) -> datetime.datetime | None:
44+
start_time: datetime.datetime | None = (
45+
current_operation.start_timestamp
46+
if current_operation
47+
else datetime.datetime.now(tz=datetime.UTC)
48+
)
49+
return start_time
50+
4151
def _get_end_time(
4252
self, current_operation: Operation | None, status: OperationStatus
4353
) -> datetime.datetime | None:
@@ -130,35 +140,17 @@ def _create_invoke_details(
130140
return ChainedInvokeDetails(result=update.payload, error=update.error)
131141
return None
132142

133-
def _create_wait_details(
134-
self, update: OperationUpdate, current_operation: Operation | None
135-
) -> WaitDetails | None:
136-
"""Create WaitDetails from OperationUpdate."""
137-
if update.operation_type == OperationType.WAIT and update.wait_options:
138-
if current_operation and current_operation.wait_details:
139-
scheduled_end_timestamp = (
140-
current_operation.wait_details.scheduled_end_timestamp
141-
)
142-
else:
143-
scheduled_end_timestamp = datetime.datetime.now(
144-
tz=datetime.UTC
145-
) + timedelta(seconds=update.wait_options.wait_seconds)
146-
return WaitDetails(scheduled_end_timestamp=scheduled_end_timestamp)
147-
return None
148-
149143
def _translate_update_to_operation(
150144
self,
151145
update: OperationUpdate,
152146
current_operation: Operation | None,
153147
status: OperationStatus,
154148
) -> Operation:
155149
"""Transform OperationUpdate to Operation, always creating new Operation."""
156-
start_time = (
157-
current_operation.start_timestamp
158-
if current_operation
159-
else datetime.datetime.now(tz=datetime.UTC)
150+
start_time: datetime.datetime | None = self._get_start_time(current_operation)
151+
end_time: datetime.datetime | None = self._get_end_time(
152+
current_operation, status
160153
)
161-
end_time = self._get_end_time(current_operation, status)
162154

163155
execution_details = self._create_execution_details(update)
164156
context_details = self._create_context_details(update)
@@ -183,3 +175,19 @@ def _translate_update_to_operation(
183175
chained_invoke_details=invoke_details,
184176
wait_details=wait_details,
185177
)
178+
179+
def _create_wait_details(
180+
self, update: OperationUpdate, current_operation: Operation | None
181+
) -> WaitDetails | None:
182+
"""Create WaitDetails from OperationUpdate."""
183+
if update.operation_type == OperationType.WAIT and update.wait_options:
184+
if current_operation and current_operation.wait_details:
185+
scheduled_end_timestamp = (
186+
current_operation.wait_details.scheduled_end_timestamp
187+
)
188+
else:
189+
scheduled_end_timestamp = datetime.datetime.now(
190+
tz=datetime.UTC
191+
) + timedelta(seconds=update.wait_options.wait_seconds)
192+
return WaitDetails(scheduled_end_timestamp=scheduled_end_timestamp)
193+
return None

src/aws_durable_execution_sdk_python_testing/checkpoint/processors/callback.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,24 @@
22

33
from __future__ import annotations
44

5+
import datetime
56
from typing import TYPE_CHECKING
67

78
from aws_durable_execution_sdk_python.lambda_service import (
89
Operation,
910
OperationAction,
1011
OperationStatus,
1112
OperationUpdate,
13+
CallbackDetails,
14+
OperationType,
1215
)
13-
1416
from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import (
1517
OperationProcessor,
1618
)
1719
from aws_durable_execution_sdk_python_testing.exceptions import (
1820
InvalidParameterValueException,
1921
)
20-
22+
from aws_durable_execution_sdk_python_testing.token import CallbackToken
2123

2224
if TYPE_CHECKING:
2325
from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
@@ -36,14 +38,46 @@ def process(
3638
"""Process CALLBACK operation update with scheduler integration for activities."""
3739
match update.action:
3840
case OperationAction.START:
39-
# TODO: create CallbackToken (see token module). Add Observer/Notifier for on_callback_created possibly,
40-
# but token might well have enough so don't need to maintain token list on execution itself
41-
return self._translate_update_to_operation(
42-
update=update,
43-
current_operation=current_op,
44-
status=OperationStatus.STARTED,
41+
callback_token: CallbackToken = CallbackToken(
42+
execution_arn=execution_arn,
43+
operation_id=update.operation_id,
44+
)
45+
46+
notifier.notify_callback_created(
47+
execution_arn=execution_arn,
48+
operation_id=update.operation_id,
49+
callback_token=callback_token,
4550
)
51+
52+
callback_id: str = callback_token.to_str()
53+
54+
callback_details: CallbackDetails | None = (
55+
CallbackDetails(
56+
callback_id=callback_id,
57+
result=update.payload,
58+
error=update.error,
59+
)
60+
if update.operation_type == OperationType.CALLBACK
61+
else None
62+
)
63+
status: OperationStatus = OperationStatus.STARTED
64+
start_time: datetime.datetime | None = self._get_start_time(current_op)
65+
end_time: datetime.datetime | None = self._get_end_time(
66+
current_op, status
67+
)
68+
operation: Operation = Operation(
69+
operation_id=update.operation_id,
70+
parent_id=update.parent_id,
71+
name=update.name,
72+
start_timestamp=start_time,
73+
end_timestamp=end_time,
74+
operation_type=update.operation_type,
75+
status=status,
76+
sub_type=update.sub_type,
77+
callback_details=callback_details,
78+
)
79+
80+
return operation
4681
case _:
4782
msg: str = "Invalid action for CALLBACK operation."
48-
4983
raise InvalidParameterValueException(msg)

src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/callback.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
VALID_ACTIONS_FOR_CALLBACK = frozenset(
1818
[
1919
OperationAction.START,
20-
OperationAction.CANCEL,
2120
]
2221
)
2322

@@ -41,14 +40,6 @@ def validate(current_state: Operation | None, update: OperationUpdate) -> None:
4140
"Cannot start a CALLBACK that already exist."
4241
)
4342
raise InvalidParameterValueException(msg_callback_exists)
44-
case OperationAction.CANCEL:
45-
if (
46-
current_state is None
47-
or current_state.status
48-
not in CallbackOperationValidator._ALLOWED_STATUS_TO_CANCEL
49-
):
50-
msg_callback_cancel: str = "Cannot cancel a CALLBACK that does not exist or has already completed."
51-
raise InvalidParameterValueException(msg_callback_cancel)
5243
case _:
53-
msg_callback_invalid: str = "Invalid CALLBACK action."
54-
raise InvalidParameterValueException(msg_callback_invalid)
44+
msg: str = "Invalid action for CALLBACK operation."
45+
raise InvalidParameterValueException(msg)

src/aws_durable_execution_sdk_python_testing/execution.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
from aws_durable_execution_sdk_python_testing.model import (
2929
StartDurableExecutionInput,
3030
)
31-
from aws_durable_execution_sdk_python_testing.token import CheckpointToken
31+
from aws_durable_execution_sdk_python_testing.token import (
32+
CheckpointToken,
33+
CallbackToken,
34+
)
3235

3336

3437
class Execution:
@@ -203,6 +206,18 @@ def find_operation(self, operation_id: str) -> tuple[int, Operation]:
203206
msg: str = f"Attempting to update state of an Operation [{operation_id}] that doesn't exist"
204207
raise IllegalStateException(msg)
205208

209+
def find_callback_operation(self, callback_id: str) -> tuple[int, Operation]:
210+
"""Find callback operation by callback_id, return index and operation."""
211+
for i, operation in enumerate(self.operations):
212+
if (
213+
operation.operation_type == OperationType.CALLBACK
214+
and operation.callback_details
215+
and operation.callback_details.callback_id == callback_id
216+
):
217+
return i, operation
218+
msg: str = f"Callback operation with callback_id [{callback_id}] not found"
219+
raise IllegalStateException(msg)
220+
206221
def complete_wait(self, operation_id: str) -> Operation:
207222
"""Complete WAIT operation when timer fires."""
208223
index, operation = self.find_operation(operation_id)
@@ -260,3 +275,55 @@ def complete_retry(self, operation_id: str) -> Operation:
260275
# Assign
261276
self.operations[index] = updated_operation
262277
return updated_operation
278+
279+
def complete_callback_success(
280+
self, callback_id: str, result: bytes | None = None
281+
) -> Operation:
282+
"""Complete CALLBACK operation with success."""
283+
index, operation = self.find_callback_operation(callback_id)
284+
if operation.status != OperationStatus.STARTED:
285+
msg: str = f"Callback operation [{callback_id}] is not in STARTED state"
286+
raise IllegalStateException(msg)
287+
288+
with self._state_lock:
289+
self._token_sequence += 1
290+
updated_callback_details = None
291+
if operation.callback_details:
292+
updated_callback_details = replace(
293+
operation.callback_details,
294+
result=result.decode() if result else None,
295+
)
296+
297+
self.operations[index] = replace(
298+
operation,
299+
status=OperationStatus.SUCCEEDED,
300+
end_timestamp=datetime.now(UTC),
301+
callback_details=updated_callback_details,
302+
)
303+
return self.operations[index]
304+
305+
def complete_callback_failure(
306+
self, callback_id: str, error: ErrorObject
307+
) -> Operation:
308+
"""Complete CALLBACK operation with failure."""
309+
index, operation = self.find_callback_operation(callback_id)
310+
311+
if operation.status != OperationStatus.STARTED:
312+
msg: str = f"Callback operation [{callback_id}] is not in STARTED state"
313+
raise IllegalStateException(msg)
314+
315+
with self._state_lock:
316+
self._token_sequence += 1
317+
updated_callback_details = None
318+
if operation.callback_details:
319+
updated_callback_details = replace(
320+
operation.callback_details, error=error
321+
)
322+
323+
self.operations[index] = replace(
324+
operation,
325+
status=OperationStatus.FAILED,
326+
end_timestamp=datetime.now(UTC),
327+
callback_details=updated_callback_details,
328+
)
329+
return self.operations[index]

0 commit comments

Comments
 (0)