1818
1919import aws_durable_execution_sdk_python
2020import boto3 # type: ignore
21+ from botocore .exceptions import ClientError # type: ignore
2122from aws_durable_execution_sdk_python .execution import (
2223 InvocationStatus ,
2324 durable_execution ,
7576
7677 from aws_durable_execution_sdk_python_testing .execution import Execution
7778 from aws_durable_execution_sdk_python_testing .web .server import WebServiceConfig
79+ from aws_durable_execution_sdk_python_testing .model import Event
80+
7881
7982logger = logging .getLogger (__name__ )
8083
@@ -792,9 +795,9 @@ def run(
792795 msg = f"Failed to invoke Lambda function { self .function_name } : { e } "
793796 raise DurableFunctionsTestError (msg ) from e
794797
795- # Check HTTP status code ( 200 for RequestResponse, 202 for Event, 204 for DryRun)
798+ # Check HTTP status code, 200 for RequestResponse
796799 status_code = response .get ("StatusCode" )
797- if status_code not in ( 200 , 202 , 204 ) :
800+ if status_code != 200 :
798801 error_payload = response ["Payload" ].read ().decode ("utf-8" )
799802 msg = f"Lambda invocation failed with status { status_code } : { error_payload } "
800803 raise DurableFunctionsTestError (msg )
@@ -819,17 +822,120 @@ def run(
819822 )
820823 raise DurableFunctionsTestError (msg )
821824
822- # Poll for completion
823- execution_response = self ._wait_for_completion (execution_arn , timeout )
825+ return self .wait_and_fetch_test_result (
826+ execution_arn = execution_arn , timeout = timeout
827+ )
824828
825- # Get execution history
826- history_response = self ._get_execution_history (execution_arn )
829+ def run_async (
830+ self ,
831+ input : str | None = None , # noqa: A002
832+ timeout : int = 60 ,
833+ ) -> str :
834+ """Execute function on AWS Lambda asynchronously"""
835+ logger .info (
836+ "Invoking Lambda function: %s (timeout: %ds)" , self .function_name , timeout
837+ )
838+ payload = json .dumps (input )
839+ try :
840+ response = self .lambda_client .invoke (
841+ FunctionName = self .function_name ,
842+ InvocationType = "Event" ,
843+ Payload = payload ,
844+ )
845+ except Exception as e :
846+ msg = f"Failed to invoke Lambda function { self .function_name } : { e } "
847+ raise DurableFunctionsTestError (msg ) from e
827848
828- # Build test result from execution history
829- return DurableFunctionTestResult .from_execution_history (
830- execution_response , history_response
849+ # Check HTTP status code, 202 for Event
850+ status_code = response .get ("StatusCode" )
851+ if status_code != 202 :
852+ error_payload = response ["Payload" ].read ().decode ("utf-8" )
853+ msg = f"Lambda invocation failed with status { status_code } : { error_payload } "
854+ raise DurableFunctionsTestError (msg )
855+
856+ return response .get ("DurableExecutionArn" )
857+
858+ def _get_callback_id_from_events (
859+ self , events : list [Event ], name : str | None = None
860+ ) -> str | None :
861+ """
862+ Get callback ID from execution history for callbacks that haven't completed.
863+
864+ Args:
865+ execution_arn: The ARN of the execution to query.
866+ name: Optional callback name to search for. If not provided, returns the latest callback.
867+
868+ Returns:
869+ The callback ID string for a non-completed callback, or None if not found.
870+
871+ Raises:
872+ DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out.
873+ """
874+ callback_started_events = [
875+ event for event in events if event .event_type == "CallbackStarted"
876+ ]
877+
878+ if not callback_started_events :
879+ return None
880+
881+ completed_callback_ids = {
882+ event .event_id
883+ for event in events
884+ if event .event_type
885+ in ["CallbackSucceeded" , "CallbackFailed" , "CallbackTimedOut" ]
886+ }
887+
888+ if name is not None :
889+ for event in callback_started_events :
890+ if event .name == name :
891+ callback_id = event .event_id
892+ if callback_id in completed_callback_ids :
893+ raise DurableFunctionsTestError (
894+ f"Callback { name } has already completed (succeeded/failed/timed out)"
895+ )
896+ return (
897+ event .callback_started_details .callback_id
898+ if event .callback_started_details
899+ else None
900+ )
901+ return None
902+
903+ # If name is not provided, find the latest non-completed callback event
904+ active_callbacks = [
905+ event
906+ for event in callback_started_events
907+ if event .event_id not in completed_callback_ids
908+ ]
909+
910+ if not active_callbacks :
911+ return None
912+
913+ latest_event = active_callbacks [- 1 ]
914+ return (
915+ latest_event .callback_started_details .callback_id
916+ if latest_event .callback_started_details
917+ else None
831918 )
832919
920+ def send_callback_success (self , callback_id : str ) -> None :
921+ self ._send_callback ("success" , callback_id )
922+
923+ def send_callback_failure (self , callback_id : str ) -> None :
924+ self ._send_callback ("failure" , callback_id )
925+
926+ def send_callback_heartbeat (self , callback_id : str ) -> None :
927+ self ._send_callback ("heartbeat" , callback_id )
928+
929+ def _send_callback (self , operation : str , callback_id : str ) -> None :
930+ """Helper method to send callback operations."""
931+ method_name = f"send_durable_execution_callback_{ operation } "
932+ try :
933+ api_method = getattr (self .lambda_client , method_name )
934+ api_method (CallbackId = callback_id )
935+ except Exception as e :
936+ msg = f"Failed to send callback { operation } for { self .function_name } , callback_id { callback_id } : { e } "
937+ raise DurableFunctionsTestError (msg ) from e
938+
833939 def _wait_for_completion (
834940 self , execution_arn : str , timeout : int
835941 ) -> GetDurableExecutionResponse :
@@ -886,7 +992,81 @@ def _wait_for_completion(
886992 )
887993 raise TimeoutError (msg )
888994
889- def _get_execution_history (
995+ def wait_and_fetch_test_result (
996+ self , execution_arn : str , timeout : int = 60
997+ ) -> DurableFunctionTestResult :
998+ # Poll for completion
999+ execution_response = self ._wait_for_completion (execution_arn , timeout )
1000+
1001+ try :
1002+ history_response = self ._fetch_execution_history (execution_arn )
1003+ except Exception as e :
1004+ msg = f"Failed to fetch execution history: { e } "
1005+ raise DurableFunctionsTestError (msg ) from e
1006+
1007+ # Build test result from execution history
1008+ return DurableFunctionTestResult .from_execution_history (
1009+ execution_response , history_response
1010+ )
1011+
1012+ def wait_and_fetch_callback_id (
1013+ self , execution_arn : str , name : str | None = None , timeout : int = 60
1014+ ) -> str :
1015+ """
1016+ Wait for and retrieve a callback ID from a Step Functions execution.
1017+
1018+ Polls the execution history at regular intervals until a callback ID is found
1019+ or the timeout is reached.
1020+
1021+ Args:
1022+ execution_arn: Execution Arn
1023+ name: Specific callback name, default to None
1024+ timeout: Maximum time in seconds to wait for callback. Defaults to 60.
1025+
1026+ Returns:
1027+ str: The callback ID/token retrieved from the execution history
1028+
1029+ Raises:
1030+ TimeoutError: If callback is not found within the specified timeout period
1031+ DurableFunctionsTestError: If there's an error fetching execution history
1032+ (excluding retryable errors)
1033+ """
1034+ start_time = time .time ()
1035+
1036+ while time .time () - start_time < timeout :
1037+ try :
1038+ history_response = self ._fetch_execution_history (execution_arn )
1039+ callback_id = self ._get_callback_id_from_events (
1040+ events = history_response .events , name = name
1041+ )
1042+ if callback_id :
1043+ return callback_id
1044+ except ClientError as e :
1045+ error_code = e .response ["Error" ]["Code" ]
1046+ # retryable error, the execution may not start yet in async invoke situation
1047+ if error_code in ["ResourceNotFoundException" ]:
1048+ pass
1049+ else :
1050+ msg = f"Failed to fetch execution history: { e } "
1051+ raise DurableFunctionsTestError (msg ) from e
1052+ except DurableFunctionsTestError as e :
1053+ raise e
1054+ except Exception as e :
1055+ msg = f"Failed to fetch execution history: { e } "
1056+ raise DurableFunctionsTestError (msg ) from e
1057+
1058+ # Wait before next poll
1059+ time .sleep (self .poll_interval )
1060+
1061+ # Timeout reached
1062+ elapsed = time .time () - start_time
1063+ msg = (
1064+ f"Callback did not available within { timeout } s "
1065+ f"(elapsed: { elapsed :.1f} s."
1066+ )
1067+ raise TimeoutError (msg )
1068+
1069+ def _fetch_execution_history (
8901070 self , execution_arn : str
8911071 ) -> GetDurableExecutionHistoryResponse :
8921072 """Retrieve execution history from Lambda service.
@@ -898,19 +1078,13 @@ def _get_execution_history(
8981078 GetDurableExecutionHistoryResponse with typed Event objects
8991079
9001080 Raises:
901- DurableFunctionsTestError : If history retrieval fails
1081+ ClientError : If lambda client encounter error
9021082 """
903- try :
904- history_dict = self .lambda_client .get_durable_execution_history (
905- DurableExecutionArn = execution_arn ,
906- IncludeExecutionData = True ,
907- )
908- history_response = GetDurableExecutionHistoryResponse .from_dict (
909- history_dict
910- )
911- except Exception as e :
912- msg = f"Failed to get execution history: { e } "
913- raise DurableFunctionsTestError (msg ) from e
1083+ history_dict = self .lambda_client .get_durable_execution_history (
1084+ DurableExecutionArn = execution_arn ,
1085+ IncludeExecutionData = True ,
1086+ )
1087+ history_response = GetDurableExecutionHistoryResponse .from_dict (history_dict )
9141088
9151089 logger .info ("Retrieved %d events from history" , len (history_response .events ))
9161090
0 commit comments