1818
1919import aws_durable_execution_sdk_python
2020import boto3 # type: ignore
21+ from botocore .exceptions import ClientError # type: ignore
2122from aws_durable_execution_sdk_python .execution import (
2223 InvocationStatus ,
2324 durable_execution ,
7576
7677 from aws_durable_execution_sdk_python_testing .execution import Execution
7778 from aws_durable_execution_sdk_python_testing .web .server import WebServiceConfig
79+ from aws_durable_execution_sdk_python_testing .model import Event
80+
7881
7982logger = logging .getLogger (__name__ )
8083
@@ -792,9 +795,9 @@ def run(
792795 msg = f"Failed to invoke Lambda function { self .function_name } : { e } "
793796 raise DurableFunctionsTestError (msg ) from e
794797
795- # Check HTTP status code ( 200 for RequestResponse, 202 for Event, 204 for DryRun)
798+ # Check HTTP status code, 200 for RequestResponse
796799 status_code = response .get ("StatusCode" )
797- if status_code not in ( 200 , 202 , 204 ) :
800+ if status_code != 200 :
798801 error_payload = response ["Payload" ].read ().decode ("utf-8" )
799802 msg = f"Lambda invocation failed with status { status_code } : { error_payload } "
800803 raise DurableFunctionsTestError (msg )
@@ -819,17 +822,128 @@ def run(
819822 )
820823 raise DurableFunctionsTestError (msg )
821824
822- # Poll for completion
823- execution_response = self ._wait_for_completion (execution_arn , timeout )
825+ return self .wait_for_result (execution_arn = execution_arn , timeout = timeout )
824826
825- # Get execution history
826- history_response = self ._get_execution_history (execution_arn )
827+ def run_async (
828+ self ,
829+ input : str | None = None , # noqa: A002
830+ timeout : int = 60 ,
831+ ) -> str :
832+ """Execute function on AWS Lambda asynchronously"""
833+ logger .info (
834+ "Invoking Lambda function: %s (timeout: %ds)" , self .function_name , timeout
835+ )
836+ payload = json .dumps (input )
837+ try :
838+ response = self .lambda_client .invoke (
839+ FunctionName = self .function_name ,
840+ InvocationType = "Event" ,
841+ Payload = payload ,
842+ )
843+ except Exception as e :
844+ msg = f"Failed to invoke Lambda function { self .function_name } : { e } "
845+ raise DurableFunctionsTestError (msg ) from e
827846
828- # Build test result from execution history
829- return DurableFunctionTestResult .from_execution_history (
830- execution_response , history_response
847+ # Check HTTP status code, 202 for Event
848+ status_code = response .get ("StatusCode" )
849+ if status_code != 202 :
850+ error_payload = response ["Payload" ].read ().decode ("utf-8" )
851+ msg = f"Lambda invocation failed with status { status_code } : { error_payload } "
852+ raise DurableFunctionsTestError (msg )
853+
854+ return response .get ("DurableExecutionArn" )
855+
856+ def _get_callback_id_from_events (
857+ self , events : list [Event ], name : str | None = None
858+ ) -> str | None :
859+ """
860+ Get callback ID from execution history for callbacks that haven't completed.
861+
862+ Args:
863+ execution_arn: The ARN of the execution to query.
864+ name: Optional callback name to search for. If not provided, returns the latest callback.
865+
866+ Returns:
867+ The callback ID string for a non-completed callback, or None if not found.
868+
869+ Raises:
870+ DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out.
871+ """
872+ callback_started_events = [
873+ event for event in events if event .event_type == "CallbackStarted"
874+ ]
875+
876+ if not callback_started_events :
877+ return None
878+
879+ completed_callback_ids = {
880+ event .event_id
881+ for event in events
882+ if event .event_type
883+ in ["CallbackSucceeded" , "CallbackFailed" , "CallbackTimedOut" ]
884+ }
885+
886+ if name is not None :
887+ for event in callback_started_events :
888+ if event .name == name :
889+ callback_id = event .event_id
890+ if callback_id in completed_callback_ids :
891+ raise DurableFunctionsTestError (
892+ f"Callback { name } has already completed (succeeded/failed/timed out)"
893+ )
894+ return (
895+ event .callback_started_details .callback_id
896+ if event .callback_started_details
897+ else None
898+ )
899+ return None
900+
901+ # If name is not provided, find the latest non-completed callback event
902+ active_callbacks = [
903+ event
904+ for event in callback_started_events
905+ if event .event_id not in completed_callback_ids
906+ ]
907+
908+ if not active_callbacks :
909+ return None
910+
911+ latest_event = active_callbacks [- 1 ]
912+ return (
913+ latest_event .callback_started_details .callback_id
914+ if latest_event .callback_started_details
915+ else None
916+ )
917+
918+ def send_callback_success (self , callback_id : str ) -> None :
919+ self ._send_callback (
920+ "success" ,
921+ callback_id ,
922+ self .lambda_client .send_durable_execution_callback_success ,
831923 )
832924
925+ def send_callback_failure (self , callback_id : str ) -> None :
926+ self ._send_callback (
927+ "failure" ,
928+ callback_id ,
929+ self .lambda_client .send_durable_execution_callback_failure ,
930+ )
931+
932+ def send_callback_heartbeat (self , callback_id : str ) -> None :
933+ self ._send_callback (
934+ "heartbeat" ,
935+ callback_id ,
936+ self .lambda_client .send_durable_execution_callback_heartbeat ,
937+ )
938+
939+ def _send_callback (self , operation : str , callback_id : str , api_method ) -> None :
940+ """Helper method to send callback operations."""
941+ try :
942+ api_method (CallbackId = callback_id )
943+ except Exception as e :
944+ msg = f"Failed to send callback { operation } for { self .function_name } , callback_id { callback_id } : { e } "
945+ raise DurableFunctionsTestError (msg ) from e
946+
833947 def _wait_for_completion (
834948 self , execution_arn : str , timeout : int
835949 ) -> GetDurableExecutionResponse :
@@ -886,7 +1000,81 @@ def _wait_for_completion(
8861000 )
8871001 raise TimeoutError (msg )
8881002
889- def _get_execution_history (
1003+ def wait_for_result (
1004+ self , execution_arn : str , timeout : int = 60
1005+ ) -> DurableFunctionTestResult :
1006+ # Poll for completion
1007+ execution_response = self ._wait_for_completion (execution_arn , timeout )
1008+
1009+ try :
1010+ history_response = self ._fetch_execution_history (execution_arn )
1011+ except Exception as e :
1012+ msg = f"Failed to fetch execution history: { e } "
1013+ raise DurableFunctionsTestError (msg ) from e
1014+
1015+ # Build test result from execution history
1016+ return DurableFunctionTestResult .from_execution_history (
1017+ execution_response , history_response
1018+ )
1019+
1020+ def wait_for_callback (
1021+ self , execution_arn : str , name : str | None = None , timeout : int = 60
1022+ ) -> str :
1023+ """
1024+ Wait for and retrieve a callback ID from a Step Functions execution.
1025+
1026+ Polls the execution history at regular intervals until a callback ID is found
1027+ or the timeout is reached.
1028+
1029+ Args:
1030+ execution_arn: Execution Arn
1031+ name: Specific callback name, default to None
1032+ timeout: Maximum time in seconds to wait for callback. Defaults to 60.
1033+
1034+ Returns:
1035+ str: The callback ID/token retrieved from the execution history
1036+
1037+ Raises:
1038+ TimeoutError: If callback is not found within the specified timeout period
1039+ DurableFunctionsTestError: If there's an error fetching execution history
1040+ (excluding retryable errors)
1041+ """
1042+ start_time = time .time ()
1043+
1044+ while time .time () - start_time < timeout :
1045+ try :
1046+ history_response = self ._fetch_execution_history (execution_arn )
1047+ callback_id = self ._get_callback_id_from_events (
1048+ events = history_response .events , name = name
1049+ )
1050+ if callback_id :
1051+ return callback_id
1052+ except ClientError as e :
1053+ error_code = e .response ["Error" ]["Code" ]
1054+ # retryable error, the execution may not start yet in async invoke situation
1055+ if error_code in ["ResourceNotFoundException" ]:
1056+ pass
1057+ else :
1058+ msg = f"Failed to fetch execution history: { e } "
1059+ raise DurableFunctionsTestError (msg ) from e
1060+ except DurableFunctionsTestError as e :
1061+ raise e
1062+ except Exception as e :
1063+ msg = f"Failed to fetch execution history: { e } "
1064+ raise DurableFunctionsTestError (msg ) from e
1065+
1066+ # Wait before next poll
1067+ time .sleep (self .poll_interval )
1068+
1069+ # Timeout reached
1070+ elapsed = time .time () - start_time
1071+ msg = (
1072+ f"Callback did not available within { timeout } s "
1073+ f"(elapsed: { elapsed :.1f} s."
1074+ )
1075+ raise TimeoutError (msg )
1076+
1077+ def _fetch_execution_history (
8901078 self , execution_arn : str
8911079 ) -> GetDurableExecutionHistoryResponse :
8921080 """Retrieve execution history from Lambda service.
@@ -898,19 +1086,13 @@ def _get_execution_history(
8981086 GetDurableExecutionHistoryResponse with typed Event objects
8991087
9001088 Raises:
901- DurableFunctionsTestError : If history retrieval fails
1089+ ClientError : If lambda client encounter error
9021090 """
903- try :
904- history_dict = self .lambda_client .get_durable_execution_history (
905- DurableExecutionArn = execution_arn ,
906- IncludeExecutionData = True ,
907- )
908- history_response = GetDurableExecutionHistoryResponse .from_dict (
909- history_dict
910- )
911- except Exception as e :
912- msg = f"Failed to get execution history: { e } "
913- raise DurableFunctionsTestError (msg ) from e
1091+ history_dict = self .lambda_client .get_durable_execution_history (
1092+ DurableExecutionArn = execution_arn ,
1093+ IncludeExecutionData = True ,
1094+ )
1095+ history_response = GetDurableExecutionHistoryResponse .from_dict (history_dict )
9141096
9151097 logger .info ("Retrieved %d events from history" , len (history_response .events ))
9161098
0 commit comments