diff --git a/core_utils/serialization.py b/core_utils/serialization.py index 88c9a33..e7136ec 100644 --- a/core_utils/serialization.py +++ b/core_utils/serialization.py @@ -1,5 +1,6 @@ import traceback from enum import Enum +from numbers import Number from typing import ( Any, Iterable, @@ -213,6 +214,7 @@ def deserialize( return type_value[value] # type: ignore else: + # Check to see that the expected value is present & has the expected type, # iff the expected type is one of JSON's value types. if ( @@ -220,8 +222,17 @@ def deserialize( or (any(map(lambda t: t == type_value, (int, float, str, bool)))) and not isinstance(value, checking_type_value) ): + fail = False + # exception: encoded a number into a string; attempt trivial decoding to avoid failure + # JSON will encode numbers that are dictionary keys as strings + # so that it can make them compliant as objects. + if issubclass(checking_type_value, Number) and isinstance(value, str): + try: + value = type_value(value.strip()) + except ValueError: + fail = True # numeric check: some ints can be a float - if ( + elif ( float == type_value and isinstance(value, int) and int(float(value)) == value @@ -233,6 +244,9 @@ def deserialize( # but in general we just identified a value that # didn't deserialize to its expected type else: + fail = True + + if fail: raise FieldDeserializeFail( field_name="", expected_type=type_value, actual_value=value ) diff --git a/tests/test_custom_serialization.py b/tests/test_custom_serialization.py index 56565bf..46e1813 100644 --- a/tests/test_custom_serialization.py +++ b/tests/test_custom_serialization.py @@ -1,6 +1,6 @@ import json from dataclasses import dataclass -from typing import Tuple, NamedTuple, Mapping, Sequence +from typing import Tuple, NamedTuple, Mapping, Sequence, Dict, List import numpy as np import torch @@ -149,3 +149,23 @@ def check(*, actual, expected): _roundtrip(mnt, custom_serialize, custom_deserialize, check) _roundtrip(mdc, custom_serialize, custom_deserialize, check) + + +def test_nested_array_dict_int_keys(custom_serialize, custom_deserialize): + N = 4 + M = 3 + + def check(*, actual, expected): + assert isinstance(actual, type(expected)) + assert len(actual) == N + for xs in actual: + assert len(xs) == M + for i, arr in xs.items(): + assert isinstance(i, int) + _check_array_like(actual=arr, expected=np.ones(i)) + + m: List[List[Dict[int, np.ndarray]]] = [ + [{i: np.ones(i)} for i in range(M)] for _ in range(N) + ] + + _roundtrip(m, custom_serialize, custom_deserialize, check) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index dfc0880..9ca297b 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -3,7 +3,18 @@ from collections import namedtuple from dataclasses import dataclass from enum import Enum, IntEnum -from typing import NamedTuple, Sequence, Optional, Mapping, Set, Tuple, Union, Any +from typing import ( + NamedTuple, + Sequence, + Optional, + Mapping, + Set, + Tuple, + Union, + Any, + List, + Dict, +) from pytest import raises, fixture import yaml @@ -580,3 +591,21 @@ def test_serialized_nested_defaults_advanced(): s = serialize(nested) assert nested == deserialize(NestedDefaultsMixed, s) + + +def test_serialize_dict_with_numeric_keys(): + d1: Dict[int, List[str]] = { + i: [x for x in "hello world! how are you today?"] for i in range(10) + } + s1 = serialize(d1) + assert deserialize(Dict[int, List[str]], s1) == d1 + + j1 = json.dumps(s1) + assert deserialize(Dict[int, List[str]], json.loads(j1)) == d1 + + d2: Dict[float, List[str]] = {float(i): xs for i, xs in d1.items()} + s2 = serialize(d2) + assert deserialize(Dict[int, List[str]], s2) == d2 + + j2 = json.dumps(s2) + assert deserialize(Dict[int, List[str]], json.loads(j2)) == d2