1818
1919import aws_durable_execution_sdk_python
2020import boto3 # type: ignore
21+ from botocore .exceptions import ClientError # type: ignore
2122from aws_durable_execution_sdk_python .execution import (
2223 InvocationStatus ,
2324 durable_execution ,
7576
7677 from aws_durable_execution_sdk_python_testing .execution import Execution
7778 from aws_durable_execution_sdk_python_testing .web .server import WebServiceConfig
79+ from aws_durable_execution_sdk_python_testing .model import Event
80+
7881
7982logger = logging .getLogger (__name__ )
8083
@@ -792,9 +795,9 @@ def run(
792795 msg = f"Failed to invoke Lambda function { self .function_name } : { e } "
793796 raise DurableFunctionsTestError (msg ) from e
794797
795- # Check HTTP status code ( 200 for RequestResponse, 202 for Event, 204 for DryRun)
798+ # Check HTTP status code, 200 for RequestResponse
796799 status_code = response .get ("StatusCode" )
797- if status_code not in ( 200 , 202 , 204 ) :
800+ if status_code != 200 :
798801 error_payload = response ["Payload" ].read ().decode ("utf-8" )
799802 msg = f"Lambda invocation failed with status { status_code } : { error_payload } "
800803 raise DurableFunctionsTestError (msg )
@@ -819,17 +822,129 @@ def run(
819822 )
820823 raise DurableFunctionsTestError (msg )
821824
822- # Poll for completion
823- execution_response = self ._wait_for_completion (execution_arn , timeout )
825+ return self .wait_for_result (execution_arn = execution_arn , timeout = timeout )
824826
825- # Get execution history
826- history_response = self ._get_execution_history (execution_arn )
827+ def run_async (
828+ self ,
829+ input : str | None = None , # noqa: A002
830+ timeout : int = 60 ,
831+ ) -> str :
832+ """Execute function on AWS Lambda asynchronously"""
833+ logger .info (
834+ "Invoking Lambda function: %s (timeout: %ds)" , self .function_name , timeout
835+ )
836+ payload = json .dumps (input )
837+ try :
838+ response = self .lambda_client .invoke (
839+ FunctionName = self .function_name ,
840+ InvocationType = "Event" ,
841+ Payload = payload ,
842+ )
843+ except Exception as e :
844+ msg = f"Failed to invoke Lambda function { self .function_name } : { e } "
845+ raise DurableFunctionsTestError (msg ) from e
827846
828- # Build test result from execution history
829- return DurableFunctionTestResult .from_execution_history (
830- execution_response , history_response
847+ # Check HTTP status code, 202 for Event
848+ status_code = response .get ("StatusCode" )
849+ if status_code != 202 :
850+ error_payload = response ["Payload" ].read ().decode ("utf-8" )
851+ msg = f"Lambda invocation failed with status { status_code } : { error_payload } "
852+ raise DurableFunctionsTestError (msg )
853+
854+ return response .get ("DurableExecutionArn" )
855+
856+ def _get_callback_id_from_events (
857+ self , events : list [Event ], name : str | None = None
858+ ) -> str | None :
859+ """
860+ Get callback ID from execution history for callbacks that haven't completed.
861+
862+ Args:
863+ execution_arn: The ARN of the execution to query.
864+ name: Optional callback name to search for. If not provided, returns the latest callback.
865+
866+ Returns:
867+ The callback ID string for a non-completed callback, or None if not found.
868+
869+ Raises:
870+ DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out.
871+ """
872+ callback_started_events = [
873+ event for event in events if event .event_type == "CallbackStarted"
874+ ]
875+
876+ if not callback_started_events :
877+ return None
878+
879+ completed_callback_ids = {
880+ event .event_id
881+ for event in events
882+ if event .event_type
883+ in ["CallbackSucceeded" , "CallbackFailed" , "CallbackTimedOut" ]
884+ }
885+
886+ if name is not None :
887+ for event in callback_started_events :
888+ if event .name == name :
889+ callback_id = event .event_id
890+ if callback_id in completed_callback_ids :
891+ raise DurableFunctionsTestError (
892+ f"Callback { name } has already completed (succeeded/failed/timed out)"
893+ )
894+ return (
895+ event .callback_started_details .callback_id
896+ if event .callback_started_details
897+ else None
898+ )
899+ return None
900+
901+ # If name is not provided, find the latest non-completed callback event
902+ active_callbacks = [
903+ event
904+ for event in callback_started_events
905+ if event .event_id not in completed_callback_ids
906+ ]
907+
908+ if not active_callbacks :
909+ return None
910+
911+ latest_event = active_callbacks [- 1 ]
912+ return (
913+ latest_event .callback_started_details .callback_id
914+ if latest_event .callback_started_details
915+ else None
916+ )
917+
918+ def send_callback_success (self , callback_id : str ) -> None :
919+ self ._send_callback (
920+ "success" ,
921+ callback_id ,
922+ self .lambda_client .send_durable_execution_callback_success ,
831923 )
832924
925+ def send_callback_failure (self , callback_id : str ) -> None :
926+ self ._send_callback (
927+ "failure" ,
928+ callback_id ,
929+ self .lambda_client .send_durable_execution_callback_failure ,
930+ )
931+
932+ def send_callback_heartbeat (self , callback_id : str ) -> None :
933+ self ._send_callback (
934+ "heartbeat" ,
935+ callback_id ,
936+ self .lambda_client .send_durable_execution_callback_heartbeat ,
937+ )
938+
939+ def _send_callback (self , operation : str , callback_id : str , api_method ) -> None :
940+ """Helper method to send callback operations."""
941+ method_name = f"send_durable_execution_callback_{ operation } "
942+ try :
943+ api_method (CallbackId = callback_id )
944+ except Exception as e :
945+ msg = f"Failed to send callback { operation } for { self .function_name } , callback_id { callback_id } : { e } "
946+ raise DurableFunctionsTestError (msg ) from e
947+
833948 def _wait_for_completion (
834949 self , execution_arn : str , timeout : int
835950 ) -> GetDurableExecutionResponse :
@@ -886,7 +1001,81 @@ def _wait_for_completion(
8861001 )
8871002 raise TimeoutError (msg )
8881003
889- def _get_execution_history (
1004+ def wait_for_result (
1005+ self , execution_arn : str , timeout : int = 60
1006+ ) -> DurableFunctionTestResult :
1007+ # Poll for completion
1008+ execution_response = self ._wait_for_completion (execution_arn , timeout )
1009+
1010+ try :
1011+ history_response = self ._fetch_execution_history (execution_arn )
1012+ except Exception as e :
1013+ msg = f"Failed to fetch execution history: { e } "
1014+ raise DurableFunctionsTestError (msg ) from e
1015+
1016+ # Build test result from execution history
1017+ return DurableFunctionTestResult .from_execution_history (
1018+ execution_response , history_response
1019+ )
1020+
1021+ def wait_for_callback (
1022+ self , execution_arn : str , name : str | None = None , timeout : int = 60
1023+ ) -> str :
1024+ """
1025+ Wait for and retrieve a callback ID from a Step Functions execution.
1026+
1027+ Polls the execution history at regular intervals until a callback ID is found
1028+ or the timeout is reached.
1029+
1030+ Args:
1031+ execution_arn: Execution Arn
1032+ name: Specific callback name, default to None
1033+ timeout: Maximum time in seconds to wait for callback. Defaults to 60.
1034+
1035+ Returns:
1036+ str: The callback ID/token retrieved from the execution history
1037+
1038+ Raises:
1039+ TimeoutError: If callback is not found within the specified timeout period
1040+ DurableFunctionsTestError: If there's an error fetching execution history
1041+ (excluding retryable errors)
1042+ """
1043+ start_time = time .time ()
1044+
1045+ while time .time () - start_time < timeout :
1046+ try :
1047+ history_response = self ._fetch_execution_history (execution_arn )
1048+ callback_id = self ._get_callback_id_from_events (
1049+ events = history_response .events , name = name
1050+ )
1051+ if callback_id :
1052+ return callback_id
1053+ except ClientError as e :
1054+ error_code = e .response ["Error" ]["Code" ]
1055+ # retryable error, the execution may not start yet in async invoke situation
1056+ if error_code in ["ResourceNotFoundException" ]:
1057+ pass
1058+ else :
1059+ msg = f"Failed to fetch execution history: { e } "
1060+ raise DurableFunctionsTestError (msg ) from e
1061+ except DurableFunctionsTestError as e :
1062+ raise e
1063+ except Exception as e :
1064+ msg = f"Failed to fetch execution history: { e } "
1065+ raise DurableFunctionsTestError (msg ) from e
1066+
1067+ # Wait before next poll
1068+ time .sleep (self .poll_interval )
1069+
1070+ # Timeout reached
1071+ elapsed = time .time () - start_time
1072+ msg = (
1073+ f"Callback did not available within { timeout } s "
1074+ f"(elapsed: { elapsed :.1f} s."
1075+ )
1076+ raise TimeoutError (msg )
1077+
1078+ def _fetch_execution_history (
8901079 self , execution_arn : str
8911080 ) -> GetDurableExecutionHistoryResponse :
8921081 """Retrieve execution history from Lambda service.
@@ -898,19 +1087,13 @@ def _get_execution_history(
8981087 GetDurableExecutionHistoryResponse with typed Event objects
8991088
9001089 Raises:
901- DurableFunctionsTestError : If history retrieval fails
1090+ ClientError : If lambda client encounter error
9021091 """
903- try :
904- history_dict = self .lambda_client .get_durable_execution_history (
905- DurableExecutionArn = execution_arn ,
906- IncludeExecutionData = True ,
907- )
908- history_response = GetDurableExecutionHistoryResponse .from_dict (
909- history_dict
910- )
911- except Exception as e :
912- msg = f"Failed to get execution history: { e } "
913- raise DurableFunctionsTestError (msg ) from e
1092+ history_dict = self .lambda_client .get_durable_execution_history (
1093+ DurableExecutionArn = execution_arn ,
1094+ IncludeExecutionData = True ,
1095+ )
1096+ history_response = GetDurableExecutionHistoryResponse .from_dict (history_dict )
9141097
9151098 logger .info ("Retrieved %d events from history" , len (history_response .events ))
9161099
0 commit comments