1919 OperationUpdate ,
2020 OperationStatus ,
2121 OperationType ,
22+ CallbackOptions ,
2223)
2324
2425from aws_durable_execution_sdk_python_testing .exceptions import (
5758
5859if TYPE_CHECKING :
5960 from collections .abc import Awaitable , Callable
61+ from concurrent .futures import Future
6062
6163 from aws_durable_execution_sdk_python_testing .checkpoint .processor import (
6264 CheckpointProcessor ,
@@ -84,10 +86,8 @@ def __init__(
8486 self ._invoker = invoker
8587 self ._checkpoint_processor = checkpoint_processor
8688 self ._completion_events : dict [str , Event ] = {}
87- self ._callback_timeouts : dict [str , Event ] = {} # callback_id -> timeout event
88- self ._callback_heartbeats : dict [
89- str , Event
90- ] = {} # callback_id -> heartbeat event
89+ self ._callback_timeouts : dict [str , Future ] = {}
90+ self ._callback_heartbeats : dict [str , Future ] = {}
9191
9292 def start_execution (
9393 self ,
@@ -1011,7 +1011,11 @@ def retry_handler() -> None:
10111011 )
10121012
10131013 def on_callback_created (
1014- self , execution_arn : str , operation_id : str , callback_token : CallbackToken
1014+ self ,
1015+ execution_arn : str ,
1016+ operation_id : str ,
1017+ callback_options : CallbackOptions | None ,
1018+ callback_token : CallbackToken ,
10151019 ) -> None :
10161020 """Handle callback creation. Observer method triggered by notifier."""
10171021 callback_id = callback_token .to_str ()
@@ -1023,34 +1027,19 @@ def on_callback_created(
10231027 )
10241028
10251029 # Schedule callback timeouts if configured
1026- self ._schedule_callback_timeouts (execution_arn , operation_id , callback_id )
1030+ self ._schedule_callback_timeouts (execution_arn , callback_options , callback_id )
10271031
10281032 # endregion ExecutionObserver
10291033
10301034 # region Callback Timeouts
10311035 def _schedule_callback_timeouts (
1032- self , execution_arn : str , operation_id : str , callback_id : str
1036+ self ,
1037+ execution_arn : str ,
1038+ callback_options : CallbackOptions | None ,
1039+ callback_id : str ,
10331040 ) -> None :
10341041 """Schedule callback timeout and heartbeat timeout if configured."""
10351042 try :
1036- execution = self .get_execution (execution_arn )
1037- _ , operation = execution .find_operation (operation_id )
1038-
1039- if not operation .callback_details :
1040- return
1041-
1042- # Find the callback options from the operation update that created this callback
1043- # We need to look at the checkpoint updates to find the original callback options
1044- callback_options = None
1045- for update in execution .updates :
1046- if (
1047- update .operation_id == operation_id
1048- and update .callback_options
1049- and update .action .value == "START"
1050- ):
1051- callback_options = update .callback_options
1052- break
1053-
10541043 if not callback_options :
10551044 return
10561045
@@ -1062,27 +1051,25 @@ def _schedule_callback_timeouts(
10621051 def timeout_handler ():
10631052 self ._on_callback_timeout (execution_arn , callback_id )
10641053
1065- timeout_event = self ._scheduler .create_event ()
1066- self ._callback_timeouts [callback_id ] = timeout_event
1067- self ._scheduler .call_later (
1054+ timeout_future = self ._scheduler .call_later (
10681055 timeout_handler ,
10691056 delay = callback_options .timeout_seconds ,
10701057 completion_event = completion_event ,
10711058 )
1059+ self ._callback_timeouts [callback_id ] = timeout_future
10721060
10731061 # Schedule heartbeat timeout if configured
10741062 if callback_options .heartbeat_timeout_seconds > 0 :
10751063
10761064 def heartbeat_timeout_handler ():
10771065 self ._on_callback_heartbeat_timeout (execution_arn , callback_id )
10781066
1079- heartbeat_event = self ._scheduler .create_event ()
1080- self ._callback_heartbeats [callback_id ] = heartbeat_event
1081- self ._scheduler .call_later (
1067+ heartbeat_future = self ._scheduler .call_later (
10821068 heartbeat_timeout_handler ,
10831069 delay = callback_options .heartbeat_timeout_seconds ,
10841070 completion_event = completion_event ,
10851071 )
1072+ self ._callback_heartbeats [callback_id ] = heartbeat_future
10861073
10871074 except Exception :
10881075 logger .exception (
@@ -1096,16 +1083,14 @@ def _reset_callback_heartbeat_timeout(
10961083 ) -> None :
10971084 """Reset the heartbeat timeout for a callback."""
10981085 # Cancel existing heartbeat timeout
1099- if heartbeat_event := self ._callback_heartbeats .get (callback_id ):
1100- heartbeat_event .remove ()
1101- del self ._callback_heartbeats [callback_id ]
1086+ if heartbeat_future := self ._callback_heartbeats .pop (callback_id , None ):
1087+ heartbeat_future .cancel ()
11021088
11031089 # Find callback options to reschedule heartbeat timeout
11041090 try :
11051091 callback_token = CallbackToken .from_str (callback_id )
11061092 execution = self .get_execution (callback_token .execution_arn )
11071093
1108- # Find callback options from updates
11091094 callback_options = None
11101095 for update in execution .updates :
11111096 if (
@@ -1122,13 +1107,14 @@ def heartbeat_timeout_handler():
11221107 self ._on_callback_heartbeat_timeout (execution_arn , callback_id )
11231108
11241109 completion_event = self ._completion_events .get (execution_arn )
1125- heartbeat_event = self ._scheduler .create_event ()
1126- self ._callback_heartbeats [callback_id ] = heartbeat_event
1127- self ._scheduler .call_later (
1110+
1111+ heartbeat_future = self ._scheduler .call_later (
11281112 heartbeat_timeout_handler ,
11291113 delay = callback_options .heartbeat_timeout_seconds ,
11301114 completion_event = completion_event ,
11311115 )
1116+ self ._callback_heartbeats [callback_id ] = heartbeat_future
1117+
11321118 except Exception :
11331119 logger .exception (
11341120 "[%s] Error resetting callback heartbeat timeout for %s" ,
@@ -1139,14 +1125,12 @@ def heartbeat_timeout_handler():
11391125 def _cleanup_callback_timeouts (self , callback_id : str ) -> None :
11401126 """Clean up timeout events for a completed callback."""
11411127 # Clean up main timeout
1142- if timeout_event := self ._callback_timeouts .get (callback_id ):
1143- timeout_event .remove ()
1144- del self ._callback_timeouts [callback_id ]
1128+ if timeout_future := self ._callback_timeouts .pop (callback_id , None ):
1129+ timeout_future .cancel ()
11451130
11461131 # Clean up heartbeat timeout
1147- if heartbeat_event := self ._callback_heartbeats .get (callback_id ):
1148- heartbeat_event .remove ()
1149- del self ._callback_heartbeats [callback_id ]
1132+ if heartbeat_future := self ._callback_heartbeats .pop (callback_id , None ):
1133+ heartbeat_future .cancel ()
11501134
11511135 def _on_callback_timeout (self , execution_arn : str , callback_id : str ) -> None :
11521136 """Handle callback timeout."""
@@ -1161,7 +1145,7 @@ def _on_callback_timeout(self, execution_arn: str, callback_id: str) -> None:
11611145 timeout_error = ErrorObject .from_message (
11621146 f"Callback timed out: { CallbackTimeoutType .TIMEOUT .value } "
11631147 )
1164- execution .complete_callback_failure (callback_id , timeout_error )
1148+ execution .complete_callback_timeout (callback_id , timeout_error )
11651149 execution .complete_fail (timeout_error )
11661150 self ._store .update (execution )
11671151 logger .warning ("[%s] Callback %s timed out" , execution_arn , callback_id )
@@ -1188,7 +1172,7 @@ def _on_callback_heartbeat_timeout(
11881172 heartbeat_error = ErrorObject .from_message (
11891173 f"Callback heartbeat timed out: { CallbackTimeoutType .HEARTBEAT .value } "
11901174 )
1191- execution .complete_callback_failure (callback_id , heartbeat_error )
1175+ execution .complete_callback_timeout (callback_id , heartbeat_error )
11921176 execution .complete_fail (heartbeat_error )
11931177 self ._store .update (execution )
11941178 logger .warning (
0 commit comments