Skip to content

Commit d01e471

Browse files
author
Alex Wang
committed
feat: Implement callback for web runner
- Implement async run, send callback for web runner - Unit tests
1 parent a3cda91 commit d01e471

File tree

4 files changed

+696
-23
lines changed

4 files changed

+696
-23
lines changed

examples/test/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,36 @@ def run(
105105
"""Execute the durable function and return results."""
106106
return self._runner.run(input=input, timeout=timeout)
107107

108+
def run_async(
109+
self,
110+
input: str | None = None, # noqa: A002
111+
timeout: int = 60,
112+
) -> str:
113+
return self._runner.run_async(input=input, timeout=timeout)
114+
115+
def send_callback_success(self, callback_id: str) -> None:
116+
self._runner.send_callback_success(callback_id=callback_id)
117+
118+
def send_callback_failure(self, callback_id: str) -> None:
119+
self._runner.send_callback_failure(callback_id=callback_id)
120+
121+
def send_callback_heartbeat(self, callback_id: str) -> None:
122+
self._runner.send_callback_heartbeat(callback_id=callback_id)
123+
124+
def wait_and_fetch_test_result(
125+
self, execution_arn: str, timeout: int = 60
126+
) -> DurableFunctionTestResult:
127+
return self._runner.wait_and_fetch_test_result(
128+
execution_arn=execution_arn, timeout=timeout
129+
)
130+
131+
def wait_and_fetch_callback_id(
132+
self, execution_arn: str, name: str | None = None, timeout: int = 60
133+
) -> str:
134+
return self._runner.wait_and_fetch_callback_id(
135+
execution_arn=execution_arn, name=name, timeout=timeout
136+
)
137+
108138
@property
109139
def mode(self) -> str:
110140
"""Get the runner mode (local or cloud)."""

src/aws_durable_execution_sdk_python_testing/runner.py

Lines changed: 170 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import aws_durable_execution_sdk_python
2020
import boto3 # type: ignore
21+
from botocore.exceptions import ClientError
2122
from aws_durable_execution_sdk_python.execution import (
2223
InvocationStatus,
2324
durable_execution,
@@ -75,6 +76,8 @@
7576

7677
from aws_durable_execution_sdk_python_testing.execution import Execution
7778
from aws_durable_execution_sdk_python_testing.web.server import WebServiceConfig
79+
from aws_durable_execution_sdk_python_testing.model import Event
80+
7881

7982
logger = logging.getLogger(__name__)
8083

@@ -792,9 +795,9 @@ def run(
792795
msg = f"Failed to invoke Lambda function {self.function_name}: {e}"
793796
raise DurableFunctionsTestError(msg) from e
794797

795-
# Check HTTP status code (200 for RequestResponse, 202 for Event, 204 for DryRun)
798+
# Check HTTP status code, 200 for RequestResponse
796799
status_code = response.get("StatusCode")
797-
if status_code not in (200, 202, 204):
800+
if status_code != 200:
798801
error_payload = response["Payload"].read().decode("utf-8")
799802
msg = f"Lambda invocation failed with status {status_code}: {error_payload}"
800803
raise DurableFunctionsTestError(msg)
@@ -819,16 +822,111 @@ def run(
819822
)
820823
raise DurableFunctionsTestError(msg)
821824

822-
# Poll for completion
823-
execution_response = self._wait_for_completion(execution_arn, timeout)
824-
825-
# Get execution history
826-
history_response = self._get_execution_history(execution_arn)
825+
return self.wait_and_fetch_test_result(
826+
execution_arn=execution_arn, timeout=timeout
827+
)
827828

828-
# Build test result from execution history
829-
return DurableFunctionTestResult.from_execution_history(
830-
execution_response, history_response
829+
def run_async(
830+
self,
831+
input: str | None = None, # noqa: A002
832+
timeout: int = 60,
833+
) -> str:
834+
"""Execute function on AWS Lambda and wait for completion."""
835+
logger.info(
836+
"Invoking Lambda function: %s (timeout: %ds)", self.function_name, timeout
831837
)
838+
payload = json.dumps(input)
839+
try:
840+
response = self.lambda_client.invoke(
841+
FunctionName=self.function_name,
842+
InvocationType="Event",
843+
Payload=payload,
844+
)
845+
except Exception as e:
846+
msg = f"Failed to invoke Lambda function {self.function_name}: {e}"
847+
raise DurableFunctionsTestError(msg) from e
848+
849+
# Check HTTP status code, 202 for Event
850+
status_code = response.get("StatusCode")
851+
if status_code != 202:
852+
error_payload = response["Payload"].read().decode("utf-8")
853+
msg = f"Lambda invocation failed with status {status_code}: {error_payload}"
854+
raise DurableFunctionsTestError(msg)
855+
856+
return response.get("DurableExecutionArn")
857+
858+
def _get_callback_id_from_events(
859+
self, events: list[Event], name: str | None = None
860+
) -> str:
861+
"""
862+
Get callback ID from execution history for callbacks that haven't completed.
863+
864+
Args:
865+
execution_arn: The ARN of the execution to query.
866+
name: Optional callback name to search for. If not provided, returns the latest callback.
867+
868+
Returns:
869+
The callback ID string for a non-completed callback, or None if not found.
870+
871+
Raises:
872+
DurableFunctionsTestError: If the named callback has already succeeded/failed/timed out.
873+
"""
874+
callback_started_events = [
875+
event for event in events if event.event_type == "CallbackStarted"
876+
]
877+
878+
if not callback_started_events:
879+
return None
880+
881+
completed_callback_ids = {
882+
event.event_id
883+
for event in events
884+
if event.event_type
885+
in ["CallbackSucceeded", "CallbackFailed", "CallbackTimedOut"]
886+
}
887+
888+
if name is not None:
889+
for event in callback_started_events:
890+
if event.name == name:
891+
callback_id = event.event_id
892+
if callback_id in completed_callback_ids:
893+
raise DurableFunctionsTestError(
894+
f"Callback {name} has already completed (succeeded/failed/timed out)"
895+
)
896+
return event.callback_started_details.callback_id
897+
return None
898+
899+
# If name is not provided, find the latest non-completed callback event
900+
active_callbacks = [
901+
event
902+
for event in callback_started_events
903+
if event.event_id not in completed_callback_ids
904+
]
905+
906+
if not active_callbacks:
907+
return None
908+
909+
latest_event = active_callbacks[-1]
910+
return latest_event.callback_started_details.callback_id
911+
912+
def send_callback_success(self, callback_id: str) -> None:
913+
self._send_callback("success", callback_id)
914+
915+
def send_callback_failure(self, callback_id: str) -> None:
916+
self._send_callback("failure", callback_id)
917+
918+
def send_callback_heartbeat(self, callback_id: str) -> None:
919+
self._send_callback("heartbeat", callback_id)
920+
921+
def _send_callback(self, operation: str, callback_id: str) -> None:
922+
"""Helper method to send callback operations."""
923+
method_name = f"send_durable_execution_callback_{operation}"
924+
try:
925+
api_method = getattr(self.lambda_client, method_name)
926+
api_method(CallbackId=callback_id)
927+
except Exception as e:
928+
msg = f"Failed to send callback {operation} for {self.function_name}, callback_id {callback_id}: {e}"
929+
raise DurableFunctionsTestError(msg) from e
832930

833931
def _wait_for_completion(
834932
self, execution_arn: str, timeout: int
@@ -886,7 +984,62 @@ def _wait_for_completion(
886984
)
887985
raise TimeoutError(msg)
888986

889-
def _get_execution_history(
987+
def wait_and_fetch_test_result(
988+
self, execution_arn: str, timeout: int = 60
989+
) -> DurableFunctionTestResult:
990+
# Poll for completion
991+
execution_response = self._wait_for_completion(execution_arn, timeout)
992+
993+
try:
994+
history_response = self._fetch_execution_history(execution_arn)
995+
except Exception as e:
996+
msg = f"Failed to fetch execution history: {e}"
997+
raise DurableFunctionsTestError(msg) from e
998+
999+
# Build test result from execution history
1000+
return DurableFunctionTestResult.from_execution_history(
1001+
execution_response, history_response
1002+
)
1003+
1004+
def wait_and_fetch_callback_id(
1005+
self, execution_arn: str, name: str | None = None, timeout: int = 60
1006+
) -> str:
1007+
start_time = time.time()
1008+
1009+
while time.time() - start_time < timeout:
1010+
try:
1011+
history_response = self._fetch_execution_history(execution_arn)
1012+
callback_id = self._get_callback_id_from_events(
1013+
events=history_response.events, name=name
1014+
)
1015+
if callback_id:
1016+
return callback_id
1017+
except ClientError as e:
1018+
error_code = e.response["Error"]["Code"]
1019+
# retryable error, the execution may not start yet in async invoke situation
1020+
if error_code in ["ResourceNotFoundException"]:
1021+
pass
1022+
else:
1023+
msg = f"Failed to fetch execution history: {e}"
1024+
raise DurableFunctionsTestError(msg) from e
1025+
except DurableFunctionsTestError as e:
1026+
raise e
1027+
except Exception as e:
1028+
msg = f"Failed to fetch execution history: {e}"
1029+
raise DurableFunctionsTestError(msg) from e
1030+
1031+
# Wait before next poll
1032+
time.sleep(self.poll_interval)
1033+
1034+
# Timeout reached
1035+
elapsed = time.time() - start_time
1036+
msg = (
1037+
f"Callback did not available within {timeout}s "
1038+
f"(elapsed: {elapsed:.1f}s."
1039+
)
1040+
raise TimeoutError(msg)
1041+
1042+
def _fetch_execution_history(
8901043
self, execution_arn: str
8911044
) -> GetDurableExecutionHistoryResponse:
8921045
"""Retrieve execution history from Lambda service.
@@ -898,19 +1051,13 @@ def _get_execution_history(
8981051
GetDurableExecutionHistoryResponse with typed Event objects
8991052
9001053
Raises:
901-
DurableFunctionsTestError: If history retrieval fails
1054+
ClientError: If lambda client encounter error
9021055
"""
903-
try:
904-
history_dict = self.lambda_client.get_durable_execution_history(
905-
DurableExecutionArn=execution_arn,
906-
IncludeExecutionData=True,
907-
)
908-
history_response = GetDurableExecutionHistoryResponse.from_dict(
909-
history_dict
910-
)
911-
except Exception as e:
912-
msg = f"Failed to get execution history: {e}"
913-
raise DurableFunctionsTestError(msg) from e
1056+
history_dict = self.lambda_client.get_durable_execution_history(
1057+
DurableExecutionArn=execution_arn,
1058+
IncludeExecutionData=True,
1059+
)
1060+
history_response = GetDurableExecutionHistoryResponse.from_dict(history_dict)
9141061

9151062
logger.info("Retrieved %d events from history", len(history_response.events))
9161063

0 commit comments

Comments
 (0)