Skip to content

Commit 41902c1

Browse files
committed
feat: set ChainedInvoke to default to json serdes
ChainedInvoke now defaults to JSON serializer for payload and result. Make the JSON and Extended types serdes singletons public. Minor linting warning fix for ops script.
1 parent 0761a6a commit 41902c1

File tree

5 files changed

+75
-15
lines changed

5 files changed

+75
-15
lines changed

ops/__tests__/test_parse_sdk_branch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ def test():
7373

7474
for input_text, expected in test_cases:
7575
result = parse_sdk_branch(input_text)
76-
if result != expected:
77-
return False
78-
79-
return True
76+
# Assert is expected in test functions
77+
assert result == expected, ( # noqa: S101
78+
f"Expected '{expected}' but got '{result}' for input: {input_text[:50]}..."
79+
)
8080

8181

8282
if __name__ == "__main__":
83-
success = test_parse_sdk_branch()
84-
sys.exit(0 if success else 1)
83+
test_parse_sdk_branch()
84+
sys.exit(0)

src/aws_durable_execution_sdk_python/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,10 +392,12 @@ class InvokeConfig(Generic[P, R]):
392392
from blocking execution indefinitely.
393393
394394
serdes_payload: Custom serialization/deserialization for the payload
395-
sent to the invoked function. If None, uses default JSON serialization.
395+
sent to the invoked function. Defaults to DEFAULT_JSON_SERDES when
396+
not set.
396397
397398
serdes_result: Custom serialization/deserialization for the result
398-
returned from the invoked function. If None, uses default JSON serialization.
399+
returned from the invoked function. Defaults to DEFAULT_JSON_SERDES when
400+
not set.
399401
400402
tenant_id: Optional tenant identifier for multi-tenant isolation.
401403
If provided, the invocation will be scoped to this tenant.

src/aws_durable_execution_sdk_python/operation/invoke.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
ChainedInvokeOptions,
1212
OperationUpdate,
1313
)
14-
from aws_durable_execution_sdk_python.serdes import deserialize, serialize
14+
from aws_durable_execution_sdk_python.serdes import (
15+
DEFAULT_JSON_SERDES,
16+
deserialize,
17+
serialize,
18+
)
1519
from aws_durable_execution_sdk_python.suspend import suspend_with_optional_resume_delay
1620

1721
if TYPE_CHECKING:
@@ -53,7 +57,7 @@ def invoke_handler(
5357
and checkpointed_result.operation.chained_invoke_details.result
5458
):
5559
return deserialize(
56-
serdes=config.serdes_result,
60+
serdes=config.serdes_result or DEFAULT_JSON_SERDES,
5761
data=checkpointed_result.operation.chained_invoke_details.result,
5862
operation_id=operation_identifier.operation_id,
5963
durable_execution_arn=state.durable_execution_arn,
@@ -78,7 +82,7 @@ def invoke_handler(
7882
suspend_with_optional_resume_delay(msg, config.timeout_seconds)
7983

8084
serialized_payload: str = serialize(
81-
serdes=config.serdes_payload,
85+
serdes=config.serdes_payload or DEFAULT_JSON_SERDES,
8286
value=payload,
8387
operation_id=operation_identifier.operation_id,
8488
durable_execution_arn=state.durable_execution_arn,

src/aws_durable_execution_sdk_python/serdes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,8 @@ def _to_json_serializable(self, obj: Any) -> Any:
441441
return obj
442442

443443

444-
_DEFAULT_JSON_SERDES: SerDes[Any] = JsonSerDes()
445-
_EXTENDED_TYPES_SERDES: SerDes[Any] = ExtendedTypeSerDes()
444+
DEFAULT_JSON_SERDES: SerDes[Any] = JsonSerDes()
445+
EXTENDED_TYPES_SERDES: SerDes[Any] = ExtendedTypeSerDes()
446446

447447

448448
def serialize(
@@ -463,7 +463,7 @@ def serialize(
463463
FatalError: If serialization fails
464464
"""
465465
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
466-
active_serdes: SerDes[T] = serdes or _EXTENDED_TYPES_SERDES
466+
active_serdes: SerDes[T] = serdes or EXTENDED_TYPES_SERDES
467467
try:
468468
return active_serdes.serialize(value, serdes_context)
469469
except Exception as e:
@@ -493,7 +493,7 @@ def deserialize(
493493
FatalError: If deserialization fails
494494
"""
495495
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
496-
active_serdes: SerDes[T] = serdes or _EXTENDED_TYPES_SERDES
496+
active_serdes: SerDes[T] = serdes or EXTENDED_TYPES_SERDES
497497
try:
498498
return active_serdes.deserialize(data, serdes_context)
499499
except Exception as e:

tests/operation/invoke_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,3 +612,57 @@ def test_invoke_handler_default_config_no_tenant_id():
612612
chained_invoke_options = operation_update.to_dict()["ChainedInvokeOptions"]
613613
assert chained_invoke_options["FunctionName"] == "test_function"
614614
assert "TenantId" not in chained_invoke_options
615+
616+
617+
def test_invoke_handler_defaults_to_json_serdes():
618+
"""Test invoke_handler uses DEFAULT_JSON_SERDES when config has no serdes."""
619+
mock_state = Mock(spec=ExecutionState)
620+
mock_state.durable_execution_arn = "test_arn"
621+
mock_state.get_checkpoint_result.return_value = (
622+
CheckpointedResult.create_not_found()
623+
)
624+
625+
config = InvokeConfig[dict, dict](serdes_payload=None, serdes_result=None)
626+
payload = {"key": "value", "number": 42}
627+
628+
with pytest.raises(SuspendExecution):
629+
invoke_handler(
630+
function_name="test_function",
631+
payload=payload,
632+
state=mock_state,
633+
operation_identifier=OperationIdentifier("invoke_json", None, None),
634+
config=config,
635+
)
636+
637+
# Verify JSON serialization was used (not extended types)
638+
operation_update = mock_state.create_checkpoint.call_args[1]["operation_update"]
639+
assert operation_update.payload == json.dumps(payload)
640+
641+
642+
def test_invoke_handler_result_defaults_to_json_serdes():
643+
"""Test invoke_handler uses DEFAULT_JSON_SERDES for result deserialization."""
644+
mock_state = Mock(spec=ExecutionState)
645+
mock_state.durable_execution_arn = "test_arn"
646+
647+
result_data = {"key": "value", "number": 42}
648+
operation = Operation(
649+
operation_id="invoke_result_json",
650+
operation_type=OperationType.CHAINED_INVOKE,
651+
status=OperationStatus.SUCCEEDED,
652+
chained_invoke_details=ChainedInvokeDetails(result=json.dumps(result_data)),
653+
)
654+
mock_result = CheckpointedResult.create_from_operation(operation)
655+
mock_state.get_checkpoint_result.return_value = mock_result
656+
657+
config = InvokeConfig[dict, dict](serdes_payload=None, serdes_result=None)
658+
659+
result = invoke_handler(
660+
function_name="test_function",
661+
payload={"input": "data"},
662+
state=mock_state,
663+
operation_identifier=OperationIdentifier("invoke_result_json", None, None),
664+
config=config,
665+
)
666+
667+
# Verify JSON deserialization was used (not extended types)
668+
assert result == result_data

0 commit comments

Comments
 (0)