Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion core_utils/serialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import traceback
from enum import Enum
from numbers import Number
from typing import (
Any,
Iterable,
Expand Down Expand Up @@ -213,15 +214,25 @@ def deserialize(
return type_value[value] # type: ignore

else:

# Check to see that the expected value is present & has the expected type,
# iff the expected type is one of JSON's value types.
if (
(value is None and not _is_optional(checking_type_value))
or (any(map(lambda t: t == type_value, (int, float, str, bool))))
and not isinstance(value, checking_type_value)
):
fail = False
# exception: encoded a number into a string; attempt trivial decoding to avoid failure
# JSON will encode numbers that are dictionary keys as strings
# so that it can make them compliant as objects.
if issubclass(checking_type_value, Number) and isinstance(value, str):
try:
value = type_value(value.strip())
except ValueError:
fail = True
# numeric check: some ints can be a float
if (
elif (
float == type_value
and isinstance(value, int)
and int(float(value)) == value
Expand All @@ -233,6 +244,9 @@ def deserialize(
# but in general we just identified a value that
# didn't deserialize to its expected type
else:
fail = True

if fail:
raise FieldDeserializeFail(
field_name="", expected_type=type_value, actual_value=value
)
Expand Down
22 changes: 21 additions & 1 deletion tests/test_custom_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from dataclasses import dataclass
from typing import Tuple, NamedTuple, Mapping, Sequence
from typing import Tuple, NamedTuple, Mapping, Sequence, Dict, List

import numpy as np
import torch
Expand Down Expand Up @@ -149,3 +149,23 @@ def check(*, actual, expected):

_roundtrip(mnt, custom_serialize, custom_deserialize, check)
_roundtrip(mdc, custom_serialize, custom_deserialize, check)


def test_nested_array_dict_int_keys(custom_serialize, custom_deserialize):
N = 4
M = 3

def check(*, actual, expected):
assert isinstance(actual, type(expected))
assert len(actual) == N
for xs in actual:
assert len(xs) == M
for i, arr in xs.items():
assert isinstance(i, int)
_check_array_like(actual=arr, expected=np.ones(i))

m: List[List[Dict[int, np.ndarray]]] = [
[{i: np.ones(i)} for i in range(M)] for _ in range(N)
]

_roundtrip(m, custom_serialize, custom_deserialize, check)
31 changes: 30 additions & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@
from collections import namedtuple
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import NamedTuple, Sequence, Optional, Mapping, Set, Tuple, Union, Any
from typing import (
NamedTuple,
Sequence,
Optional,
Mapping,
Set,
Tuple,
Union,
Any,
List,
Dict,
)

from pytest import raises, fixture
import yaml
Expand Down Expand Up @@ -580,3 +591,21 @@ def test_serialized_nested_defaults_advanced():

s = serialize(nested)
assert nested == deserialize(NestedDefaultsMixed, s)


def test_serialize_dict_with_numeric_keys():
d1: Dict[int, List[str]] = {
i: [x for x in "hello world! how are you today?"] for i in range(10)
}
s1 = serialize(d1)
assert deserialize(Dict[int, List[str]], s1) == d1

j1 = json.dumps(s1)
assert deserialize(Dict[int, List[str]], json.loads(j1)) == d1

d2: Dict[float, List[str]] = {float(i): xs for i, xs in d1.items()}
s2 = serialize(d2)
assert deserialize(Dict[int, List[str]], s2) == d2

j2 = json.dumps(s2)
assert deserialize(Dict[int, List[str]], json.loads(j2)) == d2