1818
1919import aws_durable_execution_sdk_python
2020import boto3 # type: ignore
21+ from botocore .exceptions import ClientError # type: ignore
2122from aws_durable_execution_sdk_python .execution import (
2223 InvocationStatus ,
2324 durable_execution ,
7576
7677 from aws_durable_execution_sdk_python_testing .execution import Execution
7778 from aws_durable_execution_sdk_python_testing .web .server import WebServiceConfig
79+ from aws_durable_execution_sdk_python_testing .model import Event
80+
7881
7982logger = logging .getLogger (__name__ )
8083
@@ -792,9 +795,9 @@ def run(
792795 msg = f"Failed to invoke Lambda function { self .function_name } : { e } "
793796 raise DurableFunctionsTestError (msg ) from e
794797
795- # Check HTTP status code ( 200 for RequestResponse, 202 for Event, 204 for DryRun)
798+ # Check HTTP status code, 200 for RequestResponse
796799 status_code = response .get ("StatusCode" )
797- if status_code not in ( 200 , 202 , 204 ) :
800+ if status_code != 200 :
798801 error_payload = response ["Payload" ].read ().decode ("utf-8" )
799802 msg = f"Lambda invocation failed with status { status_code } : { error_payload } "
800803 raise DurableFunctionsTestError (msg )
@@ -819,17 +822,126 @@ def run(
819822 )
820823 raise DurableFunctionsTestError (msg )
821824
822- # Poll for completion
823- execution_response = self ._wait_for_completion (execution_arn , timeout )
825+ return self .wait_for_result (execution_arn = execution_arn , timeout = timeout )
824826
825- # Get execution history
826- history_response = self ._get_execution_history (execution_arn )
827+ def run_async (
828+ self ,
829+ input : str | None = None , # noqa: A002
830+ timeout : int = 60 ,
831+ ) -> str :
832+ """Execute function on AWS Lambda asynchronously"""
833+ logger .info (
834+ "Invoking Lambda function: %s (timeout: %ds)" , self .function_name , timeout
835+ )
836+ payload = json .dumps (input )
837+ try :
838+ response = self .lambda_client .invoke (
839+ FunctionName = self .function_name ,
840+ InvocationType = "Event" ,
841+ Payload = payload ,
842+ )
843+ except Exception as e :
844+ msg = f"Failed to invoke Lambda function { self .function_name } : { e } "
845+ raise DurableFunctionsTestError (msg ) from e
827846
828- # Build test result from execution history
829- return DurableFunctionTestResult .from_execution_history (
830- execution_response , history_response
847+ # Check HTTP status code, 202 for Event
848+ status_code = response .get ("StatusCode" )
849+ if status_code != 202 :
850+ error_payload = response ["Payload" ].read ().decode ("utf-8" )
851+ msg = f"Lambda invocation failed with status { status_code } : { error_payload } "
852+ raise DurableFunctionsTestError (msg )
853+
854+ return response .get ("DurableExecutionArn" )
855+
856+ def _get_callback_id_from_events (
857+ self , events : list [Event ], name : str | None = None
858+ ) -> str | None :
859+ """
860+ Get callback ID from execution history for callbacks that haven't completed.
861+
862+ Args:
863+ execution_arn: The ARN of the execution to query.
864+ name: Optional callback name to search for. If not provided, returns the latest callback.
865+
866+ Returns:
867+ The callback ID string for a non-completed callback, or None if not found.
868+
869+ Raises:
870+ DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out.
871+ """
872+ callback_started_events = [
873+ event for event in events if event .event_type == "CallbackStarted"
874+ ]
875+
876+ if not callback_started_events :
877+ return None
878+
879+ completed_callback_ids = {
880+ event .event_id
881+ for event in events
882+ if event .event_type
883+ in ["CallbackSucceeded" , "CallbackFailed" , "CallbackTimedOut" ]
884+ }
885+
886+ if name is not None :
887+ for event in callback_started_events :
888+ if event .name == name :
889+ callback_id = event .event_id
890+ if callback_id in completed_callback_ids :
891+ raise DurableFunctionsTestError (
892+ f"Callback { name } has already completed (succeeded/failed/timed out)"
893+ )
894+ return (
895+ event .callback_started_details .callback_id
896+ if event .callback_started_details
897+ else None
898+ )
899+ return None
900+
901+ # If name is not provided, find the latest non-completed callback event
902+ active_callbacks = [
903+ event
904+ for event in callback_started_events
905+ if event .event_id not in completed_callback_ids
906+ ]
907+
908+ if not active_callbacks :
909+ return None
910+
911+ latest_event = active_callbacks [- 1 ]
912+ return (
913+ latest_event .callback_started_details .callback_id
914+ if latest_event .callback_started_details
915+ else None
831916 )
832917
918+ def send_callback_success (self , callback_id : str ) -> None :
919+ try :
920+ self .lambda_client .send_durable_execution_callback_success (
921+ CallbackId = callback_id
922+ )
923+ except Exception as e :
924+ msg = f"Failed to send callback success for { self .function_name } , callback_id { callback_id } : { e } "
925+ raise DurableFunctionsTestError (msg ) from e
926+
927+ def send_callback_failure (self , callback_id : str ) -> None :
928+ try :
929+ self .lambda_client .send_durable_execution_callback_failure (
930+ CallbackId = callback_id
931+ )
932+ except Exception as e :
933+ msg = f"Failed to send callback failure for { self .function_name } , callback_id { callback_id } : { e } "
934+ raise DurableFunctionsTestError (msg ) from e
935+
936+ def send_callback_heartbeat (self , callback_id : str ) -> None :
937+ try :
938+ self .lambda_client .send_durable_execution_callback_heartbeat (
939+ CallbackId = callback_id
940+ )
941+ except Exception as e :
942+ msg = f"Failed to send callback heartbeat for { self .function_name } , callback_id { callback_id } : { e } "
943+ raise DurableFunctionsTestError (msg ) from e
944+
833945 def _wait_for_completion (
834946 self , execution_arn : str , timeout : int
835947 ) -> GetDurableExecutionResponse :
@@ -886,7 +998,81 @@ def _wait_for_completion(
886998 )
887999 raise TimeoutError (msg )
8881000
889- def _get_execution_history (
1001+ def wait_for_result (
1002+ self , execution_arn : str , timeout : int = 60
1003+ ) -> DurableFunctionTestResult :
1004+ # Poll for completion
1005+ execution_response = self ._wait_for_completion (execution_arn , timeout )
1006+
1007+ try :
1008+ history_response = self ._fetch_execution_history (execution_arn )
1009+ except Exception as e :
1010+ msg = f"Failed to fetch execution history: { e } "
1011+ raise DurableFunctionsTestError (msg ) from e
1012+
1013+ # Build test result from execution history
1014+ return DurableFunctionTestResult .from_execution_history (
1015+ execution_response , history_response
1016+ )
1017+
1018+ def wait_for_callback (
1019+ self , execution_arn : str , name : str | None = None , timeout : int = 60
1020+ ) -> str :
1021+ """
1022+ Wait for and retrieve a callback ID from a Step Functions execution.
1023+
1024+ Polls the execution history at regular intervals until a callback ID is found
1025+ or the timeout is reached.
1026+
1027+ Args:
1028+ execution_arn: Execution Arn
1029+ name: Specific callback name, default to None
1030+ timeout: Maximum time in seconds to wait for callback. Defaults to 60.
1031+
1032+ Returns:
1033+ str: The callback ID/token retrieved from the execution history
1034+
1035+ Raises:
1036+ TimeoutError: If callback is not found within the specified timeout period
1037+ DurableFunctionsTestError: If there's an error fetching execution history
1038+ (excluding retryable errors)
1039+ """
1040+ start_time = time .time ()
1041+
1042+ while time .time () - start_time < timeout :
1043+ try :
1044+ history_response = self ._fetch_execution_history (execution_arn )
1045+ callback_id = self ._get_callback_id_from_events (
1046+ events = history_response .events , name = name
1047+ )
1048+ if callback_id :
1049+ return callback_id
1050+ except ClientError as e :
1051+ error_code = e .response ["Error" ]["Code" ]
1052+ # retryable error, the execution may not start yet in async invoke situation
1053+ if error_code in ["ResourceNotFoundException" ]:
1054+ pass
1055+ else :
1056+ msg = f"Failed to fetch execution history: { e } "
1057+ raise DurableFunctionsTestError (msg ) from e
1058+ except DurableFunctionsTestError as e :
1059+ raise e
1060+ except Exception as e :
1061+ msg = f"Failed to fetch execution history: { e } "
1062+ raise DurableFunctionsTestError (msg ) from e
1063+
1064+ # Wait before next poll
1065+ time .sleep (self .poll_interval )
1066+
1067+ # Timeout reached
1068+ elapsed = time .time () - start_time
1069+ msg = (
1070+ f"Callback did not available within { timeout } s "
1071+ f"(elapsed: { elapsed :.1f} s."
1072+ )
1073+ raise TimeoutError (msg )
1074+
1075+ def _fetch_execution_history (
8901076 self , execution_arn : str
8911077 ) -> GetDurableExecutionHistoryResponse :
8921078 """Retrieve execution history from Lambda service.
@@ -898,19 +1084,13 @@ def _get_execution_history(
8981084 GetDurableExecutionHistoryResponse with typed Event objects
8991085
9001086 Raises:
901- DurableFunctionsTestError : If history retrieval fails
1087+ ClientError : If lambda client encounter error
9021088 """
903- try :
904- history_dict = self .lambda_client .get_durable_execution_history (
905- DurableExecutionArn = execution_arn ,
906- IncludeExecutionData = True ,
907- )
908- history_response = GetDurableExecutionHistoryResponse .from_dict (
909- history_dict
910- )
911- except Exception as e :
912- msg = f"Failed to get execution history: { e } "
913- raise DurableFunctionsTestError (msg ) from e
1089+ history_dict = self .lambda_client .get_durable_execution_history (
1090+ DurableExecutionArn = execution_arn ,
1091+ IncludeExecutionData = True ,
1092+ )
1093+ history_response = GetDurableExecutionHistoryResponse .from_dict (history_dict )
9141094
9151095 logger .info ("Retrieved %d events from history" , len (history_response .events ))
9161096
0 commit comments