Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 85 additions & 75 deletions sqlalchemy_pydantic_orm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


The ORMBaseSchema is an extension of the Pydantic's BaseModel. It can use the
fields defined in it's own schema to create a SQLAlchemy model, it can do that
fields defined in its own schema to create a SQLAlchemy model, it can do that
by using a mandatory predefined link to a corresponding SQLAlchemy model.

References:
Expand Down Expand Up @@ -35,13 +35,13 @@ def __init__(self, **data: Any):
"""The init is used for validation and throwing errors where needed.

Pydantic catches all ValueError's in initialization, and then outputs
the error message in a easy to read format with the specific class
the error message in an easy-to-read format with the specific class
name displayed.
Every error is given the "sqlalchemy-pydantic-orm" identifier to
distinguish between Pydantic's or SQLAlchemy's own errors
and those of this package.

For performance its better to execute the `super().__init__()` as late
For performance, it's better to execute the `super().__init__()` as late
as possible, only the _orm_model check requires it to work properly.

Args:
Expand Down Expand Up @@ -88,21 +88,23 @@ def _orm_model(self) -> Type[DeclarativeMeta]:

This variable/property has a leading underscore and can only be
assigned as PrivateAttr (Pydantic). This is because a Pydantic schema
iterates over it's own fields and would otherwise cause problems when
iterates over its own fields and would otherwise cause problems when
encountering this variable/property.

Returns:
A SQLAlchemy model (indirectly) inherited from DeclarativeMeta
"""
pass

def orm_create(self, **extra_fields: Any) -> DeclarativeMeta:
def orm_create(self, db: Session, **extra_fields: Any) -> DeclarativeMeta:
"""Method to convert a (nested) pydantic schema to a SQLAlchemy model.

Using the validated fields in this class, together with the defined
_orm_model, this recursive methods creates a (nested) SQLAlchemy model.

Args:
db (Session):
Database session used for `.add()` and `.delete()`.
extra_fields (Any):
Extra fields (keyword arguments) not defined in the pydantic
schema used by the top level ORM model. The fields in the
Expand All @@ -122,38 +124,41 @@ def orm_create(self, **extra_fields: Any) -> DeclarativeMeta:
When a list is not fully consisted of other ORM schemas.
"""
current_level_fields = {}
for field in self.__fields_set__:
field_name = self.__fields__[field].alias
value = getattr(self, field)
if isinstance(value, ORMBaseSchema): # One-to-one
current_level_fields[field_name] = value.orm_create()

elif isinstance(value, SUPPORTED_ITERABLES): # One-to-many
models = []
for schema in value:
if not isinstance(schema, ORMBaseSchema):
raise TypeError(
"Lists should only contain other schemas "
f"inherited from '{ORMBaseSchema.__name__}' "
"(sqlalchemy-pydantic-orm)"
)
models.append(schema.orm_create())
current_level_fields[field_name] = models

else: # value without relation
current_level_fields[field_name] = value
for key, field in self.__fields__.items():
field_name = field.alias
value = getattr(self, key)
if value is not None:
# nullable value not provided must be passed.
if isinstance(value, ORMBaseSchema): # One-to-one
current_level_fields[field_name] = value.to_orm(db=db)
# current_level_fields[field_name] = value.orm_create()
elif isinstance(value, SUPPORTED_ITERABLES): # One-to-many
models = []
for schema in value:
if not isinstance(schema, ORMBaseSchema):
raise TypeError(
"Lists should only contain other schemas "
f"inherited from '{ORMBaseSchema.__name__}' "
"(sqlalchemy-pydantic-orm)"
)
# models.append(schema.orm_create())
models.append(schema.to_orm(db=db, **extra_fields))
current_level_fields[field_name] = models

else: # value without relation
current_level_fields[field_name] = value

return self._orm_model(**extra_fields, **current_level_fields)

def orm_update(self, db: Session, db_model: DeclarativeMeta) -> None:
"""Method to update a (nested) orm structure.

This method recursively updates an orm model with it's relationships.
This method recursively updates an orm model with its relationships.

In one-to-many relationships, each provided item without an id gets
added as new item with the `orm_create()` method. When a valid id is
provided it updates the item with the `orm_update()` method. It also
keep track of the parsed database items, and afterwards deletes any
keeps track of the parsed database items, and afterwards deletes any
unparsed item.

Args:
Expand All @@ -178,50 +183,51 @@ def orm_update(self, db: Session, db_model: DeclarativeMeta) -> None:
f"defined _orm_model '{self._orm_model.__name__}' "
"(sqlalchemy-pydantic-orm)"
)
for field in self.__fields_set__:
field_name = self.__fields__[field].alias
for key, field in self.__fields__.items():
field_name = field.alias
db_value = getattr(db_model, field_name)
update_value = getattr(self, field)
if isinstance(update_value, ORMBaseSchema): # One-to-one
if db_value:
update_value.orm_update(db, db_value)
else:
setattr(db_model, field_name, update_value.orm_create())

elif isinstance(update_value, SUPPORTED_ITERABLES): # One-to-many
parsed_items = set()
for schema in update_value:
if not isinstance(schema, ORMBaseSchema):
raise TypeError(
"Lists should only contain other schemas "
f"inherited from '{ORMBaseSchema.__name__}' "
"(sqlalchemy-pydantic-orm)"
)
if item_id := getattr(schema, "id", None):
try:
db_item = next(
item for item in db_value if item.id == item_id
)
except StopIteration:
raise ValueError(
f"Provided id '{item_id}' "
f"for field '{field_name}' "
"can't be found in the database "
"(sqlalchemy-pydantic-orm)"
) from None # removes unnecessary traceback

schema.orm_update(db, db_item)
parsed_items.add(db_item)
update_value = getattr(self, key)
if update_value is not None:
if isinstance(update_value, ORMBaseSchema): # One-to-one
if db_value:
update_value.orm_update(db, db_value)
else:
new_item = schema.orm_create()
parsed_items.add(new_item)
db_value.append(new_item)

for db_item in db_value:
if db_item not in parsed_items:
db.delete(db_item)
else:
setattr(db_model, field_name, update_value)
setattr(db_model, field_name, update_value.orm_create(db=db))

elif isinstance(update_value, SUPPORTED_ITERABLES): # One-to-many
parsed_items = set()
for schema in update_value:
if not isinstance(schema, ORMBaseSchema):
raise TypeError(
"Lists should only contain other schemas "
f"inherited from '{ORMBaseSchema.__name__}' "
"(sqlalchemy-pydantic-orm)"
)
if item_id := getattr(schema, "id", None):
try:
db_item = next(
item for item in db_value if item.id == item_id
)
except StopIteration:
raise ValueError(
f"Provided id '{item_id}' "
f"for field '{field_name}' "
"can't be found in the database "
"(sqlalchemy-pydantic-orm)"
) from None # removes unnecessary traceback

schema.orm_update(db, db_item)
parsed_items.add(db_item)
else:
new_item = schema.orm_create(db=db)
parsed_items.add(new_item)
db_value.append(new_item)

for db_item in db_value:
if db_item not in parsed_items:
db.delete(db_item)
else:
setattr(db_model, field_name, update_value)

def to_orm(self, db: Session, **extra_fields: Any) -> DeclarativeMeta:
"""Method that combines the functionality of orm_create & orm_update.
Expand All @@ -231,11 +237,11 @@ def to_orm(self, db: Session, **extra_fields: Any) -> DeclarativeMeta:
provided, it retrieves and updates that model.

In contrary to the orm_create function on its own, this function does
add the newly created model to the database. So after the this method
add the newly created model to the database. So after this method
has been executed you only need to call `db.commit()` after.

Args:
db (Session):
db (Session): Database session used for `.add()` and `.delete()`.
**extra_fields (Any):

Returns:
Expand All @@ -246,9 +252,7 @@ def to_orm(self, db: Session, **extra_fields: Any) -> DeclarativeMeta:
ValueError:
When the provided id is not found in the database
"""
id_ = getattr(self, "id", None)
if not id_ and "id" in extra_fields: # Pydantic field has priority
id_ = extra_fields["id"]
id_ = self.__detect_id(**extra_fields)
if id_:
db_model = db.query(self._orm_model).get(id_)
if not db_model:
Expand All @@ -260,7 +264,13 @@ def to_orm(self, db: Session, **extra_fields: Any) -> DeclarativeMeta:
)
self.orm_update(db, db_model)
else:
db_model = self.orm_create(**extra_fields)
db_model = self.orm_create(db=db, **extra_fields)
db.add(db_model)

return db_model

def __detect_id(self, **extra_fields: Any):
id_ = getattr(self, "id", None)
if not id_ and "id" in extra_fields: # Pydantic field has priority
id_ = extra_fields["id"]
return id_
2 changes: 1 addition & 1 deletion tests/test_orm_specific_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def test_orm_create() -> None:
schema_in = PydanticParent.parse_obj(orm_create_input_data)
db_model = schema_in.orm_create()
db_model = schema_in.orm_create(db)
db.add(db_model)
db.commit()
db.refresh(db_model)
Expand Down