From ead014326dba3afe62024a2968bad0d658f10241 Mon Sep 17 00:00:00 2001 From: Rares Polenciuc Date: Sat, 15 Nov 2025 16:41:59 +0000 Subject: [PATCH] refactor: replace dataclasses.replace with instance factories --- .../execution.py | 51 +++--------- .../model.py | 77 +++++++------------ 2 files changed, 39 insertions(+), 89 deletions(-) diff --git a/src/aws_durable_execution_sdk_python_testing/execution.py b/src/aws_durable_execution_sdk_python_testing/execution.py index b651bf1..56caa6b 100644 --- a/src/aws_durable_execution_sdk_python_testing/execution.py +++ b/src/aws_durable_execution_sdk_python_testing/execution.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from dataclasses import replace + from datetime import UTC, datetime from enum import Enum from threading import Lock @@ -13,12 +13,14 @@ InvocationStatus, ) from aws_durable_execution_sdk_python.lambda_service import ( + CallbackDetails, ErrorObject, ExecutionDetails, Operation, OperationStatus, OperationType, OperationUpdate, + StepDetails, ) from aws_durable_execution_sdk_python_testing.exceptions import ( @@ -289,10 +291,8 @@ def complete_wait(self, operation_id: str) -> Operation: with self._state_lock: self._token_sequence += 1 # Build and assign updated operation - self.operations[index] = replace( - operation, - status=OperationStatus.SUCCEEDED, - end_timestamp=datetime.now(UTC), + self.operations[index] = operation.create_succeeded( + end_timestamp=datetime.now(UTC) ) return self.operations[index] @@ -313,17 +313,7 @@ def complete_retry(self, operation_id: str) -> Operation: # Thread-safe increment sequence and operation update with self._state_lock: self._token_sequence += 1 - # Build updated step_details with cleared next_attempt_timestamp - new_step_details = None - if operation.step_details: - new_step_details = replace( - operation.step_details, next_attempt_timestamp=None - ) - - # Build updated operation - updated_operation = replace( - operation, status=OperationStatus.READY, step_details=new_step_details - ) + updated_operation = operation.create_completed_retry() # Assign self.operations[index] = updated_operation @@ -337,21 +327,11 @@ def complete_callback_success( if operation.status != OperationStatus.STARTED: msg: str = f"Callback operation [{callback_id}] is not in STARTED state" raise IllegalStateException(msg) - with self._state_lock: self._token_sequence += 1 - updated_callback_details = None - if operation.callback_details: - updated_callback_details = replace( - operation.callback_details, - result=result.decode() if result else None, - ) - - self.operations[index] = replace( - operation, - status=OperationStatus.SUCCEEDED, + self.operations[index] = operation.create_callback_result( + result=result.decode() if result else None, end_timestamp=datetime.now(UTC), - callback_details=updated_callback_details, ) return self.operations[index] @@ -367,17 +347,9 @@ def complete_callback_failure( with self._state_lock: self._token_sequence += 1 - updated_callback_details = None - if operation.callback_details: - updated_callback_details = replace( - operation.callback_details, error=error - ) - - self.operations[index] = replace( - operation, - status=OperationStatus.FAILED, + self.operations[index] = operation.create_callback_failure( + error=error, end_timestamp=datetime.now(UTC), - callback_details=updated_callback_details, ) return self.operations[index] @@ -386,8 +358,7 @@ def _end_execution(self, status: OperationStatus) -> None: execution_op: Operation = self.get_operation_execution_started() if execution_op.operation_type == OperationType.EXECUTION: with self._state_lock: - self.operations[0] = replace( - execution_op, + self.operations[0] = execution_op.create_execution_end( status=status, end_timestamp=datetime.now(UTC), ) diff --git a/src/aws_durable_execution_sdk_python_testing/model.py b/src/aws_durable_execution_sdk_python_testing/model.py index 27da2d8..a5c1f2c 100644 --- a/src/aws_durable_execution_sdk_python_testing/model.py +++ b/src/aws_durable_execution_sdk_python_testing/model.py @@ -3,7 +3,7 @@ from __future__ import annotations import datetime -from dataclasses import dataclass, replace +from dataclasses import dataclass from enum import Enum from typing import Any @@ -2547,26 +2547,13 @@ def events_to_operations(events: list[Event]) -> list[Operation]: # Merge with previous operation if it exists # Most fields are immutable, so they get preserved from previous events if previous_operation: - operation = replace( - operation, - name=operation.name or previous_operation.name, - parent_id=operation.parent_id or previous_operation.parent_id, - sub_type=operation.sub_type or previous_operation.sub_type, - start_timestamp=previous_operation.start_timestamp, - end_timestamp=previous_operation.end_timestamp, - execution_details=previous_operation.execution_details, - context_details=previous_operation.context_details, - step_details=previous_operation.step_details, - wait_details=previous_operation.wait_details, - callback_details=previous_operation.callback_details, - chained_invoke_details=previous_operation.chained_invoke_details, - ) + operation = operation.create_merged_from_previous(previous_operation) # Set timestamps based on event configuration if event_config.is_start_event: - operation = replace(operation, start_timestamp=event.event_timestamp) + operation = operation.create_with_start_timestamp(event.event_timestamp) if event_config.is_end_event: - operation = replace(operation, end_timestamp=event.event_timestamp) + operation = operation.create_with_end_timestamp(event.event_timestamp) # Add operation-specific details incrementally # Each event type contributes only the fields it has @@ -2577,11 +2564,10 @@ def events_to_operations(events: list[Event]) -> list[Operation]: and event.execution_started_details and event.execution_started_details.input ): - operation = replace( - operation, - execution_details=ExecutionDetails( + operation = operation.create_with_execution_details( + ExecutionDetails( input_payload=event.execution_started_details.input.payload - ), + ) ) # CALLBACK details - merge callback_id, result, and error from different events @@ -2613,13 +2599,12 @@ def events_to_operations(events: list[Event]) -> list[Operation]: ): error = event.callback_timed_out_details.error.payload - operation = replace( - operation, - callback_details=CallbackDetails( + operation = operation.create_with_callback_details( + CallbackDetails( callback_id=callback_id, result=result, error=error, - ), + ) ) # STEP details - only update if this event type has result data @@ -2655,23 +2640,21 @@ def events_to_operations(events: list[Event]) -> list[Operation]: seconds=event.step_failed_details.retry_details.next_attempt_delay_seconds ) - operation = replace( - operation, - step_details=StepDetails( + operation = operation.create_with_step_details( + StepDetails( result=result_val, error=error_val, attempt=attempt, next_attempt_timestamp=next_attempt_ts, - ), + ) ) # WAIT details if operation_type == OperationType.WAIT and event.wait_started_details: - operation = replace( - operation, - wait_details=WaitDetails( + operation = operation.create_with_wait_details( + WaitDetails( scheduled_end_timestamp=event.wait_started_details.scheduled_end_timestamp - ), + ) ) # CONTEXT details - only update if this event type has result data (matching TypeScript hasResult) @@ -2680,20 +2663,18 @@ def events_to_operations(events: list[Event]) -> list[Operation]: event.context_succeeded_details and event.context_succeeded_details.result ): - operation = replace( - operation, - context_details=ContextDetails( + operation = operation.create_with_context_details( + ContextDetails( result=event.context_succeeded_details.result.payload, error=None, - ), + ) ) elif event.context_failed_details and event.context_failed_details.error: - operation = replace( - operation, - context_details=ContextDetails( + operation = operation.create_with_context_details( + ContextDetails( result=None, error=event.context_failed_details.error.payload, - ), + ) ) # CHAINED_INVOKE details - only update if this event type has result data (matching TypeScript hasResult) @@ -2702,23 +2683,21 @@ def events_to_operations(events: list[Event]) -> list[Operation]: event.chained_invoke_succeeded_details and event.chained_invoke_succeeded_details.result ): - operation = replace( - operation, - chained_invoke_details=ChainedInvokeDetails( + operation = operation.create_with_chained_invoke_details( + ChainedInvokeDetails( result=event.chained_invoke_succeeded_details.result.payload, error=None, - ), + ) ) elif ( event.chained_invoke_failed_details and event.chained_invoke_failed_details.error ): - operation = replace( - operation, - chained_invoke_details=ChainedInvokeDetails( + operation = operation.create_with_chained_invoke_details( + ChainedInvokeDetails( result=None, error=event.chained_invoke_failed_details.error.payload, - ), + ) ) # Store in map