diff --git a/src/aws_durable_execution_sdk_python_testing/execution.py b/src/aws_durable_execution_sdk_python_testing/execution.py index b651bf1..8583efb 100644 --- a/src/aws_durable_execution_sdk_python_testing/execution.py +++ b/src/aws_durable_execution_sdk_python_testing/execution.py @@ -161,7 +161,7 @@ def start(self) -> None: operation_type=OperationType.EXECUTION, status=OperationStatus.STARTED, execution_details=ExecutionDetails( - input_payload=json.dumps(self.start_input.input) + input_payload=self.start_input.get_normalized_input() ), ) ) diff --git a/src/aws_durable_execution_sdk_python_testing/model.py b/src/aws_durable_execution_sdk_python_testing/model.py index 0b91153..a440ef9 100644 --- a/src/aws_durable_execution_sdk_python_testing/model.py +++ b/src/aws_durable_execution_sdk_python_testing/model.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, replace from enum import Enum from typing import Any +import json from dateutil.tz import UTC @@ -166,6 +167,19 @@ def to_dict(self) -> dict[str, Any]: result["Input"] = self.input return result + def get_normalized_input(self): + """ + Normalize input string to be JSON deserializable. + Avoid double coding json input. + """ + # Try to parse once + try: + _ = json.loads(self.input) + return self.input + except (json.JSONDecodeError, TypeError): + # Not valid JSON, treat as plain string and encode it + return json.dumps(self.input) + @dataclass(frozen=True) class StartDurableExecutionOutput: diff --git a/tests/execution_test.py b/tests/execution_test.py index 602698b..0f599b6 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -99,7 +99,7 @@ def test_execution_start(mock_datetime): assert operation.start_timestamp == mock_now assert operation.operation_type == OperationType.EXECUTION assert operation.status == OperationStatus.STARTED - assert operation.execution_details.input_payload == '"{\\"key\\": \\"value\\"}"' + assert operation.execution_details.input_payload == '{"key": "value"}' def test_get_operation_execution_started(): diff --git a/tests/model_test.py b/tests/model_test.py index 5605c3f..447c340 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -77,21 +77,23 @@ TIMESTAMP_2023_01_01_00_02 = datetime.datetime(2023, 1, 1, 0, 2, 0, tzinfo=datetime.UTC) TIMESTAMP_2023_01_02_00_00 = datetime.datetime(2023, 1, 2, 0, 0, 0, tzinfo=datetime.UTC) +DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA = { + "AccountId": "123456789012", + "FunctionName": "my-function", + "FunctionQualifier": "$LATEST", + "ExecutionName": "test-execution", + "ExecutionTimeoutSeconds": 300, + "ExecutionRetentionPeriodDays": 7, + "InvocationId": "invocation-123", + "TraceFields": {"key": "value"}, + "TenantId": "tenant-123", + "Input": "test-input", +} + def test_start_durable_execution_input_serialization(): """Test StartDurableExecutionInput from_dict/to_dict round-trip.""" - data = { - "AccountId": "123456789012", - "FunctionName": "my-function", - "FunctionQualifier": "$LATEST", - "ExecutionName": "test-execution", - "ExecutionTimeoutSeconds": 300, - "ExecutionRetentionPeriodDays": 7, - "InvocationId": "invocation-123", - "TraceFields": {"key": "value"}, - "TenantId": "tenant-123", - "Input": "test-input", - } + data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA # Test from_dict input_obj = StartDurableExecutionInput.from_dict(data) @@ -115,6 +117,42 @@ def test_start_durable_execution_input_serialization(): assert round_trip == input_obj +def test_start_durable_execution_input_get_input_json_input(): + """Test StartDurableExecutionInput from_dict/to_dict round-trip.""" + data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA + data["Input"] = '{"message": "hello"}' + + input_obj = StartDurableExecutionInput.from_dict(data) + assert '{"message": "hello"}' == input_obj.get_normalized_input() + + +def test_start_durable_execution_input_get_input_str_non_json_input(): + """Test StartDurableExecutionInput from_dict/to_dict round-trip.""" + data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA + data["Input"] = "hello" + + input_obj = StartDurableExecutionInput.from_dict(data) + assert '"hello"' == input_obj.get_normalized_input() + + +def test_start_durable_execution_input_get_input_str_json_input(): + """Test StartDurableExecutionInput from_dict/to_dict round-trip.""" + data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA + data["Input"] = '"hello"' + + input_obj = StartDurableExecutionInput.from_dict(data) + assert '"hello"' == input_obj.get_normalized_input() + + +def test_start_durable_execution_input_get_input_list_json_input(): + """Test StartDurableExecutionInput from_dict/to_dict round-trip.""" + data = DEFAULT_START_DURABLE_EXECUTION_INPUT_DATA + data["Input"] = "[1,2,3]" + + input_obj = StartDurableExecutionInput.from_dict(data) + assert "[1,2,3]" == input_obj.get_normalized_input() + + def test_start_durable_execution_input_minimal(): """Test StartDurableExecutionInput with only required fields.""" data = {