Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 11 additions & 40 deletions src/aws_durable_execution_sdk_python_testing/execution.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from dataclasses import replace

from datetime import UTC, datetime
from enum import Enum
from threading import Lock
Expand All @@ -13,12 +13,14 @@
InvocationStatus,
)
from aws_durable_execution_sdk_python.lambda_service import (
CallbackDetails,
ErrorObject,
ExecutionDetails,
Operation,
OperationStatus,
OperationType,
OperationUpdate,
StepDetails,
)

from aws_durable_execution_sdk_python_testing.exceptions import (
Expand Down Expand Up @@ -289,10 +291,8 @@ def complete_wait(self, operation_id: str) -> Operation:
with self._state_lock:
self._token_sequence += 1
# Build and assign updated operation
self.operations[index] = replace(
operation,
status=OperationStatus.SUCCEEDED,
end_timestamp=datetime.now(UTC),
self.operations[index] = operation.create_succeeded(
end_timestamp=datetime.now(UTC)
)
return self.operations[index]

Expand All @@ -313,17 +313,7 @@ def complete_retry(self, operation_id: str) -> Operation:
# Thread-safe increment sequence and operation update
with self._state_lock:
self._token_sequence += 1
# Build updated step_details with cleared next_attempt_timestamp
new_step_details = None
if operation.step_details:
new_step_details = replace(
operation.step_details, next_attempt_timestamp=None
)

# Build updated operation
updated_operation = replace(
operation, status=OperationStatus.READY, step_details=new_step_details
)
updated_operation = operation.create_completed_retry()

# Assign
self.operations[index] = updated_operation
Expand All @@ -337,21 +327,11 @@ def complete_callback_success(
if operation.status != OperationStatus.STARTED:
msg: str = f"Callback operation [{callback_id}] is not in STARTED state"
raise IllegalStateException(msg)

with self._state_lock:
self._token_sequence += 1
updated_callback_details = None
if operation.callback_details:
updated_callback_details = replace(
operation.callback_details,
result=result.decode() if result else None,
)

self.operations[index] = replace(
operation,
status=OperationStatus.SUCCEEDED,
self.operations[index] = operation.create_callback_result(
result=result.decode() if result else None,
end_timestamp=datetime.now(UTC),
callback_details=updated_callback_details,
)
return self.operations[index]

Expand All @@ -367,17 +347,9 @@ def complete_callback_failure(

with self._state_lock:
self._token_sequence += 1
updated_callback_details = None
if operation.callback_details:
updated_callback_details = replace(
operation.callback_details, error=error
)

self.operations[index] = replace(
operation,
status=OperationStatus.FAILED,
self.operations[index] = operation.create_callback_failure(
error=error,
end_timestamp=datetime.now(UTC),
callback_details=updated_callback_details,
)
return self.operations[index]

Expand All @@ -386,8 +358,7 @@ def _end_execution(self, status: OperationStatus) -> None:
execution_op: Operation = self.get_operation_execution_started()
if execution_op.operation_type == OperationType.EXECUTION:
with self._state_lock:
self.operations[0] = replace(
execution_op,
self.operations[0] = execution_op.create_execution_end(
status=status,
end_timestamp=datetime.now(UTC),
)
77 changes: 28 additions & 49 deletions src/aws_durable_execution_sdk_python_testing/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import datetime
from dataclasses import dataclass, replace
from dataclasses import dataclass
from enum import Enum
from typing import Any

Expand Down Expand Up @@ -2547,26 +2547,13 @@ def events_to_operations(events: list[Event]) -> list[Operation]:
# Merge with previous operation if it exists
# Most fields are immutable, so they get preserved from previous events
if previous_operation:
operation = replace(
operation,
name=operation.name or previous_operation.name,
parent_id=operation.parent_id or previous_operation.parent_id,
sub_type=operation.sub_type or previous_operation.sub_type,
start_timestamp=previous_operation.start_timestamp,
end_timestamp=previous_operation.end_timestamp,
execution_details=previous_operation.execution_details,
context_details=previous_operation.context_details,
step_details=previous_operation.step_details,
wait_details=previous_operation.wait_details,
callback_details=previous_operation.callback_details,
chained_invoke_details=previous_operation.chained_invoke_details,
)
operation = operation.create_merged_from_previous(previous_operation)

# Set timestamps based on event configuration
if event_config.is_start_event:
operation = replace(operation, start_timestamp=event.event_timestamp)
operation = operation.create_with_start_timestamp(event.event_timestamp)
if event_config.is_end_event:
operation = replace(operation, end_timestamp=event.event_timestamp)
operation = operation.create_with_end_timestamp(event.event_timestamp)

# Add operation-specific details incrementally
# Each event type contributes only the fields it has
Expand All @@ -2577,11 +2564,10 @@ def events_to_operations(events: list[Event]) -> list[Operation]:
and event.execution_started_details
and event.execution_started_details.input
):
operation = replace(
operation,
execution_details=ExecutionDetails(
operation = operation.create_with_execution_details(
ExecutionDetails(
input_payload=event.execution_started_details.input.payload
),
)
)

# CALLBACK details - merge callback_id, result, and error from different events
Expand Down Expand Up @@ -2613,13 +2599,12 @@ def events_to_operations(events: list[Event]) -> list[Operation]:
):
error = event.callback_timed_out_details.error.payload

operation = replace(
operation,
callback_details=CallbackDetails(
operation = operation.create_with_callback_details(
CallbackDetails(
callback_id=callback_id,
result=result,
error=error,
),
)
)

# STEP details - only update if this event type has result data
Expand Down Expand Up @@ -2655,23 +2640,21 @@ def events_to_operations(events: list[Event]) -> list[Operation]:
seconds=event.step_failed_details.retry_details.next_attempt_delay_seconds
)

operation = replace(
operation,
step_details=StepDetails(
operation = operation.create_with_step_details(
StepDetails(
result=result_val,
error=error_val,
attempt=attempt,
next_attempt_timestamp=next_attempt_ts,
),
)
)

# WAIT details
if operation_type == OperationType.WAIT and event.wait_started_details:
operation = replace(
operation,
wait_details=WaitDetails(
operation = operation.create_with_wait_details(
WaitDetails(
scheduled_end_timestamp=event.wait_started_details.scheduled_end_timestamp
),
)
)

# CONTEXT details - only update if this event type has result data (matching TypeScript hasResult)
Expand All @@ -2680,20 +2663,18 @@ def events_to_operations(events: list[Event]) -> list[Operation]:
event.context_succeeded_details
and event.context_succeeded_details.result
):
operation = replace(
operation,
context_details=ContextDetails(
operation = operation.create_with_context_details(
ContextDetails(
result=event.context_succeeded_details.result.payload,
error=None,
),
)
)
elif event.context_failed_details and event.context_failed_details.error:
operation = replace(
operation,
context_details=ContextDetails(
operation = operation.create_with_context_details(
ContextDetails(
result=None,
error=event.context_failed_details.error.payload,
),
)
)

# CHAINED_INVOKE details - only update if this event type has result data (matching TypeScript hasResult)
Expand All @@ -2702,23 +2683,21 @@ def events_to_operations(events: list[Event]) -> list[Operation]:
event.chained_invoke_succeeded_details
and event.chained_invoke_succeeded_details.result
):
operation = replace(
operation,
chained_invoke_details=ChainedInvokeDetails(
operation = operation.create_with_chained_invoke_details(
ChainedInvokeDetails(
result=event.chained_invoke_succeeded_details.result.payload,
error=None,
),
)
)
elif (
event.chained_invoke_failed_details
and event.chained_invoke_failed_details.error
):
operation = replace(
operation,
chained_invoke_details=ChainedInvokeDetails(
operation = operation.create_with_chained_invoke_details(
ChainedInvokeDetails(
result=None,
error=event.chained_invoke_failed_details.error.payload,
),
)
)

# Store in map
Expand Down
Loading