diff --git a/sqlalchemy_pydantic_orm/main.py b/sqlalchemy_pydantic_orm/main.py index cd17da6..76c71ba 100644 --- a/sqlalchemy_pydantic_orm/main.py +++ b/sqlalchemy_pydantic_orm/main.py @@ -6,7 +6,7 @@ The ORMBaseSchema is an extension of the Pydantic's BaseModel. It can use the -fields defined in it's own schema to create a SQLAlchemy model, it can do that +fields defined in its own schema to create a SQLAlchemy model, it can do that by using a mandatory predefined link to a corresponding SQLAlchemy model. References: @@ -35,13 +35,13 @@ def __init__(self, **data: Any): """The init is used for validation and throwing errors where needed. Pydantic catches all ValueError's in initialization, and then outputs - the error message in a easy to read format with the specific class + the error message in an easy-to-read format with the specific class name displayed. Every error is given the "sqlalchemy-pydantic-orm" identifier to distinguish between Pydantic's or SQLAlchemy's own errors and those of this package. - For performance its better to execute the `super().__init__()` as late + For performance, it's better to execute the `super().__init__()` as late as possible, only the _orm_model check requires it to work properly. Args: @@ -88,7 +88,7 @@ def _orm_model(self) -> Type[DeclarativeMeta]: This variable/property has a leading underscore and can only be assigned as PrivateAttr (Pydantic). This is because a Pydantic schema - iterates over it's own fields and would otherwise cause problems when + iterates over its own fields and would otherwise cause problems when encountering this variable/property. Returns: @@ -96,13 +96,15 @@ def _orm_model(self) -> Type[DeclarativeMeta]: """ pass - def orm_create(self, **extra_fields: Any) -> DeclarativeMeta: + def orm_create(self, db: Session, **extra_fields: Any) -> DeclarativeMeta: """Method to convert a (nested) pydantic schema to a SQLAlchemy model. Using the validated fields in this class, together with the defined _orm_model, this recursive methods creates a (nested) SQLAlchemy model. Args: + db (Session): + Database session used for `.add()` and `.delete()`. extra_fields (Any): Extra fields (keyword arguments) not defined in the pydantic schema used by the top level ORM model. The fields in the @@ -122,38 +124,41 @@ def orm_create(self, **extra_fields: Any) -> DeclarativeMeta: When a list is not fully consisted of other ORM schemas. """ current_level_fields = {} - for field in self.__fields_set__: - field_name = self.__fields__[field].alias - value = getattr(self, field) - if isinstance(value, ORMBaseSchema): # One-to-one - current_level_fields[field_name] = value.orm_create() - - elif isinstance(value, SUPPORTED_ITERABLES): # One-to-many - models = [] - for schema in value: - if not isinstance(schema, ORMBaseSchema): - raise TypeError( - "Lists should only contain other schemas " - f"inherited from '{ORMBaseSchema.__name__}' " - "(sqlalchemy-pydantic-orm)" - ) - models.append(schema.orm_create()) - current_level_fields[field_name] = models - - else: # value without relation - current_level_fields[field_name] = value + for key, field in self.__fields__.items(): + field_name = field.alias + value = getattr(self, key) + if value is not None: + # nullable value not provided must be passed. + if isinstance(value, ORMBaseSchema): # One-to-one + current_level_fields[field_name] = value.to_orm(db=db) + # current_level_fields[field_name] = value.orm_create() + elif isinstance(value, SUPPORTED_ITERABLES): # One-to-many + models = [] + for schema in value: + if not isinstance(schema, ORMBaseSchema): + raise TypeError( + "Lists should only contain other schemas " + f"inherited from '{ORMBaseSchema.__name__}' " + "(sqlalchemy-pydantic-orm)" + ) + # models.append(schema.orm_create()) + models.append(schema.to_orm(db=db, **extra_fields)) + current_level_fields[field_name] = models + + else: # value without relation + current_level_fields[field_name] = value return self._orm_model(**extra_fields, **current_level_fields) def orm_update(self, db: Session, db_model: DeclarativeMeta) -> None: """Method to update a (nested) orm structure. - This method recursively updates an orm model with it's relationships. + This method recursively updates an orm model with its relationships. In one-to-many relationships, each provided item without an id gets added as new item with the `orm_create()` method. When a valid id is provided it updates the item with the `orm_update()` method. It also - keep track of the parsed database items, and afterwards deletes any + keeps track of the parsed database items, and afterwards deletes any unparsed item. Args: @@ -178,50 +183,51 @@ def orm_update(self, db: Session, db_model: DeclarativeMeta) -> None: f"defined _orm_model '{self._orm_model.__name__}' " "(sqlalchemy-pydantic-orm)" ) - for field in self.__fields_set__: - field_name = self.__fields__[field].alias + for key, field in self.__fields__.items(): + field_name = field.alias db_value = getattr(db_model, field_name) - update_value = getattr(self, field) - if isinstance(update_value, ORMBaseSchema): # One-to-one - if db_value: - update_value.orm_update(db, db_value) - else: - setattr(db_model, field_name, update_value.orm_create()) - - elif isinstance(update_value, SUPPORTED_ITERABLES): # One-to-many - parsed_items = set() - for schema in update_value: - if not isinstance(schema, ORMBaseSchema): - raise TypeError( - "Lists should only contain other schemas " - f"inherited from '{ORMBaseSchema.__name__}' " - "(sqlalchemy-pydantic-orm)" - ) - if item_id := getattr(schema, "id", None): - try: - db_item = next( - item for item in db_value if item.id == item_id - ) - except StopIteration: - raise ValueError( - f"Provided id '{item_id}' " - f"for field '{field_name}' " - "can't be found in the database " - "(sqlalchemy-pydantic-orm)" - ) from None # removes unnecessary traceback - - schema.orm_update(db, db_item) - parsed_items.add(db_item) + update_value = getattr(self, key) + if update_value is not None: + if isinstance(update_value, ORMBaseSchema): # One-to-one + if db_value: + update_value.orm_update(db, db_value) else: - new_item = schema.orm_create() - parsed_items.add(new_item) - db_value.append(new_item) - - for db_item in db_value: - if db_item not in parsed_items: - db.delete(db_item) - else: - setattr(db_model, field_name, update_value) + setattr(db_model, field_name, update_value.orm_create(db=db)) + + elif isinstance(update_value, SUPPORTED_ITERABLES): # One-to-many + parsed_items = set() + for schema in update_value: + if not isinstance(schema, ORMBaseSchema): + raise TypeError( + "Lists should only contain other schemas " + f"inherited from '{ORMBaseSchema.__name__}' " + "(sqlalchemy-pydantic-orm)" + ) + if item_id := getattr(schema, "id", None): + try: + db_item = next( + item for item in db_value if item.id == item_id + ) + except StopIteration: + raise ValueError( + f"Provided id '{item_id}' " + f"for field '{field_name}' " + "can't be found in the database " + "(sqlalchemy-pydantic-orm)" + ) from None # removes unnecessary traceback + + schema.orm_update(db, db_item) + parsed_items.add(db_item) + else: + new_item = schema.orm_create(db=db) + parsed_items.add(new_item) + db_value.append(new_item) + + for db_item in db_value: + if db_item not in parsed_items: + db.delete(db_item) + else: + setattr(db_model, field_name, update_value) def to_orm(self, db: Session, **extra_fields: Any) -> DeclarativeMeta: """Method that combines the functionality of orm_create & orm_update. @@ -231,11 +237,11 @@ def to_orm(self, db: Session, **extra_fields: Any) -> DeclarativeMeta: provided, it retrieves and updates that model. In contrary to the orm_create function on its own, this function does - add the newly created model to the database. So after the this method + add the newly created model to the database. So after this method has been executed you only need to call `db.commit()` after. Args: - db (Session): + db (Session): Database session used for `.add()` and `.delete()`. **extra_fields (Any): Returns: @@ -246,9 +252,7 @@ def to_orm(self, db: Session, **extra_fields: Any) -> DeclarativeMeta: ValueError: When the provided id is not found in the database """ - id_ = getattr(self, "id", None) - if not id_ and "id" in extra_fields: # Pydantic field has priority - id_ = extra_fields["id"] + id_ = self.__detect_id(**extra_fields) if id_: db_model = db.query(self._orm_model).get(id_) if not db_model: @@ -260,7 +264,13 @@ def to_orm(self, db: Session, **extra_fields: Any) -> DeclarativeMeta: ) self.orm_update(db, db_model) else: - db_model = self.orm_create(**extra_fields) + db_model = self.orm_create(db=db, **extra_fields) db.add(db_model) return db_model + + def __detect_id(self, **extra_fields: Any): + id_ = getattr(self, "id", None) + if not id_ and "id" in extra_fields: # Pydantic field has priority + id_ = extra_fields["id"] + return id_ diff --git a/tests/test_orm_specific_methods.py b/tests/test_orm_specific_methods.py index 6343234..434006e 100644 --- a/tests/test_orm_specific_methods.py +++ b/tests/test_orm_specific_methods.py @@ -20,7 +20,7 @@ def test_orm_create() -> None: schema_in = PydanticParent.parse_obj(orm_create_input_data) - db_model = schema_in.orm_create() + db_model = schema_in.orm_create(db) db.add(db_model) db.commit() db.refresh(db_model)