1818
1919import aws_durable_execution_sdk_python
2020import boto3 # type: ignore
21+ from botocore .exceptions import ClientError
2122from aws_durable_execution_sdk_python .execution import (
2223 InvocationStatus ,
2324 durable_execution ,
7576
7677 from aws_durable_execution_sdk_python_testing .execution import Execution
7778 from aws_durable_execution_sdk_python_testing .web .server import WebServiceConfig
79+ from aws_durable_execution_sdk_python_testing .model import Event
80+
7881
7982logger = logging .getLogger (__name__ )
8083
@@ -792,9 +795,9 @@ def run(
792795 msg = f"Failed to invoke Lambda function { self .function_name } : { e } "
793796 raise DurableFunctionsTestError (msg ) from e
794797
795- # Check HTTP status code ( 200 for RequestResponse, 202 for Event, 204 for DryRun)
798+ # Check HTTP status code, 200 for RequestResponse
796799 status_code = response .get ("StatusCode" )
797- if status_code not in ( 200 , 202 , 204 ) :
800+ if status_code != 200 :
798801 error_payload = response ["Payload" ].read ().decode ("utf-8" )
799802 msg = f"Lambda invocation failed with status { status_code } : { error_payload } "
800803 raise DurableFunctionsTestError (msg )
@@ -819,16 +822,111 @@ def run(
819822 )
820823 raise DurableFunctionsTestError (msg )
821824
822- # Poll for completion
823- execution_response = self ._wait_for_completion (execution_arn , timeout )
824-
825- # Get execution history
826- history_response = self ._get_execution_history (execution_arn )
825+ return self .wait_and_fetch_test_result (
826+ execution_arn = execution_arn , timeout = timeout
827+ )
827828
828- # Build test result from execution history
829- return DurableFunctionTestResult .from_execution_history (
830- execution_response , history_response
829+ def run_async (
830+ self ,
831+ input : str | None = None , # noqa: A002
832+ timeout : int = 60 ,
833+ ) -> str :
834+ """Execute function on AWS Lambda and wait for completion."""
835+ logger .info (
836+ "Invoking Lambda function: %s (timeout: %ds)" , self .function_name , timeout
831837 )
838+ payload = json .dumps (input )
839+ try :
840+ response = self .lambda_client .invoke (
841+ FunctionName = self .function_name ,
842+ InvocationType = "Event" ,
843+ Payload = payload ,
844+ )
845+ except Exception as e :
846+ msg = f"Failed to invoke Lambda function { self .function_name } : { e } "
847+ raise DurableFunctionsTestError (msg ) from e
848+
849+ # Check HTTP status code, 202 for Event
850+ status_code = response .get ("StatusCode" )
851+ if status_code != 202 :
852+ error_payload = response ["Payload" ].read ().decode ("utf-8" )
853+ msg = f"Lambda invocation failed with status { status_code } : { error_payload } "
854+ raise DurableFunctionsTestError (msg )
855+
856+ return response .get ("DurableExecutionArn" )
857+
858+ def _get_callback_id_from_events (
859+ self , events : list [Event ], name : str | None = None
860+ ) -> str :
861+ """
862+ Get callback ID from execution history for callbacks that haven't completed.
863+
864+ Args:
865+ execution_arn: The ARN of the execution to query.
866+ name: Optional callback name to search for. If not provided, returns the latest callback.
867+
868+ Returns:
869+ The callback ID string for a non-completed callback, or None if not found.
870+
871+ Raises:
872+ DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out.
873+ """
874+ callback_started_events = [
875+ event for event in events if event .event_type == "CallbackStarted"
876+ ]
877+
878+ if not callback_started_events :
879+ return None
880+
881+ completed_callback_ids = {
882+ event .event_id
883+ for event in events
884+ if event .event_type
885+ in ["CallbackSucceeded" , "CallbackFailed" , "CallbackTimedOut" ]
886+ }
887+
888+ if name is not None :
889+ for event in callback_started_events :
890+ if event .name == name :
891+ callback_id = event .event_id
892+ if callback_id in completed_callback_ids :
893+ raise DurableFunctionsTestError (
894+ f"Callback { name } has already completed (succeeded/failed/timed out)"
895+ )
896+ return event .callback_started_details .callback_id
897+ return None
898+
899+ # If name is not provided, find the latest non-completed callback event
900+ active_callbacks = [
901+ event
902+ for event in callback_started_events
903+ if event .event_id not in completed_callback_ids
904+ ]
905+
906+ if not active_callbacks :
907+ return None
908+
909+ latest_event = active_callbacks [- 1 ]
910+ return latest_event .callback_started_details .callback_id
911+
912+ def send_callback_success (self , callback_id : str ) -> None :
913+ self ._send_callback ("success" , callback_id )
914+
915+ def send_callback_failure (self , callback_id : str ) -> None :
916+ self ._send_callback ("failure" , callback_id )
917+
918+ def send_callback_heartbeat (self , callback_id : str ) -> None :
919+ self ._send_callback ("heartbeat" , callback_id )
920+
921+ def _send_callback (self , operation : str , callback_id : str ) -> None :
922+ """Helper method to send callback operations."""
923+ method_name = f"send_durable_execution_callback_{ operation } "
924+ try :
925+ api_method = getattr (self .lambda_client , method_name )
926+ api_method (CallbackId = callback_id )
927+ except Exception as e :
928+ msg = f"Failed to send callback { operation } for { self .function_name } , callback_id { callback_id } : { e } "
929+ raise DurableFunctionsTestError (msg ) from e
832930
833931 def _wait_for_completion (
834932 self , execution_arn : str , timeout : int
@@ -886,7 +984,62 @@ def _wait_for_completion(
886984 )
887985 raise TimeoutError (msg )
888986
889- def _get_execution_history (
987+ def wait_and_fetch_test_result (
988+ self , execution_arn : str , timeout : int = 60
989+ ) -> DurableFunctionTestResult :
990+ # Poll for completion
991+ execution_response = self ._wait_for_completion (execution_arn , timeout )
992+
993+ try :
994+ history_response = self ._fetch_execution_history (execution_arn )
995+ except Exception as e :
996+ msg = f"Failed to fetch execution history: { e } "
997+ raise DurableFunctionsTestError (msg ) from e
998+
999+ # Build test result from execution history
1000+ return DurableFunctionTestResult .from_execution_history (
1001+ execution_response , history_response
1002+ )
1003+
1004+ def wait_and_fetch_callback_id (
1005+ self , execution_arn : str , name : str | None = None , timeout : int = 60
1006+ ) -> str :
1007+ start_time = time .time ()
1008+
1009+ while time .time () - start_time < timeout :
1010+ try :
1011+ history_response = self ._fetch_execution_history (execution_arn )
1012+ callback_id = self ._get_callback_id_from_events (
1013+ events = history_response .events , name = name
1014+ )
1015+ if callback_id :
1016+ return callback_id
1017+ except ClientError as e :
1018+ error_code = e .response ["Error" ]["Code" ]
1019+ # retryable error, the execution may not start yet in async invoke situation
1020+ if error_code in ["ResourceNotFoundException" ]:
1021+ pass
1022+ else :
1023+ msg = f"Failed to fetch execution history: { e } "
1024+ raise DurableFunctionsTestError (msg ) from e
1025+ except DurableFunctionsTestError as e :
1026+ raise e
1027+ except Exception as e :
1028+ msg = f"Failed to fetch execution history: { e } "
1029+ raise DurableFunctionsTestError (msg ) from e
1030+
1031+ # Wait before next poll
1032+ time .sleep (self .poll_interval )
1033+
1034+ # Timeout reached
1035+ elapsed = time .time () - start_time
1036+ msg = (
1037+ f"Callback did not available within { timeout } s "
1038+ f"(elapsed: { elapsed :.1f} s."
1039+ )
1040+ raise TimeoutError (msg )
1041+
1042+ def _fetch_execution_history (
8901043 self , execution_arn : str
8911044 ) -> GetDurableExecutionHistoryResponse :
8921045 """Retrieve execution history from Lambda service.
@@ -898,19 +1051,13 @@ def _get_execution_history(
8981051 GetDurableExecutionHistoryResponse with typed Event objects
8991052
9001053 Raises:
901- DurableFunctionsTestError : If history retrieval fails
1054+ ClientError : If lambda client encounter error
9021055 """
903- try :
904- history_dict = self .lambda_client .get_durable_execution_history (
905- DurableExecutionArn = execution_arn ,
906- IncludeExecutionData = True ,
907- )
908- history_response = GetDurableExecutionHistoryResponse .from_dict (
909- history_dict
910- )
911- except Exception as e :
912- msg = f"Failed to get execution history: { e } "
913- raise DurableFunctionsTestError (msg ) from e
1056+ history_dict = self .lambda_client .get_durable_execution_history (
1057+ DurableExecutionArn = execution_arn ,
1058+ IncludeExecutionData = True ,
1059+ )
1060+ history_response = GetDurableExecutionHistoryResponse .from_dict (history_dict )
9141061
9151062 logger .info ("Retrieved %d events from history" , len (history_response .events ))
9161063
0 commit comments