diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e7fb0f..212c829 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## [Unreleased] +## [0.11.0] - 2025-06-06 + +### Added + +- Support for transactions, a big thanks to `@lukwam` for this! For more details of how + to use this, see the new section in the README. + ## [0.10.0] - 2025-02-26 ### Added @@ -283,7 +290,8 @@ this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm - Update README.md - Update .gitignore -[unreleased]: https://github.com/ioxiocom/firedantic/compare/0.10.0...HEAD +[unreleased]: https://github.com/ioxiocom/firedantic/compare/0.11.0...HEAD +[0.11.0]: https://github.com/ioxiocom/firedantic/compare/0.10.0...0.11.0 [0.10.0]: https://github.com/ioxiocom/firedantic/compare/0.9.0...0.10.0 [0.9.0]: https://github.com/ioxiocom/firedantic/compare/0.8.1...0.9.0 [0.8.1]: https://github.com/ioxiocom/firedantic/compare/0.8.0...0.8.1 diff --git a/README.md b/README.md index feda0fd..e80d025 100644 --- a/README.md +++ b/README.md @@ -320,6 +320,157 @@ if __name__ == "__main__": asyncio.run(main()) ``` +## Transactions + +Firedantic has basic support for +[Firestore Transactions](https://firebase.google.com/docs/firestore/manage-data/transactions). +The following methods can be used in a transaction: + +- `Model.delete(transaction=transaction)` +- `Model.find_one(transaction=transaction)` +- `Model.find(transaction=transaction)` +- `Model.get_by_doc_id(transaction=transaction)` +- `Model.get_by_id(transaction=transaction)` +- `Model.reload(transaction=transaction)` +- `Model.save(transaction=transaction)` +- `SubModel.get_by_id(transaction=transaction)` + +When using transactions, note that read operations must come before write operations. + +### Transaction examples + +In this example (async and sync version of it below), we are updating a `City` to +increment and decrement the population of it, both using an instance method and a +standalone function. Please note that the `@async_transactional` and `@transactional` +decorators always expect the first argument of the wrapped function to be `transaction`; +i.e. you can not directly wrap an instance method that has `self` as the first argument +or a class method that has `cls` as the first argument. + +#### Transaction example (async) + +```python +import asyncio +from os import environ +from unittest.mock import Mock + +import google.auth.credentials +from google.cloud.firestore import AsyncClient +from google.cloud.firestore_v1 import async_transactional, AsyncTransaction + +from firedantic import AsyncModel, configure, get_async_transaction + +# Firestore emulator must be running if using locally. +if environ.get("FIRESTORE_EMULATOR_HOST"): + client = AsyncClient( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) +else: + client = AsyncClient() + +configure(client, prefix="firedantic-test-") + + +class City(AsyncModel): + __collection__ = "cities" + population: int + + async def increment_population(self, increment: int = 1): + @async_transactional + async def _increment_population(transaction: AsyncTransaction) -> None: + await self.reload(transaction=transaction) + self.population += increment + await self.save(transaction=transaction) + + t = get_async_transaction() + await _increment_population(transaction=t) + + +async def main(): + @async_transactional + async def decrement_population( + transaction: AsyncTransaction, city: City, decrement: int = 1 + ): + await city.reload(transaction=transaction) + city.population = max(0, city.population - decrement) + await city.save(transaction=transaction) + + c = City(id="SF", population=1) + await c.save() + await c.increment_population(increment=1) + assert c.population == 2 + + t = get_async_transaction() + await decrement_population(transaction=t, city=c, decrement=5) + assert c.population == 0 + + +if __name__ == "__main__": + asyncio.run(main()) +``` + +#### Transaction example (sync) + +```python +from os import environ +from unittest.mock import Mock + +import google.auth.credentials +from google.cloud.firestore import Client +from google.cloud.firestore_v1 import transactional, Transaction + +from firedantic import Model, configure, get_transaction + +# Firestore emulator must be running if using locally. +if environ.get("FIRESTORE_EMULATOR_HOST"): + client = Client( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) +else: + client = Client() + +configure(client, prefix="firedantic-test-") + + +class City(Model): + __collection__ = "cities" + population: int + + def increment_population(self, increment: int = 1): + @transactional + def _increment_population(transaction: Transaction) -> None: + self.reload(transaction=transaction) + self.population += increment + self.save(transaction=transaction) + + t = get_transaction() + _increment_population(transaction=t) + + +def main(): + @transactional + def decrement_population( + transaction: Transaction, city: City, decrement: int = 1 + ): + city.reload(transaction=transaction) + city.population = max(0, city.population - decrement) + city.save(transaction=transaction) + + c = City(id="SF", population=1) + c.save() + c.increment_population(increment=1) + assert c.population == 2 + + t = get_transaction() + decrement_population(transaction=t, city=c, decrement=5) + assert c.population == 0 + + +if __name__ == "__main__": + main() +``` + ## Development PRs are welcome! diff --git a/firedantic/__init__.py b/firedantic/__init__.py index 7b0e16b..e4f79d2 100644 --- a/firedantic/__init__.py +++ b/firedantic/__init__.py @@ -32,6 +32,11 @@ ) from firedantic._sync.ttl_policy import set_up_ttl_policies from firedantic.common import collection_group_index, collection_index -from firedantic.configurations import CONFIGURATIONS, configure +from firedantic.configurations import ( + CONFIGURATIONS, + configure, + get_async_transaction, + get_transaction, +) from firedantic.exceptions import * from firedantic.utils import get_all_subclasses diff --git a/firedantic/_async/helpers.py b/firedantic/_async/helpers.py index 10bd5ef..c77b650 100644 --- a/firedantic/_async/helpers.py +++ b/firedantic/_async/helpers.py @@ -4,7 +4,8 @@ async def truncate_collection( col_ref: AsyncCollectionReference, batch_size: int = 128 ) -> int: - """Removes all documents inside a collection. + """ + Removes all documents inside a collection. :param col_ref: A collection reference to the collection to be truncated. :param batch_size: Batch size for listing documents. diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index 9c6af1c..84251df 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -10,6 +10,7 @@ FieldFilter, ) from google.cloud.firestore_v1.async_query import AsyncQuery +from google.cloud.firestore_v1.async_transaction import AsyncTransaction import firedantic.operators as op from firedantic import async_truncate_collection @@ -41,6 +42,11 @@ def get_collection_name(cls, name: Optional[str]) -> str: + """ + Returns the collection name. + + :raise CollectionNotDefined: If the collection name is not defined. + """ if not name: raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") @@ -55,7 +61,8 @@ def _get_col_ref(cls, name: Optional[str]) -> AsyncCollectionReference: class AsyncBareModel(pydantic.BaseModel, ABC): - """Base model class. + """ + Base model class. Implements basic functionality for Pydantic models, such as save, delete, find etc. """ @@ -66,13 +73,18 @@ class AsyncBareModel(pydantic.BaseModel, ABC): __composite_indexes__: Optional[Iterable[IndexDefinition]] = None async def save( - self, *, exclude_unset: bool = False, exclude_none: bool = False + self, + *, + exclude_unset: bool = False, + exclude_none: bool = False, + transaction: Optional[AsyncTransaction] = None, ) -> None: """ Saves this model in the database. :param exclude_unset: Whether to exclude fields that have not been explicitly set. :param exclude_none: Whether to exclude fields that have a value of `None`. + :param transaction: Optional transaction to use. :raise DocumentIDError: If the document ID is not valid. """ data = self.model_dump( @@ -82,28 +94,35 @@ async def save( del data[self.__document_id__] doc_ref = self._get_doc_ref() - await doc_ref.set(data) + if transaction is not None: + transaction.set(doc_ref, data) + else: + await doc_ref.set(data) setattr(self, self.__document_id__, doc_ref.id) - async def delete(self) -> None: + async def delete(self, transaction: Optional[AsyncTransaction] = None) -> None: """ Deletes this model from the database. :raise DocumentIDError: If the ID is not valid. """ - await self._get_doc_ref().delete() + if transaction is not None: + transaction.delete(self._get_doc_ref()) + else: + await self._get_doc_ref().delete() - async def reload(self) -> None: + async def reload(self, transaction: Optional[AsyncTransaction] = None) -> None: """ Reloads this model from the database. + :param transaction: Optional transaction to use. :raise ModelNotFoundError: If the document ID is missing in the model. """ doc_id = self.__dict__.get(self.__document_id__) if doc_id is None: raise ModelNotFoundError("Can not reload unsaved model") - updated_model = await self.get_by_doc_id(doc_id) + updated_model = await self.get_by_doc_id(doc_id, transaction=transaction) updated_model_doc_id = updated_model.__dict__[self.__document_id__] assert doc_id == updated_model_doc_id @@ -111,7 +130,7 @@ async def reload(self) -> None: def get_document_id(self): """ - Get the document ID for this model instance + Returns the document ID for this model instance. :raise DocumentIDError: If the ID is not valid. """ @@ -123,14 +142,17 @@ def get_document_id(self): _OrderBy = List[Tuple[str, OrderDirection]] @classmethod - async def find( + async def find( # pylint: disable=too-many-arguments cls: Type[TAsyncBareModel], filter_: Optional[Dict[str, Union[str, dict]]] = None, order_by: Optional[_OrderBy] = None, limit: Optional[int] = None, offset: Optional[int] = None, + transaction: Optional[AsyncTransaction] = None, ) -> List[TAsyncBareModel]: - """Returns a list of models from the database based on a filter. + """ + Returns a list of models from the database based on a filter. + The list can be sorted with the order_by parameter, limits and offets can also be applied. Example: `Company.find({"company_id": "1234567-8"})`. @@ -142,6 +164,7 @@ async def find( :param order_by: List of columns and direction to order results by. :param limit: Maximum results to return. :param offset: Skip the first n results. + :param transaction: Optional transaction to use. :return: List of found models. """ query: Union[AsyncQuery, AsyncCollectionReference] = cls._get_col_ref() @@ -175,7 +198,7 @@ def _cls(doc_id: str, data: Dict[str, Any]) -> TAsyncBareModel: return [ _cls(doc_id, doc_dict) async for doc_id, doc_dict in ( - (doc.id, doc.to_dict()) async for doc in query.stream() # type: ignore + (doc.id, doc.to_dict()) async for doc in query.stream(transaction=transaction) # type: ignore ) if doc_dict is not None ] @@ -184,7 +207,7 @@ def _cls(doc_id: str, data: Dict[str, Any]) -> TAsyncBareModel: def _add_filter( cls, query: Union[AsyncQuery, AsyncCollectionReference], field: str, value: Any ) -> Union[AsyncQuery, AsyncCollectionReference]: - if type(value) is dict: + if isinstance(value, dict): for f_type in value: if f_type not in FIND_TYPES: raise ValueError( @@ -203,32 +226,42 @@ async def find_one( cls: Type[TAsyncBareModel], filter_: Optional[Dict[str, Union[str, dict]]] = None, order_by: Optional[_OrderBy] = None, + transaction: Optional[AsyncTransaction] = None, ) -> TAsyncBareModel: - """Returns one model from the DB based on a filter. + """ + Returns one model from the DB based on a filter. :param filter_: The filter criteria. :param order_by: List of columns and direction to order results by. :return: The model instance. :raise ModelNotFoundError: If the entry is not found. """ - model = await cls.find(filter_, limit=1, order_by=order_by) + model = await cls.find( + filter_, limit=1, order_by=order_by, transaction=transaction + ) try: return model[0] - except IndexError: - raise ModelNotFoundError(f"No '{cls.__name__}' found") + except IndexError as e: + raise ModelNotFoundError(f"No '{cls.__name__}' found") from e @classmethod - async def get_by_doc_id(cls: Type[TAsyncBareModel], doc_id: str) -> TAsyncBareModel: - """Returns a model based on the document ID. + async def get_by_doc_id( + cls: Type[TAsyncBareModel], + doc_id: str, + transaction: Optional[AsyncTransaction] = None, + ) -> TAsyncBareModel: + """ + Returns a model based on the document ID. :param doc_id: The document ID of the entry. + :param transaction: Optional transaction to use. :return: The model. :raise ModelNotFoundError: Raised if no matching document is found. """ try: cls._validate_document_id(doc_id) - except InvalidDocumentID: + except InvalidDocumentID as e: # Getting a document with doc_id set to an empty string would raise a # google.api_core.exceptions.InvalidArgument exception and a doc_id # containing an uneven number of slashes would raise a @@ -236,10 +269,10 @@ async def get_by_doc_id(cls: Type[TAsyncBareModel], doc_id: str) -> TAsyncBareMo # could even load data from a sub collection instead of the desired one. raise ModelNotFoundError( f"No '{cls.__name__}' found with {cls.__document_id__} '{doc_id}'" - ) + ) from e document: DocumentSnapshot = ( - await cls._get_col_ref().document(doc_id).get() + await cls._get_col_ref().document(doc_id).get(transaction=transaction) ) # type: ignore data = document.to_dict() if data is None: @@ -253,7 +286,8 @@ async def get_by_doc_id(cls: Type[TAsyncBareModel], doc_id: str) -> TAsyncBareMo @classmethod async def truncate_collection(cls, batch_size: int = 128) -> int: - """Removes all documents inside a collection. + """ + Removes all documents inside a collection. :param batch_size: Batch size for listing documents. :return: Number of removed documents. @@ -265,11 +299,16 @@ async def truncate_collection(cls, batch_size: int = 128) -> int: @classmethod def _get_col_ref(cls) -> AsyncCollectionReference: - """Returns the collection reference.""" + """ + Returns the collection reference. + """ return _get_col_ref(cls, cls.__collection__) @classmethod def get_collection_name(cls) -> str: + """ + Returns the collection name. + """ return get_collection_name(cls, cls.__collection__) def _get_doc_ref(self) -> AsyncDocumentReference: @@ -319,8 +358,19 @@ class AsyncModel(AsyncBareModel): id: Optional[str] = None @classmethod - async def get_by_id(cls: Type[TAsyncBareModel], id_: str) -> TAsyncBareModel: - return await cls.get_by_doc_id(id_) + async def get_by_id( + cls: Type[TAsyncBareModel], + id_: str, + transaction: Optional[AsyncTransaction] = None, + ) -> TAsyncBareModel: + """ + Get single model by document ID. + + :param id_: Document ID. + :param transaction: Optional transaction to use. + :raises ModelNotFoundError: If no model was found by given id. + """ + return await cls.get_by_doc_id(id_, transaction=transaction) class AsyncBareSubCollection(ABC): @@ -329,6 +379,9 @@ class AsyncBareSubCollection(ABC): @classmethod def model_for(cls, parent, model_class): + """ + Returns the model for this subcollection. + """ parent_props = parent.model_dump(by_alias=True) name = model_class.__name__ @@ -356,7 +409,9 @@ def _create(cls: Type[TAsyncBareSubModel], **kwargs) -> TAsyncBareSubModel: @classmethod def _get_col_ref(cls) -> AsyncCollectionReference: - """Returns the collection reference.""" + """ + Returns the collection reference. + """ if cls.__collection__ is None or "{" in cls.__collection__: raise CollectionNotDefined( f"{cls.__name__} is not properly prepared. " @@ -366,6 +421,9 @@ def _get_col_ref(cls) -> AsyncCollectionReference: @classmethod def model_for(cls, parent): + """ + Returns the model for this submodel. + """ return cls.Collection.model_for(parent, cls) @@ -373,12 +431,19 @@ class AsyncSubModel(AsyncBareSubModel): id: Optional[str] = None @classmethod - async def get_by_id(cls: Type[TAsyncBareModel], id_: str) -> TAsyncBareModel: + async def get_by_id( + cls: Type[TAsyncBareModel], + id_: str, + transaction: Optional[AsyncTransaction] = None, + ) -> TAsyncBareModel: """ Get single item by document ID + + :param id_: Document ID. + :param transaction: Optional transaction to use. :raises ModelNotFoundError: """ - return await cls.get_by_doc_id(id_) + return await cls.get_by_doc_id(id_, transaction=transaction) class AsyncSubCollection(AsyncBareSubCollection, ABC): diff --git a/firedantic/_sync/helpers.py b/firedantic/_sync/helpers.py index 3feb0d5..57dfe99 100644 --- a/firedantic/_sync/helpers.py +++ b/firedantic/_sync/helpers.py @@ -2,7 +2,8 @@ def truncate_collection(col_ref: CollectionReference, batch_size: int = 128) -> int: - """Removes all documents inside a collection. + """ + Removes all documents inside a collection. :param col_ref: A collection reference to the collection to be truncated. :param batch_size: Batch size for listing documents. diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index 4fa7d6d..78ec372 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -10,6 +10,7 @@ FieldFilter, ) from google.cloud.firestore_v1.base_query import BaseQuery +from google.cloud.firestore_v1.transaction import Transaction import firedantic.operators as op from firedantic import truncate_collection @@ -41,6 +42,11 @@ def get_collection_name(cls, name: Optional[str]) -> str: + """ + Returns the collection name. + + :raise CollectionNotDefined: If the collection name is not defined. + """ if not name: raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") @@ -55,7 +61,8 @@ def _get_col_ref(cls, name: Optional[str]) -> CollectionReference: class BareModel(pydantic.BaseModel, ABC): - """Base model class. + """ + Base model class. Implements basic functionality for Pydantic models, such as save, delete, find etc. """ @@ -65,12 +72,19 @@ class BareModel(pydantic.BaseModel, ABC): __ttl_field__: Optional[str] = None __composite_indexes__: Optional[Iterable[IndexDefinition]] = None - def save(self, *, exclude_unset: bool = False, exclude_none: bool = False) -> None: + def save( + self, + *, + exclude_unset: bool = False, + exclude_none: bool = False, + transaction: Optional[Transaction] = None, + ) -> None: """ Saves this model in the database. :param exclude_unset: Whether to exclude fields that have not been explicitly set. :param exclude_none: Whether to exclude fields that have a value of `None`. + :param transaction: Optional transaction to use. :raise DocumentIDError: If the document ID is not valid. """ data = self.model_dump( @@ -80,28 +94,35 @@ def save(self, *, exclude_unset: bool = False, exclude_none: bool = False) -> No del data[self.__document_id__] doc_ref = self._get_doc_ref() - doc_ref.set(data) + if transaction is not None: + transaction.set(doc_ref, data) + else: + doc_ref.set(data) setattr(self, self.__document_id__, doc_ref.id) - def delete(self) -> None: + def delete(self, transaction: Optional[Transaction] = None) -> None: """ Deletes this model from the database. :raise DocumentIDError: If the ID is not valid. """ - self._get_doc_ref().delete() + if transaction is not None: + transaction.delete(self._get_doc_ref()) + else: + self._get_doc_ref().delete() - def reload(self) -> None: + def reload(self, transaction: Optional[Transaction] = None) -> None: """ Reloads this model from the database. + :param transaction: Optional transaction to use. :raise ModelNotFoundError: If the document ID is missing in the model. """ doc_id = self.__dict__.get(self.__document_id__) if doc_id is None: raise ModelNotFoundError("Can not reload unsaved model") - updated_model = self.get_by_doc_id(doc_id) + updated_model = self.get_by_doc_id(doc_id, transaction=transaction) updated_model_doc_id = updated_model.__dict__[self.__document_id__] assert doc_id == updated_model_doc_id @@ -109,7 +130,7 @@ def reload(self) -> None: def get_document_id(self): """ - Get the document ID for this model instance + Returns the document ID for this model instance. :raise DocumentIDError: If the ID is not valid. """ @@ -121,14 +142,17 @@ def get_document_id(self): _OrderBy = List[Tuple[str, OrderDirection]] @classmethod - def find( + def find( # pylint: disable=too-many-arguments cls: Type[TBareModel], filter_: Optional[Dict[str, Union[str, dict]]] = None, order_by: Optional[_OrderBy] = None, limit: Optional[int] = None, offset: Optional[int] = None, + transaction: Optional[Transaction] = None, ) -> List[TBareModel]: - """Returns a list of models from the database based on a filter. + """ + Returns a list of models from the database based on a filter. + The list can be sorted with the order_by parameter, limits and offets can also be applied. Example: `Company.find({"company_id": "1234567-8"})`. @@ -140,6 +164,7 @@ def find( :param order_by: List of columns and direction to order results by. :param limit: Maximum results to return. :param offset: Skip the first n results. + :param transaction: Optional transaction to use. :return: List of found models. """ query: Union[BaseQuery, CollectionReference] = cls._get_col_ref() @@ -173,7 +198,7 @@ def _cls(doc_id: str, data: Dict[str, Any]) -> TBareModel: return [ _cls(doc_id, doc_dict) for doc_id, doc_dict in ( - (doc.id, doc.to_dict()) for doc in query.stream() # type: ignore + (doc.id, doc.to_dict()) for doc in query.stream(transaction=transaction) # type: ignore ) if doc_dict is not None ] @@ -182,7 +207,7 @@ def _cls(doc_id: str, data: Dict[str, Any]) -> TBareModel: def _add_filter( cls, query: Union[BaseQuery, CollectionReference], field: str, value: Any ) -> Union[BaseQuery, CollectionReference]: - if type(value) is dict: + if isinstance(value, dict): for f_type in value: if f_type not in FIND_TYPES: raise ValueError( @@ -201,32 +226,40 @@ def find_one( cls: Type[TBareModel], filter_: Optional[Dict[str, Union[str, dict]]] = None, order_by: Optional[_OrderBy] = None, + transaction: Optional[Transaction] = None, ) -> TBareModel: - """Returns one model from the DB based on a filter. + """ + Returns one model from the DB based on a filter. :param filter_: The filter criteria. :param order_by: List of columns and direction to order results by. :return: The model instance. :raise ModelNotFoundError: If the entry is not found. """ - model = cls.find(filter_, limit=1, order_by=order_by) + model = cls.find(filter_, limit=1, order_by=order_by, transaction=transaction) try: return model[0] - except IndexError: - raise ModelNotFoundError(f"No '{cls.__name__}' found") + except IndexError as e: + raise ModelNotFoundError(f"No '{cls.__name__}' found") from e @classmethod - def get_by_doc_id(cls: Type[TBareModel], doc_id: str) -> TBareModel: - """Returns a model based on the document ID. + def get_by_doc_id( + cls: Type[TBareModel], + doc_id: str, + transaction: Optional[Transaction] = None, + ) -> TBareModel: + """ + Returns a model based on the document ID. :param doc_id: The document ID of the entry. + :param transaction: Optional transaction to use. :return: The model. :raise ModelNotFoundError: Raised if no matching document is found. """ try: cls._validate_document_id(doc_id) - except InvalidDocumentID: + except InvalidDocumentID as e: # Getting a document with doc_id set to an empty string would raise a # google.api_core.exceptions.InvalidArgument exception and a doc_id # containing an uneven number of slashes would raise a @@ -234,10 +267,10 @@ def get_by_doc_id(cls: Type[TBareModel], doc_id: str) -> TBareModel: # could even load data from a sub collection instead of the desired one. raise ModelNotFoundError( f"No '{cls.__name__}' found with {cls.__document_id__} '{doc_id}'" - ) + ) from e document: DocumentSnapshot = ( - cls._get_col_ref().document(doc_id).get() + cls._get_col_ref().document(doc_id).get(transaction=transaction) ) # type: ignore data = document.to_dict() if data is None: @@ -251,7 +284,8 @@ def get_by_doc_id(cls: Type[TBareModel], doc_id: str) -> TBareModel: @classmethod def truncate_collection(cls, batch_size: int = 128) -> int: - """Removes all documents inside a collection. + """ + Removes all documents inside a collection. :param batch_size: Batch size for listing documents. :return: Number of removed documents. @@ -263,11 +297,16 @@ def truncate_collection(cls, batch_size: int = 128) -> int: @classmethod def _get_col_ref(cls) -> CollectionReference: - """Returns the collection reference.""" + """ + Returns the collection reference. + """ return _get_col_ref(cls, cls.__collection__) @classmethod def get_collection_name(cls) -> str: + """ + Returns the collection name. + """ return get_collection_name(cls, cls.__collection__) def _get_doc_ref(self) -> DocumentReference: @@ -317,8 +356,19 @@ class Model(BareModel): id: Optional[str] = None @classmethod - def get_by_id(cls: Type[TBareModel], id_: str) -> TBareModel: - return cls.get_by_doc_id(id_) + def get_by_id( + cls: Type[TBareModel], + id_: str, + transaction: Optional[Transaction] = None, + ) -> TBareModel: + """ + Get single model by document ID. + + :param id_: Document ID. + :param transaction: Optional transaction to use. + :raises ModelNotFoundError: If no model was found by given id. + """ + return cls.get_by_doc_id(id_, transaction=transaction) class BareSubCollection(ABC): @@ -327,6 +377,9 @@ class BareSubCollection(ABC): @classmethod def model_for(cls, parent, model_class): + """ + Returns the model for this subcollection. + """ parent_props = parent.model_dump(by_alias=True) name = model_class.__name__ @@ -354,7 +407,9 @@ def _create(cls: Type[TBareSubModel], **kwargs) -> TBareSubModel: @classmethod def _get_col_ref(cls) -> CollectionReference: - """Returns the collection reference.""" + """ + Returns the collection reference. + """ if cls.__collection__ is None or "{" in cls.__collection__: raise CollectionNotDefined( f"{cls.__name__} is not properly prepared. " @@ -364,6 +419,9 @@ def _get_col_ref(cls) -> CollectionReference: @classmethod def model_for(cls, parent): + """ + Returns the model for this submodel. + """ return cls.Collection.model_for(parent, cls) @@ -371,12 +429,19 @@ class SubModel(BareSubModel): id: Optional[str] = None @classmethod - def get_by_id(cls: Type[TBareModel], id_: str) -> TBareModel: + def get_by_id( + cls: Type[TBareModel], + id_: str, + transaction: Optional[Transaction] = None, + ) -> TBareModel: """ Get single item by document ID + + :param id_: Document ID. + :param transaction: Optional transaction to use. :raises ModelNotFoundError: """ - return cls.get_by_doc_id(id_) + return cls.get_by_doc_id(id_, transaction=transaction) class SubCollection(BareSubCollection, ABC): diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 321a4af..2f9e277 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -1,6 +1,6 @@ from typing import Any, Dict, Union -from google.cloud.firestore_v1 import AsyncClient, Client +from google.cloud.firestore_v1 import AsyncClient, AsyncTransaction, Client, Transaction CONFIGURATIONS: Dict[str, Any] = {} @@ -15,3 +15,21 @@ def configure(db: Union[Client, AsyncClient], prefix: str = "") -> None: CONFIGURATIONS["db"] = db CONFIGURATIONS["prefix"] = prefix + + +def get_transaction() -> Transaction: + """ + Get a new Firestore transaction for the configured DB. + """ + transaction = CONFIGURATIONS["db"].transaction() + assert isinstance(transaction, Transaction) + return transaction + + +def get_async_transaction() -> AsyncTransaction: + """ + Get a new Firestore transaction for the configured DB. + """ + transaction = CONFIGURATIONS["db"].transaction() + assert isinstance(transaction, AsyncTransaction) + return transaction diff --git a/firedantic/exceptions.py b/firedantic/exceptions.py index af36989..5f8dbd4 100644 --- a/firedantic/exceptions.py +++ b/firedantic/exceptions.py @@ -1,16 +1,14 @@ class ModelError(Exception): """Generic model error class.""" - pass - class InvalidDocumentID(ModelError): - pass + """Raised when a document ID is invalid.""" class ModelNotFoundError(ModelError): - pass + """Raised when a model is not found.""" class CollectionNotDefined(ModelError): - pass + """Raised when the model collection is not defined.""" diff --git a/firedantic/tests/tests_async/conftest.py b/firedantic/tests/tests_async/conftest.py index ec1aaeb..09a4710 100644 --- a/firedantic/tests/tests_async/conftest.py +++ b/firedantic/tests/tests_async/conftest.py @@ -5,8 +5,8 @@ import google.auth.credentials import pytest from google.cloud.firestore_admin_v1 import Field, FirestoreAdminClient -from google.cloud.firestore_v1 import AsyncClient -from pydantic import BaseModel, Extra, PrivateAttr +from google.cloud.firestore_v1 import AsyncClient, AsyncTransaction, async_transactional +from pydantic import BaseModel, PrivateAttr from firedantic import ( AsyncBareModel, @@ -16,7 +16,7 @@ AsyncSubCollection, AsyncSubModel, ) -from firedantic.configurations import configure +from firedantic.configurations import configure, get_async_transaction from firedantic.exceptions import ModelNotFoundError from unittest.mock import AsyncMock, Mock # noqa isort: skip @@ -30,7 +30,7 @@ class CustomIDModel(AsyncBareModel): bar: str class Config: - extra = Extra.forbid + extra = "forbid" class CustomIDModelExtra(AsyncBareModel): @@ -42,7 +42,7 @@ class CustomIDModelExtra(AsyncBareModel): baz: str class Config: - extra = Extra.forbid + extra = "forbid" class CustomIDConflictModel(AsyncModel): @@ -52,7 +52,24 @@ class CustomIDConflictModel(AsyncModel): bar: str class Config: - extra = Extra.forbid + extra = "forbid" + + +class City(AsyncModel): + """City model used for test with transactions""" + + __collection__ = "cities" + population: int + + async def increment_population(self, increment: int = 1): + @async_transactional + async def _increment_population(transaction: AsyncTransaction) -> None: + await self.reload(transaction=transaction) + self.population += increment + await self.save(transaction=transaction) + + t = get_async_transaction() + await _increment_population(transaction=t) class Owner(BaseModel): @@ -62,7 +79,7 @@ class Owner(BaseModel): last_name: str class Config: - extra = Extra.forbid + extra = "forbid" class CompanyStats(AsyncBareSubModel): @@ -99,7 +116,7 @@ class Company(AsyncModel): owner: Owner class Config: - extra = Extra.forbid + extra = "forbid" def stats(self) -> Type[CompanyStats]: return CompanyStats.model_for(self) # type: ignore @@ -115,7 +132,7 @@ class Product(AsyncModel): stock: int class Config: - extra = Extra.forbid + extra = "forbid" class Profile(AsyncModel): @@ -127,7 +144,7 @@ class Profile(AsyncModel): photo_url: Optional[str] = None class Config: - extra = Extra.forbid + extra = "forbid" class TodoList(AsyncModel): @@ -139,7 +156,7 @@ class TodoList(AsyncModel): items: List[str] class Config: - extra = Extra.forbid + extra = "forbid" class ExpiringModel(AsyncModel): diff --git a/firedantic/tests/tests_async/test_model.py b/firedantic/tests/tests_async/test_model.py index 501dad7..ed27c0d 100644 --- a/firedantic/tests/tests_async/test_model.py +++ b/firedantic/tests/tests_async/test_model.py @@ -2,17 +2,19 @@ from uuid import uuid4 import pytest -from google.cloud.firestore import Query +from google.cloud.firestore import Query, async_transactional +from google.cloud.firestore_v1.async_transaction import AsyncTransaction from pydantic import Field, ValidationError import firedantic.operators as op -from firedantic import AsyncModel +from firedantic import AsyncModel, get_async_transaction from firedantic.exceptions import ( CollectionNotDefined, InvalidDocumentID, ModelNotFoundError, ) from firedantic.tests.tests_async.conftest import ( + City, Company, CustomIDConflictModel, CustomIDModel, @@ -535,3 +537,116 @@ async def test_save_with_exclude_unset(configure_db) -> None: data = document.to_dict() assert data == {"name": "", "photo_url": None} + + +@pytest.mark.asyncio +async def test_update_city_in_transaction(configure_db) -> None: + """ + Test updating a model in a transaction. Test case from README. + + :param: configure_db: pytest fixture + """ + + @async_transactional + async def decrement_population( + transaction: AsyncTransaction, city: City, decrement: int = 1 + ): + await city.reload(transaction=transaction) + city.population = max(0, city.population - decrement) + await city.save(transaction=transaction) + + c = City(id="SF", population=1) + await c.save() + await c.increment_population(increment=1) + assert c.population == 2 + + t = get_async_transaction() + await decrement_population(transaction=t, city=c, decrement=5) + assert c.population == 0 + + +@pytest.mark.asyncio +async def test_delete_in_transaction(configure_db) -> None: + """ + Test deleting a model in a transaction. + + :param: configure_db: pytest fixture + """ + + @async_transactional + async def delete_in_transaction( + transaction: AsyncTransaction, profile_id: str + ) -> None: + """Deletes a Profile in a transaction.""" + profile = await Profile.get_by_id(profile_id, transaction=transaction) + await profile.delete(transaction=transaction) + + p = Profile(name="Foo") + await p.save() + assert p.id + + t = get_async_transaction() + await delete_in_transaction(t, p.id) + + with pytest.raises(ModelNotFoundError): + await Profile.get_by_id(p.id) + + +@pytest.mark.asyncio +async def test_update_model_in_transaction(configure_db) -> None: + """ + Test updating a model in a transaction. + + :param: configure_db: pytest fixture + """ + + @async_transactional + async def update_in_transaction( + transaction: AsyncTransaction, profile_id: str, name: str + ) -> None: + """Updates a Profile in a transaction.""" + profile = Profile(id=profile_id) + await profile.reload(transaction=transaction) + profile.name = name + await profile.save(transaction=transaction) + + p = Profile(name="Foo") + await p.save() + + t = get_async_transaction() + await update_in_transaction(t, p.id, name="Bar") + await p.reload() + assert p.name == "Bar" + + +@pytest.mark.asyncio +async def test_update_submodel_in_transaction(configure_db) -> None: + """ + Test Updating a submodel in a transaction. + + :param: configure_db: pytest fixture + """ + + @async_transactional + async def update_submodel_in_transaction( + transaction: AsyncTransaction, user_id: str, period: str + ) -> UserStats: + """Updates a UserStats in a transaction.""" + u = await User.get_by_id(user_id, transaction=transaction) + us = UserStats.model_for(u) + user_stats: UserStats = await us.get_by_id(period) # pylint: disable=no-member + user_stats.purchases += 1 + await user_stats.save(transaction=transaction) + return user_stats + + u = User(name="Foo") + await u.save() + assert u.id + us = UserStats.model_for(u) + await us(id="2021", purchases=42).save() # pylint: disable=no-member + + t = get_async_transaction() + user_stats = await update_submodel_in_transaction(t, u.id, "2021") + assert isinstance(user_stats, UserStats) + assert user_stats.purchases == 43 + assert await get_user_purchases(u.id) == 43 diff --git a/firedantic/tests/tests_sync/conftest.py b/firedantic/tests/tests_sync/conftest.py index ea7d4c2..8debfc4 100644 --- a/firedantic/tests/tests_sync/conftest.py +++ b/firedantic/tests/tests_sync/conftest.py @@ -5,8 +5,8 @@ import google.auth.credentials import pytest from google.cloud.firestore_admin_v1 import Field, FirestoreAdminClient -from google.cloud.firestore_v1 import Client -from pydantic import BaseModel, Extra, PrivateAttr +from google.cloud.firestore_v1 import Client, Transaction, transactional +from pydantic import BaseModel, PrivateAttr from firedantic import ( BareModel, @@ -16,7 +16,7 @@ SubCollection, SubModel, ) -from firedantic.configurations import configure +from firedantic.configurations import configure, get_transaction from firedantic.exceptions import ModelNotFoundError from unittest.mock import Mock, Mock # noqa isort: skip @@ -30,7 +30,7 @@ class CustomIDModel(BareModel): bar: str class Config: - extra = Extra.forbid + extra = "forbid" class CustomIDModelExtra(BareModel): @@ -42,7 +42,7 @@ class CustomIDModelExtra(BareModel): baz: str class Config: - extra = Extra.forbid + extra = "forbid" class CustomIDConflictModel(Model): @@ -52,7 +52,24 @@ class CustomIDConflictModel(Model): bar: str class Config: - extra = Extra.forbid + extra = "forbid" + + +class City(Model): + """City model used for test with transactions""" + + __collection__ = "cities" + population: int + + def increment_population(self, increment: int = 1): + @transactional + def _increment_population(transaction: Transaction) -> None: + self.reload(transaction=transaction) + self.population += increment + self.save(transaction=transaction) + + t = get_transaction() + _increment_population(transaction=t) class Owner(BaseModel): @@ -62,7 +79,7 @@ class Owner(BaseModel): last_name: str class Config: - extra = Extra.forbid + extra = "forbid" class CompanyStats(BareSubModel): @@ -99,7 +116,7 @@ class Company(Model): owner: Owner class Config: - extra = Extra.forbid + extra = "forbid" def stats(self) -> Type[CompanyStats]: return CompanyStats.model_for(self) # type: ignore @@ -115,7 +132,7 @@ class Product(Model): stock: int class Config: - extra = Extra.forbid + extra = "forbid" class Profile(Model): @@ -127,7 +144,7 @@ class Profile(Model): photo_url: Optional[str] = None class Config: - extra = Extra.forbid + extra = "forbid" class TodoList(Model): @@ -139,7 +156,7 @@ class TodoList(Model): items: List[str] class Config: - extra = Extra.forbid + extra = "forbid" class ExpiringModel(Model): diff --git a/firedantic/tests/tests_sync/test_model.py b/firedantic/tests/tests_sync/test_model.py index 0f62a5b..37e27a4 100644 --- a/firedantic/tests/tests_sync/test_model.py +++ b/firedantic/tests/tests_sync/test_model.py @@ -2,17 +2,19 @@ from uuid import uuid4 import pytest -from google.cloud.firestore import Query +from google.cloud.firestore import Query, transactional +from google.cloud.firestore_v1.transaction import Transaction from pydantic import Field, ValidationError import firedantic.operators as op -from firedantic import Model +from firedantic import Model, get_transaction from firedantic.exceptions import ( CollectionNotDefined, InvalidDocumentID, ModelNotFoundError, ) from firedantic.tests.tests_sync.conftest import ( + City, Company, CustomIDConflictModel, CustomIDModel, @@ -500,3 +502,108 @@ def test_save_with_exclude_unset(configure_db) -> None: data = document.to_dict() assert data == {"name": "", "photo_url": None} + + +def test_update_city_in_transaction(configure_db) -> None: + """ + Test updating a model in a transaction. Test case from README. + + :param: configure_db: pytest fixture + """ + + @transactional + def decrement_population(transaction: Transaction, city: City, decrement: int = 1): + city.reload(transaction=transaction) + city.population = max(0, city.population - decrement) + city.save(transaction=transaction) + + c = City(id="SF", population=1) + c.save() + c.increment_population(increment=1) + assert c.population == 2 + + t = get_transaction() + decrement_population(transaction=t, city=c, decrement=5) + assert c.population == 0 + + +def test_delete_in_transaction(configure_db) -> None: + """ + Test deleting a model in a transaction. + + :param: configure_db: pytest fixture + """ + + @transactional + def delete_in_transaction(transaction: Transaction, profile_id: str) -> None: + """Deletes a Profile in a transaction.""" + profile = Profile.get_by_id(profile_id, transaction=transaction) + profile.delete(transaction=transaction) + + p = Profile(name="Foo") + p.save() + assert p.id + + t = get_transaction() + delete_in_transaction(t, p.id) + + with pytest.raises(ModelNotFoundError): + Profile.get_by_id(p.id) + + +def test_update_model_in_transaction(configure_db) -> None: + """ + Test updating a model in a transaction. + + :param: configure_db: pytest fixture + """ + + @transactional + def update_in_transaction( + transaction: Transaction, profile_id: str, name: str + ) -> None: + """Updates a Profile in a transaction.""" + profile = Profile(id=profile_id) + profile.reload(transaction=transaction) + profile.name = name + profile.save(transaction=transaction) + + p = Profile(name="Foo") + p.save() + + t = get_transaction() + update_in_transaction(t, p.id, name="Bar") + p.reload() + assert p.name == "Bar" + + +def test_update_submodel_in_transaction(configure_db) -> None: + """ + Test Updating a submodel in a transaction. + + :param: configure_db: pytest fixture + """ + + @transactional + def update_submodel_in_transaction( + transaction: Transaction, user_id: str, period: str + ) -> UserStats: + """Updates a UserStats in a transaction.""" + u = User.get_by_id(user_id, transaction=transaction) + us = UserStats.model_for(u) + user_stats: UserStats = us.get_by_id(period) # pylint: disable=no-member + user_stats.purchases += 1 + user_stats.save(transaction=transaction) + return user_stats + + u = User(name="Foo") + u.save() + assert u.id + us = UserStats.model_for(u) + us(id="2021", purchases=42).save() # pylint: disable=no-member + + t = get_transaction() + user_stats = update_submodel_in_transaction(t, u.id, "2021") + assert isinstance(user_stats, UserStats) + assert user_stats.purchases == 43 + assert get_user_purchases(u.id) == 43 diff --git a/pyproject.toml b/pyproject.toml index 36e2901..18e2cc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "firedantic" -version = "0.10.0" +version = "0.11.0" description = "Pydantic base models for Firestore" authors = ["IOXIO Ltd"] license = "BSD-3-Clause" diff --git a/unasync.py b/unasync.py index e6ac175..ffa7887 100644 --- a/unasync.py +++ b/unasync.py @@ -5,6 +5,13 @@ from pathlib import Path SUBS = [ + ("get_async_transaction", "get_transaction"), + ( + "from google.cloud.firestore_v1.async_transaction", + "from google.cloud.firestore_v1.transaction", + ), + ("async_transactional", "transactional"), + ("AsyncTransaction", "Transaction"), ("google.cloud.firestore_v1.async_query", "google.cloud.firestore_v1.base_query"), ("AsyncQuery", "BaseQuery"), ("AsyncCollectionReference", "CollectionReference"),