Skip to content

Commit 44a2de4

Browse files
authored
bug(medcat): CU-869b9dx49 Fix pydantic model serialisation (#242)
* CU-869b9dx49: Add test for config serialisation * CU-869b9dx49: Fix issue with pydantic model serialisation when there are extra attributes * CU-869b9dx49: Fix typing issue
1 parent 6923838 commit 44a2de4

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

medcat-v2/medcat/storage/serialisables.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Any, Union, Protocol, runtime_checkable, Iterable
22
from enum import Enum, auto
33

4+
from pydantic import BaseModel
5+
46

57
class SerialisingStrategy(Enum):
68
"""Describes the strategy for serialising."""
@@ -50,6 +52,10 @@ def _iter_obj_items(self, obj: 'Serialisable'
5052
# ignore privates
5153
continue
5254
yield attr_name, attr
55+
# deal with extras in pydantic models
56+
if isinstance(obj, BaseModel) and obj.__pydantic_extra__:
57+
for attr_name, attr in obj.__pydantic_extra__.items():
58+
yield attr_name, attr
5359

5460
def _iter_obj_values(self, obj: 'Serialisable') -> Iterable[Any]:
5561
for _, val in self._iter_obj_items(obj):
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from medcat.config import Ner
2+
from medcat.storage.serialisers import serialise, deserialise
3+
from medcat.storage.serialisers import AvailableSerialisers
4+
5+
import os
6+
import tempfile
7+
8+
import unittest
9+
10+
11+
class SaveWithExtraTests(unittest.TestCase):
12+
EXTRA_KEY = "some_extra_key"
13+
EXTRA_VAL = {"val": 1, "f": ''}
14+
15+
def setUp(self):
16+
self.base = Ner()
17+
self.base.some_extra_key = self.EXTRA_VAL
18+
self.temp_dir = tempfile.TemporaryDirectory()
19+
20+
def do_save(self) -> tuple[str, str]:
21+
"""Do the save and return folder path and raw dict path.
22+
23+
Returns:
24+
tuple[str, str]: The folder and the path to raw dict.
25+
"""
26+
serialise(AvailableSerialisers.dill, self.base, self.temp_dir.name)
27+
return self.temp_dir.name, os.path.join(self.temp_dir.name,
28+
"raw_dict.dat")
29+
30+
def tearDown(self):
31+
self.temp_dir.cleanup()
32+
33+
def test_value_is_set(self):
34+
self.assertTrue(hasattr(self.base, self.EXTRA_KEY))
35+
self.assertIs(getattr(self.base, self.EXTRA_KEY), self.EXTRA_VAL)
36+
37+
def test_can_save_and_load_obj(self):
38+
folder, _ = self.do_save()
39+
other = deserialise(folder)
40+
self.assertIsInstance(other, type(self.base))
41+
self.assertEqual(other, self.base)
42+
43+
def test_loaded_has_extra_key(self):
44+
folder, _ = self.do_save()
45+
other = deserialise(folder)
46+
self.assertTrue(hasattr(other, self.EXTRA_KEY))
47+
self.assertEqual(getattr(other, self.EXTRA_KEY), self.EXTRA_VAL)

0 commit comments

Comments
 (0)