Skip to content

Commit 6ac4a9f

Browse files
committed
fix: add InvocationCompleted event support
1 parent c2cc3e9 commit 6ac4a9f

File tree

4 files changed

+53
-7
lines changed

4 files changed

+53
-7
lines changed

src/aws_durable_execution_sdk_python_testing/execution.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
self.start_input: StartDurableExecutionInput = start_input
6161
self.operations: list[Operation] = operations
6262
self.updates: list[OperationUpdate] = []
63+
self.invocation_completions: list[tuple[float, float, str]] = [] # (start_ts, end_ts, request_id)
6364
self.used_tokens: set[str] = set()
6465
# TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store
6566
self._token_sequence: int = 0
@@ -101,6 +102,7 @@ def to_dict(self) -> dict[str, Any]:
101102
"StartInput": self.start_input.to_dict(),
102103
"Operations": [op.to_dict() for op in self.operations],
103104
"Updates": [update.to_dict() for update in self.updates],
105+
"InvocationCompletions": [[start, end, req_id] for start, end, req_id in self.invocation_completions],
104106
"UsedTokens": list(self.used_tokens),
105107
"TokenSequence": self._token_sequence,
106108
"IsComplete": self.is_complete,
@@ -129,6 +131,9 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
129131
execution.updates = [
130132
OperationUpdate.from_dict(update_data) for update_data in data["Updates"]
131133
]
134+
execution.invocation_completions = [
135+
tuple(item) for item in data.get("InvocationCompletions", [])
136+
]
132137
execution.used_tokens = set(data["UsedTokens"])
133138
execution._token_sequence = data["TokenSequence"] # noqa: SLF001
134139
execution.is_complete = data["IsComplete"]
@@ -215,6 +220,12 @@ def has_pending_operations(self, execution: Execution) -> bool:
215220
return True
216221
return False
217222

223+
def record_invocation_completion(
224+
self, start_timestamp: float, end_timestamp: float, request_id: str
225+
) -> None:
226+
"""Record an invocation completion event."""
227+
self.invocation_completions.append((start_timestamp, end_timestamp, request_id))
228+
218229
def complete_success(self, result: str | None) -> None:
219230
"""Complete execution successfully (DecisionType.COMPLETE_WORKFLOW_EXECUTION)."""
220231
self.result = DurableExecutionInvocationOutput(

src/aws_durable_execution_sdk_python_testing/executor.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import logging
6+
import time
67
import uuid
78
from datetime import UTC, datetime
89
from typing import TYPE_CHECKING
@@ -413,6 +414,20 @@ def get_execution_history(
413414
updates_dict: dict[str, OperationUpdate] = {u.operation_id: u for u in updates}
414415
durable_execution_arn: str = execution.durable_execution_arn
415416

417+
# Add InvocationCompleted events
418+
for start_ts, end_ts, request_id in execution.invocation_completions:
419+
invocation_event = HistoryEvent(
420+
event_id=0, # Temporary, will be reassigned
421+
event_type="InvocationCompleted",
422+
event_timestamp=datetime.fromtimestamp(end_ts, tz=UTC),
423+
invocation_completed_details={
424+
"StartTimestamp": start_ts,
425+
"EndTimestamp": end_ts,
426+
"RequestId": request_id,
427+
},
428+
)
429+
all_events.append(invocation_event)
430+
416431
# Generate all events first (without final event IDs)
417432
for op in ops:
418433
operation_update: OperationUpdate | None = updates_dict.get(
@@ -769,14 +784,21 @@ async def invoke() -> None:
769784

770785
self._store.save(execution)
771786

772-
response: DurableExecutionInvocationOutput = self._invoker.invoke(
787+
invocation_start = time.time()
788+
response, request_id = self._invoker.invoke(
773789
execution.start_input.function_name,
774790
invocation_input,
775791
execution.start_input.lambda_endpoint,
776792
)
793+
invocation_end = time.time()
777794

778795
# Reload execution after invocation in case it was completed via checkpoint
779796
execution = self._store.load(execution_arn)
797+
798+
# Record invocation completion and save immediately
799+
execution.record_invocation_completion(invocation_start, invocation_end, request_id)
800+
self._store.save(execution)
801+
780802
if execution.is_complete:
781803
logger.info(
782804
"[%s] Execution completed during invocation, ignoring result",

src/aws_durable_execution_sdk_python_testing/invoker.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
from threading import Lock
55
from typing import TYPE_CHECKING, Any, Protocol
6+
from uuid import uuid4
67

78
import boto3 # type: ignore
89
from aws_durable_execution_sdk_python.execution import (
@@ -65,7 +66,7 @@ def invoke(
6566
function_name: str,
6667
input: DurableExecutionInvocationInput,
6768
endpoint_url: str | None = None,
68-
) -> DurableExecutionInvocationOutput: ... # pragma: no cover
69+
) -> tuple[DurableExecutionInvocationOutput, str]: ... # pragma: no cover
6970

7071
def update_endpoint(
7172
self, endpoint_url: str, region_name: str
@@ -96,14 +97,15 @@ def invoke(
9697
function_name: str, # noqa: ARG002
9798
input: DurableExecutionInvocationInput,
9899
endpoint_url: str | None = None, # noqa: ARG002
99-
) -> DurableExecutionInvocationOutput:
100+
) -> tuple[DurableExecutionInvocationOutput, str]:
100101
# TODO: reasses if function_name will be used in future
101102
input_with_client = DurableExecutionInvocationInputWithClient.from_durable_execution_invocation_input(
102103
input, self.service_client
103104
)
104105
context = create_test_lambda_context()
105106
response_dict = self.handler(input_with_client, context)
106-
return DurableExecutionInvocationOutput.from_dict(response_dict)
107+
output = DurableExecutionInvocationOutput.from_dict(response_dict)
108+
return output, context.aws_request_id
107109

108110
def update_endpoint(self, endpoint_url: str, region_name: str) -> None:
109111
"""No-op for in-process invoker."""
@@ -192,7 +194,7 @@ def invoke(
192194
function_name: str,
193195
input: DurableExecutionInvocationInput,
194196
endpoint_url: str | None = None,
195-
) -> DurableExecutionInvocationOutput:
197+
) -> tuple[DurableExecutionInvocationOutput, str]:
196198
"""Invoke AWS Lambda function and return durable execution result.
197199
198200
Args:
@@ -201,7 +203,7 @@ def invoke(
201203
endpoint_url: Lambda endpoint url
202204
203205
Returns:
204-
DurableExecutionInvocationOutput: Result of the function execution
206+
tuple: (DurableExecutionInvocationOutput, request_id)
205207
206208
Raises:
207209
ResourceNotFoundException: If function does not exist
@@ -247,8 +249,13 @@ def invoke(
247249
response_payload = response["Payload"].read().decode("utf-8")
248250
response_dict = json.loads(response_payload)
249251

252+
# Extract request ID from response headers (x-amzn-RequestId or x-amzn-request-id)
253+
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
254+
request_id = headers.get("x-amzn-RequestId") or headers.get("x-amzn-request-id") or f"local-{uuid4()}"
255+
250256
# Convert to DurableExecutionInvocationOutput
251-
return DurableExecutionInvocationOutput.from_dict(response_dict)
257+
output = DurableExecutionInvocationOutput.from_dict(response_dict)
258+
return output, request_id
252259

253260
except client.exceptions.ResourceNotFoundException as e:
254261
msg = f"Function not found: {function_name}"

src/aws_durable_execution_sdk_python_testing/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,7 @@ class Event:
13291329
callback_succeeded_details: CallbackSucceededDetails | None = None
13301330
callback_failed_details: CallbackFailedDetails | None = None
13311331
callback_timed_out_details: CallbackTimedOutDetails | None = None
1332+
invocation_completed_details: dict[str, Any] | None = None
13321333

13331334
@classmethod
13341335
def from_dict(cls, data: dict) -> Event:
@@ -1447,6 +1448,8 @@ def from_dict(cls, data: dict) -> Event:
14471448
if details_data := data.get("CallbackTimedOutDetails"):
14481449
callback_timed_out_details = CallbackTimedOutDetails.from_dict(details_data)
14491450

1451+
invocation_completed_details = data.get("InvocationCompletedDetails")
1452+
14501453
return cls(
14511454
event_type=data["EventType"],
14521455
event_timestamp=data["EventTimestamp"],
@@ -1479,6 +1482,7 @@ def from_dict(cls, data: dict) -> Event:
14791482
callback_succeeded_details=callback_succeeded_details,
14801483
callback_failed_details=callback_failed_details,
14811484
callback_timed_out_details=callback_timed_out_details,
1485+
invocation_completed_details=invocation_completed_details,
14821486
)
14831487

14841488
def to_dict(self) -> dict[str, Any]:
@@ -1563,6 +1567,8 @@ def to_dict(self) -> dict[str, Any]:
15631567
result["CallbackTimedOutDetails"] = (
15641568
self.callback_timed_out_details.to_dict()
15651569
)
1570+
if self.invocation_completed_details is not None:
1571+
result["InvocationCompletedDetails"] = self.invocation_completed_details
15661572
return result
15671573

15681574
# region execution

0 commit comments

Comments
 (0)