Skip to content

Commit a48197c

Browse files
committed
Update test code coverage
1 parent ad8dec6 commit a48197c

File tree

13 files changed

+301
-227
lines changed

13 files changed

+301
-227
lines changed

src/sqlalchemyseed/class_registry.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from inspect import isclass
2828
from sqlalchemy import inspect
2929
from . import errors
30+
from . import util
3031

3132

3233
def parse_class_path(class_path: str):
@@ -41,14 +42,10 @@ def parse_class_path(class_path: str):
4142
except AttributeError:
4243
raise errors.NotInModuleError(f"{class_name} is not found in module {module_name}.")
4344

44-
try:
45-
if isclass(class_) and inspect(class_):
46-
return class_
47-
48-
raise errors.NotClassError("'{}' is not a class".format(class_name))
49-
except NoInspectionAvailable:
50-
raise errors.UnsupportedClassError(
51-
"'{}' is an unsupported class".format(class_name))
45+
if util.is_supported_class(class_):
46+
return class_
47+
else:
48+
raise errors.UnsupportedClassError("'{}' is an unsupported class".format(class_name))
5249

5350

5451
class ClassRegistry:

src/sqlalchemyseed/errors.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,17 @@ class InvalidKeyError(Exception):
2727
"""Raised when an invalid key is invoked"""
2828
pass
2929

30+
3031
class ParseError(Exception):
3132
"""Raised when parsing string fails"""
3233
pass
3334

34-
class NotClassError(Exception):
35-
"""Raised when a value is not a class"""
36-
pass
3735

3836
class UnsupportedClassError(Exception):
3937
"""Raised when an unsupported class is invoked"""
4038
pass
4139

40+
4241
class NotInModuleError(Exception):
4342
"""Raised when a value is not found in module"""
44-
pass
43+
pass

src/sqlalchemyseed/loader.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,9 @@
2626
import json
2727
import sys
2828

29-
from . import validator
30-
3129
try:
3230
import yaml
33-
except ModuleNotFoundError: # pragma: no cover
31+
except ModuleNotFoundError: # pragma: no cover
3432
pass
3533

3634

@@ -41,8 +39,6 @@ def load_entities_from_json(json_filepath):
4139
except FileNotFoundError as error:
4240
raise FileNotFoundError(error)
4341

44-
validator.SchemaValidator.validate(entities)
45-
4642
return entities
4743

4844

@@ -59,8 +55,6 @@ def load_entities_from_yaml(yaml_filepath):
5955
except FileNotFoundError as error:
6056
raise FileNotFoundError(error)
6157

62-
validator.SchemaValidator.validate(entities)
63-
6458
return entities
6559

6660

@@ -80,6 +74,4 @@ def load_entities_from_csv(csv_filepath: str, model) -> dict:
8074

8175
entities = {'model': model_name, 'data': source_data}
8276

83-
validator.SchemaValidator.validate(entities)
84-
8577
return entities

src/sqlalchemyseed/seeder.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -88,23 +88,16 @@ def is_relationship_attribute(self):
8888

8989
@property
9090
def referenced_class(self):
91-
# if self.is_column_attribute():
92-
# return
9391
if self.is_relationship_attribute():
9492
return self.class_attribute.mapper.class_
9593

96-
if self.is_column_attribute():
97-
table_name = get_foreign_key_column(self.class_attribute).table.name
98-
return next(
99-
(
100-
mapper.class_
101-
for mapper in object_mapper(self.instance).registry.mappers
102-
if mapper.class_.__tablename__ == table_name
103-
),
104-
errors.ClassNotFoundError(
105-
"A class with table name '{}' is not found in the mappers".format(table_name)
106-
)
107-
)
94+
# if self.is_column_attribute():
95+
table_name = get_foreign_key_column(self.class_attribute).table.name
96+
97+
return next(filter(
98+
lambda mapper: mapper.class_.__tablename__ == table_name,
99+
object_mapper(self.instance).registry.mappers
100+
)).class_
108101

109102

110103
def get_foreign_key_column(attr, idx=0) -> schema.Column:
@@ -125,7 +118,7 @@ def set_parent_attr_value(instance, parent: Entity):
125118
else:
126119
parent.instance_attribute = instance
127120

128-
if parent.is_column_attribute():
121+
else: # if parent.is_column_attribute():
129122
parent.instance_attribute = instance
130123

131124

@@ -150,8 +143,7 @@ def get_model_class(self, entity, parent: Entity):
150143
return parent.referenced_class
151144

152145
def seed(self, entities, add_to_session=True):
153-
validator.SchemaValidator.validate(
154-
entities, ref_prefix=self.ref_prefix, source_keys=[validator.Key.data()])
146+
validator.validate(entities=entities, ref_prefix=self.ref_prefix)
155147

156148
self._instances.clear()
157149
self._class_registry.clear()
@@ -238,8 +230,7 @@ def get_model_class(self, entity, parent: Entity):
238230
return parent.referenced_class
239231

240232
def seed(self, entities):
241-
validator.SchemaValidator.validate(
242-
entities, ref_prefix=self.ref_prefix)
233+
validator.hybrid_validate(entities=entities, ref_prefix=self.ref_prefix)
243234

244235
self._instances.clear()
245236
self._class_registry.clear()
@@ -256,17 +247,17 @@ def _pre_seed(self, entity, parent=None):
256247
def _seed(self, entity, parent):
257248
class_ = self.get_model_class(entity, parent)
258249

259-
source_key: validator.Key = next(
260-
(sk for sk in self.__source_keys if sk in entity),
261-
None
250+
source_key = next(
251+
filter(lambda sk: sk in entity, self.__source_keys)
262252
)
263253

264254
source_data = entity[source_key]
265255

266256
# source_data is list
267257
if isinstance(source_data, list):
268258
for kwargs in source_data:
269-
instance = self._setup_instance(class_, kwargs, source_key, parent)
259+
instance = self._setup_instance(
260+
class_, kwargs, source_key, parent)
270261
self._seed_children(instance, kwargs)
271262
return
272263

@@ -295,11 +286,13 @@ def _setup_instance(self, class_, kwargs: dict, key, parent):
295286

296287
def _setup_data_instance(self, class_, filtered_kwargs, parent: Entity):
297288
if parent is not None and parent.is_column_attribute():
298-
raise errors.InvalidKeyError("'data' key is invalid for a column attribute.")
289+
raise errors.InvalidKeyError(
290+
"'data' key is invalid for a column attribute.")
299291

300292
instance = class_(**filtered_kwargs)
301-
self.session.add(instance)
293+
302294
if parent is None:
295+
self.session.add(instance)
303296
self._instances.append(instance)
304297

305298
return instance

src/sqlalchemyseed/util.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
from inspect import isclass
2+
3+
from sqlalchemy import inspect
4+
5+
16
def iter_ref_kwargs(kwargs: dict, ref_prefix: str):
27
"""Iterate kwargs with name prefix or references"""
38
for attr_name, value in kwargs.items():
@@ -11,3 +16,7 @@ def iter_non_ref_kwargs(kwargs: dict, ref_prefix: str):
1116
for attr_name, value in kwargs.items():
1217
if not attr_name.startswith(ref_prefix):
1318
yield attr_name, value
19+
20+
21+
def is_supported_class(class_):
22+
return True if isclass(class_) and inspect(class_, raiseerr=False) else False

src/sqlalchemyseed/validator.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
SOFTWARE.
2323
"""
2424

25+
import abc
2526
from . import errors, util
2627

2728

@@ -83,17 +84,20 @@ def check_source_key(entity: dict, source_keys: list) -> Key:
8384

8485
# check if current keys has at least, data or filter key
8586
if source_key is None:
86-
raise errors.MissingKeyError(f"Missing {', '.join(map(str, source_keys))} key(s).")
87+
raise errors.MissingKeyError(
88+
f"Missing {', '.join(map(str, source_keys))} key(s).")
8789

8890
return source_key
8991

9092

9193
def check_source_data(source_data, source_key: Key):
9294
if not isinstance(source_data, dict) and not isinstance(source_data, list):
93-
raise errors.InvalidTypeError(f"Invalid type_, {str(source_key)} should be either 'dict' or 'list'.")
95+
raise errors.InvalidTypeError(
96+
f"Invalid type_, {str(source_key)} should be either 'dict' or 'list'.")
9497

9598
if isinstance(source_data, list) and len(source_data) == 0:
96-
raise errors.EmptyDataError("Empty list, 'data' or 'filter' list should not be empty.")
99+
raise errors.EmptyDataError(
100+
"Empty list, 'data' or 'filter' list should not be empty.")
97101

98102

99103
def check_data_type(item, source_key: Key):
@@ -102,37 +106,38 @@ def check_data_type(item, source_key: Key):
102106
f"Invalid type_, '{source_key.name}' should be '{source_key.type_}'")
103107

104108

105-
class SchemaValidator:
106-
_source_keys = None
107-
_ref_prefix = None
109+
class SchemaValidator(abc.ABC):
110+
111+
def __init__(self, source_keys, ref_prefix):
112+
self._source_keys = source_keys
113+
self._ref_prefix = ref_prefix
108114

109115
@classmethod
110-
def validate(cls, entities, ref_prefix='!', source_keys=None):
111-
if source_keys is None:
112-
cls._source_keys = [Key.data(), Key.filter()]
113-
cls._ref_prefix = ref_prefix
116+
def validate(cls, entities, source_keys, ref_prefix='!'):
117+
self = cls(source_keys, ref_prefix)
118+
self._source_keys = source_keys
119+
self._ref_prefix = ref_prefix
114120

115-
cls._pre_validate(entities, entity_is_parent=True)
121+
self._pre_validate(entities, entity_is_parent=True)
116122

117-
@classmethod
118-
def _pre_validate(cls, entities: dict, entity_is_parent=True):
123+
def _pre_validate(self, entities: dict, entity_is_parent=True):
119124
if not isinstance(entities, dict) and not isinstance(entities, list):
120-
raise errors.InvalidTypeError("Invalid type, should be list or dict")
125+
raise errors.InvalidTypeError(
126+
"Invalid type, should be list or dict")
121127
if len(entities) == 0:
122128
return
123129
if isinstance(entities, dict):
124-
return cls._validate(entities, entity_is_parent)
130+
return self._validate(entities, entity_is_parent)
125131
# iterate list
126132
for entity in entities:
127-
cls._pre_validate(entity, entity_is_parent)
133+
self._pre_validate(entity, entity_is_parent)
128134

129-
@classmethod
130-
def _validate(cls, entity: dict, entity_is_parent=True):
135+
def _validate(self, entity: dict, entity_is_parent=True):
131136
check_max_length(entity)
132137
check_model_key(entity, entity_is_parent)
133138

134139
# get source key, either data or filter key
135-
source_key = check_source_key(entity, cls._source_keys)
140+
source_key = check_source_key(entity, self._source_keys)
136141
source_data = entity[source_key]
137142

138143
check_source_data(source_data, source_key)
@@ -141,13 +146,23 @@ def _validate(cls, entity: dict, entity_is_parent=True):
141146
for item in source_data:
142147
check_data_type(item, source_key)
143148
# check if item is a relationship attribute
144-
cls.check_attributes(item)
149+
self.check_attributes(item)
145150
else:
146151
# source_data is dict
147152
# check if item is a relationship attribute
148-
cls.check_attributes(source_data)
153+
self.check_attributes(source_data)
149154

150-
@classmethod
151-
def check_attributes(cls, source_data: dict):
152-
for _, value in util.iter_ref_kwargs(source_data, cls._ref_prefix):
153-
cls._pre_validate(value, entity_is_parent=False)
155+
def check_attributes(self, source_data: dict):
156+
for _, value in util.iter_ref_kwargs(source_data, self._ref_prefix):
157+
self._pre_validate(value, entity_is_parent=False)
158+
159+
160+
def validate(entities, ref_prefix='!'):
161+
SchemaValidator.validate(
162+
entities, ref_prefix=ref_prefix, source_keys=[Key.data()])
163+
164+
165+
def hybrid_validate(entities, ref_prefix='!'):
166+
SchemaValidator.validate(entities,
167+
ref_prefix=ref_prefix,
168+
source_keys=[Key.data(), Key.filter()])

0 commit comments

Comments
 (0)