Skip to content

Commit c442bcd

Browse files
committed
fix: add InvocationCompleted event support
1 parent c2cc3e9 commit c442bcd

File tree

6 files changed

+103
-26
lines changed

6 files changed

+103
-26
lines changed

src/aws_durable_execution_sdk_python_testing/execution.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def __init__(
6060
self.start_input: StartDurableExecutionInput = start_input
6161
self.operations: list[Operation] = operations
6262
self.updates: list[OperationUpdate] = []
63+
self.invocation_completions: list[
64+
tuple[float, float, str]
65+
] = [] # (start_ts, end_ts, request_id)
6366
self.used_tokens: set[str] = set()
6467
# TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store
6568
self._token_sequence: int = 0
@@ -101,6 +104,10 @@ def to_dict(self) -> dict[str, Any]:
101104
"StartInput": self.start_input.to_dict(),
102105
"Operations": [op.to_dict() for op in self.operations],
103106
"Updates": [update.to_dict() for update in self.updates],
107+
"InvocationCompletions": [
108+
[start, end, req_id]
109+
for start, end, req_id in self.invocation_completions
110+
],
104111
"UsedTokens": list(self.used_tokens),
105112
"TokenSequence": self._token_sequence,
106113
"IsComplete": self.is_complete,
@@ -129,6 +136,9 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
129136
execution.updates = [
130137
OperationUpdate.from_dict(update_data) for update_data in data["Updates"]
131138
]
139+
execution.invocation_completions = [
140+
tuple(item) for item in data.get("InvocationCompletions", [])
141+
]
132142
execution.used_tokens = set(data["UsedTokens"])
133143
execution._token_sequence = data["TokenSequence"] # noqa: SLF001
134144
execution.is_complete = data["IsComplete"]
@@ -215,6 +225,12 @@ def has_pending_operations(self, execution: Execution) -> bool:
215225
return True
216226
return False
217227

228+
def record_invocation_completion(
229+
self, start_timestamp: float, end_timestamp: float, request_id: str
230+
) -> None:
231+
"""Record an invocation completion event."""
232+
self.invocation_completions.append((start_timestamp, end_timestamp, request_id))
233+
218234
def complete_success(self, result: str | None) -> None:
219235
"""Complete execution successfully (DecisionType.COMPLETE_WORKFLOW_EXECUTION)."""
220236
self.result = DurableExecutionInvocationOutput(

src/aws_durable_execution_sdk_python_testing/executor.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import logging
6+
import time
67
import uuid
78
from datetime import UTC, datetime
89
from typing import TYPE_CHECKING
@@ -413,6 +414,20 @@ def get_execution_history(
413414
updates_dict: dict[str, OperationUpdate] = {u.operation_id: u for u in updates}
414415
durable_execution_arn: str = execution.durable_execution_arn
415416

417+
# Add InvocationCompleted events
418+
for start_ts, end_ts, request_id in execution.invocation_completions:
419+
invocation_event = HistoryEvent(
420+
event_id=0, # Temporary, will be reassigned
421+
event_type="InvocationCompleted",
422+
event_timestamp=datetime.fromtimestamp(end_ts, tz=UTC),
423+
invocation_completed_details={
424+
"StartTimestamp": start_ts,
425+
"EndTimestamp": end_ts,
426+
"RequestId": request_id,
427+
},
428+
)
429+
all_events.append(invocation_event)
430+
416431
# Generate all events first (without final event IDs)
417432
for op in ops:
418433
operation_update: OperationUpdate | None = updates_dict.get(
@@ -769,14 +784,23 @@ async def invoke() -> None:
769784

770785
self._store.save(execution)
771786

772-
response: DurableExecutionInvocationOutput = self._invoker.invoke(
787+
invocation_start = time.time()
788+
response, request_id = self._invoker.invoke(
773789
execution.start_input.function_name,
774790
invocation_input,
775791
execution.start_input.lambda_endpoint,
776792
)
793+
invocation_end = time.time()
777794

778795
# Reload execution after invocation in case it was completed via checkpoint
779796
execution = self._store.load(execution_arn)
797+
798+
# Record invocation completion and save immediately
799+
execution.record_invocation_completion(
800+
invocation_start, invocation_end, request_id
801+
)
802+
self._store.save(execution)
803+
780804
if execution.is_complete:
781805
logger.info(
782806
"[%s] Execution completed during invocation, ignoring result",

src/aws_durable_execution_sdk_python_testing/invoker.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
from threading import Lock
55
from typing import TYPE_CHECKING, Any, Protocol
6+
from uuid import uuid4
67

78
import boto3 # type: ignore
89
from aws_durable_execution_sdk_python.execution import (
@@ -65,7 +66,7 @@ def invoke(
6566
function_name: str,
6667
input: DurableExecutionInvocationInput,
6768
endpoint_url: str | None = None,
68-
) -> DurableExecutionInvocationOutput: ... # pragma: no cover
69+
) -> tuple[DurableExecutionInvocationOutput, str]: ... # pragma: no cover
6970

7071
def update_endpoint(
7172
self, endpoint_url: str, region_name: str
@@ -96,14 +97,15 @@ def invoke(
9697
function_name: str, # noqa: ARG002
9798
input: DurableExecutionInvocationInput,
9899
endpoint_url: str | None = None, # noqa: ARG002
99-
) -> DurableExecutionInvocationOutput:
100+
) -> tuple[DurableExecutionInvocationOutput, str]:
100101
# TODO: reasses if function_name will be used in future
101102
input_with_client = DurableExecutionInvocationInputWithClient.from_durable_execution_invocation_input(
102103
input, self.service_client
103104
)
104105
context = create_test_lambda_context()
105106
response_dict = self.handler(input_with_client, context)
106-
return DurableExecutionInvocationOutput.from_dict(response_dict)
107+
output = DurableExecutionInvocationOutput.from_dict(response_dict)
108+
return output, context.aws_request_id
107109

108110
def update_endpoint(self, endpoint_url: str, region_name: str) -> None:
109111
"""No-op for in-process invoker."""
@@ -192,7 +194,7 @@ def invoke(
192194
function_name: str,
193195
input: DurableExecutionInvocationInput,
194196
endpoint_url: str | None = None,
195-
) -> DurableExecutionInvocationOutput:
197+
) -> tuple[DurableExecutionInvocationOutput, str]:
196198
"""Invoke AWS Lambda function and return durable execution result.
197199
198200
Args:
@@ -201,7 +203,7 @@ def invoke(
201203
endpoint_url: Lambda endpoint url
202204
203205
Returns:
204-
DurableExecutionInvocationOutput: Result of the function execution
206+
tuple: (DurableExecutionInvocationOutput, request_id)
205207
206208
Raises:
207209
ResourceNotFoundException: If function does not exist
@@ -247,8 +249,17 @@ def invoke(
247249
response_payload = response["Payload"].read().decode("utf-8")
248250
response_dict = json.loads(response_payload)
249251

252+
# Extract request ID from response headers (x-amzn-RequestId or x-amzn-request-id)
253+
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
254+
request_id = (
255+
headers.get("x-amzn-RequestId")
256+
or headers.get("x-amzn-request-id")
257+
or f"local-{uuid4()}"
258+
)
259+
250260
# Convert to DurableExecutionInvocationOutput
251-
return DurableExecutionInvocationOutput.from_dict(response_dict)
261+
output = DurableExecutionInvocationOutput.from_dict(response_dict)
262+
return output, request_id
252263

253264
except client.exceptions.ResourceNotFoundException as e:
254265
msg = f"Function not found: {function_name}"

src/aws_durable_execution_sdk_python_testing/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,7 @@ class Event:
13291329
callback_succeeded_details: CallbackSucceededDetails | None = None
13301330
callback_failed_details: CallbackFailedDetails | None = None
13311331
callback_timed_out_details: CallbackTimedOutDetails | None = None
1332+
invocation_completed_details: dict[str, Any] | None = None
13321333

13331334
@classmethod
13341335
def from_dict(cls, data: dict) -> Event:
@@ -1447,6 +1448,8 @@ def from_dict(cls, data: dict) -> Event:
14471448
if details_data := data.get("CallbackTimedOutDetails"):
14481449
callback_timed_out_details = CallbackTimedOutDetails.from_dict(details_data)
14491450

1451+
invocation_completed_details = data.get("InvocationCompletedDetails")
1452+
14501453
return cls(
14511454
event_type=data["EventType"],
14521455
event_timestamp=data["EventTimestamp"],
@@ -1479,6 +1482,7 @@ def from_dict(cls, data: dict) -> Event:
14791482
callback_succeeded_details=callback_succeeded_details,
14801483
callback_failed_details=callback_failed_details,
14811484
callback_timed_out_details=callback_timed_out_details,
1485+
invocation_completed_details=invocation_completed_details,
14821486
)
14831487

14841488
def to_dict(self) -> dict[str, Any]:
@@ -1563,6 +1567,8 @@ def to_dict(self) -> dict[str, Any]:
15631567
result["CallbackTimedOutDetails"] = (
15641568
self.callback_timed_out_details.to_dict()
15651569
)
1570+
if self.invocation_completed_details is not None:
1571+
result["InvocationCompletedDetails"] = self.invocation_completed_details
15661572
return result
15671573

15681574
# region execution

tests/executor_test.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def test_should_complete_workflow_with_error_when_invocation_fails(
285285
failed_response = DurableExecutionInvocationOutput(
286286
status=InvocationStatus.FAILED, error=ErrorObject.from_message("Test error")
287287
)
288-
mock_invoker.invoke.return_value = failed_response
288+
mock_invoker.invoke.return_value = (failed_response, "test-request-id")
289289

290290
# Mock execution creation and store behavior
291291
with patch(
@@ -329,7 +329,7 @@ def test_should_complete_workflow_with_result_when_invocation_succeeds(
329329
success_response = DurableExecutionInvocationOutput(
330330
status=InvocationStatus.SUCCEEDED, result="success result"
331331
)
332-
mock_invoker.invoke.return_value = success_response
332+
mock_invoker.invoke.return_value = (success_response, "test-request-id")
333333

334334
# Mock execution creation and store behavior
335335
with patch(
@@ -372,7 +372,7 @@ def test_should_handle_pending_status_when_operations_exist(
372372
mock_invocation_input = Mock()
373373
mock_invoker.create_invocation_input.return_value = mock_invocation_input
374374
pending_response = DurableExecutionInvocationOutput(status=InvocationStatus.PENDING)
375-
mock_invoker.invoke.return_value = pending_response
375+
mock_invoker.invoke.return_value = (pending_response, "test-request-id")
376376

377377
# Mock execution creation and store behavior
378378
with patch(
@@ -409,8 +409,9 @@ def test_should_ignore_response_when_execution_already_complete(
409409

410410
# Mock invoker - this shouldn't be called since execution is complete
411411
mock_invoker.create_invocation_input.return_value = Mock()
412-
mock_invoker.invoke.return_value = DurableExecutionInvocationOutput(
413-
status=InvocationStatus.SUCCEEDED
412+
mock_invoker.invoke.return_value = (
413+
DurableExecutionInvocationOutput(status=InvocationStatus.SUCCEEDED),
414+
"test-request-id",
414415
)
415416

416417
# Mock execution creation and store behavior
@@ -452,7 +453,7 @@ def test_should_retry_when_response_has_no_status(
452453
mock_invocation_input = Mock()
453454
mock_invoker.create_invocation_input.return_value = mock_invocation_input
454455
no_status_response = DurableExecutionInvocationOutput(status=None)
455-
mock_invoker.invoke.return_value = no_status_response
456+
mock_invoker.invoke.return_value = (no_status_response, "test-request-id")
456457

457458
# Mock execution creation and store behavior
458459
with patch(
@@ -495,7 +496,7 @@ def test_should_retry_when_failed_response_has_result(
495496
invalid_response = DurableExecutionInvocationOutput(
496497
status=InvocationStatus.FAILED, result="should not have result"
497498
)
498-
mock_invoker.invoke.return_value = invalid_response
499+
mock_invoker.invoke.return_value = (invalid_response, "test-request-id")
499500

500501
# Mock execution creation and store behavior
501502
with patch(
@@ -539,7 +540,7 @@ def test_should_retry_when_success_response_has_error(
539540
status=InvocationStatus.SUCCEEDED,
540541
error=ErrorObject.from_message("should not have error"),
541542
)
542-
mock_invoker.invoke.return_value = invalid_response
543+
mock_invoker.invoke.return_value = (invalid_response, "test-request-id")
543544

544545
# Mock execution creation and store behavior
545546
with patch(
@@ -581,7 +582,7 @@ def test_should_retry_when_pending_response_has_no_operations(
581582
mock_invocation_input = Mock()
582583
mock_invoker.create_invocation_input.return_value = mock_invocation_input
583584
pending_response = DurableExecutionInvocationOutput(status=InvocationStatus.PENDING)
584-
mock_invoker.invoke.return_value = pending_response
585+
mock_invoker.invoke.return_value = (pending_response, "test-request-id")
585586

586587
# Mock execution creation and store behavior
587588
with patch(
@@ -622,7 +623,7 @@ def test_invoke_handler_success(
622623
mock_response = DurableExecutionInvocationOutput(
623624
status=InvocationStatus.SUCCEEDED, result="test"
624625
)
625-
mock_invoker.invoke.return_value = mock_response
626+
mock_invoker.invoke.return_value = (mock_response, "test-request-id")
626627

627628
# Mock execution creation and store behavior
628629
with patch(
@@ -694,7 +695,7 @@ def test_invoke_handler_execution_completed_during_invocation(
694695
mock_invocation_input = Mock()
695696
mock_invoker.create_invocation_input.return_value = mock_invocation_input
696697
mock_response = Mock()
697-
mock_invoker.invoke.return_value = mock_response
698+
mock_invoker.invoke.return_value = (mock_response, "test-request-id")
698699

699700
# Create a completed execution mock
700701
completed_execution = Mock()
@@ -1037,7 +1038,10 @@ def test_should_retry_invocation_when_under_limit_through_public_api(
10371038
success_response = DurableExecutionInvocationOutput(
10381039
status=InvocationStatus.SUCCEEDED, result="final success"
10391040
)
1040-
mock_invoker.invoke.side_effect = [invalid_response, success_response]
1041+
mock_invoker.invoke.side_effect = [
1042+
(invalid_response, "test-request-id-1"),
1043+
(success_response, "test-request-id-2"),
1044+
]
10411045

10421046
# Mock execution creation and store behavior
10431047
with patch(
@@ -1435,7 +1439,7 @@ def test_should_retry_when_response_has_unexpected_status(
14351439
mock_invoker.create_invocation_input.return_value = mock_invocation_input
14361440
unexpected_response = Mock()
14371441
unexpected_response.status = "UNKNOWN_STATUS"
1438-
mock_invoker.invoke.return_value = unexpected_response
1442+
mock_invoker.invoke.return_value = (unexpected_response, "test-request-id")
14391443

14401444
# Mock execution creation and store behavior
14411445
with patch(
@@ -1480,7 +1484,7 @@ def test_invoke_handler_execution_completed_during_invocation_async(
14801484
mock_invocation_input = Mock()
14811485
mock_invoker.create_invocation_input.return_value = mock_invocation_input
14821486
mock_response = Mock()
1483-
mock_invoker.invoke.return_value = mock_response
1487+
mock_invoker.invoke.return_value = (mock_response, "test-request-id")
14841488

14851489
# Mock execution creation
14861490
with patch(
@@ -1566,7 +1570,7 @@ def test_invoke_handler_general_exception_async(
15661570
success_response = DurableExecutionInvocationOutput(
15671571
status=InvocationStatus.SUCCEEDED, result="success"
15681572
)
1569-
mock_invoker.invoke.return_value = success_response
1573+
mock_invoker.invoke.return_value = (success_response, "test-request-id")
15701574

15711575
# Mock execution creation and store behavior
15721576
with patch(
@@ -2094,6 +2098,7 @@ def test_get_execution_history(executor, mock_store):
20942098
mock_execution = Mock()
20952099
mock_execution.operations = [] # Empty operations list
20962100
mock_execution.updates = []
2101+
mock_execution.invocation_completions = []
20972102
mock_execution.durable_execution_arn = ""
20982103
mock_execution.start_input = Mock()
20992104
mock_execution.result = Mock()
@@ -2123,6 +2128,7 @@ def test_get_execution_history_with_events(executor, mock_store):
21232128
mock_execution = Mock()
21242129
mock_execution.operations = [op1]
21252130
mock_execution.updates = []
2131+
mock_execution.invocation_completions = []
21262132
mock_execution.durable_execution_arn = ""
21272133
mock_execution.start_input = Mock()
21282134
mock_execution.result = Mock()
@@ -2148,6 +2154,7 @@ def test_get_execution_history_reverse_order(executor, mock_store):
21482154
mock_execution = Mock()
21492155
mock_execution.operations = [op1]
21502156
mock_execution.updates = []
2157+
mock_execution.invocation_completions = []
21512158
mock_execution.durable_execution_arn = ""
21522159
mock_execution.start_input = Mock()
21532160
mock_execution.result = Mock()
@@ -2178,6 +2185,7 @@ def test_get_execution_history_pagination(executor, mock_store):
21782185
mock_execution = Mock()
21792186
mock_execution.operations = operations
21802187
mock_execution.updates = []
2188+
mock_execution.invocation_completions = []
21812189
mock_execution.durable_execution_arn = ""
21822190
mock_execution.start_input = Mock()
21832191
mock_execution.result = Mock()
@@ -2206,6 +2214,7 @@ def test_get_execution_history_pagination_with_marker(executor, mock_store):
22062214
mock_execution = Mock()
22072215
mock_execution.operations = operations
22082216
mock_execution.updates = []
2217+
mock_execution.invocation_completions = []
22092218
mock_execution.durable_execution_arn = ""
22102219
mock_execution.start_input = Mock()
22112220
mock_execution.result = Mock()
@@ -2223,6 +2232,7 @@ def test_get_execution_history_invalid_marker(executor, mock_store):
22232232
mock_execution = Mock()
22242233
mock_execution.operations = []
22252234
mock_execution.updates = []
2235+
mock_execution.invocation_completions = []
22262236
mock_execution.durable_execution_arn = ""
22272237
mock_execution.start_input = Mock()
22282238
mock_execution.result = Mock()
@@ -2399,6 +2409,7 @@ def test_send_callback_heartbeat(executor, mock_store):
23992409
mock_operation.status = OperationStatus.STARTED
24002410
mock_execution.find_callback_operation.return_value = (0, mock_operation)
24012411
mock_execution.updates = [] # No callback options to reset timeout
2412+
mock_execution.invocation_completions = []
24022413
mock_store.load.return_value = mock_execution
24032414

24042415
result = executor.send_callback_heartbeat(callback_id)
@@ -2651,6 +2662,7 @@ def test_schedule_callback_timeouts_no_callback_options(executor, mock_store):
26512662
mock_execution = Mock()
26522663
mock_execution.find_operation.return_value = (0, operation)
26532664
mock_execution.updates = [] # No updates with callback options
2665+
mock_execution.invocation_completions = []
26542666
mock_store.load.return_value = mock_execution
26552667

26562668
# Should return early without scheduling

0 commit comments

Comments
 (0)