From 6924531d190ec81acc893ee4664b91fc19f6e7fb Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Tue, 26 Aug 2025 14:07:19 -0400 Subject: [PATCH 01/38] working model, not tested --- .gitignore | 1 + README.md | 31 +++++- firedantic/__init__.py | 1 + firedantic/_async/model.py | 4 +- firedantic/_sync/model.py | 6 +- firedantic/configurations.py | 95 +++++++++++++++++-- firedantic/tests/tests_async/test_model.py | 6 ++ firedantic/tests/tests_sync/test_model.py | 6 ++ multiple_clients_examples/billing_models.py | 19 ++++ multiple_clients_examples/company_models.py | 18 ++++ .../configure_firestore_db_client.py | 67 +++++++++++++ multiple_clients_examples/query_sample.py | 25 +++++ .../save_data_to_firestore.py | 84 ++++++++++++++++ 13 files changed, 344 insertions(+), 19 deletions(-) create mode 100644 multiple_clients_examples/billing_models.py create mode 100644 multiple_clients_examples/company_models.py create mode 100644 multiple_clients_examples/configure_firestore_db_client.py create mode 100644 multiple_clients_examples/query_sample.py create mode 100644 multiple_clients_examples/save_data_to_firestore.py diff --git a/.gitignore b/.gitignore index c02de79..1f28392 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,4 @@ ui-debug.log .idea .DS_Store +.vscode/ diff --git a/README.md b/README.md index e80d025..9a9c532 100644 --- a/README.md +++ b/README.md @@ -475,12 +475,14 @@ if __name__ == "__main__": PRs are welcome! -To run tests locally, you should run: +### Prequisites: -```bash -poetry install -poetry run invoke test -``` +- [pipx](https://pipx.pypa.io/stable/installation/) (to install poetry, if needed) +- `pipx install poetry` +- `pip install pre-commit` +- [nvm](https://github.com/nvm-sh/nvm) +- `npm` (for firebase CLI installation) +- [java](http://www.java.com) via `brew install openjdk@21` ### Running Firestore emulator @@ -502,6 +504,25 @@ Run the Firestore emulator with a predictable port: start_emulator ``` +### Tests + +To run tests locally, you should first: + +Set the host: + +```bash +export FIRESTORE_EMULATOR_HOST= +``` + +Then run poetry: + +```bash +poetry install +poetry run invoke test +``` + +\*Note, the emulator must be set and running for all tests to pass. + ### About sync and async versions of library Although this library provides both sync and async versions of models, please keep in diff --git a/firedantic/__init__.py b/firedantic/__init__.py index e4f79d2..30a0bb8 100644 --- a/firedantic/__init__.py +++ b/firedantic/__init__.py @@ -34,6 +34,7 @@ from firedantic.common import collection_group_index, collection_index from firedantic.configurations import ( CONFIGURATIONS, + Configuration, configure, get_async_transaction, get_transaction, diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index 84251df..5b093df 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -50,11 +50,11 @@ def get_collection_name(cls, name: Optional[str]) -> str: if not name: raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") - return f"{CONFIGURATIONS['prefix']}{name}" + return f"{CONFIGURATIONS[name].prefix}{name}" def _get_col_ref(cls, name: Optional[str]) -> AsyncCollectionReference: - collection: AsyncCollectionReference = CONFIGURATIONS["db"].collection( + collection: AsyncCollectionReference = CONFIGURATIONS[name].async_client.collection( get_collection_name(cls, name) ) return collection diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index 78ec372..fe684ad 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -50,11 +50,11 @@ def get_collection_name(cls, name: Optional[str]) -> str: if not name: raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") - return f"{CONFIGURATIONS['prefix']}{name}" + return f"{CONFIGURATIONS[name].prefix}{name}" -def _get_col_ref(cls, name: Optional[str]) -> CollectionReference: - collection: CollectionReference = CONFIGURATIONS["db"].collection( +def _get_col_ref(cls, name: Optional[str] = "(default)") -> CollectionReference: + collection: CollectionReference = CONFIGURATIONS[name].client.collection( get_collection_name(cls, name) ) return collection diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 2f9e277..5cc2eb7 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -1,35 +1,112 @@ -from typing import Any, Dict, Union +from os import environ +from typing import Dict, Optional, Union +from google.auth.credentials import Credentials from google.cloud.firestore_v1 import AsyncClient, AsyncTransaction, Client, Transaction +from pydantic import BaseModel -CONFIGURATIONS: Dict[str, Any] = {} +# for added support of multiple configurations/clients +class ConfigItem(BaseModel): + prefix: str + project: str + credentials: Optional[Credentials] = None + client: Optional[Client] = None + async_client: Optional[AsyncClient] = None -def configure(db: Union[Client, AsyncClient], prefix: str = "") -> None: + class Config: + arbitrary_types_allowed = True + + +# updating CONFIGURATIONS dict to be able to contain many/multiple configurations +CONFIGURATIONS: Dict[str, ConfigItem] = {} + + +def get_client(proj, creds) -> Union[Client, AsyncClient]: + # Firestore emulator must be running if using locally. + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = Client(project=proj, credentials=creds) + else: + client = Client() + + return client + + +# Allow configure to work as it was for backwards compatibility. +def configure( + db: Union[Client, AsyncClient] = None, + prefix: str = "", +) -> None: """Configures the prefix and DB. :param db: The firestore client instance. :param prefix: The prefix to use for collection names. """ global CONFIGURATIONS + CONFIGURATIONS["(default)"] = ConfigItem - CONFIGURATIONS["db"] = db - CONFIGURATIONS["prefix"] = prefix + if isinstance(db, Client): + CONFIGURATIONS["(default)"].client = db + elif isinstance(db, AsyncClient): + CONFIGURATIONS["(default)"].async_client = db + # otherwise gets set to None by default + CONFIGURATIONS["(default)"].prefix = prefix def get_transaction() -> Transaction: """ - Get a new Firestore transaction for the configured DB. + Get a new Firestore transaction for the configured client. """ - transaction = CONFIGURATIONS["db"].transaction() + transaction = CONFIGURATIONS["(default)"].client.transaction() assert isinstance(transaction, Transaction) return transaction def get_async_transaction() -> AsyncTransaction: """ - Get a new Firestore transaction for the configured DB. + Get a new async Firestore transaction for the configured client. """ - transaction = CONFIGURATIONS["db"].transaction() + transaction = CONFIGURATIONS["(default)"].async_client.transaction() assert isinstance(transaction, AsyncTransaction) return transaction + + +class Configuration: + def __init__(self): + self.configurations: Dict[str, ConfigItem] = {} + + def create( + self, + name: str = "(default)", + prefix: str = "", + project: str = "", + credentials: Credentials = None, + ) -> None: + self.configurations[name] = ConfigItem( + prefix=prefix, + project=project, + credentials=credentials, + client=Client( + project=project, + credentials=credentials, + ), + async_client=AsyncClient( + project=project, + credentials=credentials, + ), + ) + # add to global CONFIGURATIONS + global CONFIGURATIONS + CONFIGURATIONS[name] = self.configurations[name] + + def get_client(self, name: str = "(default)") -> Client: + return self.configurations[name].client + + def get_async_client(self, name: str = "(default)") -> AsyncClient: + return self.configurations[name].async_client + + def get_transaction(self, name: str = "(default)") -> Transaction: + return self.get_client(name=name).transaction() + + def get_async_transaction(self, name: str = "(default)") -> AsyncTransaction: + return self.get_async_client(name=name).transaction() diff --git a/firedantic/tests/tests_async/test_model.py b/firedantic/tests/tests_async/test_model.py index ed27c0d..dc25a17 100644 --- a/firedantic/tests/tests_async/test_model.py +++ b/firedantic/tests/tests_async/test_model.py @@ -103,6 +103,12 @@ async def test_find(configure_db, create_company, create_product) -> None: d = await Company.find({"owner.first_name": "John"}) assert len(d) == 4 + d = await Company.find({"owner.first_name": {op.EQ: "John"}}) + assert len(d) == 4 + + d = await Company.find({"owner.first_name": {"==": "John"}}) + assert len(d) == 4 + for p in TEST_PRODUCTS: await create_product(**p) diff --git a/firedantic/tests/tests_sync/test_model.py b/firedantic/tests/tests_sync/test_model.py index 37e27a4..0a8af2b 100644 --- a/firedantic/tests/tests_sync/test_model.py +++ b/firedantic/tests/tests_sync/test_model.py @@ -97,6 +97,12 @@ def test_find(configure_db, create_company, create_product) -> None: d = Company.find({"owner.first_name": "John"}) assert len(d) == 4 + d = Company.find({"owner.first_name": {op.EQ: "John"}}) + assert len(d) == 4 + + d = Company.find({"owner.first_name": {"==": "John"}}) + assert len(d) == 4 + for p in TEST_PRODUCTS: create_product(**p) diff --git a/multiple_clients_examples/billing_models.py b/multiple_clients_examples/billing_models.py new file mode 100644 index 0000000..2ab2d43 --- /dev/null +++ b/multiple_clients_examples/billing_models.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel + +from firedantic import Model + + +class BillingAccount(BaseModel): + """Dummy billing account Pydantic model.""" + + name: str + billing_id: str + owner: str + + +class BillingCompany(Model): + """Dummy company Firedantic model.""" + + __collection__ = "billing_companies" + company_id: str + billing_account: BillingAccount diff --git a/multiple_clients_examples/company_models.py b/multiple_clients_examples/company_models.py new file mode 100644 index 0000000..fe0a46d --- /dev/null +++ b/multiple_clients_examples/company_models.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel + +from firedantic import Model + + +class Owner(BaseModel): + """Dummy owner Pydantic model.""" + + first_name: str + last_name: str + + +class Company(Model): + """Dummy company Firedantic model.""" + + __collection__ = "companies" + company_id: str + owner: Owner diff --git a/multiple_clients_examples/configure_firestore_db_client.py b/multiple_clients_examples/configure_firestore_db_client.py new file mode 100644 index 0000000..32eb206 --- /dev/null +++ b/multiple_clients_examples/configure_firestore_db_client.py @@ -0,0 +1,67 @@ +from os import environ +from unittest.mock import Mock + +import google.auth.credentials +from google.cloud.firestore import AsyncClient, Client + +from firedantic import Configuration, configure + + +## OLD WAY: +def configure_client(): + # Firestore emulator must be running if using locally. + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = Client( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + else: + client = Client() + + configure(client, prefix="firedantic-test-") + print(client) + + +def configure_async_client(): + # Firestore emulator must be running if using locally. + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = AsyncClient( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + else: + client = AsyncClient() + + configure(client, prefix="firedantic-test-") + print(client) + + +## NEW WAY: +def configure_multiple_clients(): + config = Configuration() + + # name = (default) + config.create( + prefix="firedantic-test-", + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # name = billing + config.create( + name="billing", + prefix="test-billing-", + project="test-billing", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + print(config.get_client()) ## will pull the default client + print(config.get_client("billing")) ## will pull the billing client + print(config.get_async_client("billing")) ## will pull the billing async client + + +# print("\n---- Running OLD way ----") +# configure_client() +# configure_async_client() +# print("\n---- Running NEW way ----") +# configure_multiple_clients() diff --git a/multiple_clients_examples/query_sample.py b/multiple_clients_examples/query_sample.py new file mode 100644 index 0000000..7661963 --- /dev/null +++ b/multiple_clients_examples/query_sample.py @@ -0,0 +1,25 @@ +from company_models import Company +from configure_firestore_db_client import configure_client, configure_multiple_clients + +import firedantic.operators as op + +# Firestore emulator must be running if using locally. +# Supported types are: <, >, array_contains, in, ==, !=, not-in, <=, array_contains_any, >= + + +print("\n---- Running OLD way ----") +configure_client() + +companies1 = Company.find({"owner.first_name": "John"}) +companies2 = Company.find({"owner.first_name": {op.EQ: "John"}}) +companies3 = Company.find({"owner.first_name": {"==": "John"}}) +assert companies1 == companies2 == companies3 + + +print("\n---- Running NEW way ----") +configure_multiple_clients() + +companies1 = Company.find({"owner.first_name": "Alice"}) +companies2 = Company.find({"owner.first_name": {op.EQ: "Alice"}}) +companies3 = Company.find({"owner.first_name": {"==": "Alice"}}) +assert companies1 == companies2 == companies3 diff --git a/multiple_clients_examples/save_data_to_firestore.py b/multiple_clients_examples/save_data_to_firestore.py new file mode 100644 index 0000000..fd48dae --- /dev/null +++ b/multiple_clients_examples/save_data_to_firestore.py @@ -0,0 +1,84 @@ +from unittest.mock import Mock + +import google.auth.credentials +from billing_models import BillingAccount, BillingCompany +from company_models import Company, Owner +from configure_firestore_db_client import configure_client + +from firedantic import Configuration + + +## With single client +def old_way(): + # Firestore emulator must be running if using locally. + configure_client() + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="John", last_name="Doe") + company = Company(company_id="1234567-8", owner=owner) + company.save() + + # Prints out the firestore ID of the Company model + print(f"\nFirestore ID: {company.id}") + + # Reloads model data from the database + company.reload() + + +## Now with multiple clients/dbs: +def new_way(): + config = Configuration() + + # 1. Create first config with config_name = "(default)" + config.create( + prefix="firedantic-test-", + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="Alice", last_name="Begone") + company = Company(company_id="1234567-9", owner=owner) + company.save() # will use 'default' as config name + + # Reloads model data from the database + company.reload() # with no name supplied, config refers to "(default)" + + # 2. Create the second config with config_name = "billing" + config_name = "billing" + config.create( + name=config_name, + prefix="test-billing-", + project="test-billing", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Now you can use the model to save it to Firestore + account = BillingAccount(name="ABC Billing", billing_id="801048", owner="MFisher") + bc = BillingCompany(company_id="1234567-8", billing_account=account) + + bc.save(config_name) + + # Reloads model data from the database + bc.reload(config_name) # with config name supplied, config refers to "billing" + + # 3. Finding data + # Can retrieve info from either database/client + # When config is not specified, it will default to '(default)' config + # The models do not know which config you intended to use them for, and they + # could be used for a multitude of configurations at once. + print(Company.find({"owner.first_name": "Alice"})) + + print(BillingCompany.find({"company_id": "1234567-8"}, config=config_name)) + + print( + BillingCompany.find( + {"billing_account.billing_id": "801048"}, config=config_name + ) + ) + + +print("\n---- Running OLD way ----") +old_way() +print("\n---- Running NEW way ----") +new_way() From cfd2d550bd2ceb41a37da2e1bd2452def2166a46 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Tue, 26 Aug 2025 15:35:53 -0400 Subject: [PATCH 02/38] save_data_to_firestore working again --- README.md | 16 ++++++- firedantic/_async/model.py | 10 ++--- firedantic/_sync/helpers.py | 4 +- firedantic/_sync/model.py | 42 +++++++++++-------- firedantic/configurations.py | 10 ++--- firedantic/tests/tests_sync/conftest.py | 4 +- firedantic/tests/tests_sync/test_indexes.py | 15 ++++++- firedantic/tests/tests_sync/test_model.py | 35 +++++++++++++++- .../tests/tests_sync/test_ttl_policy.py | 1 + 9 files changed, 103 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 9a9c532..79e9955 100644 --- a/README.md +++ b/README.md @@ -514,13 +514,27 @@ Set the host: export FIRESTORE_EMULATOR_HOST= ``` -Then run poetry: +Setup a virtual environment(not required but will be easier to run poetry like so): + +```bash +python -m venv firedantic +source firedantic/bin/activate +cd firedantic/tests +``` + +Then run poetry within `tests` directory: ```bash poetry install poetry run invoke test ``` +to run a singular test: + +```bash +poetry run pytest tests_sync/.py:: +``` + \*Note, the emulator must be set and running for all tests to pass. ### About sync and async versions of library diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index 5b093df..f2fc2b8 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -41,7 +41,7 @@ } -def get_collection_name(cls, name: Optional[str]) -> str: +def get_collection_name(cls, name: Optional[str], config: str = "(default)") -> str: """ Returns the collection name. @@ -50,12 +50,12 @@ def get_collection_name(cls, name: Optional[str]) -> str: if not name: raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") - return f"{CONFIGURATIONS[name].prefix}{name}" + return f"{CONFIGURATIONS[config].prefix}{name}" -def _get_col_ref(cls, name: Optional[str]) -> AsyncCollectionReference: - collection: AsyncCollectionReference = CONFIGURATIONS[name].async_client.collection( - get_collection_name(cls, name) +def _get_col_ref(cls, name: Optional[str], config: str = "(default)") -> AsyncCollectionReference: + collection: AsyncCollectionReference = CONFIGURATIONS[config].async_client.collection( + get_collection_name(cls, name, config) ) return collection diff --git a/firedantic/_sync/helpers.py b/firedantic/_sync/helpers.py index 57dfe99..3a5f92c 100644 --- a/firedantic/_sync/helpers.py +++ b/firedantic/_sync/helpers.py @@ -1,7 +1,9 @@ from google.cloud.firestore_v1 import CollectionReference -def truncate_collection(col_ref: CollectionReference, batch_size: int = 128) -> int: +def truncate_collection( + col_ref: CollectionReference, batch_size: int = 128 +) -> int: """ Removes all documents inside a collection. diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index fe684ad..9f7aa49 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -41,7 +41,7 @@ } -def get_collection_name(cls, name: Optional[str]) -> str: +def get_collection_name(cls, name: Optional[str], config: str = "(default)") -> str: """ Returns the collection name. @@ -50,12 +50,12 @@ def get_collection_name(cls, name: Optional[str]) -> str: if not name: raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") - return f"{CONFIGURATIONS[name].prefix}{name}" + return f"{CONFIGURATIONS[config].prefix}{name}" -def _get_col_ref(cls, name: Optional[str] = "(default)") -> CollectionReference: - collection: CollectionReference = CONFIGURATIONS[name].client.collection( - get_collection_name(cls, name) +def _get_col_ref(cls, name: Optional[str], config: str = "(default)") -> CollectionReference: + collection: CollectionReference = CONFIGURATIONS[config].client.collection( + get_collection_name(cls, name, config) ) return collection @@ -74,10 +74,12 @@ class BareModel(pydantic.BaseModel, ABC): def save( self, + config: str = "(default)", *, exclude_unset: bool = False, exclude_none: bool = False, transaction: Optional[Transaction] = None, + ) -> None: """ Saves this model in the database. @@ -93,7 +95,7 @@ def save( if self.__document_id__ in data: del data[self.__document_id__] - doc_ref = self._get_doc_ref() + doc_ref = self._get_doc_ref(config) if transaction is not None: transaction.set(doc_ref, data) else: @@ -111,7 +113,7 @@ def delete(self, transaction: Optional[Transaction] = None) -> None: else: self._get_doc_ref().delete() - def reload(self, transaction: Optional[Transaction] = None) -> None: + def reload(self, config: str = "(default)", transaction: Optional[Transaction] = None) -> None: """ Reloads this model from the database. @@ -122,7 +124,7 @@ def reload(self, transaction: Optional[Transaction] = None) -> None: if doc_id is None: raise ModelNotFoundError("Can not reload unsaved model") - updated_model = self.get_by_doc_id(doc_id, transaction=transaction) + updated_model = self.get_by_doc_id(doc_id, config=config, transaction=transaction) updated_model_doc_id = updated_model.__dict__[self.__document_id__] assert doc_id == updated_model_doc_id @@ -149,6 +151,7 @@ def find( # pylint: disable=too-many-arguments limit: Optional[int] = None, offset: Optional[int] = None, transaction: Optional[Transaction] = None, + config: str = "(default)" ) -> List[TBareModel]: """ Returns a list of models from the database based on a filter. @@ -167,7 +170,7 @@ def find( # pylint: disable=too-many-arguments :param transaction: Optional transaction to use. :return: List of found models. """ - query: Union[BaseQuery, CollectionReference] = cls._get_col_ref() + query: Union[BaseQuery, CollectionReference] = cls._get_col_ref(config) if filter_: for key, value in filter_.items(): query = cls._add_filter(query, key, value) @@ -236,7 +239,9 @@ def find_one( :return: The model instance. :raise ModelNotFoundError: If the entry is not found. """ - model = cls.find(filter_, limit=1, order_by=order_by, transaction=transaction) + model = cls.find( + filter_, limit=1, order_by=order_by, transaction=transaction + ) try: return model[0] except IndexError as e: @@ -246,6 +251,7 @@ def find_one( def get_by_doc_id( cls: Type[TBareModel], doc_id: str, + config: str = "(default)", transaction: Optional[Transaction] = None, ) -> TBareModel: """ @@ -270,7 +276,7 @@ def get_by_doc_id( ) from e document: DocumentSnapshot = ( - cls._get_col_ref().document(doc_id).get(transaction=transaction) + cls._get_col_ref(config).document(doc_id).get(transaction=transaction) ) # type: ignore data = document.to_dict() if data is None: @@ -296,11 +302,11 @@ def truncate_collection(cls, batch_size: int = 128) -> int: ) @classmethod - def _get_col_ref(cls) -> CollectionReference: + def _get_col_ref(cls, config: str = "(default)") -> CollectionReference: """ Returns the collection reference. """ - return _get_col_ref(cls, cls.__collection__) + return _get_col_ref(cls, cls.__collection__, config) @classmethod def get_collection_name(cls) -> str: @@ -309,13 +315,13 @@ def get_collection_name(cls) -> str: """ return get_collection_name(cls, cls.__collection__) - def _get_doc_ref(self) -> DocumentReference: + def _get_doc_ref(self, config: str = "(default)") -> DocumentReference: """ Returns the document reference. :raise DocumentIDError: If the ID is not valid. """ - return self._get_col_ref().document(self.get_document_id()) # type: ignore + return self._get_col_ref(config).document(self.get_document_id()) # type: ignore @staticmethod def _validate_document_id(document_id: str): @@ -359,6 +365,7 @@ class Model(BareModel): def get_by_id( cls: Type[TBareModel], id_: str, + config: str = "(default)", transaction: Optional[Transaction] = None, ) -> TBareModel: """ @@ -368,7 +375,7 @@ def get_by_id( :param transaction: Optional transaction to use. :raises ModelNotFoundError: If no model was found by given id. """ - return cls.get_by_doc_id(id_, transaction=transaction) + return cls.get_by_doc_id(id_, config=config, transaction=transaction) class BareSubCollection(ABC): @@ -432,6 +439,7 @@ class SubModel(BareSubModel): def get_by_id( cls: Type[TBareModel], id_: str, + config: str = "(default)", transaction: Optional[Transaction] = None, ) -> TBareModel: """ @@ -441,7 +449,7 @@ def get_by_id( :param transaction: Optional transaction to use. :raises ModelNotFoundError: """ - return cls.get_by_doc_id(id_, transaction=transaction) + return cls.get_by_doc_id(id_, config=config, transaction=transaction) class SubCollection(BareSubCollection, ABC): diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 5cc2eb7..6898f02 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -32,7 +32,7 @@ def get_client(proj, creds) -> Union[Client, AsyncClient]: return client -# Allow configure to work as it was for backwards compatibility. +# Allow configure method to work as it was for backwards compatibility. def configure( db: Union[Client, AsyncClient] = None, prefix: str = "", @@ -53,20 +53,20 @@ def configure( CONFIGURATIONS["(default)"].prefix = prefix -def get_transaction() -> Transaction: +def get_transaction(config: str = "(default)") -> Transaction: """ Get a new Firestore transaction for the configured client. """ - transaction = CONFIGURATIONS["(default)"].client.transaction() + transaction = CONFIGURATIONS[config].client.transaction() assert isinstance(transaction, Transaction) return transaction -def get_async_transaction() -> AsyncTransaction: +def get_async_transaction(config: str = "(default)") -> AsyncTransaction: """ Get a new async Firestore transaction for the configured client. """ - transaction = CONFIGURATIONS["(default)"].async_client.transaction() + transaction = CONFIGURATIONS[config].async_client.transaction() assert isinstance(transaction, AsyncTransaction) return transaction diff --git a/firedantic/tests/tests_sync/conftest.py b/firedantic/tests/tests_sync/conftest.py index 8debfc4..410769d 100644 --- a/firedantic/tests/tests_sync/conftest.py +++ b/firedantic/tests/tests_sync/conftest.py @@ -195,7 +195,9 @@ def _create( @pytest.fixture def create_product(): - def _create(product_id: Optional[str] = None, price: float = 1.23, stock: int = 3): + def _create( + product_id: Optional[str] = None, price: float = 1.23, stock: int = 3 + ): if not product_id: product_id = str(uuid.uuid4()) p = Product(product_id=product_id, price=price, stock=stock) diff --git a/firedantic/tests/tests_sync/test_indexes.py b/firedantic/tests/tests_sync/test_indexes.py index c4fb399..28d5e21 100644 --- a/firedantic/tests/tests_sync/test_indexes.py +++ b/firedantic/tests/tests_sync/test_indexes.py @@ -26,6 +26,7 @@ class BaseModelWithIndexes(Model): age: int + def test_set_up_composite_index(mock_admin_client) -> None: class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( @@ -58,6 +59,7 @@ class ModelWithIndexes(BaseModelWithIndexes): assert index.fields[1].order.name == Query.DESCENDING + def test_set_up_collection_group_index(mock_admin_client) -> None: class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( @@ -86,6 +88,7 @@ class ModelWithIndexes(BaseModelWithIndexes): assert len(index.fields) == 2 + def test_set_up_composite_indexes_and_policies(mock_admin_client) -> None: class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( @@ -109,6 +112,7 @@ class ModelWithIndexes(BaseModelWithIndexes): assert len(call_list) == 1 + def test_set_up_many_composite_indexes(mock_admin_client) -> None: class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( @@ -135,6 +139,7 @@ class ModelWithIndexes(BaseModelWithIndexes): assert len(result) == 3 + def test_set_up_indexes_model_without_indexes(mock_admin_client) -> None: class ModelWithoutIndexes(Model): __collection__ = "modelWithoutIndexes" @@ -152,6 +157,7 @@ class ModelWithoutIndexes(Model): assert len(call_list) == 0 + def test_existing_indexes_are_skipped(mock_admin_client) -> None: resp = ListIndexesResponse( { @@ -183,7 +189,9 @@ def test_existing_indexes_are_skipped(mock_admin_client) -> None: ] } ) - mock_admin_client.list_indexes = Mock(return_value=MockListIndexOperation([resp])) + mock_admin_client.list_indexes = Mock( + return_value=MockListIndexOperation([resp]) + ) class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( @@ -205,6 +213,7 @@ class ModelWithIndexes(BaseModelWithIndexes): assert len(result) == 0 + def test_same_fields_in_another_collection(mock_admin_client) -> None: # Test that when another collection has an index with exactly the same fields, # it won't affect creating an index in the target collection @@ -226,7 +235,9 @@ def test_same_fields_in_another_collection(mock_admin_client) -> None: ] } ) - mock_admin_client.list_indexes = Mock(return_value=MockListIndexOperation([resp])) + mock_admin_client.list_indexes = Mock( + return_value=MockListIndexOperation([resp]) + ) class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( diff --git a/firedantic/tests/tests_sync/test_model.py b/firedantic/tests/tests_sync/test_model.py index 0a8af2b..24231d1 100644 --- a/firedantic/tests/tests_sync/test_model.py +++ b/firedantic/tests/tests_sync/test_model.py @@ -35,6 +35,7 @@ ] + def test_save_model(configure_db, create_company) -> None: company = create_company() @@ -43,6 +44,7 @@ def test_save_model(configure_db, create_company) -> None: assert company.owner.last_name == "Doe" + def test_delete_model(configure_db, create_company) -> None: company: Company = create_company( company_id="11223344-5", first_name="Jane", last_name="Doe" @@ -57,6 +59,7 @@ def test_delete_model(configure_db, create_company) -> None: Company.get_by_id(_id) + def test_find_one(configure_db, create_company) -> None: with pytest.raises(ModelNotFoundError): Company.find_one() @@ -217,6 +220,7 @@ def test_find_order_by(configure_db, create_company) -> None: assert companies_and_owners == lastname_ascending_firstname_ascending + def test_find_offset(configure_db, create_company) -> None: ids_and_lastnames = ( ("1234555-1", "A"), @@ -234,6 +238,7 @@ def test_find_offset(configure_db, create_company) -> None: assert len(companies_ascending) == 2 + def test_get_by_id(configure_db, create_company) -> None: c: Company = create_company(company_id="1234567-8") @@ -248,11 +253,13 @@ def test_get_by_id(configure_db, create_company) -> None: assert c_2.owner.first_name == "John" + def test_get_by_empty_str_id(configure_db) -> None: with pytest.raises(ModelNotFoundError): Company.get_by_id("") + def test_missing_collection(configure_db) -> None: class User(Model): name: str @@ -261,6 +268,7 @@ class User(Model): User(name="John").save() + def test_model_aliases(configure_db) -> None: class User(Model): __collection__ = "User" @@ -277,6 +285,7 @@ class User(Model): assert user_from_db.city == "Helsinki" + @pytest.mark.parametrize( "model_id", [ @@ -318,6 +327,7 @@ def test_models_with_valid_custom_id(configure_db, model_id) -> None: found.delete() + @pytest.mark.parametrize( "model_id", [ @@ -343,6 +353,7 @@ def test_models_with_invalid_custom_id(configure_db, model_id: str) -> None: Product.get_by_id(model_id) + def test_truncate_collection(configure_db, create_company) -> None: create_company(company_id="1234567-8") create_company(company_id="1234567-9") @@ -355,6 +366,7 @@ def test_truncate_collection(configure_db, create_company) -> None: assert len(new_companies) == 0 + def test_custom_id_model(configure_db) -> None: c = CustomIDModel(bar="bar") # type: ignore c.save() @@ -367,6 +379,7 @@ def test_custom_id_model(configure_db) -> None: assert m.bar == "bar" + def test_custom_id_conflict(configure_db) -> None: CustomIDConflictModel(foo="foo", bar="bar").save() @@ -378,6 +391,7 @@ def test_custom_id_conflict(configure_db) -> None: assert m.bar == "bar" + def test_model_id_persistency(configure_db) -> None: c = CustomIDConflictModel(foo="foo", bar="bar") c.save() @@ -389,6 +403,7 @@ def test_model_id_persistency(configure_db) -> None: assert len(CustomIDConflictModel.find({})) == 1 + def test_bare_model_document_id_persistency(configure_db) -> None: c = CustomIDModel(bar="bar") # type: ignore c.save() @@ -400,17 +415,20 @@ def test_bare_model_document_id_persistency(configure_db) -> None: assert len(CustomIDModel.find({})) == 1 + def test_bare_model_get_by_empty_doc_id(configure_db) -> None: with pytest.raises(ModelNotFoundError): CustomIDModel.get_by_doc_id("") + def test_extra_fields(configure_db) -> None: CustomIDModelExtra(foo="foo", bar="bar", baz="baz").save() # type: ignore with pytest.raises(ValidationError): CustomIDModel.find({}) + def test_company_stats(configure_db, create_company) -> None: company: Company = create_company(company_id="1234567-8") company_stats = company.stats() @@ -431,6 +449,7 @@ def test_company_stats(configure_db, create_company) -> None: assert stats.sales == 101 + def test_subcollection_model_safety(configure_db) -> None: """ Ensure you shouldn't be able to use unprepared subcollection models accidentally @@ -439,6 +458,7 @@ def test_subcollection_model_safety(configure_db) -> None: UserStats.find({}) + def test_get_user_purchases(configure_db) -> None: u = User(name="Foo") u.save() @@ -450,6 +470,7 @@ def test_get_user_purchases(configure_db) -> None: assert get_user_purchases(u.id) == 42 + def test_reload(configure_db) -> None: u = User(name="Foo") u.save() @@ -468,6 +489,7 @@ def test_reload(configure_db) -> None: another_user.reload() + def test_save_with_exclude_none(configure_db) -> None: p = Profile(name="Foo") p.save(exclude_none=True) @@ -489,6 +511,7 @@ def test_save_with_exclude_none(configure_db) -> None: assert data == {"name": "Foo", "photo_url": None} + def test_save_with_exclude_unset(configure_db) -> None: p = Profile(photo_url=None) p.save(exclude_unset=True) @@ -510,6 +533,7 @@ def test_save_with_exclude_unset(configure_db) -> None: assert data == {"name": "", "photo_url": None} + def test_update_city_in_transaction(configure_db) -> None: """ Test updating a model in a transaction. Test case from README. @@ -518,7 +542,9 @@ def test_update_city_in_transaction(configure_db) -> None: """ @transactional - def decrement_population(transaction: Transaction, city: City, decrement: int = 1): + def decrement_population( + transaction: Transaction, city: City, decrement: int = 1 + ): city.reload(transaction=transaction) city.population = max(0, city.population - decrement) city.save(transaction=transaction) @@ -533,6 +559,7 @@ def decrement_population(transaction: Transaction, city: City, decrement: int = assert c.population == 0 + def test_delete_in_transaction(configure_db) -> None: """ Test deleting a model in a transaction. @@ -541,7 +568,9 @@ def test_delete_in_transaction(configure_db) -> None: """ @transactional - def delete_in_transaction(transaction: Transaction, profile_id: str) -> None: + def delete_in_transaction( + transaction: Transaction, profile_id: str + ) -> None: """Deletes a Profile in a transaction.""" profile = Profile.get_by_id(profile_id, transaction=transaction) profile.delete(transaction=transaction) @@ -557,6 +586,7 @@ def delete_in_transaction(transaction: Transaction, profile_id: str) -> None: Profile.get_by_id(p.id) + def test_update_model_in_transaction(configure_db) -> None: """ Test updating a model in a transaction. @@ -583,6 +613,7 @@ def update_in_transaction( assert p.name == "Bar" + def test_update_submodel_in_transaction(configure_db) -> None: """ Test Updating a submodel in a transaction. diff --git a/firedantic/tests/tests_sync/test_ttl_policy.py b/firedantic/tests/tests_sync/test_ttl_policy.py index 8d2d549..bcccf2b 100644 --- a/firedantic/tests/tests_sync/test_ttl_policy.py +++ b/firedantic/tests/tests_sync/test_ttl_policy.py @@ -26,6 +26,7 @@ def test_set_up_ttl_policies_new_policy(mock_admin_client): [Field.TtlConfig.State.NEEDS_REPAIR], ), ) + def test_set_up_ttl_policies_other_states(mock_admin_client, state): mock_admin_client.field_state = Field.TtlConfig.State.ACTIVE result = set_up_ttl_policies( From 2ad1b88ee711cd2d41787d8e681bb2ee35cb7307 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Tue, 26 Aug 2025 20:25:34 -0400 Subject: [PATCH 03/38] fixed sync model tests --- firedantic/_sync/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index 9f7aa49..909bc1c 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -170,6 +170,8 @@ def find( # pylint: disable=too-many-arguments :param transaction: Optional transaction to use. :return: List of found models. """ + print(f'\nconfig is: {config}') + query: Union[BaseQuery, CollectionReference] = cls._get_col_ref(config) if filter_: for key, value in filter_.items(): @@ -413,7 +415,7 @@ def _create(cls: Type[TBareSubModel], **kwargs) -> TBareSubModel: ) @classmethod - def _get_col_ref(cls) -> CollectionReference: + def _get_col_ref(cls, config: str = "(default)") -> CollectionReference: """ Returns the collection reference. """ @@ -422,7 +424,7 @@ def _get_col_ref(cls) -> CollectionReference: f"{cls.__name__} is not properly prepared. " f"You should use {cls.__name__}.model_for(parent)" ) - return _get_col_ref(cls.__collection_cls__, cls.__collection__) + return _get_col_ref(cls.__collection_cls__, cls.__collection__, config) @classmethod def model_for(cls, parent): From cb4918e4439d9d40f2f79fa79871b1d9a9b94554 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Tue, 26 Aug 2025 20:34:32 -0400 Subject: [PATCH 04/38] all existing tests passing --- firedantic/tests/tests_async/test_indexes.py | 10 +++++----- firedantic/tests/tests_sync/test_indexes.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/firedantic/tests/tests_async/test_indexes.py b/firedantic/tests/tests_async/test_indexes.py index 5c0d3e9..d1bef9c 100644 --- a/firedantic/tests/tests_async/test_indexes.py +++ b/firedantic/tests/tests_async/test_indexes.py @@ -48,7 +48,7 @@ class ModelWithIndexes(BaseModelWithIndexes): path = call_list[0][1]["request"].parent assert ( path - == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS['prefix']}modelWithIndexes" + == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS["(default)"].prefix}modelWithIndexes" ) index = call_list[0][1]["request"].index assert index.query_scope.name == "COLLECTION" @@ -81,7 +81,7 @@ class ModelWithIndexes(BaseModelWithIndexes): path = call_list[0][1]["request"].parent assert ( path - == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS['prefix']}modelWithIndexes" + == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS["(default)"].prefix}modelWithIndexes" ) index = call_list[0][1]["request"].index assert index.query_scope.name == "COLLECTION_GROUP" @@ -165,7 +165,7 @@ async def test_existing_indexes_are_skipped(mock_admin_client) -> None: { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS['prefix']}modelWithIndexes/123456" + f"{CONFIGURATIONS["(default)"].prefix}modelWithIndexes/123456" ), "query_scope": "COLLECTION", "fields": [ @@ -177,7 +177,7 @@ async def test_existing_indexes_are_skipped(mock_admin_client) -> None: { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS['prefix']}modelWithIndexes/67889" + f"{CONFIGURATIONS["(default)"].prefix}modelWithIndexes/67889" ), "query_scope": "COLLECTION", "fields": [ @@ -223,7 +223,7 @@ async def test_same_fields_in_another_collection(mock_admin_client) -> None: { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS['prefix']}anotherModel/123456" + f"{CONFIGURATIONS["(default)"].prefix}anotherModel/123456" ), "query_scope": "COLLECTION", "fields": [ diff --git a/firedantic/tests/tests_sync/test_indexes.py b/firedantic/tests/tests_sync/test_indexes.py index 28d5e21..2673c4c 100644 --- a/firedantic/tests/tests_sync/test_indexes.py +++ b/firedantic/tests/tests_sync/test_indexes.py @@ -48,7 +48,7 @@ class ModelWithIndexes(BaseModelWithIndexes): path = call_list[0][1]["request"].parent assert ( path - == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS['prefix']}modelWithIndexes" + == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS["(default)"].prefix}modelWithIndexes" ) index = call_list[0][1]["request"].index assert index.query_scope.name == "COLLECTION" @@ -81,7 +81,7 @@ class ModelWithIndexes(BaseModelWithIndexes): path = call_list[0][1]["request"].parent assert ( path - == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS['prefix']}modelWithIndexes" + == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS["(default)"].prefix}modelWithIndexes" ) index = call_list[0][1]["request"].index assert index.query_scope.name == "COLLECTION_GROUP" @@ -165,7 +165,7 @@ def test_existing_indexes_are_skipped(mock_admin_client) -> None: { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS['prefix']}modelWithIndexes/123456" + f"{CONFIGURATIONS["(default)"].prefix}modelWithIndexes/123456" ), "query_scope": "COLLECTION", "fields": [ @@ -177,7 +177,7 @@ def test_existing_indexes_are_skipped(mock_admin_client) -> None: { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS['prefix']}modelWithIndexes/67889" + f"{CONFIGURATIONS["(default)"].prefix}modelWithIndexes/67889" ), "query_scope": "COLLECTION", "fields": [ @@ -223,7 +223,7 @@ def test_same_fields_in_another_collection(mock_admin_client) -> None: { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS['prefix']}anotherModel/123456" + f"{CONFIGURATIONS["(default)"].prefix}anotherModel/123456" ), "query_scope": "COLLECTION", "fields": [ From f64cee926c9d66ccb146bff41260b1f8d23e5312 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Tue, 26 Aug 2025 22:53:17 -0400 Subject: [PATCH 05/38] all 161 test cases passing with pytest --- firedantic/__init__.py | 1 + firedantic/_sync/model.py | 17 +- firedantic/configurations.py | 4 +- firedantic/tests/tests_async/conftest.py | 2 +- firedantic/tests/tests_async/test_model.py | 76 +++--- firedantic/tests/tests_sync/conftest.py | 35 ++- firedantic/tests/tests_sync/test_model.py | 257 +++++++++++++++++---- 7 files changed, 292 insertions(+), 100 deletions(-) diff --git a/firedantic/__init__.py b/firedantic/__init__.py index 30a0bb8..cc67a70 100644 --- a/firedantic/__init__.py +++ b/firedantic/__init__.py @@ -36,6 +36,7 @@ CONFIGURATIONS, Configuration, configure, + ConfigItem, get_async_transaction, get_transaction, ) diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index 909bc1c..3cf20e2 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -102,16 +102,16 @@ def save( doc_ref.set(data) setattr(self, self.__document_id__, doc_ref.id) - def delete(self, transaction: Optional[Transaction] = None) -> None: + def delete(self, config: str = "(default)", transaction: Optional[Transaction] = None) -> None: """ Deletes this model from the database. :raise DocumentIDError: If the ID is not valid. """ if transaction is not None: - transaction.delete(self._get_doc_ref()) + transaction.delete(self._get_doc_ref(config)) else: - self._get_doc_ref().delete() + self._get_doc_ref(config).delete() def reload(self, config: str = "(default)", transaction: Optional[Transaction] = None) -> None: """ @@ -151,7 +151,7 @@ def find( # pylint: disable=too-many-arguments limit: Optional[int] = None, offset: Optional[int] = None, transaction: Optional[Transaction] = None, - config: str = "(default)" + config: Optional[str] = "(default)" ) -> List[TBareModel]: """ Returns a list of models from the database based on a filter. @@ -168,9 +168,9 @@ def find( # pylint: disable=too-many-arguments :param limit: Maximum results to return. :param offset: Skip the first n results. :param transaction: Optional transaction to use. + :param config: Client configuration to use. :return: List of found models. """ - print(f'\nconfig is: {config}') query: Union[BaseQuery, CollectionReference] = cls._get_col_ref(config) if filter_: @@ -232,6 +232,7 @@ def find_one( filter_: Optional[Dict[str, Union[str, dict]]] = None, order_by: Optional[_OrderBy] = None, transaction: Optional[Transaction] = None, + config: Optional[str] = "(default)" ) -> TBareModel: """ Returns one model from the DB based on a filter. @@ -242,7 +243,7 @@ def find_one( :raise ModelNotFoundError: If the entry is not found. """ model = cls.find( - filter_, limit=1, order_by=order_by, transaction=transaction + filter_, limit=1, order_by=order_by, transaction=transaction, config=config ) try: return model[0] @@ -291,7 +292,7 @@ def get_by_doc_id( return model @classmethod - def truncate_collection(cls, batch_size: int = 128) -> int: + def truncate_collection(cls, config: str = "(default)", batch_size: int = 128) -> int: """ Removes all documents inside a collection. @@ -299,7 +300,7 @@ def truncate_collection(cls, batch_size: int = 128) -> int: :return: Number of removed documents. """ return truncate_collection( - col_ref=cls._get_col_ref(), + col_ref=cls._get_col_ref(config), batch_size=batch_size, ) diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 6898f02..8f2f958 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -9,7 +9,7 @@ # for added support of multiple configurations/clients class ConfigItem(BaseModel): prefix: str - project: str + project: str = None credentials: Optional[Credentials] = None client: Optional[Client] = None async_client: Optional[AsyncClient] = None @@ -49,8 +49,8 @@ def configure( CONFIGURATIONS["(default)"].client = db elif isinstance(db, AsyncClient): CONFIGURATIONS["(default)"].async_client = db - # otherwise gets set to None by default CONFIGURATIONS["(default)"].prefix = prefix + # other params get set to None by default def get_transaction(config: str = "(default)") -> Transaction: diff --git a/firedantic/tests/tests_async/conftest.py b/firedantic/tests/tests_async/conftest.py index 09a4710..19b3bb4 100644 --- a/firedantic/tests/tests_async/conftest.py +++ b/firedantic/tests/tests_async/conftest.py @@ -170,7 +170,7 @@ class ExpiringModel(AsyncModel): @pytest.fixture(autouse=True) -def configure_db(): +def configure_client(): client = AsyncClient( project="ioxio-local-dev", credentials=Mock(spec=google.auth.credentials.Credentials), diff --git a/firedantic/tests/tests_async/test_model.py b/firedantic/tests/tests_async/test_model.py index dc25a17..bb3d514 100644 --- a/firedantic/tests/tests_async/test_model.py +++ b/firedantic/tests/tests_async/test_model.py @@ -19,6 +19,7 @@ CustomIDConflictModel, CustomIDModel, CustomIDModelExtra, + Owner, Product, Profile, TodoList, @@ -36,16 +37,15 @@ @pytest.mark.asyncio -async def test_save_model(configure_db, create_company) -> None: +async def test_save_model(configure_client, create_company) -> None: company = await create_company() assert company.id is not None assert company.owner.first_name == "John" assert company.owner.last_name == "Doe" - @pytest.mark.asyncio -async def test_delete_model(configure_db, create_company) -> None: +async def test_delete_model(configure_client, create_company) -> None: company: Company = await create_company( company_id="11223344-5", first_name="Jane", last_name="Doe" ) @@ -60,7 +60,7 @@ async def test_delete_model(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_find_one(configure_db, create_company) -> None: +async def test_find_one(configure_client, create_company) -> None: with pytest.raises(ModelNotFoundError): await Company.find_one() @@ -91,7 +91,7 @@ async def test_find_one(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_find(configure_db, create_company, create_product) -> None: +async def test_find(configure_client, create_company, create_product) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: await create_company(company_id=company_id) @@ -128,7 +128,7 @@ async def test_find(configure_db, create_company, create_product) -> None: @pytest.mark.asyncio -async def test_find_not_in(configure_db, create_company) -> None: +async def test_find_not_in(configure_client, create_company) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: await create_company(company_id=company_id) @@ -149,7 +149,7 @@ async def test_find_not_in(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_find_array_contains(configure_db, create_todolist) -> None: +async def test_find_array_contains(configure_client, create_todolist) -> None: list_1 = await create_todolist("list_1", ["Work", "Eat", "Sleep"]) await create_todolist("list_2", ["Learn Python", "Walk the dog"]) @@ -159,7 +159,7 @@ async def test_find_array_contains(configure_db, create_todolist) -> None: @pytest.mark.asyncio -async def test_find_array_contains_any(configure_db, create_todolist) -> None: +async def test_find_array_contains_any(configure_client, create_todolist) -> None: list_1 = await create_todolist("list_1", ["Work", "Eat"]) list_2 = await create_todolist("list_2", ["Relax", "Chill", "Sleep"]) await create_todolist("list_3", ["Learn Python", "Walk the dog"]) @@ -171,7 +171,7 @@ async def test_find_array_contains_any(configure_db, create_todolist) -> None: @pytest.mark.asyncio -async def test_find_limit(configure_db, create_company) -> None: +async def test_find_limit(configure_client, create_company) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: await create_company(company_id=company_id) @@ -184,7 +184,7 @@ async def test_find_limit(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_find_order_by(configure_db, create_company) -> None: +async def test_find_order_by(configure_client, create_company) -> None: companies_and_owners = [ {"company_id": "1234555-1", "last_name": "A", "first_name": "A"}, {"company_id": "1234555-2", "last_name": "A", "first_name": "B"}, @@ -233,7 +233,7 @@ async def test_find_order_by(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_find_offset(configure_db, create_company) -> None: +async def test_find_offset(configure_client, create_company) -> None: ids_and_lastnames = ( ("1234555-1", "A"), ("1234567-8", "B"), @@ -251,7 +251,7 @@ async def test_find_offset(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_get_by_id(configure_db, create_company) -> None: +async def test_get_by_id(configure_client, create_company) -> None: c: Company = await create_company(company_id="1234567-8") assert c.id is not None @@ -266,13 +266,13 @@ async def test_get_by_id(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_get_by_empty_str_id(configure_db) -> None: +async def test_get_by_empty_str_id(configure_client) -> None: with pytest.raises(ModelNotFoundError): await Company.get_by_id("") @pytest.mark.asyncio -async def test_missing_collection(configure_db) -> None: +async def test_missing_collection(configure_client) -> None: class User(AsyncModel): name: str @@ -281,7 +281,7 @@ class User(AsyncModel): @pytest.mark.asyncio -async def test_model_aliases(configure_db) -> None: +async def test_model_aliases(configure_client) -> None: class User(AsyncModel): __collection__ = "User" @@ -326,7 +326,7 @@ class User(AsyncModel): "!:&+-*'()", ], ) -async def test_models_with_valid_custom_id(configure_db, model_id) -> None: +async def test_models_with_valid_custom_id(configure_client, model_id) -> None: product_id = str(uuid4()) product = Product(product_id=product_id, price=123.45, stock=2) @@ -355,7 +355,7 @@ async def test_models_with_valid_custom_id(configure_db, model_id) -> None: "foo/bar/baz", ], ) -async def test_models_with_invalid_custom_id(configure_db, model_id: str) -> None: +async def test_models_with_invalid_custom_id(configure_client, model_id: str) -> None: product = Product(product_id="product 123", price=123.45, stock=2) product.id = model_id with pytest.raises(InvalidDocumentID): @@ -366,7 +366,7 @@ async def test_models_with_invalid_custom_id(configure_db, model_id: str) -> Non @pytest.mark.asyncio -async def test_truncate_collection(configure_db, create_company) -> None: +async def test_truncate_collection(configure_client, create_company) -> None: await create_company(company_id="1234567-8") await create_company(company_id="1234567-9") @@ -379,7 +379,7 @@ async def test_truncate_collection(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_custom_id_model(configure_db) -> None: +async def test_custom_id_model(configure_client) -> None: c = CustomIDModel(bar="bar") # type: ignore await c.save() @@ -392,7 +392,7 @@ async def test_custom_id_model(configure_db) -> None: @pytest.mark.asyncio -async def test_custom_id_conflict(configure_db) -> None: +async def test_custom_id_conflict(configure_client) -> None: await CustomIDConflictModel(foo="foo", bar="bar").save() models = await CustomIDModel.find({}) @@ -404,7 +404,7 @@ async def test_custom_id_conflict(configure_db) -> None: @pytest.mark.asyncio -async def test_model_id_persistency(configure_db) -> None: +async def test_model_id_persistency(configure_client) -> None: c = CustomIDConflictModel(foo="foo", bar="bar") await c.save() assert c.id @@ -416,7 +416,7 @@ async def test_model_id_persistency(configure_db) -> None: @pytest.mark.asyncio -async def test_bare_model_document_id_persistency(configure_db) -> None: +async def test_bare_model_document_id_persistency(configure_client) -> None: c = CustomIDModel(bar="bar") # type: ignore await c.save() assert c.foo @@ -428,20 +428,20 @@ async def test_bare_model_document_id_persistency(configure_db) -> None: @pytest.mark.asyncio -async def test_bare_model_get_by_empty_doc_id(configure_db) -> None: +async def test_bare_model_get_by_empty_doc_id(configure_client) -> None: with pytest.raises(ModelNotFoundError): await CustomIDModel.get_by_doc_id("") @pytest.mark.asyncio -async def test_extra_fields(configure_db) -> None: +async def test_extra_fields(configure_client) -> None: await CustomIDModelExtra(foo="foo", bar="bar", baz="baz").save() # type: ignore with pytest.raises(ValidationError): await CustomIDModel.find({}) @pytest.mark.asyncio -async def test_company_stats(configure_db, create_company) -> None: +async def test_company_stats(configure_client, create_company) -> None: company: Company = await create_company(company_id="1234567-8") company_stats = company.stats() @@ -462,7 +462,7 @@ async def test_company_stats(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_subcollection_model_safety(configure_db) -> None: +async def test_subcollection_model_safety(configure_client) -> None: """ Ensure you shouldn't be able to use unprepared subcollection models accidentally """ @@ -471,7 +471,7 @@ async def test_subcollection_model_safety(configure_db) -> None: @pytest.mark.asyncio -async def test_get_user_purchases(configure_db) -> None: +async def test_get_user_purchases(configure_client) -> None: u = User(name="Foo") await u.save() assert u.id @@ -483,7 +483,7 @@ async def test_get_user_purchases(configure_db) -> None: @pytest.mark.asyncio -async def test_reload(configure_db) -> None: +async def test_reload(configure_client) -> None: u = User(name="Foo") await u.save() @@ -502,7 +502,7 @@ async def test_reload(configure_db) -> None: @pytest.mark.asyncio -async def test_save_with_exclude_none(configure_db) -> None: +async def test_save_with_exclude_none(configure_client) -> None: p = Profile(name="Foo") await p.save(exclude_none=True) @@ -524,7 +524,7 @@ async def test_save_with_exclude_none(configure_db) -> None: @pytest.mark.asyncio -async def test_save_with_exclude_unset(configure_db) -> None: +async def test_save_with_exclude_unset(configure_client) -> None: p = Profile(photo_url=None) await p.save(exclude_unset=True) @@ -546,11 +546,11 @@ async def test_save_with_exclude_unset(configure_db) -> None: @pytest.mark.asyncio -async def test_update_city_in_transaction(configure_db) -> None: +async def test_update_city_in_transaction(configure_client) -> None: """ Test updating a model in a transaction. Test case from README. - :param: configure_db: pytest fixture + :param: configure_client: pytest fixture """ @async_transactional @@ -572,11 +572,11 @@ async def decrement_population( @pytest.mark.asyncio -async def test_delete_in_transaction(configure_db) -> None: +async def test_delete_in_transaction(configure_client) -> None: """ Test deleting a model in a transaction. - :param: configure_db: pytest fixture + :param: configure_client: pytest fixture """ @async_transactional @@ -599,11 +599,11 @@ async def delete_in_transaction( @pytest.mark.asyncio -async def test_update_model_in_transaction(configure_db) -> None: +async def test_update_model_in_transaction(configure_client) -> None: """ Test updating a model in a transaction. - :param: configure_db: pytest fixture + :param: configure_client: pytest fixture """ @async_transactional @@ -626,11 +626,11 @@ async def update_in_transaction( @pytest.mark.asyncio -async def test_update_submodel_in_transaction(configure_db) -> None: +async def test_update_submodel_in_transaction(configure_client) -> None: """ Test Updating a submodel in a transaction. - :param: configure_db: pytest fixture + :param: configure_client: pytest fixture """ @async_transactional diff --git a/firedantic/tests/tests_sync/conftest.py b/firedantic/tests/tests_sync/conftest.py index 410769d..2924ae0 100644 --- a/firedantic/tests/tests_sync/conftest.py +++ b/firedantic/tests/tests_sync/conftest.py @@ -16,10 +16,10 @@ SubCollection, SubModel, ) -from firedantic.configurations import configure, get_transaction +from firedantic.configurations import configure, Configuration, get_transaction from firedantic.exceptions import ModelNotFoundError -from unittest.mock import Mock, Mock # noqa isort: skip +from unittest.mock import Mock # noqa isort: skip class CustomIDModel(BareModel): @@ -61,14 +61,14 @@ class City(Model): __collection__ = "cities" population: int - def increment_population(self, increment: int = 1): + def increment_population(self, increment: int = 1, config: str = "(default)"): @transactional def _increment_population(transaction: Transaction) -> None: - self.reload(transaction=transaction) + self.reload(config=config, transaction=transaction) self.population += increment - self.save(transaction=transaction) + self.save(config=config, transaction=transaction) - t = get_transaction() + t = get_transaction(config) _increment_population(transaction=t) @@ -170,7 +170,7 @@ class ExpiringModel(Model): @pytest.fixture(autouse=True) -def configure_db(): +def configure_client(): client = Client( project="ioxio-local-dev", credentials=Mock(spec=google.auth.credentials.Credentials), @@ -179,6 +179,27 @@ def configure_db(): prefix = str(uuid.uuid4()) + "-" configure(client, prefix) +@pytest.fixture +def configure_multiple_clients(): + config = Configuration() + mock_creds = Mock(spec=google.auth.credentials.Credentials) + + # name = (default) + config.create( + prefix="ioxio-local-dev-", + project=str(uuid.uuid4()) + "-", + credentials=mock_creds + ) + + # name = multi + config.create( + name="multi", + prefix="test-multi-", + project="test-multi", + credentials=mock_creds, + ) + + @pytest.fixture def create_company(): diff --git a/firedantic/tests/tests_sync/test_model.py b/firedantic/tests/tests_sync/test_model.py index 24231d1..b3fcbde 100644 --- a/firedantic/tests/tests_sync/test_model.py +++ b/firedantic/tests/tests_sync/test_model.py @@ -19,6 +19,7 @@ CustomIDConflictModel, CustomIDModel, CustomIDModelExtra, + Owner, Product, Profile, TodoList, @@ -36,7 +37,7 @@ -def test_save_model(configure_db, create_company) -> None: +def test_save_model(configure_client, create_company) -> None: company = create_company() assert company.id is not None @@ -44,8 +45,29 @@ def test_save_model(configure_db, create_company) -> None: assert company.owner.last_name == "Doe" +def test_save_multi_client_model(configure_multiple_clients): + owner = Owner(first_name="Alice", last_name="Begone") + company = Company(company_id="1234567-9", owner=owner) + company.save() # will use 'default' as config name + company.reload() # with no name supplied, config refers to "(default)" -def test_delete_model(configure_db, create_company) -> None: + config="multi" + product = Product(product_id=str(uuid4()), price=123.45, stock=2) + product.save(config) + product = Product(product_id=str(uuid4()), price=89.95, stock=8) + product.save(config) + product.reload(config) + + # Can retrieve info from either client: (default) or multi. + # When config= is not set, it will default to '(default)' config. + # The models do not know which config you intended to use them for, and they + # could be used for a multitude of configurations at once. + assert Company.find_one({"owner.first_name": "Alice"}).company_id == "1234567-9" + assert Product.find_one({"price": 123.45}, config=config).stock == 2 + assert Product.find_one({"stock": 8}, config=config).stock == 8 + + +def test_delete_model(configure_client, create_company) -> None: company: Company = create_company( company_id="11223344-5", first_name="Jane", last_name="Doe" ) @@ -59,8 +81,35 @@ def test_delete_model(configure_db, create_company) -> None: Company.get_by_id(_id) +def test_delete_multi_client_model(configure_multiple_clients, create_company) -> None: + # will use 'default' as config name + owner = Owner(first_name="Alice", last_name="Begone") + company = Company(company_id="12345678", owner=owner) + company.save() + + config="multi" + product = Product(product_id="1112", price=123.45, stock=2) + product.save(config) + product = Product(product_id="1114", price=89.95, stock=8) + product.save(config) + + _id = company.id + assert _id + + _pid = product.id + assert _pid + + company.delete() + product.delete(config) -def test_find_one(configure_db, create_company) -> None: + with pytest.raises(ModelNotFoundError): + Company.get_by_id(_id) + + with pytest.raises(ModelNotFoundError): + Product.get_by_id(_pid) + + +def test_find_one(configure_client, create_company) -> None: with pytest.raises(ModelNotFoundError): Company.find_one() @@ -88,7 +137,7 @@ def test_find_one(configure_db, create_company) -> None: assert first_desc.owner.first_name == "Foo" -def test_find(configure_db, create_company, create_product) -> None: +def test_find(configure_client, create_company, create_product) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: create_company(company_id=company_id) @@ -97,14 +146,12 @@ def test_find(configure_db, create_company, create_product) -> None: assert c[0].company_id == "4124432-4" assert c[0].owner.first_name == "John" - d = Company.find({"owner.first_name": "John"}) - assert len(d) == 4 + d1 = Company.find({"owner.first_name": "John"}) - d = Company.find({"owner.first_name": {op.EQ: "John"}}) - assert len(d) == 4 + d2 = Company.find({"owner.first_name": {op.EQ: "John"}}) - d = Company.find({"owner.first_name": {"==": "John"}}) - assert len(d) == 4 + d3 = Company.find({"owner.first_name": {"==": "John"}}) + assert d1 == d2 == d3 for p in TEST_PRODUCTS: create_product(**p) @@ -124,7 +171,7 @@ def test_find(configure_db, create_company, create_product) -> None: Product.find({"product_id": {"<>": "a"}}) -def test_find_not_in(configure_db, create_company) -> None: +def test_find_not_in(configure_client, create_company) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: create_company(company_id=company_id) @@ -144,7 +191,7 @@ def test_find_not_in(configure_db, create_company) -> None: assert company.company_id in ("2131232-4", "4124432-4") -def test_find_array_contains(configure_db, create_todolist) -> None: +def test_find_array_contains(configure_client, create_todolist) -> None: list_1 = create_todolist("list_1", ["Work", "Eat", "Sleep"]) create_todolist("list_2", ["Learn Python", "Walk the dog"]) @@ -153,7 +200,7 @@ def test_find_array_contains(configure_db, create_todolist) -> None: assert found[0].name == list_1.name -def test_find_array_contains_any(configure_db, create_todolist) -> None: +def test_find_array_contains_any(configure_client, create_todolist) -> None: list_1 = create_todolist("list_1", ["Work", "Eat"]) list_2 = create_todolist("list_2", ["Relax", "Chill", "Sleep"]) create_todolist("list_3", ["Learn Python", "Walk the dog"]) @@ -164,7 +211,7 @@ def test_find_array_contains_any(configure_db, create_todolist) -> None: assert lst.name in (list_1.name, list_2.name) -def test_find_limit(configure_db, create_company) -> None: +def test_find_limit(configure_client, create_company) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: create_company(company_id=company_id) @@ -176,7 +223,7 @@ def test_find_limit(configure_db, create_company) -> None: assert len(companies_2) == 2 -def test_find_order_by(configure_db, create_company) -> None: +def test_find_order_by(configure_client, create_company) -> None: companies_and_owners = [ {"company_id": "1234555-1", "last_name": "A", "first_name": "A"}, {"company_id": "1234555-2", "last_name": "A", "first_name": "B"}, @@ -221,7 +268,7 @@ def test_find_order_by(configure_db, create_company) -> None: -def test_find_offset(configure_db, create_company) -> None: +def test_find_offset(configure_client, create_company) -> None: ids_and_lastnames = ( ("1234555-1", "A"), ("1234567-8", "B"), @@ -238,8 +285,7 @@ def test_find_offset(configure_db, create_company) -> None: assert len(companies_ascending) == 2 - -def test_get_by_id(configure_db, create_company) -> None: +def test_get_by_id(configure_client, create_company) -> None: c: Company = create_company(company_id="1234567-8") assert c.id is not None @@ -253,14 +299,55 @@ def test_get_by_id(configure_db, create_company) -> None: assert c_2.owner.first_name == "John" +def test_get_by_id_multi_client(configure_multiple_clients) -> None: + owner = Owner(first_name="Alice", last_name="Begone") + c1 = Company(company_id="1234567-8", owner=owner) + c1.save() + + config="multi" + owner = Owner(first_name="Crazy", last_name="Kitten") + c2 = Company(company_id="1234567-9", owner=owner) + c2.save(config) + + assert c1.id is not None + assert c1.company_id == "1234567-8" + assert c1.owner.last_name == "Begone" + + cid1 = Company.get_by_id(c1.id) + + assert cid1.id == c1.id + assert cid1.company_id == "1234567-8" + assert cid1.owner.first_name == "Alice" + + assert c2.id is not None + assert c2.company_id == "1234567-9" + assert c2.owner.last_name == "Kitten" + + cid2 = Company.get_by_id(c2.id, config) + + assert cid2.id == c2.id + assert cid2.company_id == "1234567-9" + assert cid2.owner.first_name == "Crazy" + + +def test_get_by_id_missing_config(configure_multiple_clients): + config="multi" + owner = Owner(first_name="Crazy", last_name="Kitten") + c2 = Company(company_id="1234567-9", owner=owner) + c2.save(config) -def test_get_by_empty_str_id(configure_db) -> None: + with pytest.raises(ModelNotFoundError): + cid2 = Company.get_by_id(c2.id) + + + +def test_get_by_empty_str_id(configure_client) -> None: with pytest.raises(ModelNotFoundError): Company.get_by_id("") -def test_missing_collection(configure_db) -> None: +def test_missing_collection(configure_client) -> None: class User(Model): name: str @@ -269,7 +356,7 @@ class User(Model): -def test_model_aliases(configure_db) -> None: +def test_model_aliases(configure_client) -> None: class User(Model): __collection__ = "User" @@ -314,7 +401,7 @@ class User(Model): "!:&+-*'()", ], ) -def test_models_with_valid_custom_id(configure_db, model_id) -> None: +def test_models_with_valid_custom_id(configure_client, model_id) -> None: product_id = str(uuid4()) product = Product(product_id=product_id, price=123.45, stock=2) @@ -343,7 +430,7 @@ def test_models_with_valid_custom_id(configure_db, model_id) -> None: "foo/bar/baz", ], ) -def test_models_with_invalid_custom_id(configure_db, model_id: str) -> None: +def test_models_with_invalid_custom_id(configure_client, model_id: str) -> None: product = Product(product_id="product 123", price=123.45, stock=2) product.id = model_id with pytest.raises(InvalidDocumentID): @@ -354,7 +441,7 @@ def test_models_with_invalid_custom_id(configure_db, model_id: str) -> None: -def test_truncate_collection(configure_db, create_company) -> None: +def test_truncate_collection(configure_client, create_company) -> None: create_company(company_id="1234567-8") create_company(company_id="1234567-9") @@ -366,8 +453,33 @@ def test_truncate_collection(configure_db, create_company) -> None: assert len(new_companies) == 0 +def test_truncate_multiple_collections(configure_multiple_clients) -> None: + owner = Owner(first_name="Alice", last_name="Begone") + c1 = Company(company_id="1234567-8", owner=owner) + c1.save() + c2 = Company(company_id="1234567-9", owner=owner) + c2.save() + + config="multi" + owner = Owner(first_name="Crazy", last_name="Kitten") + c3 = Company(company_id="1234567-0", owner=owner) + c3.save(config) -def test_custom_id_model(configure_db) -> None: + companies = Company.find({}) + assert len(companies) >= 2 + + companies = Company.find({}, config=config) + assert len(companies) >= 1 + + Company.truncate_collection() + assert len(Company.find({})) == 0 + + Company.truncate_collection(config) + assert len(Company.find({}, config=config)) == 0 + + + +def test_custom_id_model(configure_client) -> None: c = CustomIDModel(bar="bar") # type: ignore c.save() @@ -380,7 +492,7 @@ def test_custom_id_model(configure_db) -> None: -def test_custom_id_conflict(configure_db) -> None: +def test_custom_id_conflict(configure_client) -> None: CustomIDConflictModel(foo="foo", bar="bar").save() models = CustomIDModel.find({}) @@ -392,7 +504,7 @@ def test_custom_id_conflict(configure_db) -> None: -def test_model_id_persistency(configure_db) -> None: +def test_model_id_persistency(configure_client) -> None: c = CustomIDConflictModel(foo="foo", bar="bar") c.save() assert c.id @@ -404,7 +516,7 @@ def test_model_id_persistency(configure_db) -> None: -def test_bare_model_document_id_persistency(configure_db) -> None: +def test_bare_model_document_id_persistency(configure_client) -> None: c = CustomIDModel(bar="bar") # type: ignore c.save() assert c.foo @@ -416,20 +528,20 @@ def test_bare_model_document_id_persistency(configure_db) -> None: -def test_bare_model_get_by_empty_doc_id(configure_db) -> None: +def test_bare_model_get_by_empty_doc_id(configure_client) -> None: with pytest.raises(ModelNotFoundError): CustomIDModel.get_by_doc_id("") -def test_extra_fields(configure_db) -> None: +def test_extra_fields(configure_client) -> None: CustomIDModelExtra(foo="foo", bar="bar", baz="baz").save() # type: ignore with pytest.raises(ValidationError): CustomIDModel.find({}) -def test_company_stats(configure_db, create_company) -> None: +def test_company_stats(configure_client, create_company) -> None: company: Company = create_company(company_id="1234567-8") company_stats = company.stats() @@ -450,7 +562,7 @@ def test_company_stats(configure_db, create_company) -> None: -def test_subcollection_model_safety(configure_db) -> None: +def test_subcollection_model_safety(configure_client) -> None: """ Ensure you shouldn't be able to use unprepared subcollection models accidentally """ @@ -459,7 +571,7 @@ def test_subcollection_model_safety(configure_db) -> None: -def test_get_user_purchases(configure_db) -> None: +def test_get_user_purchases(configure_client) -> None: u = User(name="Foo") u.save() assert u.id @@ -471,7 +583,7 @@ def test_get_user_purchases(configure_db) -> None: -def test_reload(configure_db) -> None: +def test_reload(configure_client) -> None: u = User(name="Foo") u.save() @@ -488,9 +600,35 @@ def test_reload(configure_db) -> None: with pytest.raises(ModelNotFoundError): another_user.reload() +def test_reload_multiple(configure_multiple_clients) -> None: + u1 = User(name="Coco") + u1.save() + + u2 = User(name="Foo") + u2.save("multi") + # change the value in the 'multi' database + u2_ = User.find_one({"name": "Foo"}, config="multi") + u2_.name = "Bar" + u2_.save("multi") -def test_save_with_exclude_none(configure_db) -> None: + # assert that u2 was updated + assert u2.name == "Foo" + u2.reload("multi") + + assert u2.name == "Bar" + + # assert that u1 did not change even after reload + u1.reload() + assert u1.name == "Coco" + + another_user = User(name="Another") + with pytest.raises(ModelNotFoundError): + another_user.reload() + + + +def test_save_with_exclude_none(configure_client) -> None: p = Profile(name="Foo") p.save(exclude_none=True) @@ -512,7 +650,7 @@ def test_save_with_exclude_none(configure_db) -> None: -def test_save_with_exclude_unset(configure_db) -> None: +def test_save_with_exclude_unset(configure_client) -> None: p = Profile(photo_url=None) p.save(exclude_unset=True) @@ -534,11 +672,11 @@ def test_save_with_exclude_unset(configure_db) -> None: -def test_update_city_in_transaction(configure_db) -> None: +def test_update_city_in_transaction(configure_client) -> None: """ Test updating a model in a transaction. Test case from README. - :param: configure_db: pytest fixture + :param: configure_client: pytest fixture """ @transactional @@ -559,12 +697,43 @@ def decrement_population( assert c.population == 0 +def test_update_city_in_multiple_transactions(configure_multiple_clients) -> None: + """ + Test updating a model in 2 separate transactions. Based on test case from README. + """ + + @transactional + def decrement_population( + transaction: Transaction, city: City, decrement: int = 1, config: str = "(default)" + ): + city.reload(config, transaction=transaction) + city.population = max(0, city.population - decrement) + city.save(config, transaction=transaction) + + c = City(id="SF", population=1) + c.save() + c.increment_population(increment=1) + assert c.population == 2 + + t = get_transaction() + decrement_population(transaction=t, city=c, decrement=5) + assert c.population == 0 + + c = City(id="SF", population=8) + c.save("multi") + c.increment_population(increment=1, config="multi") + assert c.population == 9 + + t = get_transaction("multi") + decrement_population(transaction=t, city=c, decrement=5, config="multi") + assert c.population == 4 + -def test_delete_in_transaction(configure_db) -> None: +def test_delete_in_transaction(configure_client) -> None: """ Test deleting a model in a transaction. - :param: configure_db: pytest fixture + :param: configure_client: pytest fixture """ @transactional @@ -587,11 +756,11 @@ def delete_in_transaction( -def test_update_model_in_transaction(configure_db) -> None: +def test_update_model_in_transaction(configure_client) -> None: """ Test updating a model in a transaction. - :param: configure_db: pytest fixture + :param: configure_client: pytest fixture """ @transactional @@ -614,11 +783,11 @@ def update_in_transaction( -def test_update_submodel_in_transaction(configure_db) -> None: +def test_update_submodel_in_transaction(configure_client) -> None: """ Test Updating a submodel in a transaction. - :param: configure_db: pytest fixture + :param: configure_client: pytest fixture """ @transactional From 42a3b4acd44958d41e74449fa530961f1a46b5d3 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Tue, 26 Aug 2025 23:04:45 -0400 Subject: [PATCH 06/38] small readme update --- README.md | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 79e9955..1d59abd 100644 --- a/README.md +++ b/README.md @@ -522,15 +522,23 @@ source firedantic/bin/activate cd firedantic/tests ``` -Then run poetry within `tests` directory: +Run pytests within `tests` directory: +```bash +pip install pytest +pip install pytest-asyncio +pytest . +``` + +Poetry is configured to run on the published version of firedantic and will overwrite changes. +Not recommended for local development. For public use, poetry is recommended: +Run poetry within `tests` directory: ```bash poetry install poetry run invoke test ``` -to run a singular test: - +To run a singular test: ```bash poetry run pytest tests_sync/.py:: ``` From 4ab6b4327d8529a2bdaaad5050340ab3fd24d5b3 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Mon, 17 Nov 2025 14:52:43 -0300 Subject: [PATCH 07/38] reformatting a bit --- firedantic/_async/model.py | 12 +++++------- firedantic/_sync/helpers.py | 4 +--- firedantic/_sync/model.py | 14 +++++--------- .../configure_firestore_db_client.py | 10 +++++----- multiple_clients_examples/query_sample.py | 5 +++++ 5 files changed, 21 insertions(+), 24 deletions(-) diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index f2fc2b8..f6054c8 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -15,7 +15,7 @@ import firedantic.operators as op from firedantic import async_truncate_collection from firedantic.common import IndexDefinition, OrderDirection -from firedantic.configurations import CONFIGURATIONS +from firedantic.configurations import Configuration from firedantic.exceptions import ( CollectionNotDefined, InvalidDocumentID, @@ -25,6 +25,7 @@ TAsyncBareModel = TypeVar("TAsyncBareModel", bound="AsyncBareModel") TAsyncBareSubModel = TypeVar("TAsyncBareSubModel", bound="AsyncBareSubModel") logger = getLogger("firedantic") +configuration = Configuration() # https://firebase.google.com/docs/firestore/query-data/queries#query_operators FIND_TYPES = { @@ -49,14 +50,11 @@ def get_collection_name(cls, name: Optional[str], config: str = "(default)") -> """ if not name: raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") + return f"{configuration.get_prefix}{name}" - return f"{CONFIGURATIONS[config].prefix}{name}" - -def _get_col_ref(cls, name: Optional[str], config: str = "(default)") -> AsyncCollectionReference: - collection: AsyncCollectionReference = CONFIGURATIONS[config].async_client.collection( - get_collection_name(cls, name, config) - ) +def _get_col_ref(cls, client, name: Optional[str], config: str = "(default)") -> AsyncCollectionReference: + collection: AsyncCollectionReference = configuration.get_collection_name(cls, client, name, config) return collection diff --git a/firedantic/_sync/helpers.py b/firedantic/_sync/helpers.py index 3a5f92c..57dfe99 100644 --- a/firedantic/_sync/helpers.py +++ b/firedantic/_sync/helpers.py @@ -1,9 +1,7 @@ from google.cloud.firestore_v1 import CollectionReference -def truncate_collection( - col_ref: CollectionReference, batch_size: int = 128 -) -> int: +def truncate_collection(col_ref: CollectionReference, batch_size: int = 128) -> int: """ Removes all documents inside a collection. diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index 3cf20e2..836a561 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -15,7 +15,7 @@ import firedantic.operators as op from firedantic import truncate_collection from firedantic.common import IndexDefinition, OrderDirection -from firedantic.configurations import CONFIGURATIONS +from firedantic.configurations import Configuration from firedantic.exceptions import ( CollectionNotDefined, InvalidDocumentID, @@ -25,6 +25,7 @@ TBareModel = TypeVar("TBareModel", bound="BareModel") TBareSubModel = TypeVar("TBareSubModel", bound="BareSubModel") logger = getLogger("firedantic") +configuration = Configuration() # https://firebase.google.com/docs/firestore/query-data/queries#query_operators FIND_TYPES = { @@ -49,14 +50,11 @@ def get_collection_name(cls, name: Optional[str], config: str = "(default)") -> """ if not name: raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") + return f"{configuration.get_prefix}{name}" - return f"{CONFIGURATIONS[config].prefix}{name}" - -def _get_col_ref(cls, name: Optional[str], config: str = "(default)") -> CollectionReference: - collection: CollectionReference = CONFIGURATIONS[config].client.collection( - get_collection_name(cls, name, config) - ) +def _get_col_ref(cls, client, name: Optional[str], config: str = "(default)") -> CollectionReference: + collection: CollectionReference = configuration.get_collection_name(cls, client, name, config) return collection @@ -79,7 +77,6 @@ def save( exclude_unset: bool = False, exclude_none: bool = False, transaction: Optional[Transaction] = None, - ) -> None: """ Saves this model in the database. @@ -171,7 +168,6 @@ def find( # pylint: disable=too-many-arguments :param config: Client configuration to use. :return: List of found models. """ - query: Union[BaseQuery, CollectionReference] = cls._get_col_ref(config) if filter_: for key, value in filter_.items(): diff --git a/multiple_clients_examples/configure_firestore_db_client.py b/multiple_clients_examples/configure_firestore_db_client.py index 32eb206..0cf779d 100644 --- a/multiple_clients_examples/configure_firestore_db_client.py +++ b/multiple_clients_examples/configure_firestore_db_client.py @@ -60,8 +60,8 @@ def configure_multiple_clients(): print(config.get_async_client("billing")) ## will pull the billing async client -# print("\n---- Running OLD way ----") -# configure_client() -# configure_async_client() -# print("\n---- Running NEW way ----") -# configure_multiple_clients() +print("\n---- Running OLD way ----") +configure_client() +configure_async_client() +print("\n---- Running NEW way ----") +configure_multiple_clients() diff --git a/multiple_clients_examples/query_sample.py b/multiple_clients_examples/query_sample.py index 7661963..e3d441f 100644 --- a/multiple_clients_examples/query_sample.py +++ b/multiple_clients_examples/query_sample.py @@ -13,6 +13,9 @@ companies1 = Company.find({"owner.first_name": "John"}) companies2 = Company.find({"owner.first_name": {op.EQ: "John"}}) companies3 = Company.find({"owner.first_name": {"==": "John"}}) +print(companies1) + +assert companies1 != [] assert companies1 == companies2 == companies3 @@ -22,4 +25,6 @@ companies1 = Company.find({"owner.first_name": "Alice"}) companies2 = Company.find({"owner.first_name": {op.EQ: "Alice"}}) companies3 = Company.find({"owner.first_name": {"==": "Alice"}}) +print(companies1) +assert companies1 != [] assert companies1 == companies2 == companies3 From 5a80e80b1366408f7816d87a197dda3a70330b60 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Mon, 17 Nov 2025 16:54:37 -0300 Subject: [PATCH 08/38] reverting and adding some config stuff --- firedantic/configurations.py | 129 +++++++++++++++++------------------ 1 file changed, 61 insertions(+), 68 deletions(-) diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 8f2f958..af76b0a 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -1,112 +1,105 @@ from os import environ -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union from google.auth.credentials import Credentials from google.cloud.firestore_v1 import AsyncClient, AsyncTransaction, Client, Transaction from pydantic import BaseModel +""" Old Way """ +CONFIGURATIONS: Dict[str, Any] = {} -# for added support of multiple configurations/clients -class ConfigItem(BaseModel): - prefix: str - project: str = None - credentials: Optional[Credentials] = None - client: Optional[Client] = None - async_client: Optional[AsyncClient] = None - - class Config: - arbitrary_types_allowed = True - - -# updating CONFIGURATIONS dict to be able to contain many/multiple configurations -CONFIGURATIONS: Dict[str, ConfigItem] = {} - - -def get_client(proj, creds) -> Union[Client, AsyncClient]: - # Firestore emulator must be running if using locally. - if environ.get("FIRESTORE_EMULATOR_HOST"): - client = Client(project=proj, credentials=creds) - else: - client = Client() - - return client - - -# Allow configure method to work as it was for backwards compatibility. -def configure( - db: Union[Client, AsyncClient] = None, - prefix: str = "", -) -> None: +def configure(client: Union[Client, AsyncClient], prefix: str = "") -> None: """Configures the prefix and DB. - :param db: The firestore client instance. :param prefix: The prefix to use for collection names. """ global CONFIGURATIONS - CONFIGURATIONS["(default)"] = ConfigItem + CONFIGURATIONS["db"] = client + CONFIGURATIONS["prefix"] = prefix - if isinstance(db, Client): - CONFIGURATIONS["(default)"].client = db - elif isinstance(db, AsyncClient): - CONFIGURATIONS["(default)"].async_client = db - CONFIGURATIONS["(default)"].prefix = prefix - # other params get set to None by default - -def get_transaction(config: str = "(default)") -> Transaction: +def get_transaction() -> Transaction: """ Get a new Firestore transaction for the configured client. """ - transaction = CONFIGURATIONS[config].client.transaction() + transaction = CONFIGURATIONS["db"].client.transaction() assert isinstance(transaction, Transaction) return transaction -def get_async_transaction(config: str = "(default)") -> AsyncTransaction: +def get_async_transaction() -> AsyncTransaction: """ Get a new async Firestore transaction for the configured client. """ - transaction = CONFIGURATIONS[config].async_client.transaction() + transaction = CONFIGURATIONS["db"].async_client.transaction() assert isinstance(transaction, AsyncTransaction) return transaction +""" New Way """ +class ConfigItem(BaseModel): + prefix: str + project: str = None + credentials: Optional[Credentials] = None + client: Optional[Client] = None + async_client: Optional[AsyncClient] = None class Configuration: def __init__(self): self.configurations: Dict[str, ConfigItem] = {} - def create( + """ + Add a named configuration. + + You may either pass pre-built client and/or async_client, + or provide only project/credentials so clients will be constructed here. + """ + def add( self, name: str = "(default)", prefix: str = "", project: str = "", - credentials: Credentials = None, - ) -> None: - self.configurations[name] = ConfigItem( + credentials: Optional[Credentials] = None, + client: Optional[Client] = None, + async_client: Optional[AsyncClient] = None, + ) -> ConfigItem: + # Construct clients only if they were not supplied + if client is None: + client = Client(project=project, credentials=credentials) + if async_client is None: + async_client = AsyncClient(project=project, credentials=credentials) + + item = ConfigItem( prefix=prefix, project=project, credentials=credentials, - client=Client( - project=project, - credentials=credentials, - ), - async_client=AsyncClient( - project=project, - credentials=credentials, - ), + client=client, + async_client=async_client, ) - # add to global CONFIGURATIONS - global CONFIGURATIONS - CONFIGURATIONS[name] = self.configurations[name] - - def get_client(self, name: str = "(default)") -> Client: - return self.configurations[name].client + self.configurations[name] = item + return item + + def get_config(self, name: str = "(default)") -> ConfigItem: + try: + return self.configurations[name] + except KeyError as err: + raise KeyError( + f"Configuration '{name}' not found. Available: {list(self.configurations.keys())}" + ) from err - def get_async_client(self, name: str = "(default)") -> AsyncClient: - return self.configurations[name].async_client - - def get_transaction(self, name: str = "(default)") -> Transaction: + """ + Resolve None -> "(default)" for clients since one may pass None to indicate default, + the same follows for transactions. + """ + def get_client(self, name: Optional[str] = None) -> Client: + resolved = name if name is not None else "(default)" + return self.get_config(resolved).client + + def get_async_client(self, name: Optional[str] = None) -> AsyncClient: + resolved = name if name is not None else "(default)" + return self.get_config(resolved).async_client + + def get_transaction(self, name: Optional[str] = None) -> Transaction: return self.get_client(name=name).transaction() - def get_async_transaction(self, name: str = "(default)") -> AsyncTransaction: + def get_async_transaction(self, name: Optional[str] = None) -> AsyncTransaction: return self.get_async_client(name=name).transaction() From 9675ce3ad32bd0066d7d69cbce1b226d3ec1a14e Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Mon, 17 Nov 2025 16:58:15 -0300 Subject: [PATCH 09/38] one more small fix --- firedantic/configurations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firedantic/configurations.py b/firedantic/configurations.py index af76b0a..a48ab48 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -22,7 +22,7 @@ def get_transaction() -> Transaction: """ Get a new Firestore transaction for the configured client. """ - transaction = CONFIGURATIONS["db"].client.transaction() + transaction = CONFIGURATIONS["db"].transaction() assert isinstance(transaction, Transaction) return transaction @@ -31,7 +31,7 @@ def get_async_transaction() -> AsyncTransaction: """ Get a new async Firestore transaction for the configured client. """ - transaction = CONFIGURATIONS["db"].async_client.transaction() + transaction = CONFIGURATIONS["db"].transaction() assert isinstance(transaction, AsyncTransaction) return transaction From b9d29814f593d71c56a0e79158780d4305d79cc2 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Tue, 18 Nov 2025 11:32:22 -0300 Subject: [PATCH 10/38] remove config as added parameter, too risky for bugs --- firedantic/_sync/model.py | 53 ++++++++++++++++-------------------- firedantic/configurations.py | 4 +++ 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index 836a561..f85242c 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -42,7 +42,7 @@ } -def get_collection_name(cls, name: Optional[str], config: str = "(default)") -> str: +def get_collection_name(cls, name: Optional[str]) -> str: """ Returns the collection name. @@ -53,8 +53,8 @@ def get_collection_name(cls, name: Optional[str], config: str = "(default)") -> return f"{configuration.get_prefix}{name}" -def _get_col_ref(cls, client, name: Optional[str], config: str = "(default)") -> CollectionReference: - collection: CollectionReference = configuration.get_collection_name(cls, client, name, config) +def _get_col_ref(cls, name: Optional[str]) -> CollectionReference: + collection: CollectionReference = configuration.get_collection_name(cls, name) return collection @@ -69,10 +69,11 @@ class BareModel(pydantic.BaseModel, ABC): __document_id__: str __ttl_field__: Optional[str] = None __composite_indexes__: Optional[Iterable[IndexDefinition]] = None + __db_config__: str = "(default)" # can be overridden in subclasses + def save( self, - config: str = "(default)", *, exclude_unset: bool = False, exclude_none: bool = False, @@ -92,25 +93,25 @@ def save( if self.__document_id__ in data: del data[self.__document_id__] - doc_ref = self._get_doc_ref(config) + doc_ref = self._get_doc_ref() if transaction is not None: transaction.set(doc_ref, data) else: doc_ref.set(data) setattr(self, self.__document_id__, doc_ref.id) - def delete(self, config: str = "(default)", transaction: Optional[Transaction] = None) -> None: + def delete(self, transaction: Optional[Transaction] = None) -> None: """ Deletes this model from the database. :raise DocumentIDError: If the ID is not valid. """ if transaction is not None: - transaction.delete(self._get_doc_ref(config)) + transaction.delete(self._get_doc_ref()) else: - self._get_doc_ref(config).delete() + self._get_doc_ref().delete() - def reload(self, config: str = "(default)", transaction: Optional[Transaction] = None) -> None: + def reload(self, transaction: Optional[Transaction] = None) -> None: """ Reloads this model from the database. @@ -121,7 +122,7 @@ def reload(self, config: str = "(default)", transaction: Optional[Transaction] = if doc_id is None: raise ModelNotFoundError("Can not reload unsaved model") - updated_model = self.get_by_doc_id(doc_id, config=config, transaction=transaction) + updated_model = self.get_by_doc_id(doc_id, transaction=transaction) updated_model_doc_id = updated_model.__dict__[self.__document_id__] assert doc_id == updated_model_doc_id @@ -148,7 +149,6 @@ def find( # pylint: disable=too-many-arguments limit: Optional[int] = None, offset: Optional[int] = None, transaction: Optional[Transaction] = None, - config: Optional[str] = "(default)" ) -> List[TBareModel]: """ Returns a list of models from the database based on a filter. @@ -165,10 +165,9 @@ def find( # pylint: disable=too-many-arguments :param limit: Maximum results to return. :param offset: Skip the first n results. :param transaction: Optional transaction to use. - :param config: Client configuration to use. :return: List of found models. """ - query: Union[BaseQuery, CollectionReference] = cls._get_col_ref(config) + query: Union[BaseQuery, CollectionReference] = cls._get_col_ref() if filter_: for key, value in filter_.items(): query = cls._add_filter(query, key, value) @@ -228,7 +227,6 @@ def find_one( filter_: Optional[Dict[str, Union[str, dict]]] = None, order_by: Optional[_OrderBy] = None, transaction: Optional[Transaction] = None, - config: Optional[str] = "(default)" ) -> TBareModel: """ Returns one model from the DB based on a filter. @@ -239,7 +237,7 @@ def find_one( :raise ModelNotFoundError: If the entry is not found. """ model = cls.find( - filter_, limit=1, order_by=order_by, transaction=transaction, config=config + filter_, limit=1, order_by=order_by, transaction=transaction ) try: return model[0] @@ -250,7 +248,6 @@ def find_one( def get_by_doc_id( cls: Type[TBareModel], doc_id: str, - config: str = "(default)", transaction: Optional[Transaction] = None, ) -> TBareModel: """ @@ -275,7 +272,7 @@ def get_by_doc_id( ) from e document: DocumentSnapshot = ( - cls._get_col_ref(config).document(doc_id).get(transaction=transaction) + cls._get_col_ref().document(doc_id).get(transaction=transaction) ) # type: ignore data = document.to_dict() if data is None: @@ -288,7 +285,7 @@ def get_by_doc_id( return model @classmethod - def truncate_collection(cls, config: str = "(default)", batch_size: int = 128) -> int: + def truncate_collection(cls, batch_size: int = 128) -> int: """ Removes all documents inside a collection. @@ -296,16 +293,16 @@ def truncate_collection(cls, config: str = "(default)", batch_size: int = 128) - :return: Number of removed documents. """ return truncate_collection( - col_ref=cls._get_col_ref(config), + col_ref=cls._get_col_ref(), batch_size=batch_size, ) @classmethod - def _get_col_ref(cls, config: str = "(default)") -> CollectionReference: + def _get_col_ref(cls) -> CollectionReference: """ Returns the collection reference. """ - return _get_col_ref(cls, cls.__collection__, config) + return _get_col_ref(cls, cls.__collection__) @classmethod def get_collection_name(cls) -> str: @@ -314,13 +311,13 @@ def get_collection_name(cls) -> str: """ return get_collection_name(cls, cls.__collection__) - def _get_doc_ref(self, config: str = "(default)") -> DocumentReference: + def _get_doc_ref(self) -> DocumentReference: """ Returns the document reference. :raise DocumentIDError: If the ID is not valid. """ - return self._get_col_ref(config).document(self.get_document_id()) # type: ignore + return self._get_col_ref().document(self.get_document_id()) # type: ignore @staticmethod def _validate_document_id(document_id: str): @@ -364,7 +361,6 @@ class Model(BareModel): def get_by_id( cls: Type[TBareModel], id_: str, - config: str = "(default)", transaction: Optional[Transaction] = None, ) -> TBareModel: """ @@ -374,7 +370,7 @@ def get_by_id( :param transaction: Optional transaction to use. :raises ModelNotFoundError: If no model was found by given id. """ - return cls.get_by_doc_id(id_, config=config, transaction=transaction) + return cls.get_by_doc_id(id_, transaction=transaction) class BareSubCollection(ABC): @@ -412,7 +408,7 @@ def _create(cls: Type[TBareSubModel], **kwargs) -> TBareSubModel: ) @classmethod - def _get_col_ref(cls, config: str = "(default)") -> CollectionReference: + def _get_col_ref(cls) -> CollectionReference: """ Returns the collection reference. """ @@ -421,7 +417,7 @@ def _get_col_ref(cls, config: str = "(default)") -> CollectionReference: f"{cls.__name__} is not properly prepared. " f"You should use {cls.__name__}.model_for(parent)" ) - return _get_col_ref(cls.__collection_cls__, cls.__collection__, config) + return _get_col_ref(cls.__collection_cls__, cls.__collection__) @classmethod def model_for(cls, parent): @@ -438,7 +434,6 @@ class SubModel(BareSubModel): def get_by_id( cls: Type[TBareModel], id_: str, - config: str = "(default)", transaction: Optional[Transaction] = None, ) -> TBareModel: """ @@ -448,7 +443,7 @@ def get_by_id( :param transaction: Optional transaction to use. :raises ModelNotFoundError: """ - return cls.get_by_doc_id(id_, config=config, transaction=transaction) + return cls.get_by_doc_id(id_, transaction=transaction) class SubCollection(BareSubCollection, ABC): diff --git a/firedantic/configurations.py b/firedantic/configurations.py index a48ab48..654572f 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -43,6 +43,10 @@ class ConfigItem(BaseModel): client: Optional[Client] = None async_client: Optional[AsyncClient] = None + model_config = { + "arbitrary_types_allowed": True + } + class Configuration: def __init__(self): self.configurations: Dict[str, ConfigItem] = {} From 6abdf6fdbadf0652f4e02f145e46990d75b6b03c Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Tue, 18 Nov 2025 20:52:05 -0300 Subject: [PATCH 11/38] more re-org, tests not passing but closer --- firedantic/_async/model.py | 34 +++++++++----- firedantic/_sync/model.py | 56 +++++++++++++++++++---- firedantic/configurations.py | 31 +++++++++---- firedantic/tests/tests_async/conftest.py | 7 +++ firedantic/tests/tests_sync/conftest.py | 29 +++++++++--- firedantic/tests/tests_sync/test_model.py | 9 +++- 6 files changed, 132 insertions(+), 34 deletions(-) diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index f6054c8..c52dc82 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -8,6 +8,7 @@ AsyncDocumentReference, DocumentSnapshot, FieldFilter, + AsyncClient ) from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.async_transaction import AsyncTransaction @@ -42,7 +43,7 @@ } -def get_collection_name(cls, name: Optional[str], config: str = "(default)") -> str: +def get_collection_name(cls, name: Optional[str]) -> str: """ Returns the collection name. @@ -50,11 +51,11 @@ def get_collection_name(cls, name: Optional[str], config: str = "(default)") -> """ if not name: raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") - return f"{configuration.get_prefix}{name}" + return f"{configuration.get_config_name}" -def _get_col_ref(cls, client, name: Optional[str], config: str = "(default)") -> AsyncCollectionReference: - collection: AsyncCollectionReference = configuration.get_collection_name(cls, client, name, config) +def _get_col_ref(cls, name: Optional[str]) -> AsyncCollectionReference: + collection: AsyncCollectionReference = configuration.get_async_client().collection(get_collection_name(cls, name)) return collection @@ -69,6 +70,12 @@ class AsyncBareModel(pydantic.BaseModel, ABC): __document_id__: str __ttl_field__: Optional[str] = None __composite_indexes__: Optional[Iterable[IndexDefinition]] = None + __db_config__: str = "(default)" # can be overridden in subclasses + + @property + def _client(self) -> AsyncClient: + """Return the Firestore client bound to this model’s configured DB.""" + return configuration.get_async_client(self.__db_config__) async def save( self, @@ -268,9 +275,8 @@ async def get_by_doc_id( raise ModelNotFoundError( f"No '{cls.__name__}' found with {cls.__document_id__} '{doc_id}'" ) from e - document: DocumentSnapshot = ( - await cls._get_col_ref().document(doc_id).get(transaction=transaction) + await cls._get_doc_ref().get(transaction=transaction) ) # type: ignore data = document.to_dict() if data is None: @@ -308,14 +314,20 @@ def get_collection_name(cls) -> str: Returns the collection name. """ return get_collection_name(cls, cls.__collection__) - + def _get_doc_ref(self) -> AsyncDocumentReference: """ - Returns the document reference. - - :raise DocumentIDError: If the ID is not valid. + Return an AsyncDocumentReference for this instance. + If the instance already has an id, use it; otherwise call .document() with no args + so Firestore generates a new id. """ - return self._get_col_ref().document(self.get_document_id()) # type: ignore + col_ref = self._get_col_ref() + + doc_id = self.get_document_id() + if doc_id: + return col_ref.document(doc_id) + # no id -> create a new document reference (server will generate id on set) + return col_ref.document() @staticmethod def _validate_document_id(document_id: str): diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index f85242c..d7efa2b 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -8,6 +8,7 @@ DocumentReference, DocumentSnapshot, FieldFilter, + Client ) from google.cloud.firestore_v1.base_query import BaseQuery from google.cloud.firestore_v1.transaction import Transaction @@ -42,19 +43,19 @@ } -def get_collection_name(cls, name: Optional[str]) -> str: +def get_collection_name(model_class, name: Optional[str]) -> str: """ Returns the collection name. :raise CollectionNotDefined: If the collection name is not defined. """ if not name: - raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") - return f"{configuration.get_prefix}{name}" + raise CollectionNotDefined(f"Missing collection name for {model_class.__name__}") + return f"{configuration.get_config_name(name)}" -def _get_col_ref(cls, name: Optional[str]) -> CollectionReference: - collection: CollectionReference = configuration.get_collection_name(cls, name) +def _get_col_ref(model_class, name: Optional[str]) -> CollectionReference: + collection: CollectionReference = get_collection_name(model_class, name) return collection @@ -71,6 +72,42 @@ class BareModel(pydantic.BaseModel, ABC): __composite_indexes__: Optional[Iterable[IndexDefinition]] = None __db_config__: str = "(default)" # can be overridden in subclasses + @property + def _resolve_config(self) -> str: + """Return the config name this class is bound to""" + return getattr(self.__class__, "__db_config__", "(default)") + + @property + def _client(self) -> Client: + """Return the Firestore client this class is bound to""" + return configuration.get_client(self.__db_config__) + + def _get_collection(self, config_name: Optional[str] = None) -> CollectionReference: + """ + Return a CollectionReference for this model based on the resolved config and + a collection name derived from the class (can override with a different naming scheme) + """ + resolved = config_name if config_name is not None else self._resolve_config() + client: Client = configuration.get_client(name=resolved) + + # Use the config item's prefix + class-name-based collection name. + cfg_item = configuration.get_config(resolved) + collection_name = f"{cfg_item.prefix}{self.__class__.__name__.lower()}s" + return client.collection(collection_name) + + def _get_doc_ref(self, config_name: Optional[str] = None) -> DocumentReference: + """ + Return a DocumentReference for this instance. + If the instance has an id (self.__document_id__), use it; otherwise create a new doc ref + (client.collection().document() without an id will generate one). + """ + collection = self._get_collection(config_name) + doc_id = getattr(self, self.__document_id__, None) + if doc_id: + return collection.document(doc_id) + # No id: create a new document reference (Firestore client will generate ID) + return collection.document() + def save( self, @@ -93,11 +130,16 @@ def save( if self.__document_id__ in data: del data[self.__document_id__] + # get the document ref bound to the models configuration (db_config) doc_ref = self._get_doc_ref() + + # Write data to firestore db using transaction or directly. if transaction is not None: transaction.set(doc_ref, data) else: doc_ref.set(data) + + # Ensure the instance has the document id set. setattr(self, self.__document_id__, doc_ref.id) def delete(self, transaction: Optional[Transaction] = None) -> None: @@ -236,9 +278,7 @@ def find_one( :return: The model instance. :raise ModelNotFoundError: If the entry is not found. """ - model = cls.find( - filter_, limit=1, order_by=order_by, transaction=transaction - ) + model = cls.find(filter_, limit=1, order_by=order_by, transaction=transaction) try: return model[0] except IndexError as e: diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 654572f..5c7258b 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -49,7 +49,22 @@ class ConfigItem(BaseModel): class Configuration: def __init__(self): - self.configurations: Dict[str, ConfigItem] = {} + self.config: Dict[str, ConfigItem] = { + "(default)": ConfigItem( + name="(default)", + prefix="", + project=environ.get("GOOGLE_CLOUD_PROJECT", ""), + credentials=None, + client=Client( + project=environ.get("GOOGLE_CLOUD_PROJECT", ""), + credentials=None, + ), + async_client=AsyncClient( + project=environ.get("GOOGLE_CLOUD_PROJECT", ""), + credentials=None, + ), + ) + } """ Add a named configuration. @@ -59,7 +74,7 @@ def __init__(self): """ def add( self, - name: str = "(default)", + name: str = "(default)", # adding a config without a name results in overriding the default prefix: str = "", project: str = "", credentials: Optional[Credentials] = None, @@ -79,15 +94,15 @@ def add( client=client, async_client=async_client, ) - self.configurations[name] = item + self.config[name] = item return item - def get_config(self, name: str = "(default)") -> ConfigItem: + def get_config_name(self, name: str = "(default)") -> ConfigItem: try: - return self.configurations[name] + return self.config[name] except KeyError as err: raise KeyError( - f"Configuration '{name}' not found. Available: {list(self.configurations.keys())}" + f"Configuration '{name}' not found. Available: {list(self.config.keys())}" ) from err """ @@ -96,11 +111,11 @@ def get_config(self, name: str = "(default)") -> ConfigItem: """ def get_client(self, name: Optional[str] = None) -> Client: resolved = name if name is not None else "(default)" - return self.get_config(resolved).client + return self.get_config_name(resolved).client def get_async_client(self, name: Optional[str] = None) -> AsyncClient: resolved = name if name is not None else "(default)" - return self.get_config(resolved).async_client + return self.get_config_name(resolved).async_client def get_transaction(self, name: Optional[str] = None) -> Transaction: return self.get_client(name=name).transaction() diff --git a/firedantic/tests/tests_async/conftest.py b/firedantic/tests/tests_async/conftest.py index 19b3bb4..41b829c 100644 --- a/firedantic/tests/tests_async/conftest.py +++ b/firedantic/tests/tests_async/conftest.py @@ -169,6 +169,13 @@ class ExpiringModel(AsyncModel): content: str +@pytest.fixture(scope="session", autouse=True) +def use_emulator(monkeypatch): + monkeypatch.setenv("FIRESTORE_EMULATOR_HOST", "localhost:8080") + monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project") + yield + + @pytest.fixture(autouse=True) def configure_client(): client = AsyncClient( diff --git a/firedantic/tests/tests_sync/conftest.py b/firedantic/tests/tests_sync/conftest.py index 2924ae0..660ddf8 100644 --- a/firedantic/tests/tests_sync/conftest.py +++ b/firedantic/tests/tests_sync/conftest.py @@ -19,7 +19,7 @@ from firedantic.configurations import configure, Configuration, get_transaction from firedantic.exceptions import ModelNotFoundError -from unittest.mock import Mock # noqa isort: skip +from unittest.mock import MagicMock, Mock # noqa isort: skip class CustomIDModel(BareModel): @@ -44,7 +44,6 @@ class CustomIDModelExtra(BareModel): class Config: extra = "forbid" - class CustomIDConflictModel(Model): __collection__ = "custom" @@ -75,6 +74,7 @@ def _increment_population(transaction: Transaction) -> None: class Owner(BaseModel): """Dummy owner Pydantic model.""" + __collection__ = "owners" first_name: str last_name: str @@ -146,7 +146,6 @@ class Profile(Model): class Config: extra = "forbid" - class TodoList(Model): """Dummy todo list Firedantic model.""" @@ -158,7 +157,6 @@ class TodoList(Model): class Config: extra = "forbid" - class ExpiringModel(Model): """Dummy expiring model Firedantic model.""" @@ -169,8 +167,16 @@ class ExpiringModel(Model): content: str -@pytest.fixture(autouse=True) -def configure_client(): + +# @pytest.fixture(scope="session", autouse=True) +# def use_emulator(monkeypatch): +# monkeypatch.setenv("FIRESTORE_EMULATOR_HOST", "localhost:8080") +# monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project") +# yield + + +# allow configuring test client in this way to ensure old 'configure' method remains intact +def configure_client_old(): client = Client( project="ioxio-local-dev", credentials=Mock(spec=google.auth.credentials.Credentials), @@ -203,6 +209,17 @@ def configure_multiple_clients(): @pytest.fixture def create_company(): + # Register the config under the name "companies" + config = Configuration() + config.add( + name="companies", + prefix="test_", + project="test-project", + ) + + # # cleanup after test(s) + # config.configurations.pop("companies", None) + def _create( company_id: str = "1234567-8", first_name: str = "John", last_name: str = "Doe" ): diff --git a/firedantic/tests/tests_sync/test_model.py b/firedantic/tests/tests_sync/test_model.py index b3fcbde..496845c 100644 --- a/firedantic/tests/tests_sync/test_model.py +++ b/firedantic/tests/tests_sync/test_model.py @@ -37,7 +37,14 @@ -def test_save_model(configure_client, create_company) -> None: +def test_save_model_old_way(create_company) -> None: + company = create_company() + + assert company.id is not None + assert company.owner.first_name == "John" + assert company.owner.last_name == "Doe" + +def test_save_model_new_way(create_company) -> None: company = create_company() assert company.id is not None From e32a2f9387c34dc2ac035b65b3568ddfebc5c654 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Sun, 23 Nov 2025 22:00:05 -0300 Subject: [PATCH 12/38] add model async tests passing --- firedantic/_async/indexes.py | 26 ++- firedantic/_async/model.py | 90 +++++--- firedantic/_sync/indexes.py | 27 ++- firedantic/_sync/model.py | 3 +- firedantic/configurations.py | 213 +++++++++++++------ firedantic/tests/tests_async/conftest.py | 16 +- firedantic/tests/tests_async/test_indexes.py | 27 ++- firedantic/tests/tests_async/test_model.py | 76 +++---- firedantic/tests/tests_sync/conftest.py | 4 +- firedantic/tests/tests_sync/test_indexes.py | 23 +- firedantic/tests/tests_sync/test_model.py | 13 +- start_emulator.sh | 2 + 12 files changed, 337 insertions(+), 183 deletions(-) diff --git a/firedantic/_async/indexes.py b/firedantic/_async/indexes.py index b7d0d52..0f11be8 100644 --- a/firedantic/_async/indexes.py +++ b/firedantic/_async/indexes.py @@ -14,6 +14,7 @@ from firedantic._async.model import AsyncBareModel from firedantic._async.ttl_policy import set_up_ttl_policies from firedantic.common import IndexDefinition, IndexField +from firedantic.configurations import configuration logger = getLogger("firedantic") @@ -81,7 +82,7 @@ async def create_composite_index( async def set_up_composite_indexes( - gcloud_project: str, + gcloud_project: Optional[str], models: Iterable[Type[AsyncBareModel]], database: str = "(default)", client: Optional[FirestoreAdminAsyncClient] = None, @@ -100,24 +101,29 @@ async def set_up_composite_indexes( operations = [] for model in models: - if not model.__composite_indexes__: + if not getattr(model, "__composite_indexes__", None): continue - path = ( - f"projects/{gcloud_project}/databases/{database}/" - f"collectionGroups/{model.get_collection_name()}" - ) + + # Resolve config name: prefer model __db_config__ if present; else default + config_name = getattr(model, "__db_config__", "(default)") + + # If caller did not pass gcloud_project, try to get it from config + project = gcloud_project or configuration.get_config(config_name).project + + # Build collection group path using configuration helper (includes prefix) + collection_group = configuration.get_collection_name(model, name=config_name) + path = f"projects/{project}/databases/{database}/collectionGroups/{collection_group}" + indexes_in_db = await get_existing_indexes(client, path=path) model_indexes = set(model.__composite_indexes__) existing_indexes = indexes_in_db.intersection(model_indexes) new_indexes = model_indexes.difference(indexes_in_db) for index in existing_indexes: - log_str = "Composite index already exists in DB: %s, collection: %s" - logger.debug(log_str, index, model.get_collection_name()) + logger.debug("Composite index already exists in DB: %s, collection: %s", index, collection_group) for index in new_indexes: - log_str = "Creating new composite index: %s, collection: %s" - logger.info(log_str, index, model.get_collection_name()) + logger.info("Creating new composite index: %s, collection: %s", index, collection_group) operation = await create_composite_index(client, index, path) operations.append(operation) diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index c52dc82..6295e2e 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -8,7 +8,6 @@ AsyncDocumentReference, DocumentSnapshot, FieldFilter, - AsyncClient ) from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.async_transaction import AsyncTransaction @@ -16,7 +15,7 @@ import firedantic.operators as op from firedantic import async_truncate_collection from firedantic.common import IndexDefinition, OrderDirection -from firedantic.configurations import Configuration +from firedantic.configurations import configuration from firedantic.exceptions import ( CollectionNotDefined, InvalidDocumentID, @@ -26,7 +25,6 @@ TAsyncBareModel = TypeVar("TAsyncBareModel", bound="AsyncBareModel") TAsyncBareSubModel = TypeVar("TAsyncBareSubModel", bound="AsyncBareSubModel") logger = getLogger("firedantic") -configuration = Configuration() # https://firebase.google.com/docs/firestore/query-data/queries#query_operators FIND_TYPES = { @@ -43,20 +41,56 @@ } -def get_collection_name(cls, name: Optional[str]) -> str: +def get_collection_name(cls, collection_name: Optional[str]) -> str: """ - Returns the collection name. + Return the collection name for `cls`. - :raise CollectionNotDefined: If the collection name is not defined. + - If `collection_name` is provided, treat it as an explicit collection name + and prefix it using the configured prefix for the class (via __db_config__). + - Otherwise, use the class name and the configured prefix. + + :raises CollectionNotDefined: If neither a class collection nor derived name is available. + """ + # Resolve the config name from the model class (default to "(default)") + config_name = getattr(cls, "__db_config__", "(default)") + + # If caller provided an explicit collection string, apply prefix from config + if collection_name: + cfg = configuration.get_config(config_name) + prefix = cfg.prefix or "" + return f"{prefix}{collection_name}" + + if getattr(cls, "__collection__", None): + cfg = configuration.get_config(config_name) + return f"{cfg.prefix or ''}{cls.__collection__}" + + raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") + + + +def _get_col_ref(cls, collection_name: Optional[str]) -> AsyncCollectionReference: + """ + Return an AsyncCollectionReference for the model class using the configured async client. + + :param cls: model class + :param collection_name: optional explicit collection name override + :raises CollectionNotDefined: when collection cannot be resolved """ - if not name: - raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") - return f"{configuration.get_config_name}" + # Build the prefixed collection name + col_name = get_collection_name(cls, collection_name) + # Resolve config name from class and fetch async client + config_name = getattr(cls, "__db_config__", "(default)") + async_client = configuration.get_async_client(config_name) + + # Return the AsyncCollectionReference (this is a real client call) + col_ref = async_client.collection(col_name) + + # Ensure we got the right object back + if not hasattr(col_ref, "document"): + raise RuntimeError(f"_get_col_ref returned unexpected object for {cls}: {type(col_ref)!r}") + return col_ref -def _get_col_ref(cls, name: Optional[str]) -> AsyncCollectionReference: - collection: AsyncCollectionReference = configuration.get_async_client().collection(get_collection_name(cls, name)) - return collection class AsyncBareModel(pydantic.BaseModel, ABC): @@ -70,12 +104,8 @@ class AsyncBareModel(pydantic.BaseModel, ABC): __document_id__: str __ttl_field__: Optional[str] = None __composite_indexes__: Optional[Iterable[IndexDefinition]] = None - __db_config__: str = "(default)" # can be overridden in subclasses + __db_config__: str = "(default)" # override in subclasses when needed - @property - def _client(self) -> AsyncClient: - """Return the Firestore client bound to this model’s configured DB.""" - return configuration.get_async_client(self.__db_config__) async def save( self, @@ -133,7 +163,7 @@ async def reload(self, transaction: Optional[AsyncTransaction] = None) -> None: self.__dict__.update(updated_model.__dict__) - def get_document_id(self): + def get_document_id(self) -> Optional[str]: """ Returns the document ID for this model instance. @@ -178,8 +208,7 @@ async def find( # pylint: disable=too-many-arguments query = cls._add_filter(query, key, value) if order_by is not None: - for order_by_item in order_by: - field, direction = order_by_item + for field, direction in order_by: query = query.order_by(field, direction=direction) # type: ignore if limit is not None: query = query.limit(limit) # type: ignore @@ -275,8 +304,9 @@ async def get_by_doc_id( raise ModelNotFoundError( f"No '{cls.__name__}' found with {cls.__document_id__} '{doc_id}'" ) from e + document: DocumentSnapshot = ( - await cls._get_doc_ref().get(transaction=transaction) + await cls._get_col_ref().document(doc_id).get(transaction=transaction) ) # type: ignore data = document.to_dict() if data is None: @@ -314,20 +344,14 @@ def get_collection_name(cls) -> str: Returns the collection name. """ return get_collection_name(cls, cls.__collection__) - + def _get_doc_ref(self) -> AsyncDocumentReference: """ - Return an AsyncDocumentReference for this instance. - If the instance already has an id, use it; otherwise call .document() with no args - so Firestore generates a new id. - """ - col_ref = self._get_col_ref() + Returns the document reference. - doc_id = self.get_document_id() - if doc_id: - return col_ref.document(doc_id) - # no id -> create a new document reference (server will generate id on set) - return col_ref.document() + :raise DocumentIDError: If the ID is not valid. + """ + return self._get_col_ref().document(self.get_document_id()) # type: ignore @staticmethod def _validate_document_id(document_id: str): @@ -458,4 +482,4 @@ async def get_by_id( class AsyncSubCollection(AsyncBareSubCollection, ABC): __document_id__ = "id" - __model_cls__: Type[AsyncSubModel] + __model_cls__: Type[AsyncSubModel] \ No newline at end of file diff --git a/firedantic/_sync/indexes.py b/firedantic/_sync/indexes.py index a1eb7c5..d5ae4d3 100644 --- a/firedantic/_sync/indexes.py +++ b/firedantic/_sync/indexes.py @@ -14,6 +14,7 @@ from firedantic._sync.model import BareModel from firedantic._sync.ttl_policy import set_up_ttl_policies from firedantic.common import IndexDefinition, IndexField +from firedantic.configurations import configuration logger = getLogger("firedantic") @@ -81,7 +82,7 @@ def create_composite_index( def set_up_composite_indexes( - gcloud_project: str, + gcloud_project: Optional[str], models: Iterable[Type[BareModel]], database: str = "(default)", client: Optional[FirestoreAdminClient] = None, @@ -100,24 +101,30 @@ def set_up_composite_indexes( operations = [] for model in models: - if not model.__composite_indexes__: + + if not getattr(model, "__composite_indexes__", None): continue - path = ( - f"projects/{gcloud_project}/databases/{database}/" - f"collectionGroups/{model.get_collection_name()}" - ) + + # Resolve config name: prefer model __db_config__ if present; else default + config_name = getattr(model, "__db_config__", "(default)") + + # If caller did not pass gcloud_project, try to get it from config + project = gcloud_project or configuration.get_config(config_name).project + + # Build collection group path using configuration helper (includes prefix) + collection_group = configuration.get_collection_name(model, name=config_name) + path = f"projects/{project}/databases/{database}/collectionGroups/{collection_group}" + indexes_in_db = get_existing_indexes(client, path=path) model_indexes = set(model.__composite_indexes__) existing_indexes = indexes_in_db.intersection(model_indexes) new_indexes = model_indexes.difference(indexes_in_db) for index in existing_indexes: - log_str = "Composite index already exists in DB: %s, collection: %s" - logger.debug(log_str, index, model.get_collection_name()) + logger.debug("Composite index already exists in DB: %s, collection: %s", index, collection_group) for index in new_indexes: - log_str = "Creating new composite index: %s, collection: %s" - logger.info(log_str, index, model.get_collection_name()) + logger.info("Creating new composite index: %s, collection: %s", index, collection_group) operation = create_composite_index(client, index, path) operations.append(operation) diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index d7efa2b..7d7680f 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -51,7 +51,7 @@ def get_collection_name(model_class, name: Optional[str]) -> str: """ if not name: raise CollectionNotDefined(f"Missing collection name for {model_class.__name__}") - return f"{configuration.get_config_name(name)}" + return f"{configuration.get_config(name)}" def _get_col_ref(model_class, name: Optional[str]) -> CollectionReference: @@ -72,6 +72,7 @@ class BareModel(pydantic.BaseModel, ABC): __composite_indexes__: Optional[Iterable[IndexDefinition]] = None __db_config__: str = "(default)" # can be overridden in subclasses + @property def _resolve_config(self) -> str: """Return the config name this class is bound to""" diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 5c7258b..35fa696 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -1,70 +1,114 @@ from os import environ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Type, Union from google.auth.credentials import Credentials -from google.cloud.firestore_v1 import AsyncClient, AsyncTransaction, Client, Transaction +from google.cloud.firestore_v1 import ( + AsyncClient, + AsyncTransaction, + Client, + CollectionReference, + Transaction +) from pydantic import BaseModel -""" Old Way """ +# --- Old compatibility surface (kept for backwards compatibility) --- CONFIGURATIONS: Dict[str, Any] = {} def configure(client: Union[Client, AsyncClient], prefix: str = "") -> None: - """Configures the prefix and DB. + """ + Legacy helper: updates the module-level `configuration` default entry and preserves + the old CONFIGURATIONS mapping so old callers continue to work. + Configures the prefix and DB. :param db: The firestore client instance. :param prefix: The prefix to use for collection names. """ - global CONFIGURATIONS + if isinstance(client, AsyncClient): + configuration.add(name="(default)", prefix=prefix, async_client=client) + else: + # treat as sync client + configuration.add(name="(default)", prefix=prefix, client=client) + CONFIGURATIONS["db"] = client CONFIGURATIONS["prefix"] = prefix def get_transaction() -> Transaction: - """ - Get a new Firestore transaction for the configured client. - """ - transaction = CONFIGURATIONS["db"].transaction() - assert isinstance(transaction, Transaction) - return transaction + """Backward-compatible transaction getter (sync).""" + return configuration.get_transaction() def get_async_transaction() -> AsyncTransaction: - """ - Get a new async Firestore transaction for the configured client. - """ - transaction = CONFIGURATIONS["db"].transaction() - assert isinstance(transaction, AsyncTransaction) - return transaction + """Backward-compatible async transaction getter.""" + return configuration.get_async_transaction() -""" New Way """ + +# --- New configuration system --- class ConfigItem(BaseModel): + """ + Holds configuration for a named Firestore connection. + Use arbitrary_types_allowed so we can store Client/AsyncClient/Credentials objects. + """ + name: str prefix: str - project: str = None + project: Optional[str] = None credentials: Optional[Credentials] = None - client: Optional[Client] = None - async_client: Optional[AsyncClient] = None + client: Optional[Any] = None + async_client: Optional[Any] = None - model_config = { - "arbitrary_types_allowed": True - } + model_config = {"arbitrary_types_allowed": True} class Configuration: - def __init__(self): - self.config: Dict[str, ConfigItem] = { - "(default)": ConfigItem( - name="(default)", - prefix="", - project=environ.get("GOOGLE_CLOUD_PROJECT", ""), - credentials=None, - client=Client( - project=environ.get("GOOGLE_CLOUD_PROJECT", ""), - credentials=None, - ), - async_client=AsyncClient( - project=environ.get("GOOGLE_CLOUD_PROJECT", ""), - credentials=None, - ), - ) - } + """ + Registry for named Firestore configurations. + + Usage Example: + configuration.add(name="billing", project="myproj", credentials=creds) + client = configuration.get_client("billing") + """ + + def __init__(self) -> None: + + # mapping name -> ConfigItem + self.config: Dict[str, ConfigItem] = {} + + # Create a sensible default config entry, but avoid passing empty string as project. + default_project = environ.get("GOOGLE_CLOUD_PROJECT") or None + + self.config["(default)"] = ConfigItem( + name="(default)", + prefix="", + project=default_project, + credentials=None, + client=None, + async_client=None + ) + + # dict-like accessors + def __getitem__(self, name: str) -> ConfigItem: + return self.get_config(name) + + def __contains__(self, name: str) -> bool: + return name in self.config + + def get(self, name: str, default=None): + return self.config.get(name, default) + + def get_config(self, name: str = "(default)") -> ConfigItem: + try: + return self.config[name] + except KeyError as err: + raise KeyError( + f"Configuration '{name}' not found. Available: {list(self.config.keys())}" + ) from err + + def _normalize_project(self, project: Optional[str]) -> Optional[str]: + """ + Convert empty-string project to None so Client(...) doesn't get empty string for project (i.e. project="") + """ + if project: + return project + return environ.get("GOOGLE_CLOUD_PROJECT") or None + """ Add a named configuration. @@ -76,49 +120,94 @@ def add( self, name: str = "(default)", # adding a config without a name results in overriding the default prefix: str = "", - project: str = "", + project: Optional[str] = None, credentials: Optional[Credentials] = None, client: Optional[Client] = None, async_client: Optional[AsyncClient] = None, ) -> ConfigItem: - # Construct clients only if they were not supplied - if client is None: - client = Client(project=project, credentials=credentials) - if async_client is None: - async_client = AsyncClient(project=project, credentials=credentials) + + normalized_project = self._normalize_project(project) + + # Construct clients lazily, only when they were not supplied + if client is None and normalized_project is not None: + client = Client(project=normalized_project, credentials=credentials) + if async_client is None and normalized_project is not None: + async_client = AsyncClient(project=normalized_project, credentials=credentials) item = ConfigItem( + name=name, prefix=prefix, - project=project, + project=normalized_project, credentials=credentials, client=client, async_client=async_client, ) self.config[name] = item return item - - def get_config_name(self, name: str = "(default)") -> ConfigItem: - try: - return self.config[name] - except KeyError as err: - raise KeyError( - f"Configuration '{name}' not found. Available: {list(self.config.keys())}" - ) from err - """ - Resolve None -> "(default)" for clients since one may pass None to indicate default, - the same follows for transactions. - """ + # client accessors (lazy-check and raise errors) def get_client(self, name: Optional[str] = None) -> Client: resolved = name if name is not None else "(default)" - return self.get_config_name(resolved).client + cfg = self.get_config(resolved) + if cfg.client is None: + # give a clear error; tests should register a mock client or provide a project + raise RuntimeError( + f"No sync client configured for config '{resolved}'. " + "Call configuration.add(..., client=...) or set GOOGLE_CLOUD_PROJECT and let add() build the client." + ) + return cfg.client def get_async_client(self, name: Optional[str] = None) -> AsyncClient: resolved = name if name is not None else "(default)" - return self.get_config_name(resolved).async_client + cfg = self.get_config(resolved) + if cfg.async_client is None: + raise RuntimeError( + f"No async client configured for config '{resolved}'. " + "Call configuration.add(..., async_client=...) or set GOOGLE_CLOUD_PROJECT and let add() build the client." + ) + return cfg.async_client def get_transaction(self, name: Optional[str] = None) -> Transaction: return self.get_client(name=name).transaction() def get_async_transaction(self, name: Optional[str] = None) -> AsyncTransaction: return self.get_async_client(name=name).transaction() + + # helpers for models to derive collection name / reference + def get_collection_name(self, model_class: Type, name: Optional[str] = None) -> str: + """ + Return the collection name string (prefix + model name). + """ + resolved = name if name is not None else "(default)" + cfg = self.get_config(resolved) + prefix = cfg.prefix or "" + model_name = model_class.__name__ + return f"{prefix}{model_name[0].lower()}{model_name[1:]}" # (lower case first letter of model name) + + def get_collection_ref(self, model_class: Type, name: Optional[str] = None) -> CollectionReference: + """ + Return a CollectionReference for the given model_class using the sync client. + """ + resolved = name if name is not None else "(default)" + cfg = self.get_config(resolved) + client = cfg.client + if client is None: + raise RuntimeError(f"No sync client configured for config '{resolved}'") + collection_name = self.get_collection_name(model_class, resolved) + return client.collection(collection_name) + + def get_async_collection_ref(self, model_class: Type, name: Optional[str] = None): + """ + Return an AsyncCollectionReference using the async client. + """ + resolved = name if name is not None else "(default)" + cfg = self.get_config(resolved) + async_client = cfg.async_client + if async_client is None: + raise RuntimeError(f"No async client configured for config '{resolved}'") + collection_name = self.get_collection_name(model_class, resolved) + return async_client.collection(collection_name) + +# make the module-level singleton available to models/tests +# other modules should: from firedantic.configurations import configuration +configuration = Configuration() diff --git a/firedantic/tests/tests_async/conftest.py b/firedantic/tests/tests_async/conftest.py index 41b829c..cd99802 100644 --- a/firedantic/tests/tests_async/conftest.py +++ b/firedantic/tests/tests_async/conftest.py @@ -1,4 +1,5 @@ import uuid +import os from datetime import datetime from typing import Any, List, Optional, Type @@ -168,13 +169,16 @@ class ExpiringModel(AsyncModel): expire: datetime content: str +# @pytest.fixture +# def use_emulator(monkeypatch): +# monkeypatch.setenv("FIRESTORE_EMULATOR_HOST", "localhost:8080") +# monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project") -@pytest.fixture(scope="session", autouse=True) -def use_emulator(monkeypatch): - monkeypatch.setenv("FIRESTORE_EMULATOR_HOST", "localhost:8080") - monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project") - yield - +# # point to your running emulator host:port +# os.environ["FIRESTORE_EMULATOR_HOST"] = "localhost:8080" +# # any non-empty project id is fine for emulator +# os.environ["GOOGLE_CLOUD_PROJECT"] = "test-project" +# yield @pytest.fixture(autouse=True) def configure_client(): diff --git a/firedantic/tests/tests_async/test_indexes.py b/firedantic/tests/tests_async/test_indexes.py index d1bef9c..ee9caf2 100644 --- a/firedantic/tests/tests_async/test_indexes.py +++ b/firedantic/tests/tests_async/test_indexes.py @@ -14,6 +14,7 @@ ) from firedantic.common import IndexField from firedantic.tests.tests_async.conftest import MockListIndexOperation +from firedantic.configurations import configuration import pytest # noqa isort: skip @@ -28,6 +29,7 @@ class BaseModelWithIndexes(AsyncModel): @pytest.mark.asyncio async def test_set_up_composite_index(mock_admin_client) -> None: + configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -46,9 +48,10 @@ class ModelWithIndexes(BaseModelWithIndexes): call_list = mock_admin_client.create_index.call_args_list # index is a protobuf structure sent via Google Cloud Admin path = call_list[0][1]["request"].parent + expected_prefix = configuration.get_config("(default)").prefix assert ( path - == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS["(default)"].prefix}modelWithIndexes" + == f"projects/proj/databases/(default)/collectionGroups/{expected_prefix}modelWithIndexes" ) index = call_list[0][1]["request"].index assert index.query_scope.name == "COLLECTION" @@ -61,6 +64,7 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_set_up_collection_group_index(mock_admin_client) -> None: + configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_group_index( @@ -79,9 +83,11 @@ class ModelWithIndexes(BaseModelWithIndexes): call_list = mock_admin_client.create_index.call_args_list # index is a protobuf structure sent via Google Cloud Admin path = call_list[0][1]["request"].parent + expected_prefix = configuration.get_config("(default)").prefix + assert ( path - == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS["(default)"].prefix}modelWithIndexes" + == f"projects/proj/databases/(default)/collectionGroups/{expected_prefix}modelWithIndexes" ) index = call_list[0][1]["request"].index assert index.query_scope.name == "COLLECTION_GROUP" @@ -90,6 +96,7 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_set_up_composite_indexes_and_policies(mock_admin_client) -> None: + configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -114,6 +121,7 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_set_up_many_composite_indexes(mock_admin_client) -> None: + configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -141,6 +149,7 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_set_up_indexes_model_without_indexes(mock_admin_client) -> None: + configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) class ModelWithoutIndexes(AsyncModel): __collection__ = "modelWithoutIndexes" @@ -159,13 +168,16 @@ class ModelWithoutIndexes(AsyncModel): @pytest.mark.asyncio async def test_existing_indexes_are_skipped(mock_admin_client) -> None: + configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) + expected_prefix = configuration.get_config("(default)").prefix + resp = ListIndexesResponse( { "indexes": [ { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS["(default)"].prefix}modelWithIndexes/123456" + f"{expected_prefix}modelWithIndexes/123456" ), "query_scope": "COLLECTION", "fields": [ @@ -177,7 +189,7 @@ async def test_existing_indexes_are_skipped(mock_admin_client) -> None: { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS["(default)"].prefix}modelWithIndexes/67889" + f"{expected_prefix}modelWithIndexes/67889" ), "query_scope": "COLLECTION", "fields": [ @@ -215,6 +227,9 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_same_fields_in_another_collection(mock_admin_client) -> None: + configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) + expected_prefix = configuration.get_config("(default)").prefix + # Test that when another collection has an index with exactly the same fields, # it won't affect creating an index in the target collection resp = ListIndexesResponse( @@ -223,7 +238,7 @@ async def test_same_fields_in_another_collection(mock_admin_client) -> None: { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS["(default)"].prefix}anotherModel/123456" + f"{expected_prefix}anotherModel/123456" ), "query_scope": "COLLECTION", "fields": [ @@ -252,4 +267,4 @@ class ModelWithIndexes(BaseModelWithIndexes): models=[ModelWithIndexes], client=mock_admin_client, ) - assert len(result) == 1 + assert len(result) == 1 \ No newline at end of file diff --git a/firedantic/tests/tests_async/test_model.py b/firedantic/tests/tests_async/test_model.py index bb3d514..40c791c 100644 --- a/firedantic/tests/tests_async/test_model.py +++ b/firedantic/tests/tests_async/test_model.py @@ -37,7 +37,7 @@ @pytest.mark.asyncio -async def test_save_model(configure_client, create_company) -> None: +async def test_save_model(create_company) -> None: company = await create_company() assert company.id is not None @@ -45,7 +45,7 @@ async def test_save_model(configure_client, create_company) -> None: assert company.owner.last_name == "Doe" @pytest.mark.asyncio -async def test_delete_model(configure_client, create_company) -> None: +async def test_delete_model(create_company) -> None: company: Company = await create_company( company_id="11223344-5", first_name="Jane", last_name="Doe" ) @@ -60,7 +60,7 @@ async def test_delete_model(configure_client, create_company) -> None: @pytest.mark.asyncio -async def test_find_one(configure_client, create_company) -> None: +async def test_find_one(create_company) -> None: with pytest.raises(ModelNotFoundError): await Company.find_one() @@ -91,7 +91,7 @@ async def test_find_one(configure_client, create_company) -> None: @pytest.mark.asyncio -async def test_find(configure_client, create_company, create_product) -> None: +async def test_find(create_company, create_product) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: await create_company(company_id=company_id) @@ -128,7 +128,7 @@ async def test_find(configure_client, create_company, create_product) -> None: @pytest.mark.asyncio -async def test_find_not_in(configure_client, create_company) -> None: +async def test_find_not_in(create_company) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: await create_company(company_id=company_id) @@ -149,7 +149,7 @@ async def test_find_not_in(configure_client, create_company) -> None: @pytest.mark.asyncio -async def test_find_array_contains(configure_client, create_todolist) -> None: +async def test_find_array_contains(create_todolist) -> None: list_1 = await create_todolist("list_1", ["Work", "Eat", "Sleep"]) await create_todolist("list_2", ["Learn Python", "Walk the dog"]) @@ -159,7 +159,7 @@ async def test_find_array_contains(configure_client, create_todolist) -> None: @pytest.mark.asyncio -async def test_find_array_contains_any(configure_client, create_todolist) -> None: +async def test_find_array_contains_any(create_todolist) -> None: list_1 = await create_todolist("list_1", ["Work", "Eat"]) list_2 = await create_todolist("list_2", ["Relax", "Chill", "Sleep"]) await create_todolist("list_3", ["Learn Python", "Walk the dog"]) @@ -171,7 +171,7 @@ async def test_find_array_contains_any(configure_client, create_todolist) -> Non @pytest.mark.asyncio -async def test_find_limit(configure_client, create_company) -> None: +async def test_find_limit(create_company) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: await create_company(company_id=company_id) @@ -184,7 +184,7 @@ async def test_find_limit(configure_client, create_company) -> None: @pytest.mark.asyncio -async def test_find_order_by(configure_client, create_company) -> None: +async def test_find_order_by(create_company) -> None: companies_and_owners = [ {"company_id": "1234555-1", "last_name": "A", "first_name": "A"}, {"company_id": "1234555-2", "last_name": "A", "first_name": "B"}, @@ -233,7 +233,7 @@ async def test_find_order_by(configure_client, create_company) -> None: @pytest.mark.asyncio -async def test_find_offset(configure_client, create_company) -> None: +async def test_find_offset(create_company) -> None: ids_and_lastnames = ( ("1234555-1", "A"), ("1234567-8", "B"), @@ -251,7 +251,7 @@ async def test_find_offset(configure_client, create_company) -> None: @pytest.mark.asyncio -async def test_get_by_id(configure_client, create_company) -> None: +async def test_get_by_id(create_company) -> None: c: Company = await create_company(company_id="1234567-8") assert c.id is not None @@ -266,13 +266,13 @@ async def test_get_by_id(configure_client, create_company) -> None: @pytest.mark.asyncio -async def test_get_by_empty_str_id(configure_client) -> None: +async def test_get_by_empty_str_id() -> None: with pytest.raises(ModelNotFoundError): await Company.get_by_id("") @pytest.mark.asyncio -async def test_missing_collection(configure_client) -> None: +async def test_missing_collection() -> None: class User(AsyncModel): name: str @@ -281,7 +281,7 @@ class User(AsyncModel): @pytest.mark.asyncio -async def test_model_aliases(configure_client) -> None: +async def test_model_aliases() -> None: class User(AsyncModel): __collection__ = "User" @@ -326,7 +326,7 @@ class User(AsyncModel): "!:&+-*'()", ], ) -async def test_models_with_valid_custom_id(configure_client, model_id) -> None: +async def test_models_with_valid_custom_id(model_id) -> None: product_id = str(uuid4()) product = Product(product_id=product_id, price=123.45, stock=2) @@ -355,7 +355,7 @@ async def test_models_with_valid_custom_id(configure_client, model_id) -> None: "foo/bar/baz", ], ) -async def test_models_with_invalid_custom_id(configure_client, model_id: str) -> None: +async def test_models_with_invalid_custom_id(model_id: str) -> None: product = Product(product_id="product 123", price=123.45, stock=2) product.id = model_id with pytest.raises(InvalidDocumentID): @@ -366,7 +366,7 @@ async def test_models_with_invalid_custom_id(configure_client, model_id: str) -> @pytest.mark.asyncio -async def test_truncate_collection(configure_client, create_company) -> None: +async def test_truncate_collection(create_company) -> None: await create_company(company_id="1234567-8") await create_company(company_id="1234567-9") @@ -379,7 +379,7 @@ async def test_truncate_collection(configure_client, create_company) -> None: @pytest.mark.asyncio -async def test_custom_id_model(configure_client) -> None: +async def test_custom_id_model() -> None: c = CustomIDModel(bar="bar") # type: ignore await c.save() @@ -392,7 +392,7 @@ async def test_custom_id_model(configure_client) -> None: @pytest.mark.asyncio -async def test_custom_id_conflict(configure_client) -> None: +async def test_custom_id_conflict() -> None: await CustomIDConflictModel(foo="foo", bar="bar").save() models = await CustomIDModel.find({}) @@ -404,7 +404,7 @@ async def test_custom_id_conflict(configure_client) -> None: @pytest.mark.asyncio -async def test_model_id_persistency(configure_client) -> None: +async def test_model_id_persistency() -> None: c = CustomIDConflictModel(foo="foo", bar="bar") await c.save() assert c.id @@ -416,7 +416,7 @@ async def test_model_id_persistency(configure_client) -> None: @pytest.mark.asyncio -async def test_bare_model_document_id_persistency(configure_client) -> None: +async def test_bare_model_document_id_persistency() -> None: c = CustomIDModel(bar="bar") # type: ignore await c.save() assert c.foo @@ -428,20 +428,20 @@ async def test_bare_model_document_id_persistency(configure_client) -> None: @pytest.mark.asyncio -async def test_bare_model_get_by_empty_doc_id(configure_client) -> None: +async def test_bare_model_get_by_empty_doc_id() -> None: with pytest.raises(ModelNotFoundError): await CustomIDModel.get_by_doc_id("") @pytest.mark.asyncio -async def test_extra_fields(configure_client) -> None: +async def test_extra_fields() -> None: await CustomIDModelExtra(foo="foo", bar="bar", baz="baz").save() # type: ignore with pytest.raises(ValidationError): await CustomIDModel.find({}) @pytest.mark.asyncio -async def test_company_stats(configure_client, create_company) -> None: +async def test_company_stats( create_company) -> None: company: Company = await create_company(company_id="1234567-8") company_stats = company.stats() @@ -462,7 +462,7 @@ async def test_company_stats(configure_client, create_company) -> None: @pytest.mark.asyncio -async def test_subcollection_model_safety(configure_client) -> None: +async def test_subcollection_model_safety() -> None: """ Ensure you shouldn't be able to use unprepared subcollection models accidentally """ @@ -471,7 +471,7 @@ async def test_subcollection_model_safety(configure_client) -> None: @pytest.mark.asyncio -async def test_get_user_purchases(configure_client) -> None: +async def test_get_user_purchases() -> None: u = User(name="Foo") await u.save() assert u.id @@ -483,7 +483,7 @@ async def test_get_user_purchases(configure_client) -> None: @pytest.mark.asyncio -async def test_reload(configure_client) -> None: +async def test_reload() -> None: u = User(name="Foo") await u.save() @@ -502,7 +502,7 @@ async def test_reload(configure_client) -> None: @pytest.mark.asyncio -async def test_save_with_exclude_none(configure_client) -> None: +async def test_save_with_exclude_none() -> None: p = Profile(name="Foo") await p.save(exclude_none=True) @@ -510,7 +510,7 @@ async def test_save_with_exclude_none(configure_client) -> None: assert document_id # pylint: disable=protected-access - document = await Profile._get_col_ref().document(document_id).get() + document = await Profile._get_col_ref().document(document_id).get() ### here data = document.to_dict() assert data == {"name": "Foo"} @@ -524,7 +524,7 @@ async def test_save_with_exclude_none(configure_client) -> None: @pytest.mark.asyncio -async def test_save_with_exclude_unset(configure_client) -> None: +async def test_save_with_exclude_unset() -> None: p = Profile(photo_url=None) await p.save(exclude_unset=True) @@ -546,11 +546,9 @@ async def test_save_with_exclude_unset(configure_client) -> None: @pytest.mark.asyncio -async def test_update_city_in_transaction(configure_client) -> None: +async def test_update_city_in_transaction() -> None: """ Test updating a model in a transaction. Test case from README. - - :param: configure_client: pytest fixture """ @async_transactional @@ -572,11 +570,9 @@ async def decrement_population( @pytest.mark.asyncio -async def test_delete_in_transaction(configure_client) -> None: +async def test_delete_in_transaction() -> None: """ Test deleting a model in a transaction. - - :param: configure_client: pytest fixture """ @async_transactional @@ -599,11 +595,9 @@ async def delete_in_transaction( @pytest.mark.asyncio -async def test_update_model_in_transaction(configure_client) -> None: +async def test_update_model_in_transaction() -> None: """ Test updating a model in a transaction. - - :param: configure_client: pytest fixture """ @async_transactional @@ -626,11 +620,9 @@ async def update_in_transaction( @pytest.mark.asyncio -async def test_update_submodel_in_transaction(configure_client) -> None: +async def test_update_submodel_in_transaction() -> None: """ Test Updating a submodel in a transaction. - - :param: configure_client: pytest fixture """ @async_transactional diff --git a/firedantic/tests/tests_sync/conftest.py b/firedantic/tests/tests_sync/conftest.py index 660ddf8..b30307f 100644 --- a/firedantic/tests/tests_sync/conftest.py +++ b/firedantic/tests/tests_sync/conftest.py @@ -191,14 +191,14 @@ def configure_multiple_clients(): mock_creds = Mock(spec=google.auth.credentials.Credentials) # name = (default) - config.create( + config.add( prefix="ioxio-local-dev-", project=str(uuid.uuid4()) + "-", credentials=mock_creds ) # name = multi - config.create( + config.add( name="multi", prefix="test-multi-", project="test-multi", diff --git a/firedantic/tests/tests_sync/test_indexes.py b/firedantic/tests/tests_sync/test_indexes.py index 2673c4c..7be99d9 100644 --- a/firedantic/tests/tests_sync/test_indexes.py +++ b/firedantic/tests/tests_sync/test_indexes.py @@ -14,6 +14,7 @@ ) from firedantic.common import IndexField from firedantic.tests.tests_sync.conftest import MockListIndexOperation +from firedantic.configurations import configuration import pytest # noqa isort: skip @@ -28,6 +29,7 @@ class BaseModelWithIndexes(Model): def test_set_up_composite_index(mock_admin_client) -> None: + configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -46,9 +48,11 @@ class ModelWithIndexes(BaseModelWithIndexes): call_list = mock_admin_client.create_index.call_args_list # index is a protobuf structure sent via Google Cloud Admin path = call_list[0][1]["request"].parent + expected_prefix = configuration.get_config("(default)").prefix + assert ( path - == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS["(default)"].prefix}modelWithIndexes" + == f"projects/proj/databases/(default)/collectionGroups/{expected_prefix}modelWithIndexes" ) index = call_list[0][1]["request"].index assert index.query_scope.name == "COLLECTION" @@ -61,6 +65,7 @@ class ModelWithIndexes(BaseModelWithIndexes): def test_set_up_collection_group_index(mock_admin_client) -> None: + configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_group_index( @@ -79,9 +84,10 @@ class ModelWithIndexes(BaseModelWithIndexes): call_list = mock_admin_client.create_index.call_args_list # index is a protobuf structure sent via Google Cloud Admin path = call_list[0][1]["request"].parent + expected_prefix = configuration.get_config("(default)").prefix assert ( path - == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS["(default)"].prefix}modelWithIndexes" + == f"projects/proj/databases/(default)/collectionGroups/{expected_prefix}modelWithIndexes" ) index = call_list[0][1]["request"].index assert index.query_scope.name == "COLLECTION_GROUP" @@ -90,6 +96,7 @@ class ModelWithIndexes(BaseModelWithIndexes): def test_set_up_composite_indexes_and_policies(mock_admin_client) -> None: + configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -114,6 +121,7 @@ class ModelWithIndexes(BaseModelWithIndexes): def test_set_up_many_composite_indexes(mock_admin_client) -> None: + configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -141,6 +149,7 @@ class ModelWithIndexes(BaseModelWithIndexes): def test_set_up_indexes_model_without_indexes(mock_admin_client) -> None: + configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) class ModelWithoutIndexes(Model): __collection__ = "modelWithoutIndexes" @@ -159,13 +168,15 @@ class ModelWithoutIndexes(Model): def test_existing_indexes_are_skipped(mock_admin_client) -> None: + configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) + expected_prefix = configuration.get_config("(default)").prefix resp = ListIndexesResponse( { "indexes": [ { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS["(default)"].prefix}modelWithIndexes/123456" + f"{expected_prefix}modelWithIndexes/123456" ), "query_scope": "COLLECTION", "fields": [ @@ -177,7 +188,7 @@ def test_existing_indexes_are_skipped(mock_admin_client) -> None: { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS["(default)"].prefix}modelWithIndexes/67889" + f"{expected_prefix}modelWithIndexes/67889" ), "query_scope": "COLLECTION", "fields": [ @@ -217,13 +228,15 @@ class ModelWithIndexes(BaseModelWithIndexes): def test_same_fields_in_another_collection(mock_admin_client) -> None: # Test that when another collection has an index with exactly the same fields, # it won't affect creating an index in the target collection + configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) + expected_prefix = configuration.get_config("(default)").prefix resp = ListIndexesResponse( { "indexes": [ { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS["(default)"].prefix}anotherModel/123456" + f"{expected_prefix}anotherModel/123456" ), "query_scope": "COLLECTION", "fields": [ diff --git a/firedantic/tests/tests_sync/test_model.py b/firedantic/tests/tests_sync/test_model.py index 496845c..4bbec22 100644 --- a/firedantic/tests/tests_sync/test_model.py +++ b/firedantic/tests/tests_sync/test_model.py @@ -38,6 +38,7 @@ def test_save_model_old_way(create_company) -> None: + company = create_company() assert company.id is not None @@ -569,7 +570,7 @@ def test_company_stats(configure_client, create_company) -> None: -def test_subcollection_model_safety(configure_client) -> None: +def test_subcollection_model_safety() -> None: """ Ensure you shouldn't be able to use unprepared subcollection models accidentally """ @@ -578,7 +579,7 @@ def test_subcollection_model_safety(configure_client) -> None: -def test_get_user_purchases(configure_client) -> None: +def test_get_user_purchases() -> None: u = User(name="Foo") u.save() assert u.id @@ -590,7 +591,7 @@ def test_get_user_purchases(configure_client) -> None: -def test_reload(configure_client) -> None: +def test_reload() -> None: u = User(name="Foo") u.save() @@ -607,7 +608,7 @@ def test_reload(configure_client) -> None: with pytest.raises(ModelNotFoundError): another_user.reload() -def test_reload_multiple(configure_multiple_clients) -> None: +def test_reload_multiple() -> None: u1 = User(name="Coco") u1.save() @@ -635,7 +636,7 @@ def test_reload_multiple(configure_multiple_clients) -> None: -def test_save_with_exclude_none(configure_client) -> None: +def test_save_with_exclude_none() -> None: p = Profile(name="Foo") p.save(exclude_none=True) @@ -657,7 +658,7 @@ def test_save_with_exclude_none(configure_client) -> None: -def test_save_with_exclude_unset(configure_client) -> None: +def test_save_with_exclude_unset() -> None: p = Profile(photo_url=None) p.save(exclude_unset=True) diff --git a/start_emulator.sh b/start_emulator.sh index c559ecc..0913753 100755 --- a/start_emulator.sh +++ b/start_emulator.sh @@ -1,2 +1,4 @@ #!/usr/bin/env sh +export GOOGLE_CLOUD_PROJECT=test-project + exec firebase -P firedantic-test emulators:start --only firestore From 5e1d656a330e880c1a9b43bcd478743cc2cbb412 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Sun, 23 Nov 2025 22:27:50 -0300 Subject: [PATCH 13/38] sync indexes test passing --- firedantic/_sync/indexes.py | 7 +- firedantic/_sync/model.py | 60 ++++- firedantic/tests/tests_async/conftest.py | 12 - firedantic/tests/tests_sync/test_model.py | 287 +++------------------- 4 files changed, 87 insertions(+), 279 deletions(-) diff --git a/firedantic/_sync/indexes.py b/firedantic/_sync/indexes.py index d5ae4d3..f39c61b 100644 --- a/firedantic/_sync/indexes.py +++ b/firedantic/_sync/indexes.py @@ -82,7 +82,7 @@ def create_composite_index( def set_up_composite_indexes( - gcloud_project: Optional[str], + gcloud_project: str, models: Iterable[Type[BareModel]], database: str = "(default)", client: Optional[FirestoreAdminClient] = None, @@ -101,10 +101,9 @@ def set_up_composite_indexes( operations = [] for model in models: - if not getattr(model, "__composite_indexes__", None): continue - + # Resolve config name: prefer model __db_config__ if present; else default config_name = getattr(model, "__db_config__", "(default)") @@ -149,4 +148,4 @@ def set_up_composite_indexes_and_ttl_policies( models = list(models) ops = set_up_composite_indexes(gcloud_project, models, database, client) ops.extend(set_up_ttl_policies(gcloud_project, models, database, client)) - return ops + return ops \ No newline at end of file diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index 7d7680f..7cad20d 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -43,20 +43,55 @@ } -def get_collection_name(model_class, name: Optional[str]) -> str: +def get_collection_name(cls, collection_name: Optional[str]) -> str: """ - Returns the collection name. + Return the collection name for `cls`. - :raise CollectionNotDefined: If the collection name is not defined. + - If `collection_name` is provided, treat it as an explicit collection name + and prefix it using the configured prefix for the class (via __db_config__). + - Otherwise, use the class name and the configured prefix. + + :raises CollectionNotDefined: If neither a class collection nor derived name is available. + """ + # Resolve the config name from the model class (default to "(default)") + config_name = getattr(cls, "__db_config__", "(default)") + + # If caller provided an explicit collection string, apply prefix from config + if collection_name: + cfg = configuration.get_config(config_name) + prefix = cfg.prefix or "" + return f"{prefix}{collection_name}" + + if getattr(cls, "__collection__", None): + cfg = configuration.get_config(config_name) + return f"{cfg.prefix or ''}{cls.__collection__}" + + raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") + + +def _get_col_ref(cls, collection_name: Optional[str]) -> CollectionReference: + """ + Return a CollectionReference for the model class using the configured client. + + :param cls: model class + :param collection_name: optional explicit collection name override + :raises CollectionNotDefined: when collection cannot be resolved """ - if not name: - raise CollectionNotDefined(f"Missing collection name for {model_class.__name__}") - return f"{configuration.get_config(name)}" + # Build the prefixed collection name + col_name = get_collection_name(cls, collection_name) + + # Resolve config name from class and fetch client + config_name = getattr(cls, "__db_config__", "(default)") + client = configuration.get_client(config_name) + + # Return the AsyncCollectionReference (this is a real client call) + col_ref = client.collection(col_name) + # Ensure we got the right object back + if not hasattr(col_ref, "document"): + raise RuntimeError(f"_get_col_ref returned unexpected object for {cls}: {type(col_ref)!r}") + return col_ref -def _get_col_ref(model_class, name: Optional[str]) -> CollectionReference: - collection: CollectionReference = get_collection_name(model_class, name) - return collection class BareModel(pydantic.BaseModel, ABC): @@ -171,7 +206,7 @@ def reload(self, transaction: Optional[Transaction] = None) -> None: self.__dict__.update(updated_model.__dict__) - def get_document_id(self): + def get_document_id(self) -> Optional[str]: """ Returns the document ID for this model instance. @@ -216,8 +251,7 @@ def find( # pylint: disable=too-many-arguments query = cls._add_filter(query, key, value) if order_by is not None: - for order_by_item in order_by: - field, direction = order_by_item + for field, direction in order_by: query = query.order_by(field, direction=direction) # type: ignore if limit is not None: query = query.limit(limit) # type: ignore @@ -489,4 +523,4 @@ def get_by_id( class SubCollection(BareSubCollection, ABC): __document_id__ = "id" - __model_cls__: Type[SubModel] + __model_cls__: Type[SubModel] \ No newline at end of file diff --git a/firedantic/tests/tests_async/conftest.py b/firedantic/tests/tests_async/conftest.py index cd99802..91d37ea 100644 --- a/firedantic/tests/tests_async/conftest.py +++ b/firedantic/tests/tests_async/conftest.py @@ -1,5 +1,4 @@ import uuid -import os from datetime import datetime from typing import Any, List, Optional, Type @@ -169,17 +168,6 @@ class ExpiringModel(AsyncModel): expire: datetime content: str -# @pytest.fixture -# def use_emulator(monkeypatch): -# monkeypatch.setenv("FIRESTORE_EMULATOR_HOST", "localhost:8080") -# monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project") - -# # point to your running emulator host:port -# os.environ["FIRESTORE_EMULATOR_HOST"] = "localhost:8080" -# # any non-empty project id is fine for emulator -# os.environ["GOOGLE_CLOUD_PROJECT"] = "test-project" -# yield - @pytest.fixture(autouse=True) def configure_client(): client = AsyncClient( diff --git a/firedantic/tests/tests_sync/test_model.py b/firedantic/tests/tests_sync/test_model.py index 4bbec22..0272250 100644 --- a/firedantic/tests/tests_sync/test_model.py +++ b/firedantic/tests/tests_sync/test_model.py @@ -36,16 +36,7 @@ ] - -def test_save_model_old_way(create_company) -> None: - - company = create_company() - - assert company.id is not None - assert company.owner.first_name == "John" - assert company.owner.last_name == "Doe" - -def test_save_model_new_way(create_company) -> None: +def test_save_model(create_company) -> None: company = create_company() assert company.id is not None @@ -53,29 +44,7 @@ def test_save_model_new_way(create_company) -> None: assert company.owner.last_name == "Doe" -def test_save_multi_client_model(configure_multiple_clients): - owner = Owner(first_name="Alice", last_name="Begone") - company = Company(company_id="1234567-9", owner=owner) - company.save() # will use 'default' as config name - company.reload() # with no name supplied, config refers to "(default)" - - config="multi" - product = Product(product_id=str(uuid4()), price=123.45, stock=2) - product.save(config) - product = Product(product_id=str(uuid4()), price=89.95, stock=8) - product.save(config) - product.reload(config) - - # Can retrieve info from either client: (default) or multi. - # When config= is not set, it will default to '(default)' config. - # The models do not know which config you intended to use them for, and they - # could be used for a multitude of configurations at once. - assert Company.find_one({"owner.first_name": "Alice"}).company_id == "1234567-9" - assert Product.find_one({"price": 123.45}, config=config).stock == 2 - assert Product.find_one({"stock": 8}, config=config).stock == 8 - - -def test_delete_model(configure_client, create_company) -> None: +def test_delete_model(create_company) -> None: company: Company = create_company( company_id="11223344-5", first_name="Jane", last_name="Doe" ) @@ -89,35 +58,7 @@ def test_delete_model(configure_client, create_company) -> None: Company.get_by_id(_id) -def test_delete_multi_client_model(configure_multiple_clients, create_company) -> None: - # will use 'default' as config name - owner = Owner(first_name="Alice", last_name="Begone") - company = Company(company_id="12345678", owner=owner) - company.save() - - config="multi" - product = Product(product_id="1112", price=123.45, stock=2) - product.save(config) - product = Product(product_id="1114", price=89.95, stock=8) - product.save(config) - - _id = company.id - assert _id - - _pid = product.id - assert _pid - - company.delete() - product.delete(config) - - with pytest.raises(ModelNotFoundError): - Company.get_by_id(_id) - - with pytest.raises(ModelNotFoundError): - Product.get_by_id(_pid) - - -def test_find_one(configure_client, create_company) -> None: +def test_find_one(create_company) -> None: with pytest.raises(ModelNotFoundError): Company.find_one() @@ -145,7 +86,7 @@ def test_find_one(configure_client, create_company) -> None: assert first_desc.owner.first_name == "Foo" -def test_find(configure_client, create_company, create_product) -> None: +def test_find(create_company, create_product) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: create_company(company_id=company_id) @@ -154,12 +95,8 @@ def test_find(configure_client, create_company, create_product) -> None: assert c[0].company_id == "4124432-4" assert c[0].owner.first_name == "John" - d1 = Company.find({"owner.first_name": "John"}) - - d2 = Company.find({"owner.first_name": {op.EQ: "John"}}) - - d3 = Company.find({"owner.first_name": {"==": "John"}}) - assert d1 == d2 == d3 + d = Company.find({"owner.first_name": "John"}) + assert len(d) == 4 for p in TEST_PRODUCTS: create_product(**p) @@ -179,7 +116,7 @@ def test_find(configure_client, create_company, create_product) -> None: Product.find({"product_id": {"<>": "a"}}) -def test_find_not_in(configure_client, create_company) -> None: +def test_find_not_in(create_company) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: create_company(company_id=company_id) @@ -199,7 +136,7 @@ def test_find_not_in(configure_client, create_company) -> None: assert company.company_id in ("2131232-4", "4124432-4") -def test_find_array_contains(configure_client, create_todolist) -> None: +def test_find_array_contains(create_todolist) -> None: list_1 = create_todolist("list_1", ["Work", "Eat", "Sleep"]) create_todolist("list_2", ["Learn Python", "Walk the dog"]) @@ -208,7 +145,7 @@ def test_find_array_contains(configure_client, create_todolist) -> None: assert found[0].name == list_1.name -def test_find_array_contains_any(configure_client, create_todolist) -> None: +def test_find_array_contains_any(create_todolist) -> None: list_1 = create_todolist("list_1", ["Work", "Eat"]) list_2 = create_todolist("list_2", ["Relax", "Chill", "Sleep"]) create_todolist("list_3", ["Learn Python", "Walk the dog"]) @@ -219,7 +156,7 @@ def test_find_array_contains_any(configure_client, create_todolist) -> None: assert lst.name in (list_1.name, list_2.name) -def test_find_limit(configure_client, create_company) -> None: +def test_find_limit(create_company) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: create_company(company_id=company_id) @@ -231,7 +168,7 @@ def test_find_limit(configure_client, create_company) -> None: assert len(companies_2) == 2 -def test_find_order_by(configure_client, create_company) -> None: +def test_find_order_by(create_company) -> None: companies_and_owners = [ {"company_id": "1234555-1", "last_name": "A", "first_name": "A"}, {"company_id": "1234555-2", "last_name": "A", "first_name": "B"}, @@ -275,8 +212,7 @@ def test_find_order_by(configure_client, create_company) -> None: assert companies_and_owners == lastname_ascending_firstname_ascending - -def test_find_offset(configure_client, create_company) -> None: +def test_find_offset(create_company) -> None: ids_and_lastnames = ( ("1234555-1", "A"), ("1234567-8", "B"), @@ -293,7 +229,7 @@ def test_find_offset(configure_client, create_company) -> None: assert len(companies_ascending) == 2 -def test_get_by_id(configure_client, create_company) -> None: +def test_get_by_id(create_company) -> None: c: Company = create_company(company_id="1234567-8") assert c.id is not None @@ -307,55 +243,12 @@ def test_get_by_id(configure_client, create_company) -> None: assert c_2.owner.first_name == "John" -def test_get_by_id_multi_client(configure_multiple_clients) -> None: - owner = Owner(first_name="Alice", last_name="Begone") - c1 = Company(company_id="1234567-8", owner=owner) - c1.save() - - config="multi" - owner = Owner(first_name="Crazy", last_name="Kitten") - c2 = Company(company_id="1234567-9", owner=owner) - c2.save(config) - - assert c1.id is not None - assert c1.company_id == "1234567-8" - assert c1.owner.last_name == "Begone" - - cid1 = Company.get_by_id(c1.id) - - assert cid1.id == c1.id - assert cid1.company_id == "1234567-8" - assert cid1.owner.first_name == "Alice" - - assert c2.id is not None - assert c2.company_id == "1234567-9" - assert c2.owner.last_name == "Kitten" - - cid2 = Company.get_by_id(c2.id, config) - - assert cid2.id == c2.id - assert cid2.company_id == "1234567-9" - assert cid2.owner.first_name == "Crazy" - - -def test_get_by_id_missing_config(configure_multiple_clients): - config="multi" - owner = Owner(first_name="Crazy", last_name="Kitten") - c2 = Company(company_id="1234567-9", owner=owner) - c2.save(config) - - with pytest.raises(ModelNotFoundError): - cid2 = Company.get_by_id(c2.id) - - - -def test_get_by_empty_str_id(configure_client) -> None: +def test_get_by_empty_str_id() -> None: with pytest.raises(ModelNotFoundError): Company.get_by_id("") - -def test_missing_collection(configure_client) -> None: +def test_missing_collection() -> None: class User(Model): name: str @@ -363,8 +256,7 @@ class User(Model): User(name="John").save() - -def test_model_aliases(configure_client) -> None: +def test_model_aliases() -> None: class User(Model): __collection__ = "User" @@ -380,7 +272,6 @@ class User(Model): assert user_from_db.city == "Helsinki" - @pytest.mark.parametrize( "model_id", [ @@ -409,7 +300,7 @@ class User(Model): "!:&+-*'()", ], ) -def test_models_with_valid_custom_id(configure_client, model_id) -> None: +def test_models_with_valid_custom_id( model_id) -> None: product_id = str(uuid4()) product = Product(product_id=product_id, price=123.45, stock=2) @@ -422,7 +313,6 @@ def test_models_with_valid_custom_id(configure_client, model_id) -> None: found.delete() - @pytest.mark.parametrize( "model_id", [ @@ -438,7 +328,7 @@ def test_models_with_valid_custom_id(configure_client, model_id) -> None: "foo/bar/baz", ], ) -def test_models_with_invalid_custom_id(configure_client, model_id: str) -> None: +def test_models_with_invalid_custom_id(model_id: str) -> None: product = Product(product_id="product 123", price=123.45, stock=2) product.id = model_id with pytest.raises(InvalidDocumentID): @@ -448,8 +338,7 @@ def test_models_with_invalid_custom_id(configure_client, model_id: str) -> None: Product.get_by_id(model_id) - -def test_truncate_collection(configure_client, create_company) -> None: +def test_truncate_collection(create_company) -> None: create_company(company_id="1234567-8") create_company(company_id="1234567-9") @@ -461,33 +350,7 @@ def test_truncate_collection(configure_client, create_company) -> None: assert len(new_companies) == 0 -def test_truncate_multiple_collections(configure_multiple_clients) -> None: - owner = Owner(first_name="Alice", last_name="Begone") - c1 = Company(company_id="1234567-8", owner=owner) - c1.save() - c2 = Company(company_id="1234567-9", owner=owner) - c2.save() - - config="multi" - owner = Owner(first_name="Crazy", last_name="Kitten") - c3 = Company(company_id="1234567-0", owner=owner) - c3.save(config) - - companies = Company.find({}) - assert len(companies) >= 2 - - companies = Company.find({}, config=config) - assert len(companies) >= 1 - - Company.truncate_collection() - assert len(Company.find({})) == 0 - - Company.truncate_collection(config) - assert len(Company.find({}, config=config)) == 0 - - - -def test_custom_id_model(configure_client) -> None: +def test_custom_id_model() -> None: c = CustomIDModel(bar="bar") # type: ignore c.save() @@ -499,8 +362,7 @@ def test_custom_id_model(configure_client) -> None: assert m.bar == "bar" - -def test_custom_id_conflict(configure_client) -> None: +def test_custom_id_conflict() -> None: CustomIDConflictModel(foo="foo", bar="bar").save() models = CustomIDModel.find({}) @@ -511,8 +373,7 @@ def test_custom_id_conflict(configure_client) -> None: assert m.bar == "bar" - -def test_model_id_persistency(configure_client) -> None: +def test_model_id_persistency() -> None: c = CustomIDConflictModel(foo="foo", bar="bar") c.save() assert c.id @@ -523,8 +384,7 @@ def test_model_id_persistency(configure_client) -> None: assert len(CustomIDConflictModel.find({})) == 1 - -def test_bare_model_document_id_persistency(configure_client) -> None: +def test_bare_model_document_id_persistency() -> None: c = CustomIDModel(bar="bar") # type: ignore c.save() assert c.foo @@ -535,21 +395,18 @@ def test_bare_model_document_id_persistency(configure_client) -> None: assert len(CustomIDModel.find({})) == 1 - -def test_bare_model_get_by_empty_doc_id(configure_client) -> None: +def test_bare_model_get_by_empty_doc_id(configure_db) -> None: with pytest.raises(ModelNotFoundError): CustomIDModel.get_by_doc_id("") - -def test_extra_fields(configure_client) -> None: +def test_extra_fields() -> None: CustomIDModelExtra(foo="foo", bar="bar", baz="baz").save() # type: ignore with pytest.raises(ValidationError): CustomIDModel.find({}) - -def test_company_stats(configure_client, create_company) -> None: +def test_company_stats(create_company) -> None: company: Company = create_company(company_id="1234567-8") company_stats = company.stats() @@ -569,7 +426,6 @@ def test_company_stats(configure_client, create_company) -> None: assert stats.sales == 101 - def test_subcollection_model_safety() -> None: """ Ensure you shouldn't be able to use unprepared subcollection models accidentally @@ -578,7 +434,6 @@ def test_subcollection_model_safety() -> None: UserStats.find({}) - def test_get_user_purchases() -> None: u = User(name="Foo") u.save() @@ -590,7 +445,6 @@ def test_get_user_purchases() -> None: assert get_user_purchases(u.id) == 42 - def test_reload() -> None: u = User(name="Foo") u.save() @@ -608,33 +462,6 @@ def test_reload() -> None: with pytest.raises(ModelNotFoundError): another_user.reload() -def test_reload_multiple() -> None: - u1 = User(name="Coco") - u1.save() - - u2 = User(name="Foo") - u2.save("multi") - - # change the value in the 'multi' database - u2_ = User.find_one({"name": "Foo"}, config="multi") - u2_.name = "Bar" - u2_.save("multi") - - # assert that u2 was updated - assert u2.name == "Foo" - u2.reload("multi") - - assert u2.name == "Bar" - - # assert that u1 did not change even after reload - u1.reload() - assert u1.name == "Coco" - - another_user = User(name="Another") - with pytest.raises(ModelNotFoundError): - another_user.reload() - - def test_save_with_exclude_none() -> None: p = Profile(name="Foo") @@ -657,7 +484,6 @@ def test_save_with_exclude_none() -> None: assert data == {"name": "Foo", "photo_url": None} - def test_save_with_exclude_unset() -> None: p = Profile(photo_url=None) p.save(exclude_unset=True) @@ -679,18 +505,15 @@ def test_save_with_exclude_unset() -> None: assert data == {"name": "", "photo_url": None} - -def test_update_city_in_transaction(configure_client) -> None: +def test_update_city_in_transaction() -> None: """ Test updating a model in a transaction. Test case from README. - :param: configure_client: pytest fixture + :param: configure_db: pytest fixture """ @transactional - def decrement_population( - transaction: Transaction, city: City, decrement: int = 1 - ): + def decrement_population(transaction: Transaction, city: City, decrement: int = 1): city.reload(transaction=transaction) city.population = max(0, city.population - decrement) city.save(transaction=transaction) @@ -705,49 +528,15 @@ def decrement_population( assert c.population == 0 -def test_update_city_in_multiple_transactions(configure_multiple_clients) -> None: - """ - Test updating a model in 2 separate transactions. Based on test case from README. - """ - - @transactional - def decrement_population( - transaction: Transaction, city: City, decrement: int = 1, config: str = "(default)" - ): - city.reload(config, transaction=transaction) - city.population = max(0, city.population - decrement) - city.save(config, transaction=transaction) - - c = City(id="SF", population=1) - c.save() - c.increment_population(increment=1) - assert c.population == 2 - - t = get_transaction() - decrement_population(transaction=t, city=c, decrement=5) - assert c.population == 0 - - c = City(id="SF", population=8) - c.save("multi") - c.increment_population(increment=1, config="multi") - assert c.population == 9 - - t = get_transaction("multi") - decrement_population(transaction=t, city=c, decrement=5, config="multi") - assert c.population == 4 - - -def test_delete_in_transaction(configure_client) -> None: +def test_delete_in_transaction() -> None: """ Test deleting a model in a transaction. - :param: configure_client: pytest fixture + :param: configure_db: pytest fixture """ @transactional - def delete_in_transaction( - transaction: Transaction, profile_id: str - ) -> None: + def delete_in_transaction(transaction: Transaction, profile_id: str) -> None: """Deletes a Profile in a transaction.""" profile = Profile.get_by_id(profile_id, transaction=transaction) profile.delete(transaction=transaction) @@ -763,12 +552,11 @@ def delete_in_transaction( Profile.get_by_id(p.id) - -def test_update_model_in_transaction(configure_client) -> None: +def test_update_model_in_transaction() -> None: """ Test updating a model in a transaction. - :param: configure_client: pytest fixture + :param: configure_db: pytest fixture """ @transactional @@ -790,12 +578,11 @@ def update_in_transaction( assert p.name == "Bar" - -def test_update_submodel_in_transaction(configure_client) -> None: +def test_update_submodel_in_transaction() -> None: """ Test Updating a submodel in a transaction. - :param: configure_client: pytest fixture + :param: configure_db: pytest fixture """ @transactional @@ -820,4 +607,4 @@ def update_submodel_in_transaction( user_stats = update_submodel_in_transaction(t, u.id, "2021") assert isinstance(user_stats, UserStats) assert user_stats.purchases == 43 - assert get_user_purchases(u.id) == 43 + assert get_user_purchases(u.id) == 43 \ No newline at end of file From 9f9a08934e9053b203cc71377be65983cf221f0e Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Sun, 23 Nov 2025 22:32:57 -0300 Subject: [PATCH 14/38] reverting conf test for sync to as it was --- firedantic/_sync/model.py | 5 +- firedantic/tests/tests_sync/conftest.py | 66 +++++-------------------- 2 files changed, 15 insertions(+), 56 deletions(-) diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index 7cad20d..8957f4d 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -16,7 +16,7 @@ import firedantic.operators as op from firedantic import truncate_collection from firedantic.common import IndexDefinition, OrderDirection -from firedantic.configurations import Configuration +from firedantic.configurations import configuration from firedantic.exceptions import ( CollectionNotDefined, InvalidDocumentID, @@ -26,7 +26,6 @@ TBareModel = TypeVar("TBareModel", bound="BareModel") TBareSubModel = TypeVar("TBareSubModel", bound="BareSubModel") logger = getLogger("firedantic") -configuration = Configuration() # https://firebase.google.com/docs/firestore/query-data/queries#query_operators FIND_TYPES = { @@ -45,7 +44,7 @@ def get_collection_name(cls, collection_name: Optional[str]) -> str: """ - Return the collection name for `cls`. + Returns the collection name for `cls`. - If `collection_name` is provided, treat it as an explicit collection name and prefix it using the configured prefix for the class (via __db_config__). diff --git a/firedantic/tests/tests_sync/conftest.py b/firedantic/tests/tests_sync/conftest.py index b30307f..fc59651 100644 --- a/firedantic/tests/tests_sync/conftest.py +++ b/firedantic/tests/tests_sync/conftest.py @@ -16,10 +16,10 @@ SubCollection, SubModel, ) -from firedantic.configurations import configure, Configuration, get_transaction +from firedantic.configurations import configure, get_transaction from firedantic.exceptions import ModelNotFoundError -from unittest.mock import MagicMock, Mock # noqa isort: skip +from unittest.mock import Mock, Mock # noqa isort: skip class CustomIDModel(BareModel): @@ -44,6 +44,7 @@ class CustomIDModelExtra(BareModel): class Config: extra = "forbid" + class CustomIDConflictModel(Model): __collection__ = "custom" @@ -60,21 +61,20 @@ class City(Model): __collection__ = "cities" population: int - def increment_population(self, increment: int = 1, config: str = "(default)"): + def increment_population(self, increment: int = 1): @transactional def _increment_population(transaction: Transaction) -> None: - self.reload(config=config, transaction=transaction) + self.reload(transaction=transaction) self.population += increment - self.save(config=config, transaction=transaction) + self.save(transaction=transaction) - t = get_transaction(config) + t = get_transaction() _increment_population(transaction=t) class Owner(BaseModel): """Dummy owner Pydantic model.""" - __collection__ = "owners" first_name: str last_name: str @@ -146,6 +146,7 @@ class Profile(Model): class Config: extra = "forbid" + class TodoList(Model): """Dummy todo list Firedantic model.""" @@ -157,6 +158,7 @@ class TodoList(Model): class Config: extra = "forbid" + class ExpiringModel(Model): """Dummy expiring model Firedantic model.""" @@ -167,16 +169,8 @@ class ExpiringModel(Model): content: str - -# @pytest.fixture(scope="session", autouse=True) -# def use_emulator(monkeypatch): -# monkeypatch.setenv("FIRESTORE_EMULATOR_HOST", "localhost:8080") -# monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project") -# yield - - -# allow configuring test client in this way to ensure old 'configure' method remains intact -def configure_client_old(): +@pytest.fixture(autouse=True) +def configure_db(): client = Client( project="ioxio-local-dev", credentials=Mock(spec=google.auth.credentials.Credentials), @@ -185,41 +179,9 @@ def configure_client_old(): prefix = str(uuid.uuid4()) + "-" configure(client, prefix) -@pytest.fixture -def configure_multiple_clients(): - config = Configuration() - mock_creds = Mock(spec=google.auth.credentials.Credentials) - - # name = (default) - config.add( - prefix="ioxio-local-dev-", - project=str(uuid.uuid4()) + "-", - credentials=mock_creds - ) - - # name = multi - config.add( - name="multi", - prefix="test-multi-", - project="test-multi", - credentials=mock_creds, - ) - - @pytest.fixture def create_company(): - # Register the config under the name "companies" - config = Configuration() - config.add( - name="companies", - prefix="test_", - project="test-project", - ) - - # # cleanup after test(s) - # config.configurations.pop("companies", None) - def _create( company_id: str = "1234567-8", first_name: str = "John", last_name: str = "Doe" ): @@ -233,9 +195,7 @@ def _create( @pytest.fixture def create_product(): - def _create( - product_id: Optional[str] = None, price: float = 1.23, stock: int = 3 - ): + def _create(product_id: Optional[str] = None, price: float = 1.23, stock: int = 3): if not product_id: product_id = str(uuid.uuid4()) p = Product(product_id=product_id, price=price, stock=stock) @@ -345,4 +305,4 @@ def update_field(self, data) -> MockOperation: @pytest.fixture() def mock_admin_client(): - return MockFirestoreAdminClient() + return MockFirestoreAdminClient() \ No newline at end of file From 53cdb49b6b6357ff56747d12284102ef5f99a28d Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Sun, 23 Nov 2025 22:36:10 -0300 Subject: [PATCH 15/38] cleanup --- firedantic/tests/tests_async/conftest.py | 5 +++-- firedantic/tests/tests_async/test_model.py | 2 +- firedantic/tests/tests_sync/test_ttl_policy.py | 1 - 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/firedantic/tests/tests_async/conftest.py b/firedantic/tests/tests_async/conftest.py index 91d37ea..b115a8c 100644 --- a/firedantic/tests/tests_async/conftest.py +++ b/firedantic/tests/tests_async/conftest.py @@ -168,8 +168,9 @@ class ExpiringModel(AsyncModel): expire: datetime content: str + @pytest.fixture(autouse=True) -def configure_client(): +def configure_db(): client = AsyncClient( project="ioxio-local-dev", credentials=Mock(spec=google.auth.credentials.Credentials), @@ -306,4 +307,4 @@ async def update_field(self, data) -> MockOperation: @pytest.fixture() def mock_admin_client(): - return AsyncMockFirestoreAdminClient() + return AsyncMockFirestoreAdminClient() \ No newline at end of file diff --git a/firedantic/tests/tests_async/test_model.py b/firedantic/tests/tests_async/test_model.py index 40c791c..dac4c2e 100644 --- a/firedantic/tests/tests_async/test_model.py +++ b/firedantic/tests/tests_async/test_model.py @@ -510,7 +510,7 @@ async def test_save_with_exclude_none() -> None: assert document_id # pylint: disable=protected-access - document = await Profile._get_col_ref().document(document_id).get() ### here + document = await Profile._get_col_ref().document(document_id).get() data = document.to_dict() assert data == {"name": "Foo"} diff --git a/firedantic/tests/tests_sync/test_ttl_policy.py b/firedantic/tests/tests_sync/test_ttl_policy.py index bcccf2b..8d2d549 100644 --- a/firedantic/tests/tests_sync/test_ttl_policy.py +++ b/firedantic/tests/tests_sync/test_ttl_policy.py @@ -26,7 +26,6 @@ def test_set_up_ttl_policies_new_policy(mock_admin_client): [Field.TtlConfig.State.NEEDS_REPAIR], ), ) - def test_set_up_ttl_policies_other_states(mock_admin_client, state): mock_admin_client.field_state = Field.TtlConfig.State.ACTIVE result = set_up_ttl_policies( From 922da2a9fb1eb76876cea49ffa6aa9c851823f26 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Sun, 23 Nov 2025 22:37:57 -0300 Subject: [PATCH 16/38] removing old examples --- multiple_clients_examples/billing_models.py | 19 ----- multiple_clients_examples/company_models.py | 18 ---- .../configure_firestore_db_client.py | 67 --------------- multiple_clients_examples/query_sample.py | 30 ------- .../save_data_to_firestore.py | 84 ------------------- 5 files changed, 218 deletions(-) delete mode 100644 multiple_clients_examples/billing_models.py delete mode 100644 multiple_clients_examples/company_models.py delete mode 100644 multiple_clients_examples/configure_firestore_db_client.py delete mode 100644 multiple_clients_examples/query_sample.py delete mode 100644 multiple_clients_examples/save_data_to_firestore.py diff --git a/multiple_clients_examples/billing_models.py b/multiple_clients_examples/billing_models.py deleted file mode 100644 index 2ab2d43..0000000 --- a/multiple_clients_examples/billing_models.py +++ /dev/null @@ -1,19 +0,0 @@ -from pydantic import BaseModel - -from firedantic import Model - - -class BillingAccount(BaseModel): - """Dummy billing account Pydantic model.""" - - name: str - billing_id: str - owner: str - - -class BillingCompany(Model): - """Dummy company Firedantic model.""" - - __collection__ = "billing_companies" - company_id: str - billing_account: BillingAccount diff --git a/multiple_clients_examples/company_models.py b/multiple_clients_examples/company_models.py deleted file mode 100644 index fe0a46d..0000000 --- a/multiple_clients_examples/company_models.py +++ /dev/null @@ -1,18 +0,0 @@ -from pydantic import BaseModel - -from firedantic import Model - - -class Owner(BaseModel): - """Dummy owner Pydantic model.""" - - first_name: str - last_name: str - - -class Company(Model): - """Dummy company Firedantic model.""" - - __collection__ = "companies" - company_id: str - owner: Owner diff --git a/multiple_clients_examples/configure_firestore_db_client.py b/multiple_clients_examples/configure_firestore_db_client.py deleted file mode 100644 index 0cf779d..0000000 --- a/multiple_clients_examples/configure_firestore_db_client.py +++ /dev/null @@ -1,67 +0,0 @@ -from os import environ -from unittest.mock import Mock - -import google.auth.credentials -from google.cloud.firestore import AsyncClient, Client - -from firedantic import Configuration, configure - - -## OLD WAY: -def configure_client(): - # Firestore emulator must be running if using locally. - if environ.get("FIRESTORE_EMULATOR_HOST"): - client = Client( - project="firedantic-test", - credentials=Mock(spec=google.auth.credentials.Credentials), - ) - else: - client = Client() - - configure(client, prefix="firedantic-test-") - print(client) - - -def configure_async_client(): - # Firestore emulator must be running if using locally. - if environ.get("FIRESTORE_EMULATOR_HOST"): - client = AsyncClient( - project="firedantic-test", - credentials=Mock(spec=google.auth.credentials.Credentials), - ) - else: - client = AsyncClient() - - configure(client, prefix="firedantic-test-") - print(client) - - -## NEW WAY: -def configure_multiple_clients(): - config = Configuration() - - # name = (default) - config.create( - prefix="firedantic-test-", - project="firedantic-test", - credentials=Mock(spec=google.auth.credentials.Credentials), - ) - - # name = billing - config.create( - name="billing", - prefix="test-billing-", - project="test-billing", - credentials=Mock(spec=google.auth.credentials.Credentials), - ) - - print(config.get_client()) ## will pull the default client - print(config.get_client("billing")) ## will pull the billing client - print(config.get_async_client("billing")) ## will pull the billing async client - - -print("\n---- Running OLD way ----") -configure_client() -configure_async_client() -print("\n---- Running NEW way ----") -configure_multiple_clients() diff --git a/multiple_clients_examples/query_sample.py b/multiple_clients_examples/query_sample.py deleted file mode 100644 index e3d441f..0000000 --- a/multiple_clients_examples/query_sample.py +++ /dev/null @@ -1,30 +0,0 @@ -from company_models import Company -from configure_firestore_db_client import configure_client, configure_multiple_clients - -import firedantic.operators as op - -# Firestore emulator must be running if using locally. -# Supported types are: <, >, array_contains, in, ==, !=, not-in, <=, array_contains_any, >= - - -print("\n---- Running OLD way ----") -configure_client() - -companies1 = Company.find({"owner.first_name": "John"}) -companies2 = Company.find({"owner.first_name": {op.EQ: "John"}}) -companies3 = Company.find({"owner.first_name": {"==": "John"}}) -print(companies1) - -assert companies1 != [] -assert companies1 == companies2 == companies3 - - -print("\n---- Running NEW way ----") -configure_multiple_clients() - -companies1 = Company.find({"owner.first_name": "Alice"}) -companies2 = Company.find({"owner.first_name": {op.EQ: "Alice"}}) -companies3 = Company.find({"owner.first_name": {"==": "Alice"}}) -print(companies1) -assert companies1 != [] -assert companies1 == companies2 == companies3 diff --git a/multiple_clients_examples/save_data_to_firestore.py b/multiple_clients_examples/save_data_to_firestore.py deleted file mode 100644 index fd48dae..0000000 --- a/multiple_clients_examples/save_data_to_firestore.py +++ /dev/null @@ -1,84 +0,0 @@ -from unittest.mock import Mock - -import google.auth.credentials -from billing_models import BillingAccount, BillingCompany -from company_models import Company, Owner -from configure_firestore_db_client import configure_client - -from firedantic import Configuration - - -## With single client -def old_way(): - # Firestore emulator must be running if using locally. - configure_client() - - # Now you can use the model to save it to Firestore - owner = Owner(first_name="John", last_name="Doe") - company = Company(company_id="1234567-8", owner=owner) - company.save() - - # Prints out the firestore ID of the Company model - print(f"\nFirestore ID: {company.id}") - - # Reloads model data from the database - company.reload() - - -## Now with multiple clients/dbs: -def new_way(): - config = Configuration() - - # 1. Create first config with config_name = "(default)" - config.create( - prefix="firedantic-test-", - project="firedantic-test", - credentials=Mock(spec=google.auth.credentials.Credentials), - ) - - # Now you can use the model to save it to Firestore - owner = Owner(first_name="Alice", last_name="Begone") - company = Company(company_id="1234567-9", owner=owner) - company.save() # will use 'default' as config name - - # Reloads model data from the database - company.reload() # with no name supplied, config refers to "(default)" - - # 2. Create the second config with config_name = "billing" - config_name = "billing" - config.create( - name=config_name, - prefix="test-billing-", - project="test-billing", - credentials=Mock(spec=google.auth.credentials.Credentials), - ) - - # Now you can use the model to save it to Firestore - account = BillingAccount(name="ABC Billing", billing_id="801048", owner="MFisher") - bc = BillingCompany(company_id="1234567-8", billing_account=account) - - bc.save(config_name) - - # Reloads model data from the database - bc.reload(config_name) # with config name supplied, config refers to "billing" - - # 3. Finding data - # Can retrieve info from either database/client - # When config is not specified, it will default to '(default)' config - # The models do not know which config you intended to use them for, and they - # could be used for a multitude of configurations at once. - print(Company.find({"owner.first_name": "Alice"})) - - print(BillingCompany.find({"company_id": "1234567-8"}, config=config_name)) - - print( - BillingCompany.find( - {"billing_account.billing_id": "801048"}, config=config_name - ) - ) - - -print("\n---- Running OLD way ----") -old_way() -print("\n---- Running NEW way ----") -new_way() From 1ee9a7988c3a84c138c34cd10cb91e5c9205c490 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Thu, 27 Nov 2025 19:25:47 -0300 Subject: [PATCH 17/38] sync and async example flows working --- firedantic/_async/model.py | 56 ++++- firedantic/_sync/model.py | 16 +- firedantic/configurations.py | 12 +- firedantic/tests/tests_async/conftest.py | 1 - multiple_clients_examples/billing_models.py | 24 +++ multiple_clients_examples/company_models.py | 19 ++ .../configure_firestore_db_client.py | 82 +++++++ multiple_clients_examples/full_async_flow.py | 200 ++++++++++++++++++ multiple_clients_examples/full_sync_flow.py | 174 +++++++++++++++ multiple_clients_examples/query_sample.py | 51 +++++ 10 files changed, 621 insertions(+), 14 deletions(-) create mode 100644 multiple_clients_examples/billing_models.py create mode 100644 multiple_clients_examples/company_models.py create mode 100644 multiple_clients_examples/configure_firestore_db_client.py create mode 100644 multiple_clients_examples/full_async_flow.py create mode 100644 multiple_clients_examples/full_sync_flow.py create mode 100644 multiple_clients_examples/query_sample.py diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index 6295e2e..db500f2 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -68,7 +68,7 @@ def get_collection_name(cls, collection_name: Optional[str]) -> str: -def _get_col_ref(cls, collection_name: Optional[str]) -> AsyncCollectionReference: +def _get_col_ref(cls, collection_name: Optional[str] = None) -> AsyncCollectionReference: """ Return an AsyncCollectionReference for the model class using the configured async client. @@ -110,6 +110,7 @@ class AsyncBareModel(pydantic.BaseModel, ABC): async def save( self, *, + config_name: Optional[str] = None, exclude_unset: bool = False, exclude_none: bool = False, transaction: Optional[AsyncTransaction] = None, @@ -122,17 +123,42 @@ async def save( :param transaction: Optional transaction to use. :raise DocumentIDError: If the document ID is not valid. """ + + # Resolve config to use (explicit -> instance -> class -> default) + if config_name is not None: + resolved = config_name + else: + resolved = getattr(self, "__db_config__", None) + if not resolved: + resolved = getattr(self.__class__, "__db_config__", "(default)") + + # Build payload - model_dump? data = self.model_dump( - by_alias=True, exclude_unset=exclude_unset, exclude_none=exclude_none + by_alias=True, + exclude_unset=exclude_unset, + exclude_none=exclude_none ) if self.__document_id__ in data: del data[self.__document_id__] + + # Get Async collection reference using configuration - what? + col_ref = configuration.get_async_collection_ref(self.__class__, name=resolved) + + # Build doc ref (use provided id if set, otherwise let server generate) + doc_id = self.get_document_id() + if doc_id: + doc_ref = col_ref.document(doc_id) + else: + doc_ref = col_ref.document() - doc_ref = self._get_doc_ref() + # Use transaction if provided (assume it's compatible) otherwise do direct await set if transaction is not None: + # Transaction.delete/set expects DocumentReference from the same client. transaction.set(doc_ref, data) else: await doc_ref.set(data) + + # what does this do? setattr(self, self.__document_id__, doc_ref.id) async def delete(self, transaction: Optional[AsyncTransaction] = None) -> None: @@ -176,6 +202,20 @@ def get_document_id(self) -> Optional[str]: _OrderBy = List[Tuple[str, OrderDirection]] + @classmethod + async def delete_all_for_model(cls, config_name: Optional[str] = None) -> None: + + # Resolve config to use (explicit -> instance -> class -> default) + config_name = cls.__db_config__ + + client = configuration.get_async_client(config_name) + col_name = configuration.get_collection_name(cls, config_name=config_name) + col_ref = client.collection(col_name) + + async for doc in col_ref.stream(): + await doc.reference.delete() + + @classmethod async def find( # pylint: disable=too-many-arguments cls: Type[TAsyncBareModel], @@ -332,11 +372,11 @@ async def truncate_collection(cls, batch_size: int = 128) -> int: ) @classmethod - def _get_col_ref(cls) -> AsyncCollectionReference: + def _get_col_ref(cls, collection_name: Optional[str] = None) -> AsyncCollectionReference: """ Returns the collection reference. """ - return _get_col_ref(cls, cls.__collection__) + return _get_col_ref(cls, collection_name) @classmethod def get_collection_name(cls) -> str: @@ -345,13 +385,13 @@ def get_collection_name(cls) -> str: """ return get_collection_name(cls, cls.__collection__) - def _get_doc_ref(self) -> AsyncDocumentReference: + def _get_doc_ref(self, config_name: Optional[str] = "(default)") -> AsyncDocumentReference: """ Returns the document reference. :raise DocumentIDError: If the ID is not valid. """ - return self._get_col_ref().document(self.get_document_id()) # type: ignore + return self._get_col_ref(config_name).document(self.get_document_id()) # type: ignore @staticmethod def _validate_document_id(document_id: str): @@ -442,7 +482,7 @@ def _create(cls: Type[TAsyncBareSubModel], **kwargs) -> TAsyncBareSubModel: ) @classmethod - def _get_col_ref(cls) -> AsyncCollectionReference: + def _get_col_ref(cls, collection_name: Optional[str]) -> AsyncCollectionReference: """ Returns the collection reference. """ diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index 8957f4d..eda0e10 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -218,6 +218,19 @@ def get_document_id(self) -> Optional[str]: _OrderBy = List[Tuple[str, OrderDirection]] + @classmethod + def delete_all_for_model(cls, config_name: Optional[str] = None) -> None: + + # Resolve config to use (explicit -> instance -> class -> default) + config_name = cls.__db_config__ + + client = configuration.get_client(config_name) + col_name = configuration.get_collection_name(cls, config_name=config_name) + col_ref = client.collection(col_name) + + for doc in col_ref.stream(): + doc.reference.delete() + @classmethod def find( # pylint: disable=too-many-arguments cls: Type[TBareModel], @@ -391,7 +404,8 @@ def _get_doc_ref(self) -> DocumentReference: :raise DocumentIDError: If the ID is not valid. """ - return self._get_col_ref().document(self.get_document_id()) # type: ignore + doc_id = self.get_document_id() + return self._get_col_ref().document(doc_id) # type: ignore @staticmethod def _validate_document_id(document_id: str): diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 35fa696..010cf2e 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -174,15 +174,19 @@ def get_async_transaction(self, name: Optional[str] = None) -> AsyncTransaction: return self.get_async_client(name=name).transaction() # helpers for models to derive collection name / reference - def get_collection_name(self, model_class: Type, name: Optional[str] = None) -> str: + def get_collection_name(self, model_class: Type, config_name: Optional[str] = None) -> str: """ Return the collection name string (prefix + model name). """ - resolved = name if name is not None else "(default)" + resolved = config_name if config_name is not None else "(default)" cfg = self.get_config(resolved) prefix = cfg.prefix or "" - model_name = model_class.__name__ - return f"{prefix}{model_name[0].lower()}{model_name[1:]}" # (lower case first letter of model name) + + if hasattr(model_class, "__collection__"): + return prefix + model_class.__collection__ + else: + model_name = model_class.__name__ + return f"{prefix}{model_name[0].lower()}{model_name[1:]}" # (lower case first letter of model name) def get_collection_ref(self, model_class: Type, name: Optional[str] = None) -> CollectionReference: """ diff --git a/firedantic/tests/tests_async/conftest.py b/firedantic/tests/tests_async/conftest.py index b115a8c..6eabbfb 100644 --- a/firedantic/tests/tests_async/conftest.py +++ b/firedantic/tests/tests_async/conftest.py @@ -21,7 +21,6 @@ from unittest.mock import AsyncMock, Mock # noqa isort: skip - class CustomIDModel(AsyncBareModel): __collection__ = "custom" __document_id__ = "foo" diff --git a/multiple_clients_examples/billing_models.py b/multiple_clients_examples/billing_models.py new file mode 100644 index 0000000..b775e63 --- /dev/null +++ b/multiple_clients_examples/billing_models.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel + +from firedantic import Model, AsyncModel + + +class MyModel(AsyncModel): + pass + + +class BillingAccount(BaseModel): + """Dummy billing account Pydantic model.""" + __db_config__ = "billing" + __collection__ = "billing_accounts" + name: str + billing_id: int + owner: str + + +class BillingCompany(Model): + """Dummy company Firedantic model.""" + __db_config__ = "billing" + __collection__ = "billing_companies" + company_id: str + billing_account: BillingAccount diff --git a/multiple_clients_examples/company_models.py b/multiple_clients_examples/company_models.py new file mode 100644 index 0000000..1a545e7 --- /dev/null +++ b/multiple_clients_examples/company_models.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel + +from firedantic import Model + + +class Owner(BaseModel): + """Dummy owner Pydantic model.""" + __db_config__ = "owners" + __collection__ = "owners" + first_name: str + last_name: str + + +class Company(Model): + """Dummy company Firedantic model.""" + __db_config__ = "companies" + __collection__ = "companies" + company_id: str + owner: Owner diff --git a/multiple_clients_examples/configure_firestore_db_client.py b/multiple_clients_examples/configure_firestore_db_client.py new file mode 100644 index 0000000..c6eb0ac --- /dev/null +++ b/multiple_clients_examples/configure_firestore_db_client.py @@ -0,0 +1,82 @@ +from os import environ +from unittest.mock import Mock + +import google.auth.credentials +from google.cloud.firestore import AsyncClient, Client + +from firedantic import Configuration, configure, CONFIGURATIONS +from billing_models import BillingAccount, BillingCompany +from base_models import MyModel + + +## OLD WAY: +def configure_sync_client(): + # Firestore emulator must be running if using locally. + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = Client( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + else: + client = Client() + + configure(client, prefix="firedantic-sync-") + assert CONFIGURATIONS["prefix"] == "firedantic-sync-" + assert isinstance(CONFIGURATIONS["db"], Client) + + +def configure_async_client(): + # Firestore emulator must be running if using locally. + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = AsyncClient( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + else: + client = AsyncClient() + + configure(client, prefix="firedantic-async-") + assert CONFIGURATIONS["prefix"] == "firedantic-async-" + assert isinstance(CONFIGURATIONS["db"], AsyncClient) + + + +## NEW WAY: +def configure_multiple_clients(): + config = Configuration() + + # name = (default) + config.add( + prefix="firedantic-test-", + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # name = billing + config.add( + name="billing", + prefix="test-billing-", + project="test-billing", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + default = config.get_config() ## pulls default config + print(config.get_collection_ref(MyModel)) + + assert default.prefix == "firedantic-test-" + assert isinstance(config.get_client(), Client) + + billing = config.get_config("billing") ## pulls billing config + assert billing.prefix == "test-billing-" + + print(config.get_collection_ref(BillingAccount, "billing")) + assert isinstance(config.get_client("billing"), Client) + + print(config.get_async_collection_ref(BillingAccount, "billing")) + assert isinstance(config.get_async_client("billing"), AsyncClient) + +print("\n---- Running OLD way ----") +configure_sync_client() +configure_async_client() +print("\n---- Running NEW way ----") +configure_multiple_clients() diff --git a/multiple_clients_examples/full_async_flow.py b/multiple_clients_examples/full_async_flow.py new file mode 100644 index 0000000..0b90967 --- /dev/null +++ b/multiple_clients_examples/full_async_flow.py @@ -0,0 +1,200 @@ +from unittest.mock import Mock + +from pydantic import BaseAsyncModel +from firedantic import AsyncModel +from firedantic.configurations import configuration, AsyncClient + +import google.auth.credentials +import sys + + +## With single async client +async def old_way(): + + class Owner(AsyncModel): + """Dummy owner Pydantic model.""" + __collection__ = "owners" + first_name: str + last_name: str + + + class Company(AsyncModel): + """Dummy company Firedantic model.""" + __collection__ = "companies" + company_id: str + owner: Owner + + # Firestore emulator must be running if using locally, name defaults to "(default)" + configuration.add( + prefix="async-default-test-", + project="async-default-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="John", last_name="Doe") + company = Company(company_id="1234567-8a", owner=owner) + + # Save the company/owner info to the DB + await company.save() # only need to include config_name when not using default + + # Reloads model data from the database to ensure most-up-to-date info + await company.reload() + + # Assert that async client exists and configuration is correct + assert isinstance(configuration.get_async_client(), AsyncClient) + assert configuration.get_collection_name(Owner) == configuration.get_config().prefix + "owners" + assert configuration.get_collection_name(Company) == configuration.get_config().prefix + "companies" + assert configuration.get_config().prefix == "async-default-test-" + assert configuration.get_config().project == "async-default-test" + assert configuration.get_config().name == "(default)" + + # Finding data from DB + print(f"\nNumber of company owners with first name: 'John': {len(await Company.find({"owner.first_name": "John"}))}") + + print(f"\nNumber of companies with id: '1234567-8a': {len(await Company.find({"company_id": "1234567-8a"}))}") + + # Delete everything from the database + await company.delete_all_for_model() + deletion_success = [] == await Company.find({"company_id": "1234567-8a"}) + print(f"\nDeletion of (default) DB succeeded: {deletion_success}\n") + + + + +# Now with multiple ASYNC clients/dbs: +async def new_way(): + + config_name = "companies" + + class Owner(AsyncModel): + """Dummy owner Pydantic model.""" + __db_config__ = config_name + __collection__ = "owners" + first_name: str + last_name: str + + + class Company(AsyncModel): + """Dummy company Firedantic model.""" + __db_config__ = config_name + __collection__ = config_name + company_id: str + owner: Owner + + # 1. Add first config with config_name = "companies" + configuration.add( + name=config_name, + prefix="async-companies-test-", + project="async-companies-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="Alice", last_name="Begone") + company = Company(company_id="1234567-9", owner=owner) + + # Save the company/owner info to the DB + await company.save() # only need to include config_name when not using default + + # Reloads model data from the database to ensure most-up-to-date info + await company.reload() + + # ##################################################### + + # 2. Add the second config with config_name = "billing" + config_name = "billing" + + class BillingAccount(AsyncModel): + """Dummy billing account Pydantic model.""" + __db_config__ = config_name + __collection__ = "accounts" + name: str + billing_id: int + owner: str + + + class BillingCompany(AsyncModel): + """Dummy company Firedantic model.""" + __db_config__ = config_name + __collection__ = "companies" + company_id: str + billing_account: BillingAccount + + configuration.add( + name=config_name, + prefix="async-billing-test-", + project="async-billing-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Assert that async clients exists and added configurations are correct + assert isinstance(configuration.get_async_client("billing"), AsyncClient) + + assert configuration.get_collection_name(BillingAccount, "billing") == configuration.get_config("billing").prefix + "accounts" + assert configuration.get_collection_name(BillingCompany, "billing") == configuration.get_config("billing").prefix + "companies" + + assert configuration.get_config("billing").prefix == "async-billing-test-" + assert configuration.get_config("billing").project == "async-billing-test" + assert configuration.get_config("billing").name == "billing" + + assert isinstance(configuration.get_async_client("companies"), AsyncClient) + + assert configuration.get_collection_name(Company, "companies") == configuration.get_config("companies").prefix + "companies" + + assert configuration.get_config("companies").prefix == "async-companies-test-" + assert configuration.get_config("companies").project == "async-companies-test" + assert configuration.get_config("companies").name == "companies" + + # Create and save billing account and billing company + account = BillingAccount(name="ABC Billing", billing_id=801048, owner="MFisher") + bc = BillingCompany(company_id="1234567-8c", billing_account=account) + + await account.save() + await bc.save() + + # Reloads model data from the database + await account.reload() + await bc.reload() + + # 3. Finding data + print(f"\nNumber of company owners with first name: 'Alice': {len(await Company.find({"owner.first_name": "Alice"}))}") + + print(f"\nNumber of billing companies with id: '1234567-8c': {len(await BillingCompany.find({"company_id": "1234567-8c"}))}") + + print(f"\nNumber of billing accounts with billing_id: 801048: {len(await BillingCompany.find({"billing_account.billing_id": 801048}))}") + + # Delete everything from the database + await company.delete_all_for_model() + await bc.delete_all_for_model() + + deletion_success = [] == await Company.find({"company_id": "1234567-8a"}) + print(f"\nDeletion of 'companies' DB succeeded: {deletion_success}") + + deletion_success = [] == await BillingCompany.find({"billing_account.billing_id": 801048}) + print(f"\nDeletion of 'billing' DB succeeded: {deletion_success}") + + + +# suppress silent error msg from asyncio threads +def dbg_hook(exctype, value, tb): + print("=== Uncaught exception ===") + print(exctype, value) + import traceback + traceback.print_tb(tb) + +sys.excepthook = dbg_hook + +if __name__ == "__main__": + import asyncio + + print("\n---- Running OLD way ----") + asyncio.run(old_way()) + + print("\n---- Running NEW way ----") + asyncio.run(new_way()) + + + + + diff --git a/multiple_clients_examples/full_sync_flow.py b/multiple_clients_examples/full_sync_flow.py new file mode 100644 index 0000000..4266001 --- /dev/null +++ b/multiple_clients_examples/full_sync_flow.py @@ -0,0 +1,174 @@ +from unittest.mock import Mock + +# from pydantic import BaseModel +from firedantic import Model +from firedantic.configurations import configuration, Client + +import google.auth.credentials + + +## With single sync client +def old_way(): + + class Owner(Model): + """Dummy owner Pydantic model.""" + __collection__ = "owners" + first_name: str + last_name: str + + + class Company(Model): + """Dummy company Firedantic model.""" + __collection__ = "companies" + company_id: str + owner: Owner + + # Firestore emulator must be running if using locally, name defaults to "(default)" + configuration.add( + prefix="sync-default-test-", + project="sync-default-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="John", last_name="Doe") + company = Company(company_id="1234567-8", owner=owner) + + # Save the company/owner info to the DB + company.save() + + # Reloads model data from the database to ensure most-up-to-date info + company.reload() + + # Assert that sync client exists and configuration is correct + assert isinstance(configuration.get_client(), Client) + assert configuration.get_collection_name(Owner) == configuration.get_config().prefix + "owners" + assert configuration.get_collection_name(Company) == configuration.get_config().prefix + "companies" + assert configuration.get_config().prefix == "sync-default-test-" + assert configuration.get_config().project == "sync-default-test" + assert configuration.get_config().name == "(default)" + + # Finding data from DB + print(f"\nNumber of company owners with first name: 'John': {len(Company.find({"owner.first_name": "John"}))}") + + print(f"\nNumber of companies with id: '1234567-8': {len(Company.find({"company_id": "1234567-8"}))}") + + + # Delete everything from the database + company.delete_all_for_model() + +# Now with multiple SYNC clients/dbs: +def new_way(): + + config_name = "companies" + + class Owner(Model): + """Dummy owner Pydantic model.""" + __db_config__ = config_name + __collection__ = "owners" + first_name: str + last_name: str + + + class Company(Model): + """Dummy company Firedantic model.""" + __db_config__ = config_name + __collection__ = config_name + company_id: str + owner: Owner + + # 1. Add first config with config_name = "companies" + configuration.add( + name=config_name, + prefix="sync-companies-test-", + project="sync-companies-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="Alice", last_name="Begone") + company = Company(company_id="1234567-9", owner=owner) + + company.save() + + # Reloads model data from the database + company.reload() + + # ##################################################### + + # 2. Add the second config with config_name = "billing" + config_name = "billing" + + class BillingAccount(Model): + """Dummy billing account Pydantic model.""" + __db_config__ = config_name + __collection__ = "accounts" + name: str + billing_id: int + owner: str + + + class BillingCompany(Model): + """Dummy company Firedantic model.""" + __db_config__ = config_name + __collection__ = "companies" + company_id: str + billing_account: BillingAccount + + configuration.add( + name=config_name, + prefix="sync-billing-test-", + project="sync-billing-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Assert that sync clients exists and added configurations are correct + assert isinstance(configuration.get_client("billing"), Client) + + assert configuration.get_collection_name(BillingAccount, "billing") == configuration.get_config("billing").prefix + "accounts" + assert configuration.get_collection_name(BillingCompany, "billing") == configuration.get_config("billing").prefix + "companies" + + assert configuration.get_config("billing").prefix == "sync-billing-test-" + assert configuration.get_config("billing").project == "sync-billing-test" + assert configuration.get_config("billing").name == "billing" + + assert isinstance(configuration.get_client("companies"), Client) + + assert configuration.get_collection_name(Company, "companies") == configuration.get_config("companies").prefix + "companies" + + assert configuration.get_config("companies").prefix == "sync-companies-test-" + assert configuration.get_config("companies").project == "sync-companies-test" + assert configuration.get_config("companies").name == "companies" + + + # Create and save billing account and company + account = BillingAccount(name="ABC Billing", billing_id=801048, owner="MFisher") + bc = BillingCompany(company_id="1234567-8c", billing_account=account) + + account.save() + bc.save() + + # Reloads model data from the database + account.reload() + bc.reload() + + # 3. Finding data + print(f"\nNumber of company owners with first name: 'Alice': {len(Company.find({"owner.first_name": "Alice"}))}") + + print(f"\nNumber of billing companies with id: '1234567-8c': {len(BillingCompany.find({"company_id": "1234567-8c"}))}") + + print(f"\nNumber of billing accounts with billing_id: 801048: {len(BillingCompany.find({"billing_account.billing_id": 801048}))}") + + # now delete everything from the DBs: + company.delete_all_for_model() + bc.delete_all_for_model() + + + + +print("\n---- Running OLD way ----") +old_way() +print("\n---- Running NEW way ----") +new_way() + + diff --git a/multiple_clients_examples/query_sample.py b/multiple_clients_examples/query_sample.py new file mode 100644 index 0000000..c02a3eb --- /dev/null +++ b/multiple_clients_examples/query_sample.py @@ -0,0 +1,51 @@ +from unittest.mock import Mock + +from company_models import Company +from configure_firestore_db_client import configure_sync_client, configure_multiple_clients + +import firedantic.operators as op + +# Firestore emulator must be running if using locally. +# Supported types are: <, >, array_contains, in, ==, !=, not-in, <=, array_contains_any, >= + + +print("\n---- Running OLD way ----") +configure_sync_client() + +companies1 = Company.find({"owner.first_name": "John"}) +companies2 = Company.find({"owner.first_name": {op.EQ: "John"}}) +companies3 = Company.find({"owner.first_name": {"==": "John"}}) +print(companies1) + +assert companies1 != [] +assert companies1 == companies2 == companies3 + + +print("\n---- Running NEW way ----") +configure_multiple_clients() + +#config = Configuration() + +# # name = (default) +# config.add( +# prefix="firedantic-test-", +# project="firedantic-test", +# credentials=Mock(spec=google.auth.credentials.Credentials), +# ) + +# # name = billing +# config.add( + +# name="companies", +# prefix="test-companies-", +# project="test-companies", +# credentials=Mock(spec=google.auth.credentials.Credentials), +# ) + + +companies1 = Company.find({"owner.first_name": "Alice"}) +companies2 = Company.find({"owner.first_name": {op.EQ: "Alice"}}) +companies3 = Company.find({"owner.first_name": {"==": "Alice"}}) +print(companies1) +assert companies1 != [] +assert companies1 == companies2 == companies3 From 3c69944060c7ef5a843cf1d7bc319dee9782aa85 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Fri, 28 Nov 2025 16:31:22 -0300 Subject: [PATCH 18/38] removing some files, cleanup --- multiple_clients_examples/billing_models.py | 24 --------- multiple_clients_examples/company_models.py | 19 -------- multiple_clients_examples/full_async_flow.py | 29 ++++------- multiple_clients_examples/full_sync_flow.py | 15 ++---- multiple_clients_examples/query_sample.py | 51 -------------------- 5 files changed, 14 insertions(+), 124 deletions(-) delete mode 100644 multiple_clients_examples/billing_models.py delete mode 100644 multiple_clients_examples/company_models.py delete mode 100644 multiple_clients_examples/query_sample.py diff --git a/multiple_clients_examples/billing_models.py b/multiple_clients_examples/billing_models.py deleted file mode 100644 index b775e63..0000000 --- a/multiple_clients_examples/billing_models.py +++ /dev/null @@ -1,24 +0,0 @@ -from pydantic import BaseModel - -from firedantic import Model, AsyncModel - - -class MyModel(AsyncModel): - pass - - -class BillingAccount(BaseModel): - """Dummy billing account Pydantic model.""" - __db_config__ = "billing" - __collection__ = "billing_accounts" - name: str - billing_id: int - owner: str - - -class BillingCompany(Model): - """Dummy company Firedantic model.""" - __db_config__ = "billing" - __collection__ = "billing_companies" - company_id: str - billing_account: BillingAccount diff --git a/multiple_clients_examples/company_models.py b/multiple_clients_examples/company_models.py deleted file mode 100644 index 1a545e7..0000000 --- a/multiple_clients_examples/company_models.py +++ /dev/null @@ -1,19 +0,0 @@ -from pydantic import BaseModel - -from firedantic import Model - - -class Owner(BaseModel): - """Dummy owner Pydantic model.""" - __db_config__ = "owners" - __collection__ = "owners" - first_name: str - last_name: str - - -class Company(Model): - """Dummy company Firedantic model.""" - __db_config__ = "companies" - __collection__ = "companies" - company_id: str - owner: Owner diff --git a/multiple_clients_examples/full_async_flow.py b/multiple_clients_examples/full_async_flow.py index 0b90967..b7d7776 100644 --- a/multiple_clients_examples/full_async_flow.py +++ b/multiple_clients_examples/full_async_flow.py @@ -1,6 +1,5 @@ from unittest.mock import Mock -from pydantic import BaseAsyncModel from firedantic import AsyncModel from firedantic.configurations import configuration, AsyncClient @@ -9,7 +8,7 @@ ## With single async client -async def old_way(): +async def test_with_default(): class Owner(AsyncModel): """Dummy owner Pydantic model.""" @@ -57,13 +56,12 @@ class Company(AsyncModel): # Delete everything from the database await company.delete_all_for_model() deletion_success = [] == await Company.find({"company_id": "1234567-8a"}) - print(f"\nDeletion of (default) DB succeeded: {deletion_success}\n") - - + if not deletion_success: + print(f"\nDeletion of (default) DB failed\n") # Now with multiple ASYNC clients/dbs: -async def new_way(): +async def test_with_multiple(): config_name = "companies" @@ -169,10 +167,12 @@ class BillingCompany(AsyncModel): await bc.delete_all_for_model() deletion_success = [] == await Company.find({"company_id": "1234567-8a"}) - print(f"\nDeletion of 'companies' DB succeeded: {deletion_success}") + if not deletion_success: + print(f"\nDeletion of Company DB failed\n") deletion_success = [] == await BillingCompany.find({"billing_account.billing_id": 801048}) - print(f"\nDeletion of 'billing' DB succeeded: {deletion_success}") + if not deletion_success: + print(f"\nDeletion of BillingCompany DB failed\n") @@ -187,14 +187,5 @@ def dbg_hook(exctype, value, tb): if __name__ == "__main__": import asyncio - - print("\n---- Running OLD way ----") - asyncio.run(old_way()) - - print("\n---- Running NEW way ----") - asyncio.run(new_way()) - - - - - + asyncio.run(test_with_default()) + asyncio.run(test_with_multiple()) diff --git a/multiple_clients_examples/full_sync_flow.py b/multiple_clients_examples/full_sync_flow.py index 4266001..b27b62a 100644 --- a/multiple_clients_examples/full_sync_flow.py +++ b/multiple_clients_examples/full_sync_flow.py @@ -1,6 +1,5 @@ from unittest.mock import Mock -# from pydantic import BaseModel from firedantic import Model from firedantic.configurations import configuration, Client @@ -8,7 +7,7 @@ ## With single sync client -def old_way(): +def test_with_default(): class Owner(Model): """Dummy owner Pydantic model.""" @@ -58,7 +57,7 @@ class Company(Model): company.delete_all_for_model() # Now with multiple SYNC clients/dbs: -def new_way(): +def test_with_multiple(): config_name = "companies" @@ -164,11 +163,5 @@ class BillingCompany(Model): bc.delete_all_for_model() - - -print("\n---- Running OLD way ----") -old_way() -print("\n---- Running NEW way ----") -new_way() - - +test_with_default() +test_with_multiple() diff --git a/multiple_clients_examples/query_sample.py b/multiple_clients_examples/query_sample.py deleted file mode 100644 index c02a3eb..0000000 --- a/multiple_clients_examples/query_sample.py +++ /dev/null @@ -1,51 +0,0 @@ -from unittest.mock import Mock - -from company_models import Company -from configure_firestore_db_client import configure_sync_client, configure_multiple_clients - -import firedantic.operators as op - -# Firestore emulator must be running if using locally. -# Supported types are: <, >, array_contains, in, ==, !=, not-in, <=, array_contains_any, >= - - -print("\n---- Running OLD way ----") -configure_sync_client() - -companies1 = Company.find({"owner.first_name": "John"}) -companies2 = Company.find({"owner.first_name": {op.EQ: "John"}}) -companies3 = Company.find({"owner.first_name": {"==": "John"}}) -print(companies1) - -assert companies1 != [] -assert companies1 == companies2 == companies3 - - -print("\n---- Running NEW way ----") -configure_multiple_clients() - -#config = Configuration() - -# # name = (default) -# config.add( -# prefix="firedantic-test-", -# project="firedantic-test", -# credentials=Mock(spec=google.auth.credentials.Credentials), -# ) - -# # name = billing -# config.add( - -# name="companies", -# prefix="test-companies-", -# project="test-companies", -# credentials=Mock(spec=google.auth.credentials.Credentials), -# ) - - -companies1 = Company.find({"owner.first_name": "Alice"}) -companies2 = Company.find({"owner.first_name": {op.EQ: "Alice"}}) -companies3 = Company.find({"owner.first_name": {"==": "Alice"}}) -print(companies1) -assert companies1 != [] -assert companies1 == companies2 == companies3 From 55c90d8f8ce500a2164ca7a7d9e800c6a89a8ba0 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Fri, 28 Nov 2025 21:44:29 -0300 Subject: [PATCH 19/38] adding readme for integ tests --- integration_tests/README.md | 52 +++++++ .../configure_firestore_db_client.py | 147 ++++++++++++++++++ .../full_async_flow.py | 47 +++++- .../full_sync_flow.py | 54 ++++++- .../configure_firestore_db_client.py | 82 ---------- 5 files changed, 294 insertions(+), 88 deletions(-) create mode 100644 integration_tests/README.md create mode 100644 integration_tests/configure_firestore_db_client.py rename {multiple_clients_examples => integration_tests}/full_async_flow.py (81%) rename {multiple_clients_examples => integration_tests}/full_sync_flow.py (77%) delete mode 100644 multiple_clients_examples/configure_firestore_db_client.py diff --git a/integration_tests/README.md b/integration_tests/README.md new file mode 100644 index 0000000..edca292 --- /dev/null +++ b/integration_tests/README.md @@ -0,0 +1,52 @@ +# integration_tests/README.md + +## Overview +This folder contains three integration test files. This README explains how to run them and what each file is for. + +## Prerequisites +- Python 3.13 +- Run: +```bash +python -m venv firedantic +source firedantic/bin/activate +cd firedantic/integration_tests +``` + +## Files and purpose (replace placeholders with real filenames) +- `integration_tests/configure_firestore_db_clients.py` — Purpose: shows how to create and connect to various db clients. +- `integration_tests/full_sync_flow.py` — Purpose: configures clients, saves data to db, finds the data, and deletes all data in a sync fashion. +- `integration_tests/full_async_flow.py` — Purpose: configures async clients, saves data to db, finds the data, and deletes all data in an async fashion. + + +## How to run +Run each individual test file: + - `python configure_firestore_db_clients.py` + - `python full_sync_flow.py` + - `python full_async_flow.py` + + +## Environment and configuration +- Ensure firestore emulator is running in another terminal window: + - `./start_emulator.sh` + +## What to expect: +- For the configure_firestore_db_clients test, you should expect to see the following: +`All configure_firestore_db_client tests passed!` + +- For the `full_sync_flow` and `full_async_flow`, you should expect to see the following output: \ +You can readily play around with the models to update the data as desired. +``` +Number of company owners with first name: 'Bill': 1 + +Number of companies with id: '1234567-7': 1 + +Number of company owners with first name: 'John': 1 + +Number of companies with id: '1234567-8a': 1 + +Number of company owners with first name: 'Alice': 1 + +Number of billing companies with id: '1234567-8c': 1 + +Number of billing accounts with billing_id: 801048: 1 +``` \ No newline at end of file diff --git a/integration_tests/configure_firestore_db_client.py b/integration_tests/configure_firestore_db_client.py new file mode 100644 index 0000000..7830764 --- /dev/null +++ b/integration_tests/configure_firestore_db_client.py @@ -0,0 +1,147 @@ +from os import environ +from unittest.mock import Mock + +import google.auth.credentials +from google.cloud.firestore import AsyncClient, Client + +from firedantic.configurations import configuration, configure, CONFIGURATIONS + + +## OLD WAY: +def configure_sync_client(): + # Firestore emulator must be running if using locally. + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = Client( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + else: + client = Client() + + configure(client, prefix="firedantic-sync-") + + # ---- Assertions ---- + assert CONFIGURATIONS["prefix"] == "firedantic-sync-" + assert isinstance(CONFIGURATIONS["db"], Client) + + # project check + assert CONFIGURATIONS["db"].project == "firedantic-test" + + # emulator expectations + assert isinstance(CONFIGURATIONS["db"]._credentials, Mock) + + # ensure no accidental async setting + assert not isinstance(CONFIGURATIONS["db"], AsyncClient) + + + +def configure_async_client(): + # Firestore emulator must be running if using locally. + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = AsyncClient( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + else: + client = AsyncClient() + + configure(client, prefix="firedantic-async-") + + # ---- Assertions ---- + assert CONFIGURATIONS["prefix"] == "firedantic-async-" + assert isinstance(CONFIGURATIONS["db"], AsyncClient) + + # project check + assert CONFIGURATIONS["db"].project == "firedantic-test" + + # emulator expectations + assert isinstance(CONFIGURATIONS["db"]._credentials, Mock) + + # ensure no accidental async setting + assert not isinstance(CONFIGURATIONS["db"], Client) + + + +## NEW WAY: +def configure_multiple_clients(): + config = configuration + + # name = (default) + config.add( + prefix="firedantic-test-", + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # name = billing + config.add( + name="billing", + prefix="test-billing-", + project="test-billing", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + + # =========================== + # VALIDATE DEFAULT CONFIG + # =========================== + default = config.get_config() # loads "(default)" + + # Basic prefix & project + assert default.prefix == "firedantic-test-" + assert default.project == "firedantic-test" + + # Client objects exist + assert default.client is not None + assert default.async_client is not None + + # Types are correct + assert isinstance(default.client, Client) + assert isinstance(default.async_client, AsyncClient) + + # Credentials are what we passed in + assert isinstance(default.credentials, Mock) + + # Config can return correct client API + assert isinstance(config.get_client(), Client) + assert isinstance(config.get_async_client(), AsyncClient) + + + # =========================== + # VALIDATE BILLING CONFIG + # =========================== + billing = config.get_config("billing") + + # Prefix & project + assert billing.prefix == "test-billing-" + assert billing.project == "test-billing" + + # Correct client types + assert isinstance(billing.client, Client) + assert isinstance(billing.async_client, AsyncClient) + + # Correct credentials object + assert isinstance(billing.credentials, Mock) + + # Ensure default and billing configs are distinct objects + assert billing is not default + + # Make sure nothing leaked between configs + assert billing.prefix != default.prefix + assert billing.project != default.project + + + +try: + ### ---- Running OLD way ---- + configure_sync_client() + configure_async_client() + + ### ---- Running NEW way ---- + configure_multiple_clients() + + print("\nAll configure_firestore_db_client tests passed!\n") + +except AssertionError as e: + print(f"\nconfigure_firestore_db_client tests failed: {e}\n") + raise e diff --git a/multiple_clients_examples/full_async_flow.py b/integration_tests/full_async_flow.py similarity index 81% rename from multiple_clients_examples/full_async_flow.py rename to integration_tests/full_async_flow.py index b7d7776..81c8787 100644 --- a/multiple_clients_examples/full_async_flow.py +++ b/integration_tests/full_async_flow.py @@ -1,11 +1,54 @@ from unittest.mock import Mock from firedantic import AsyncModel -from firedantic.configurations import configuration, AsyncClient +from firedantic.configurations import configuration, AsyncClient, configure import google.auth.credentials import sys +from os import environ +async def test_old_way(): + # This is the old way of doing things, kept for backwards compatibility. + # Firestore emulator must be running if using locally. + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = AsyncClient( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials) + ) + else: + client = AsyncClient() + + configure(client, prefix="firedantic-test-") + + class Owner(AsyncModel): + """Dummy owner Pydantic model.""" + first_name: str + last_name: str + + + class Company(AsyncModel): + """Dummy company Firedantic model.""" + __collection__ = "companies" + company_id: str + owner: Owner + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="Bill", last_name="Smith") + company = Company(company_id="1234567-7", owner=owner) + await company.save() + + # Reloads model data from the database + await company.reload() + + # Finding data from DB + print(f"\nNumber of company owners with first name: 'Bill': {len(await Company.find({"owner.first_name": "Bill"}))}") + print(f"\nNumber of companies with id: '1234567-7': {len(await Company.find({"company_id": "1234567-7"}))}") + + # Delete everything from the database + await company.delete_all_for_model() + deletion_success = [] == await Company.find({"company_id": "1234567-7"}) + if not deletion_success: + print(f"\nDeletion of (default) DB failed\n") ## With single async client async def test_with_default(): @@ -185,7 +228,9 @@ def dbg_hook(exctype, value, tb): sys.excepthook = dbg_hook +# Run the tests if __name__ == "__main__": import asyncio + asyncio.run(test_old_way()) asyncio.run(test_with_default()) asyncio.run(test_with_multiple()) diff --git a/multiple_clients_examples/full_sync_flow.py b/integration_tests/full_sync_flow.py similarity index 77% rename from multiple_clients_examples/full_sync_flow.py rename to integration_tests/full_sync_flow.py index b27b62a..02316f8 100644 --- a/multiple_clients_examples/full_sync_flow.py +++ b/integration_tests/full_sync_flow.py @@ -1,9 +1,52 @@ from unittest.mock import Mock +from pydantic import BaseModel from firedantic import Model -from firedantic.configurations import configuration, Client +from firedantic.configurations import configuration, Client, configure import google.auth.credentials +from os import environ + + +def test_old_way(): + # This is the old way of doing things, kept for backwards compatibility. + # Firestore emulator must be running if using locally. + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = Client( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials) + ) + else: + client = Client() + + configure(client, prefix="firedantic-test-") + + class Owner(BaseModel): + """Dummy owner Pydantic model.""" + first_name: str + last_name: str + + + class Company(Model): + """Dummy company Firedantic model.""" + __collection__ = "companies" + company_id: str + owner: Owner + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="Bill", last_name="Smith") + company = Company(company_id="1234567-7", owner=owner) + company.save() + + # Reloads model data from the database + company.reload() + + # Finding data from DB + print(f"\nNumber of company owners with first name: 'Bill': {len(Company.find({"owner.first_name": "Bill"}))}") + print(f"\nNumber of companies with id: '1234567-7': {len(Company.find({"company_id": "1234567-7"}))}") + + # Delete everything from the database + company.delete_all_for_model() ## With single sync client @@ -49,10 +92,8 @@ class Company(Model): # Finding data from DB print(f"\nNumber of company owners with first name: 'John': {len(Company.find({"owner.first_name": "John"}))}") - print(f"\nNumber of companies with id: '1234567-8': {len(Company.find({"company_id": "1234567-8"}))}") - # Delete everything from the database company.delete_all_for_model() @@ -153,15 +194,18 @@ class BillingCompany(Model): # 3. Finding data print(f"\nNumber of company owners with first name: 'Alice': {len(Company.find({"owner.first_name": "Alice"}))}") - print(f"\nNumber of billing companies with id: '1234567-8c': {len(BillingCompany.find({"company_id": "1234567-8c"}))}") - print(f"\nNumber of billing accounts with billing_id: 801048: {len(BillingCompany.find({"billing_account.billing_id": 801048}))}") # now delete everything from the DBs: company.delete_all_for_model() bc.delete_all_for_model() +## Ensure existing configure function works for backwards compatibility +test_old_way() +## Ensure new way with single default client works test_with_default() + +## Ensure new way with multiple clients works test_with_multiple() diff --git a/multiple_clients_examples/configure_firestore_db_client.py b/multiple_clients_examples/configure_firestore_db_client.py deleted file mode 100644 index c6eb0ac..0000000 --- a/multiple_clients_examples/configure_firestore_db_client.py +++ /dev/null @@ -1,82 +0,0 @@ -from os import environ -from unittest.mock import Mock - -import google.auth.credentials -from google.cloud.firestore import AsyncClient, Client - -from firedantic import Configuration, configure, CONFIGURATIONS -from billing_models import BillingAccount, BillingCompany -from base_models import MyModel - - -## OLD WAY: -def configure_sync_client(): - # Firestore emulator must be running if using locally. - if environ.get("FIRESTORE_EMULATOR_HOST"): - client = Client( - project="firedantic-test", - credentials=Mock(spec=google.auth.credentials.Credentials), - ) - else: - client = Client() - - configure(client, prefix="firedantic-sync-") - assert CONFIGURATIONS["prefix"] == "firedantic-sync-" - assert isinstance(CONFIGURATIONS["db"], Client) - - -def configure_async_client(): - # Firestore emulator must be running if using locally. - if environ.get("FIRESTORE_EMULATOR_HOST"): - client = AsyncClient( - project="firedantic-test", - credentials=Mock(spec=google.auth.credentials.Credentials), - ) - else: - client = AsyncClient() - - configure(client, prefix="firedantic-async-") - assert CONFIGURATIONS["prefix"] == "firedantic-async-" - assert isinstance(CONFIGURATIONS["db"], AsyncClient) - - - -## NEW WAY: -def configure_multiple_clients(): - config = Configuration() - - # name = (default) - config.add( - prefix="firedantic-test-", - project="firedantic-test", - credentials=Mock(spec=google.auth.credentials.Credentials), - ) - - # name = billing - config.add( - name="billing", - prefix="test-billing-", - project="test-billing", - credentials=Mock(spec=google.auth.credentials.Credentials), - ) - - default = config.get_config() ## pulls default config - print(config.get_collection_ref(MyModel)) - - assert default.prefix == "firedantic-test-" - assert isinstance(config.get_client(), Client) - - billing = config.get_config("billing") ## pulls billing config - assert billing.prefix == "test-billing-" - - print(config.get_collection_ref(BillingAccount, "billing")) - assert isinstance(config.get_client("billing"), Client) - - print(config.get_async_collection_ref(BillingAccount, "billing")) - assert isinstance(config.get_async_client("billing"), AsyncClient) - -print("\n---- Running OLD way ----") -configure_sync_client() -configure_async_client() -print("\n---- Running NEW way ----") -configure_multiple_clients() From 6eae4af040c4c088bdbcecfe99208d0c90d034b1 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Mon, 1 Dec 2025 17:59:57 -0300 Subject: [PATCH 20/38] fix async tests, delete w transaction not working nor sync, yet --- firedantic/_async/indexes.py | 2 +- firedantic/_async/model.py | 20 ++++++++++++-------- firedantic/configurations.py | 9 ++++++++- firedantic/tests/tests_async/test_model.py | 7 ++++--- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/firedantic/_async/indexes.py b/firedantic/_async/indexes.py index 0f11be8..51e6b67 100644 --- a/firedantic/_async/indexes.py +++ b/firedantic/_async/indexes.py @@ -111,7 +111,7 @@ async def set_up_composite_indexes( project = gcloud_project or configuration.get_config(config_name).project # Build collection group path using configuration helper (includes prefix) - collection_group = configuration.get_collection_name(model, name=config_name) + collection_group = configuration.get_collection_name(model, config_name=config_name) path = f"projects/{project}/databases/{database}/collectionGroups/{collection_group}" indexes_in_db = await get_existing_indexes(client, path=path) diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index db500f2..c438b04 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -41,7 +41,7 @@ } -def get_collection_name(cls, collection_name: Optional[str]) -> str: +def get_collection_name(cls, collection_name: Optional[str] = None) -> str: """ Return the collection name for `cls`. @@ -132,7 +132,7 @@ async def save( if not resolved: resolved = getattr(self.__class__, "__db_config__", "(default)") - # Build payload - model_dump? + # Build payload data = self.model_dump( by_alias=True, exclude_unset=exclude_unset, @@ -140,10 +140,15 @@ async def save( ) if self.__document_id__ in data: del data[self.__document_id__] + + async_client = configuration.get_async_client(resolved) + if async_client is None: + raise RuntimeError(f"No async client configured for config '{resolved}'") - # Get Async collection reference using configuration - what? - col_ref = configuration.get_async_collection_ref(self.__class__, name=resolved) - + # Get collection reference from async_client with the collection_name + collection_name = self.get_collection_name() + col_ref = async_client.collection(collection_name) + # Build doc ref (use provided id if set, otherwise let server generate) doc_id = self.get_document_id() if doc_id: @@ -158,7 +163,6 @@ async def save( else: await doc_ref.set(data) - # what does this do? setattr(self, self.__document_id__, doc_ref.id) async def delete(self, transaction: Optional[AsyncTransaction] = None) -> None: @@ -209,7 +213,7 @@ async def delete_all_for_model(cls, config_name: Optional[str] = None) -> None: config_name = cls.__db_config__ client = configuration.get_async_client(config_name) - col_name = configuration.get_collection_name(cls, config_name=config_name) + col_name = get_collection_name(cls) col_ref = client.collection(col_name) async for doc in col_ref.stream(): @@ -482,7 +486,7 @@ def _create(cls: Type[TAsyncBareSubModel], **kwargs) -> TAsyncBareSubModel: ) @classmethod - def _get_col_ref(cls, collection_name: Optional[str]) -> AsyncCollectionReference: + def _get_col_ref(cls) -> AsyncCollectionReference: """ Returns the collection reference. """ diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 010cf2e..6bfea12 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -9,6 +9,7 @@ CollectionReference, Transaction ) +from firedantic.exceptions import CollectionNotDefined from pydantic import BaseModel # --- Old compatibility surface (kept for backwards compatibility) --- @@ -181,7 +182,11 @@ def get_collection_name(self, model_class: Type, config_name: Optional[str] = No resolved = config_name if config_name is not None else "(default)" cfg = self.get_config(resolved) prefix = cfg.prefix or "" - + + # Resolve collection to use (explicit -> instance -> class -> default) + # col_name = getattr(self, "__collection__", None) + # if not col_name: + # raise CollectionNotDefined(f"Missing collection name for {resolved}") if hasattr(model_class, "__collection__"): return prefix + model_class.__collection__ else: @@ -209,7 +214,9 @@ def get_async_collection_ref(self, model_class: Type, name: Optional[str] = None async_client = cfg.async_client if async_client is None: raise RuntimeError(f"No async client configured for config '{resolved}'") + collection_name = self.get_collection_name(model_class, resolved) + return async_client.collection(collection_name) # make the module-level singleton available to models/tests diff --git a/firedantic/tests/tests_async/test_model.py b/firedantic/tests/tests_async/test_model.py index dac4c2e..ce8040c 100644 --- a/firedantic/tests/tests_async/test_model.py +++ b/firedantic/tests/tests_async/test_model.py @@ -45,7 +45,7 @@ async def test_save_model(create_company) -> None: assert company.owner.last_name == "Doe" @pytest.mark.asyncio -async def test_delete_model(create_company) -> None: +async def test_delete_all_for_model(create_company) -> None: company: Company = await create_company( company_id="11223344-5", first_name="Jane", last_name="Doe" ) @@ -53,7 +53,7 @@ async def test_delete_model(create_company) -> None: _id = company.id assert _id - await company.delete() + await company.delete_all_for_model() with pytest.raises(ModelNotFoundError): await Company.get_by_id(_id) @@ -275,6 +275,7 @@ async def test_get_by_empty_str_id() -> None: async def test_missing_collection() -> None: class User(AsyncModel): name: str + # normally __collection__ would be defined here with pytest.raises(CollectionNotDefined): await User(name="John").save() @@ -441,7 +442,7 @@ async def test_extra_fields() -> None: @pytest.mark.asyncio -async def test_company_stats( create_company) -> None: +async def test_company_stats(create_company) -> None: company: Company = await create_company(company_id="1234567-8") company_stats = company.stats() From fc468dbd82d04fc9f6fd108763d85e65bbec2388 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Mon, 1 Dec 2025 18:09:12 -0300 Subject: [PATCH 21/38] all passing but one --- firedantic/_sync/indexes.py | 2 +- firedantic/_sync/model.py | 4 ++-- firedantic/configurations.py | 5 ----- firedantic/tests/tests_sync/test_model.py | 4 ++-- 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/firedantic/_sync/indexes.py b/firedantic/_sync/indexes.py index f39c61b..5ec5c66 100644 --- a/firedantic/_sync/indexes.py +++ b/firedantic/_sync/indexes.py @@ -111,7 +111,7 @@ def set_up_composite_indexes( project = gcloud_project or configuration.get_config(config_name).project # Build collection group path using configuration helper (includes prefix) - collection_group = configuration.get_collection_name(model, name=config_name) + collection_group = configuration.get_collection_name(model, config_name=config_name) path = f"projects/{project}/databases/{database}/collectionGroups/{collection_group}" indexes_in_db = get_existing_indexes(client, path=path) diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index eda0e10..99156fa 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -42,7 +42,7 @@ } -def get_collection_name(cls, collection_name: Optional[str]) -> str: +def get_collection_name(cls, collection_name: Optional[str] = None) -> str: """ Returns the collection name for `cls`. @@ -225,7 +225,7 @@ def delete_all_for_model(cls, config_name: Optional[str] = None) -> None: config_name = cls.__db_config__ client = configuration.get_client(config_name) - col_name = configuration.get_collection_name(cls, config_name=config_name) + col_name = get_collection_name(cls) col_ref = client.collection(col_name) for doc in col_ref.stream(): diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 6bfea12..50a2471 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -9,7 +9,6 @@ CollectionReference, Transaction ) -from firedantic.exceptions import CollectionNotDefined from pydantic import BaseModel # --- Old compatibility surface (kept for backwards compatibility) --- @@ -183,10 +182,6 @@ def get_collection_name(self, model_class: Type, config_name: Optional[str] = No cfg = self.get_config(resolved) prefix = cfg.prefix or "" - # Resolve collection to use (explicit -> instance -> class -> default) - # col_name = getattr(self, "__collection__", None) - # if not col_name: - # raise CollectionNotDefined(f"Missing collection name for {resolved}") if hasattr(model_class, "__collection__"): return prefix + model_class.__collection__ else: diff --git a/firedantic/tests/tests_sync/test_model.py b/firedantic/tests/tests_sync/test_model.py index 0272250..321e51c 100644 --- a/firedantic/tests/tests_sync/test_model.py +++ b/firedantic/tests/tests_sync/test_model.py @@ -44,7 +44,7 @@ def test_save_model(create_company) -> None: assert company.owner.last_name == "Doe" -def test_delete_model(create_company) -> None: +def test_delete_all_for_model(create_company) -> None: company: Company = create_company( company_id="11223344-5", first_name="Jane", last_name="Doe" ) @@ -52,7 +52,7 @@ def test_delete_model(create_company) -> None: _id = company.id assert _id - company.delete() + company.delete_all_for_model() with pytest.raises(ModelNotFoundError): Company.get_by_id(_id) From d5541228538d7b09dd573eff094488d4230d9553 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Mon, 1 Dec 2025 18:52:59 -0300 Subject: [PATCH 22/38] delete still not working for async :/ --- firedantic/_async/model.py | 9 +++++++-- firedantic/tests/tests_async/test_model.py | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index c438b04..829dc1a 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -171,10 +171,15 @@ async def delete(self, transaction: Optional[AsyncTransaction] = None) -> None: :raise DocumentIDError: If the ID is not valid. """ + doc_ref = self._get_doc_ref() + if transaction is not None: - transaction.delete(self._get_doc_ref()) + print(f"\nDeleting document: {doc_ref.path}") + transaction.delete(doc_ref) + print(f"Deleted document in transaction: {doc_ref.path}\n") else: - await self._get_doc_ref().delete() + print(f"\nTransaction not provided, deleting document: {self._get_doc_ref().path}") + await doc_ref.delete() async def reload(self, transaction: Optional[AsyncTransaction] = None) -> None: """ diff --git a/firedantic/tests/tests_async/test_model.py b/firedantic/tests/tests_async/test_model.py index ce8040c..8e40652 100644 --- a/firedantic/tests/tests_async/test_model.py +++ b/firedantic/tests/tests_async/test_model.py @@ -589,7 +589,8 @@ async def delete_in_transaction( assert p.id t = get_async_transaction() - await delete_in_transaction(t, p.id) + async with t: + await delete_in_transaction(t, p.id) with pytest.raises(ModelNotFoundError): await Profile.get_by_id(p.id) From 22c641671bb39f2a85117bfb12821065b1994c2a Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Wed, 3 Dec 2025 19:53:40 -0300 Subject: [PATCH 23/38] add admin clients support --- firedantic/_async/indexes.py | 1 + firedantic/configurations.py | 178 +++++++++++++++++++++++++---------- 2 files changed, 128 insertions(+), 51 deletions(-) diff --git a/firedantic/_async/indexes.py b/firedantic/_async/indexes.py index 51e6b67..322e1ec 100644 --- a/firedantic/_async/indexes.py +++ b/firedantic/_async/indexes.py @@ -130,6 +130,7 @@ async def set_up_composite_indexes( return operations +## Edit this to handle different project and databases async def set_up_composite_indexes_and_ttl_policies( gcloud_project: str, models: Iterable[Type[AsyncBareModel]], diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 50a2471..3fcbfe6 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -2,6 +2,8 @@ from typing import Any, Dict, Optional, Type, Union from google.auth.credentials import Credentials +from google.api_core.client_options import ClientOptions +from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud.firestore_v1 import ( AsyncClient, AsyncTransaction, @@ -9,6 +11,9 @@ CollectionReference, Transaction ) +from google.cloud.firestore_admin_v1 import FirestoreAdminClient +from google.cloud.firestore_admin_v1.services.firestore_admin import FirestoreAdminAsyncClient +from google.cloud.firestore_admin_v1.services.firestore_admin.transports.base import DEFAULT_CLIENT_INFO from pydantic import BaseModel # --- Old compatibility surface (kept for backwards compatibility) --- @@ -46,14 +51,24 @@ def get_async_transaction() -> AsyncTransaction: class ConfigItem(BaseModel): """ Holds configuration for a named Firestore connection. - Use arbitrary_types_allowed so we can store Client/AsyncClient/Credentials objects. + Clients may be provided directly, or created lazily from the stored params. """ name: str - prefix: str + prefix: str = "" project: Optional[str] = None - credentials: Optional[Credentials] = None + database: str = "(default)" + + # client objects (may be None; created lazily) client: Optional[Any] = None async_client: Optional[Any] = None + admin_client: Optional[Any] = None + async_admin_client: Optional[Any] = None + + # creation params (kept so we can lazily instantiate clients) + credentials: Optional[Credentials] = None + client_info: Optional[Any] = None + client_options: Optional[Union[Dict[str, Any], Any]] = None + admin_transport: Optional[Any] = None model_config = {"arbitrary_types_allowed": True} @@ -67,7 +82,7 @@ class Configuration: """ def __init__(self) -> None: - + # mapping name -> ConfigItem self.config: Dict[str, ConfigItem] = {} @@ -78,10 +93,56 @@ def __init__(self) -> None: name="(default)", prefix="", project=default_project, - credentials=None, - client=None, - async_client=None + database="(default)", ) + + def add( + self, + name: str = "(default)", # adding a config without a name results in overriding the default + *, + project: Optional[str] = None, + database: str = "(default)", + prefix: str = "", + client: Optional[Any] = None, + async_client: Optional[Any] = None, + admin_client: Optional[Any] = None, + async_admin_client: Optional[Any] = None, + credentials: Optional[Credentials] = None, + client_info: Optional[Any] = None, + client_options: Optional[Union[Dict[str, Any], Any]] = None, + admin_transport: Optional[Any] = None, + ) -> ConfigItem: + """ + Add a named configuration. + + You may either pass a pre-built client and/or an async_client, + or provide only project/credentials so clients will be constructed lazily. + """ + + normalized_project = self._normalize_project(project) + + # Construct clients lazily, only when they were not supplied + # if client is None and normalized_project is not None: + # client = Client(project=normalized_project, credentials=credentials) + # if async_client is None and normalized_project is not None: + # async_client = AsyncClient(project=normalized_project, credentials=credentials) + + item = ConfigItem( + name=name, + project=normalized_project, + database=database, + prefix=prefix, + client=client, + async_client=async_client, + admin_client=admin_client, + async_admin_client=async_admin_client, + credentials=credentials, + client_info=client_info, + client_options=client_options, + admin_transport=admin_transport, + ) + self.config[name] = item + return item # dict-like accessors def __getitem__(self, name: str) -> ConfigItem: @@ -101,6 +162,9 @@ def get_config(self, name: str = "(default)") -> ConfigItem: f"Configuration '{name}' not found. Available: {list(self.config.keys())}" ) from err + def get_config_names(self): + return list(self.config.keys()) + def _normalize_project(self, project: Optional[str]) -> Optional[str]: """ Convert empty-string project to None so Client(...) doesn't get empty string for project (i.e. project="") @@ -109,64 +173,76 @@ def _normalize_project(self, project: Optional[str]) -> Optional[str]: return project return environ.get("GOOGLE_CLOUD_PROJECT") or None - - """ - Add a named configuration. - - You may either pass pre-built client and/or async_client, - or provide only project/credentials so clients will be constructed here. - """ - def add( - self, - name: str = "(default)", # adding a config without a name results in overriding the default - prefix: str = "", - project: Optional[str] = None, - credentials: Optional[Credentials] = None, - client: Optional[Client] = None, - async_client: Optional[AsyncClient] = None, - ) -> ConfigItem: - - normalized_project = self._normalize_project(project) - - # Construct clients lazily, only when they were not supplied - if client is None and normalized_project is not None: - client = Client(project=normalized_project, credentials=credentials) - if async_client is None and normalized_project is not None: - async_client = AsyncClient(project=normalized_project, credentials=credentials) - - item = ConfigItem( - name=name, - prefix=prefix, - project=normalized_project, - credentials=credentials, - client=client, - async_client=async_client, - ) - self.config[name] = item - return item - - # client accessors (lazy-check and raise errors) + # sync client accessor (lazy-create) def get_client(self, name: Optional[str] = None) -> Client: resolved = name if name is not None else "(default)" cfg = self.get_config(resolved) if cfg.client is None: - # give a clear error; tests should register a mock client or provide a project - raise RuntimeError( - f"No sync client configured for config '{resolved}'. " - "Call configuration.add(..., client=...) or set GOOGLE_CLOUD_PROJECT and let add() build the client." + # don't try to create a client if we lack project info + if cfg.project is None: + raise RuntimeError( + f"No sync client configured for '{resolved}' and no project available; " + "call configuration.add(..., client=..., project=...) or set " \ + "GOOGLE_CLOUD_PROJECT and let add() build the client." + ) + cfg.client = Client( + project=cfg.project, + credentials=cfg.credentials, + client_info=cfg.client_info, + client_options=cfg.client_options, # type: ignore[arg-type] + # NOTE: modern firestore clients may accept database param in constructor; keep for future. ) return cfg.client + # async client accessor (lazy-create) def get_async_client(self, name: Optional[str] = None) -> AsyncClient: resolved = name if name is not None else "(default)" cfg = self.get_config(resolved) + if cfg.async_client is None: - raise RuntimeError( - f"No async client configured for config '{resolved}'. " - "Call configuration.add(..., async_client=...) or set GOOGLE_CLOUD_PROJECT and let add() build the client." + if cfg.project is None: + raise RuntimeError( + f"No async client configured for '{resolved}' and no project available; " + "call configuration.add(..., async_client=..., project=...) or set " \ + "GOOGLE_CLOUD_PROJECT and let add() build the client." + ) + cfg.async_client = AsyncClient( + project=cfg.project, + credentials=cfg.credentials, + client_info=cfg.client_info, + client_options=cfg.client_options, # type: ignore[arg-type] ) return cfg.async_client + + # admin client accessor (lazy-create) + def get_admin_client(self, name: Optional[str] = None) -> FirestoreAdminClient: + resolved = name if name is not None else "(default)" + cfg = self.get_config(resolved) + + if cfg.admin_client is None: + cfg.admin_client = FirestoreAdminClient( + credentials=cfg.credentials, + transport=cfg.admin_transport, + client_options=cfg.client_options, # type: ignore[arg-type] + client_info=cfg.client_info or DEFAULT_CLIENT_INFO, + ) + return cfg.admin_client + + # async admin client accessor (lazy-create) + def get_async_admin_client(self, name: Optional[str] = None) -> FirestoreAdminAsyncClient: + resolved = name if name is not None else "(default)" + cfg = self.get_config(resolved) + + if cfg.async_admin_client is None: + cfg.async_admin_client = FirestoreAdminAsyncClient( + credentials=cfg.credentials, + transport=cfg.admin_transport, + client_options=cfg.client_options, # type: ignore[arg-type] + client_info=cfg.client_info or DEFAULT_CLIENT_INFO, + ) + return cfg.async_admin_client + # transactions def get_transaction(self, name: Optional[str] = None) -> Transaction: return self.get_client(name=name).transaction() From 7a40e19e210fc624cef91b34271a5586cde26977 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Thu, 4 Dec 2025 12:33:11 -0300 Subject: [PATCH 24/38] adding and improving configurations unit tests --- firedantic/_async/indexes.py | 1 - firedantic/configurations.py | 6 ------ 2 files changed, 7 deletions(-) diff --git a/firedantic/_async/indexes.py b/firedantic/_async/indexes.py index 322e1ec..51e6b67 100644 --- a/firedantic/_async/indexes.py +++ b/firedantic/_async/indexes.py @@ -130,7 +130,6 @@ async def set_up_composite_indexes( return operations -## Edit this to handle different project and databases async def set_up_composite_indexes_and_ttl_policies( gcloud_project: str, models: Iterable[Type[AsyncBareModel]], diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 3fcbfe6..edeaa26 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -120,12 +120,6 @@ def add( """ normalized_project = self._normalize_project(project) - - # Construct clients lazily, only when they were not supplied - # if client is None and normalized_project is not None: - # client = Client(project=normalized_project, credentials=credentials) - # if async_client is None and normalized_project is not None: - # async_client = AsyncClient(project=normalized_project, credentials=credentials) item = ConfigItem( name=name, From 4e10002c0c5f3cbcc7d5666a355301ce4e36843e Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Thu, 4 Dec 2025 17:05:54 -0300 Subject: [PATCH 25/38] additions and clarifications to README --- README.md | 325 ++++++++++++++++++++++++++++------- firedantic/configurations.py | 14 +- 2 files changed, 276 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 1d59abd..c6cb364 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,41 @@ The package is available on PyPI: pip install firedantic ``` +## Quick overview + +Firedantic provides simple Pydantic-based models for Firestore, with both sync and async model classes, helpers for composite indexes and TTL policies, and an improved configuration system that supports multiple named Firestore connections. + ## Usage +### Migration Guide: Legacy `configure()` → New `Configuration` + +We introduced a new, more flexible `Configuration` registry to support multiple Firestore clients, lazy client creation, and admin clients. The legacy `configure()` helper is still supported for backwards compatibility, but it is now considered **deprecated**. Scroll below for legacy instructions. + + +### New (Recommended) Usage +```python +from firedantic.configurations import configuration + +# default config +configuration.add(prefix="app-", project="my-project") + +# extra config +configuration.add(name="billing", prefix="billing-", project="billing-project") + +# get clients +client = configuration.get_client() # sync client for "(default)" +async_client = configuration.get_async_client() # async client for "(default)" +billing_client = configuration.get_client("billing") # sync client for "billing" +``` + +Notes: + +- `configuration.add(...)` accepts either client/async_client (pre-built) or project+credentials and will lazily create clients. +- Models can declare `__db_config__ = "custom-name"` to use a named configuration or omit it to use the `"(default)"` config. +- Backwards-compatible helpers (`configure()`, `CONFIGURATIONS`) will still populate the old surface but are deprecated. + +### Old (Legacy – Still Works, Deprecated) Usage + In your application you will need to configure the firestore db client and optionally the collection prefix, which by default is empty. @@ -41,21 +74,29 @@ else: configure(client, prefix="firedantic-test-") ``` -Once that is done, you can start defining your Pydantic models, e.g: +You may also still use: +```python +from firedantic.configurations import CONFIGURATIONS + +db = CONFIGURATIONS["db"] +prefix = CONFIGURATIONS["prefix"] +``` +### Defining Models + +Once that is done, you can start defining your Pydantic models. Models are Pydantic classes that extend Firedantic’s sync Model or async AsyncModel: + +#### Sync Model Example ```python from pydantic import BaseModel - from firedantic import Model class Owner(BaseModel): - """Dummy owner Pydantic model.""" first_name: str last_name: str class Company(Model): - """Dummy company Firedantic model.""" __collection__ = "companies" company_id: str owner: Owner @@ -65,7 +106,7 @@ owner = Owner(first_name="John", last_name="Doe") company = Company(company_id="1234567-8", owner=owner) company.save() -# Prints out the firestore ID of the Company model +# Access the company id print(company.id) # Reloads model data from the database @@ -97,44 +138,22 @@ Product.find(order_by=[('unit_value', Query.ASCENDING), ('stock', Query.DESCENDI The query operators are found at [https://firebase.google.com/docs/firestore/query-data/queries#query_operators](https://firebase.google.com/docs/firestore/query-data/queries#query_operators). -### Async usage +### Async Usage Firedantic can also be used in an async way, like this: -```python -import asyncio -from os import environ -from unittest.mock import Mock - -import google.auth.credentials -from google.cloud.firestore import AsyncClient - -from firedantic import AsyncModel, configure - -# Firestore emulator must be running if using locally. -if environ.get("FIRESTORE_EMULATOR_HOST"): - client = AsyncClient( - project="firedantic-test", - credentials=Mock(spec=google.auth.credentials.Credentials), - ) -else: - client = AsyncClient() - -configure(client, prefix="firedantic-test-") - +#### Async Model Example +```python class Person(AsyncModel): __collection__ = "persons" name: str - async def main(): alice = Person(name="Alice") await alice.save() - print(f"Saved Alice as {alice.id}") bob = Person(name="Bob") await bob.save() - print(f"Saved Bob as {bob.id}") found_alice = await Person.find_one({"name": "Alice"}) print(f"Found Alice: {found_alice.id}") @@ -145,10 +164,7 @@ async def main(): print(f"Found Bob: {found_bob.id}") await alice.delete() - print("Deleted Alice") await bob.delete() - print("Deleted Bob") - if __name__ == "__main__": asyncio.run(main()) @@ -164,10 +180,8 @@ reference to using the `model_for()` method. ```python from typing import Optional, Type - from firedantic import AsyncModel, AsyncSubCollection, AsyncSubModel, ModelNotFoundError - class UserStats(AsyncSubModel): id: Optional[str] = None purchases: int = 0 @@ -176,12 +190,10 @@ class UserStats(AsyncSubModel): # Can use any properties of the "parent" model __collection_tpl__ = "users/{id}/stats" - class User(AsyncModel): __collection__ = "users" name: str - async def get_user_purchases(user_id: str, period: str = "2021") -> int: user = await User.get_by_id(user_id) stats_model: Type[UserStats] = UserStats.model_for(user) @@ -195,10 +207,141 @@ async def get_user_purchases(user_id: str, period: str = "2021") -> int: ## Composite Indexes and TTL Policies -Firedantic has support for defining composite indexes and TTL policies as well as -creating them. +Firedantic supports defining and automatically creating Composite Indexes and TTL Policies for your Firestore models. These can be created using either: + +- The new `Configuration` class (recommended) +- The legacy `configure()` method (deprecated but still supported) + +### Defining Composite Indexes + +Composite indexes are defined on your model using the `__composite_indexes__` attribute. + +It must be a list of composite index definitions created using: -### Composite indexes +- `collection_index(...)` – for single-collection queries +- `collection_group_index(...)` – for collection group queries + +Each index definition takes an arbitrary number of (field_name, order) tuples. Order must be `Query.ASCENDING` or `Query.DESCENDING`. + +#### Example: +```python +__composite_indexes__ = [ + collection_index(("content", Query.ASCENDING), ("expire", Query.DESCENDING)), + collection_group_index(("content", Query.DESCENDING), ("expire", Query.ASCENDING)), +] +``` + +### Defining TTL Policies + +TTL (Time-To-Live) policies are defined using the `__ttl_field__` attribute. + +Rules: +- The field must be a datetime object +- The field name must be assigned to `__ttl_field__` +- TTL policies cannot be created in the Firestore Emulator + +#### Example: +```python +__ttl_field__ = "expire" +``` + +#### Recommended Usage (New Configuration API) + +All index and TTL setup functions now automatically resolve: +- The correct project +- The correct database +- The correct admin client +…based on each model’s `__db_config__` + +This means you no longer need to pass projects, databases, or admin clients manually-- And further, you can maintain multiple of each within your app. + +#### Sync Example (Recommended) +```python +from datetime import datetime +from google.cloud.firestore import Query + +from firedantic import ( + Model, + collection_index, + collection_group_index, + get_all_subclasses, + set_up_composite_indexes_and_ttl_policies, +) +from firedantic.configurations import configuration + +class ExpiringModel(Model): + __collection__ = "expiringModel" + __ttl_field__ = "expire" + __composite_indexes__ = [ + collection_index(("content", Query.ASCENDING), ("expire", Query.DESCENDING)), + collection_group_index(("content", Query.DESCENDING), ("expire", Query.ASCENDING)), + ] + + expire: datetime + content: str + +def main(): + configuration.add( + name="(default)", + project="my-project", + prefix="firedantic-test-", + # credentials=... optional + ) + + set_up_composite_indexes_and_ttl_policies( + models=get_all_subclasses(Model), + ) + +if __name__ == "__main__": + main() +``` + +#### Async Example (Recommended) + +```python +import asyncio +from datetime import datetime +from google.cloud.firestore import Query + +from firedantic import ( + AsyncModel, + async_set_up_composite_indexes_and_ttl_policies, + collection_index, + collection_group_index, + get_all_subclasses, +) +from firedantic.configurations import configuration + +class ExpiringModel(AsyncModel): + __collection__ = "expiringModel" + __ttl_field__ = "expire" + __composite_indexes__ = [ + collection_index(("content", Query.ASCENDING), ("expire", Query.DESCENDING)), + collection_group_index(("content", Query.DESCENDING), ("expire", Query.ASCENDING)), + ] + + expire: datetime + content: str + +async def main(): + configuration.add( + project="my-project", + prefix="firedantic-test-", + ) + + await async_set_up_composite_indexes_and_ttl_policies( + models=get_all_subclasses(AsyncModel), + ) + +if __name__ == "__main__": + asyncio.run(main()) +``` + +#### Legacy Usage (Still Supported, Deprecated) + +The old method using `configure()` and manually passing Firestore Admin clients is still supported but deprecated. + +##### Composite indexes Composite indexes of a collection are defined in `__composite_indexes__`, which is a list of all indexes to be created. @@ -213,7 +356,7 @@ to create indexes. For more details, see the example further down. -### TTL Policies +##### TTL Policies The field used for the TTL policy should be a datetime field and the name of the field should be defined in `__ttl_field__`. The `set_up_ttl_policies` and @@ -221,31 +364,32 @@ should be defined in `__ttl_field__`. The `set_up_ttl_policies` and Note: The TTL policies can not be set up in the Firestore emulator. -### Examples +##### Examples Below are examples (both sync and async) to show how to use Firedantic to set up -composite indexes and TTL policies. +composite indexes and TTL policies with the legacy format. The examples use `async_set_up_composite_indexes_and_ttl_policies` and `set_up_composite_indexes_and_ttl_policies` functions to set up both composite indexes and TTL policies. However, you can use separate functions to set up only either one of them. -#### Composite Index and TTL Policy Example (sync) +### Legacy Sync Example ```python from datetime import datetime +from google.cloud.firestore import Client, Query +from google.cloud.firestore_admin_v1 import FirestoreAdminClient + from firedantic import ( + Model, collection_index, collection_group_index, configure, get_all_subclasses, - Model, set_up_composite_indexes_and_ttl_policies, ) -from google.cloud.firestore import Client, Query -from google.cloud.firestore_admin_v1 import FirestoreAdminClient class ExpiringModel(Model): @@ -267,18 +411,17 @@ def main(): models=get_all_subclasses(Model), client=FirestoreAdminClient(), ) - # or use set_up_composite_indexes / set_up_ttl_policies functions separately - if __name__ == "__main__": main() ``` -#### Composite Index and TTL Policy Example (async) - +### Legacy Async Example ```python import asyncio from datetime import datetime +from google.cloud.firestore import AsyncClient, Query +from google.cloud.firestore_admin_v1.services.firestore_admin import FirestoreAdminAsyncClient from firedantic import ( AsyncModel, @@ -288,10 +431,6 @@ from firedantic import ( configure, get_all_subclasses, ) -from google.cloud.firestore import AsyncClient, Query -from google.cloud.firestore_admin_v1.services.firestore_admin import ( - FirestoreAdminAsyncClient, -) class ExpiringModel(AsyncModel): @@ -313,8 +452,6 @@ async def main(): models=get_all_subclasses(AsyncModel), client=FirestoreAdminAsyncClient(), ) - # or await async_set_up_composite_indexes / async_set_up_ttl_policies separately - if __name__ == "__main__": asyncio.run(main()) @@ -324,7 +461,7 @@ if __name__ == "__main__": Firedantic has basic support for [Firestore Transactions](https://firebase.google.com/docs/firestore/manage-data/transactions). -The following methods can be used in a transaction: +The following methods can be used in a transaction for both **sync** and **async** models: - `Model.delete(transaction=transaction)` - `Model.find_one(transaction=transaction)` @@ -337,8 +474,76 @@ The following methods can be used in a transaction: When using transactions, note that read operations must come before write operations. +## Recommended Usage (New Configuration Class) + +With the new `Configuration` system, transactions automatically use the configured: +- Project +- Database +- Client +based on the active configuration. + + ### Transaction examples +#### Sync Transaction Example (Recommended) +```python +from firedantic import Model, get_transaction +from firedantic.configurations import configuration +from google.cloud.firestore_v1 import transactional, Transaction + + +# Configure once +configuration.add( + project="firedantic-test", + prefix="firedantic-test-", +) + + +class City(Model): + __collection__ = "cities" + population: int + + def increment_population(self, increment: int = 1): + @transactional + def _increment_population(transaction: Transaction) -> None: + self.reload(transaction=transaction) + self.population += increment + self.save(transaction=transaction) + + t = get_transaction() + _increment_population(transaction=t) + + +def main(): + @transactional + def decrement_population( + transaction: Transaction, city: City, decrement: int = 1 + ): + city.reload(transaction=transaction) + city.population = max(0, city.population - decrement) + city.save(transaction=transaction) + + c = City(id="SF", population=1) + c.save() + + c.increment_population(increment=1) + assert c.population == 2 + + t = get_transaction() + decrement_population(transaction=t, city=c, decrement=5) + assert c.population == 0 + + +if __name__ == "__main__": + main() +``` + +### Legacy Usage (Still Supported, Deprecated) + +The original configure() method and manual client wiring still works, but is now deprecated in favor of `Configuration`. + +#### Legacy Async Transaction Example + In this example (async and sync version of it below), we are updating a `City` to increment and decrement the population of it, both using an instance method and a standalone function. Please note that the `@async_transactional` and `@transactional` @@ -346,8 +551,6 @@ decorators always expect the first argument of the wrapped function to be `trans i.e. you can not directly wrap an instance method that has `self` as the first argument or a class method that has `cls` as the first argument. -#### Transaction example (async) - ```python import asyncio from os import environ @@ -409,7 +612,7 @@ if __name__ == "__main__": asyncio.run(main()) ``` -#### Transaction example (sync) +#### Legacy Sync Transaction Example ```python from os import environ diff --git a/firedantic/configurations.py b/firedantic/configurations.py index edeaa26..330aa06 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -1,5 +1,6 @@ from os import environ from typing import Any, Dict, Optional, Type, Union +import warnings from google.auth.credentials import Credentials from google.api_core.client_options import ClientOptions @@ -27,6 +28,15 @@ def configure(client: Union[Client, AsyncClient], prefix: str = "") -> None: :param db: The firestore client instance. :param prefix: The prefix to use for collection names. """ + + # soft deprecation notice for users (no stacktrace) + warnings.warn( + "firedantic.configure(client, ...) is deprecated and will be removed in a " + "future release. Use firedantic.configurations.configuration.add(...) instead.", + DeprecationWarning, + stacklevel=2, + ) + if isinstance(client, AsyncClient): configuration.add(name="(default)", prefix=prefix, async_client=client) else: @@ -38,12 +48,12 @@ def configure(client: Union[Client, AsyncClient], prefix: str = "") -> None: def get_transaction() -> Transaction: - """Backward-compatible transaction getter (sync).""" + """Backward-compatible sync transaction getter, forwards to new implementation.""" return configuration.get_transaction() def get_async_transaction() -> AsyncTransaction: - """Backward-compatible async transaction getter.""" + """Backward-compatible async transaction getter, forwards to new implementation.""" return configuration.get_async_transaction() From d5f021db070881c40a340ce22327f6e2c0097421 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Fri, 5 Dec 2025 16:21:38 -0300 Subject: [PATCH 26/38] testing readme examples, saving code for reference --- README.md | 8 +- firedantic/_async/model.py | 4 +- integration_tests/test_integration_all.py | 554 ++++++++++++++++++++++ 3 files changed, 559 insertions(+), 7 deletions(-) create mode 100644 integration_tests/test_integration_all.py diff --git a/README.md b/README.md index c6cb364..ed53dfa 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ pip install firedantic ## Quick overview -Firedantic provides simple Pydantic-based models for Firestore, with both sync and async model classes, helpers for composite indexes and TTL policies, and an improved configuration system that supports multiple named Firestore connections. +Firedantic provides simple Pydantic-based models for Firestore, with both sync and async model classes, helpers for composite indexes and TTL policies, and a new configuration system that supports multiple named Firestore connections. ## Usage @@ -375,7 +375,7 @@ and TTL policies. However, you can use separate functions to set up only either them. -### Legacy Sync Example +###### Legacy Sync Example ```python from datetime import datetime from google.cloud.firestore import Client, Query @@ -416,7 +416,7 @@ if __name__ == "__main__": main() ``` -### Legacy Async Example +###### Legacy Async Example ```python import asyncio from datetime import datetime @@ -474,7 +474,7 @@ The following methods can be used in a transaction for both **sync** and **async When using transactions, note that read operations must come before write operations. -## Recommended Usage (New Configuration Class) +### Recommended Usage (New Configuration Class) With the new `Configuration` system, transactions automatically use the configured: - Project diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index 829dc1a..c095ba5 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -174,11 +174,9 @@ async def delete(self, transaction: Optional[AsyncTransaction] = None) -> None: doc_ref = self._get_doc_ref() if transaction is not None: - print(f"\nDeleting document: {doc_ref.path}") transaction.delete(doc_ref) - print(f"Deleted document in transaction: {doc_ref.path}\n") else: - print(f"\nTransaction not provided, deleting document: {self._get_doc_ref().path}") + # print(f"\nTransaction not provided, deleting document: {self._get_doc_ref().path}") await doc_ref.delete() async def reload(self, transaction: Optional[AsyncTransaction] = None) -> None: diff --git a/integration_tests/test_integration_all.py b/integration_tests/test_integration_all.py new file mode 100644 index 0000000..c132818 --- /dev/null +++ b/integration_tests/test_integration_all.py @@ -0,0 +1,554 @@ +#!/usr/bin/env python3 +""" +integration_test_readme_full.py + +Integration-style test script covering README examples: +- new Configuration API and legacy configure() +- sync & async model operations: save / reload / find / delete +- transactions (sync & async) +- subcollections +- multiple named configurations / per-model __db_config__ +- composite indexes + TTL policies (attempted, skipped if admin client not available) + +Run with a running Firestore emulator: +export FIRESTORE_EMULATOR_HOST="127.0.0.1:8686" +python test_integration_all.py +""" +from __future__ import annotations + +import asyncio +import os +import sys +import time +from datetime import datetime, timedelta +from typing import Optional +from unittest.mock import Mock + +import google.auth.credentials + +# require firedantic package (adjust path or install locally) +try: + # recommended imports + from firedantic.common import IndexField + from firedantic.configurations import configuration, CONFIGURATIONS, configure + from firedantic import ( + Model, + AsyncModel, + collection_index, + collection_group_index, + ## IndexField, + set_up_composite_indexes_and_ttl_policies, + async_set_up_composite_indexes_and_ttl_policies, + get_all_subclasses, + get_transaction, + get_async_transaction, + ) + from google.cloud.firestore import Client, Query + # from google.cloud.firestore import AsyncClient as AsyncClientSyncName # alias to avoid name clash + from google.cloud.firestore_v1 import AsyncClient, transactional, Transaction, async_transactional + from google.cloud.firestore_admin_v1.services.firestore_admin import FirestoreAdminClient + from google.cloud.firestore_admin_v1.services.firestore_admin import FirestoreAdminAsyncClient + +except Exception as e: + print("Failed to import required modules. Ensure `firedantic` is importable and google-cloud-* libs are installed.") + raise + + +EMULATOR = os.environ.get("FIRESTORE_EMULATOR_HOST") +USE_EMULATOR = bool(EMULATOR) + +def must(msg: str): + print("ERROR:", msg) + sys.exit(2) + +def info(msg: str): + print(msg) + +# ----------------------- +# Helpers +# ----------------------- +def assert_or_exit(cond: bool, message: str): + if not cond: + print("ASSERTION FAILED:", message) + raise AssertionError(message) + +def cleanup_sync_collection_by_model(model_cls: type[Model]) -> None: + """Delete everything from a model's collection (sync).""" + info(f"Cleaning collection for sync model {model_cls.__name__}") + # find all docs and delete + for obj in model_cls.find({}): + obj.delete() + +async def cleanup_async_collection_by_model(model_cls: type[AsyncModel]) -> None: + info(f"Cleaning collection for async model {model_cls.__name__}") + items = await model_cls.find({}) + for it in items: + await it.delete() + +# admin helpers +def get_admin_client_if_possible(): + """Return admin client or None (if emulator/admin not available).""" + try: + if USE_EMULATOR: + # Emulator usually accepts the same transport; credentials can be a mock + creds = Mock(spec=google.auth.credentials.Credentials) + client = FirestoreAdminClient(credentials=creds) + else: + client = FirestoreAdminClient() + return client + except Exception as e: + info(f"FirestoreAdminClient unavailable/skipping admin steps: {e}") + return None + +async def get_async_admin_client_if_possible(): + try: + if USE_EMULATOR: + creds = Mock(spec=google.auth.credentials.Credentials) + client = FirestoreAdminAsyncClient(credentials=creds) + else: + client = FirestoreAdminAsyncClient() + return client + except Exception as e: + info(f"FirestoreAdminAsyncClient unavailable/skipping admin steps: {e}") + return None + +# ----------------------- +# Test cases +# ----------------------- +def test_new_config_sync_flow(): + info("\n=== new Configuration API: sync flow ===") + # register a default configuration (use mock creds for emulator) + mock_creds = Mock(spec=google.auth.credentials.Credentials) + + # default config + configuration.add(prefix="app-", project="my-project") + + # extra config + configuration.add(name="billing", prefix="billing-", project="billing-project") + + # get clients + client = configuration.get_client() # sync client for "(default)" + # async_client = configuration.get_async_client() # async client for "(default)" + billing_client = configuration.get_client("billing") # sync client for "billing" + + class Billing(Model): + __db_config__ = "billing" + __collection__ = "billing_accounts" + billing_id: str + name: str + + # cleanup + cleanup_sync_collection_by_model(Billing) + + # create and save + b = Billing(billing_id="9901981", name="Firedantic Billing") + b.save() + print(f"billing.id after save: {b.id}") + assert_or_exit(b.id is not None, "Billing save did not set id") + + # reload + b.reload() + assert_or_exit(b.billing_id == "9901981", "billing_id mismatch after reload") + + # find + results = Billing.find({"name": "Firedantic Billing"}) + assert_or_exit(isinstance(results, list), "find returned not a list") + assert_or_exit(any(r.billing_id == "9901981" for r in results), "find did not return saved billing account") + + # delete + b.delete() + time.sleep(0.1) # allow emulator a bit of time to delete + remaining = Billing.find({"billing_id": "9901981"}) + assert_or_exit(all(r.company_id != "9901981" for r in remaining), "delete failed") + + print("new config sync flow OK") + + +def test_legacy_config_sync_flow(): + info("\n=== legacy configure() sync flow ===") + # legacy configure() should still populate CONFIGURATIONS and work + mock_creds = Mock(spec=google.auth.credentials.Credentials) + # create a sync Client for legacy configure + if USE_EMULATOR: + client = Client(project="legacy-project", credentials=mock_creds) + else: + client = Client() + + configure(client, prefix="legacy-sync-") + assert_or_exit(CONFIGURATIONS["prefix"].startswith("legacy-sync-"), "legacy configure prefix mismatch") + + # use same model pattern as earlier + class LegacyOwner(Model): + first_name: str + last_name: str + + class LegacyCompany(Model): + __collection__ = "companies" + company_id: str + owner: LegacyOwner + + # cleanup and run + cleanup_sync_collection_by_model(LegacyCompany) + o = LegacyOwner(first_name="L", last_name="User") + comp = LegacyCompany(company_id="L-1", owner=o) + comp.save() + comp.reload() + assert_or_exit(comp.company_id == "L-1", "legacy reload failed") + comp.delete() + info("legacy sync flow OK") + + +async def test_new_config_async_flow(): + info("\n=== new Configuration API: async flow ===") + mock_creds = Mock(spec=google.auth.credentials.Credentials) + configuration.add(prefix="integ-async-", project="test-project", credentials=mock_creds) + + class Owner(AsyncModel): + first_name: str + last_name: str + + class Company(AsyncModel): + __collection__ = "companies" + company_id: str + owner: Owner + + await cleanup_async_collection_by_model(Company) + + owner = Owner(first_name="Alice", last_name="Smith") + c = Company(company_id="A-1", owner=owner) + await c.save() + print(f"company.id after save: {c.id}") + assert_or_exit(c.id is not None, "async save didn't set id") + + await c.reload() + assert_or_exit(c.company_id == "A-1", "async reload company_id mismatch") + assert_or_exit(c.owner.first_name == "Alice", "async owner first name mismatch") + + found = await Company.find({"company_id": "A-1"}) + assert_or_exit(len(found) >= 1 and found[0].company_id == "A-1", "async find failed") + + await c.delete() + remains = await Company.find({"company_id": "A-1"}) + assert_or_exit(all(x.company_id != "A-1" for x in remains), "async delete failed") + + print("new config async flow OK") + + + +async def test_legacy_config_async_flow(): + info("\n=== legacy configure() async flow ===") + mock_creds = Mock(spec=google.auth.credentials.Credentials) + # create async client and call legacy configure + if USE_EMULATOR: + aclient = AsyncClient(project="legacy-project", credentials=mock_creds) + else: + aclient = AsyncClient() + # legacy configure supports AsyncClient and sets CONFIGURATIONS + configure(aclient, prefix="legacy-async-") + assert_or_exit(CONFIGURATIONS["prefix"].startswith("legacy-async-"), "legacy async configure failed") + + class LOwner(AsyncModel): + first_name: str + last_name: str + + class LCompany(AsyncModel): + __collection__ = "companies" + company_id: str + owner: LOwner + + await cleanup_async_collection_by_model(LCompany) + o = LOwner(first_name="LA", last_name="User") + c = LCompany(company_id="LA-1", owner=o) + await c.save() + await c.reload() + assert_or_exit(c.company_id == "LA-1", "legacy async reload failed") + await c.delete() + info("legacy async flow OK") + +def test_subcollections_and_model_for(): + info("\n=== subcollections (model_for) ===") + # ensure configuration exists + mock_creds = Mock(spec=google.auth.credentials.Credentials) + configuration.add(prefix="sc-", project="test-project", credentials=mock_creds) + + class Parent(Model): + __collection__ = "parents" + name: str + + class Stats(Model): + __collection__ = None # will be supplied via Collection + __document_id__ = "id" + purchases: int = 0 + + class Collection: + __collection_tpl__ = "parents/{id}/stats" + + # create parent and get submodel + cleanup_sync_collection_by_model(Parent) + p = Parent(name="P1") + p.save() + # build subcollection model class for parent + StatsCollection = type("StatsForParent", (Stats,), {}) + # attach collection tpl like library would; simpler: use provided pattern via utility in README (model_for pattern) + # Here we emulate model_for behavior: + StatsCollection.__collection_cls__ = Stats.Collection + StatsCollection.__collection__ = Stats.Collection.__collection_tpl__.format(id=p.id) + StatsCollection.__document_id__ = "id" + + s = StatsCollection(purchases=3) + s.save() + # find back + found = StatsCollection.find({"purchases": 3}) + assert_or_exit(any(x.purchases == 3 for x in found), "subcollection save/find failed") + # cleanup + for x in StatsCollection.find({}): + x.delete() + for x in Parent.find({}): + x.delete() + info("subcollections OK") + +def test_transactions_sync(): + info("\n=== sync transactions ===") + # mock_creds = Mock(spec=google.auth.credentials.Credentials) + + # Configure once + configuration.add( + project="firedantic-test", + prefix="firedantic-test-", + ) + + class City(Model): + __collection__ = "cities" + population: int = 0 + + def increment_population(self, increment: int = 1): + @transactional + def _increment_population(transaction: Transaction) -> None: + self.reload(transaction=transaction) + self.population += increment + self.save(transaction=transaction) + + t = get_transaction() + _increment_population(transaction=t) + + # cleanup + cleanup_sync_collection_by_model(City) + + @transactional + def decrement_population( + transaction: Transaction, city: City, decrement: int = 1 + ): + city.reload(transaction=transaction) + city.population = max(0, city.population - decrement) + city.save(transaction=transaction) + + # create object + c = City(id="SF", population=1) + c.save() + + c.increment_population(increment=2) + assert c.population == 3 + + t = get_transaction() + decrement_population(transaction=t, city=c, decrement=1) + assert c.population == 2 + + # reload + c.reload() + assert_or_exit(c.population == 2, "sync transaction increment failed") + + # delete + c.delete() + print("sync transactions OK") + +async def test_transactions_async(): + info("\n=== async transactions ===") + + # Configure once + configuration.add( + project="async-tx-test", + prefix="async-tx-test-", + ) + + class CityA(AsyncModel): + __collection__ = "cities" + population: int = 0 + + await cleanup_async_collection_by_model(CityA) + + @async_transactional + async def increment_async(tx, city_id: str, inc: int = 1): + c = await CityA.get_by_id(city_id, transaction=tx) + c.population += inc + await c.save(transaction=tx) + + # create ans save object + c = CityA(id="AC1", population=2) + await c.save() + + # increment with transaction + t = get_async_transaction() + await increment_async(transaction=t, city_id=c.id, inc=3) + + # reload + await c.reload() + assert_or_exit(c.population == 5, "async transaction failed") + + # delete + await c.delete() + print("async transactions OK") + +def test_multi_config_usage(): + info("\n=== multi-config usage ===") + + mock_creds = Mock(spec=google.auth.credentials.Credentials) + + # default config + configuration.add(prefix="multi-default-", project="proj-default", credentials=mock_creds) + + # billing config + configuration.add(name="billing", prefix="multi-billing-", project="proj-billing", credentials=mock_creds) + + class CompanyDefault(Model): + __collection__ = "companies" + company_id: str + + class BillingAccount(Model): + __db_config__ = "billing" + __collection__ = "billing_accounts" + billing_id: str + name: str + + # cleanup + cleanup_sync_collection_by_model(CompanyDefault) + cleanup_sync_collection_by_model(BillingAccount) + + # save to default + cd = CompanyDefault(company_id="MD-1") + cd.save() + cd.reload() + assert_or_exit(cd.company_id == "MD-1", "default config save/reload failed") + + # save to billing (model defines __db_config__) + ba = BillingAccount(billing_id="B-1", name="Bill Co") + ba.save() # should use billing config + ba.reload() + assert_or_exit(ba.billing_id == "B-1", "billing model save/reload failed") + + # find calls (default vs billing) + found_default = CompanyDefault.find({"company_id": "MD-1"}) + found_billing = BillingAccount.find({"billing_id": "B-1"}) + assert_or_exit(any(x.company_id == "MD-1" for x in found_default), "find default failed") + assert_or_exit(any(x.billing_id == "B-1" for x in found_billing), "find billing failed") + + # cleanup + cleanup_sync_collection_by_model(CompanyDefault) + cleanup_sync_collection_by_model(BillingAccount) + info("multi-config usage OK") + +def test_indexes_and_ttl_sync(): + info("\n=== composite indexes and TTL (sync) ===") + + admin_client = get_admin_client_if_possible() + if not admin_client: + info("Skipping sync index/TTL setup (admin client not available).") + return + + class ExpiringModel(Model): ## uses "(default)" config name + __collection__ = "expiringModel" + __ttl_field__ = "expire" + __composite_indexes__ = [ + collection_index(IndexField("content", Query.ASCENDING), IndexField("expire", Query.DESCENDING)), + collection_group_index(IndexField("content", Query.DESCENDING), IndexField("expire", Query.ASCENDING)), + ] + content: str + expire: datetime + + # register config with admin client (best-effort) + mock_creds = Mock(spec=google.auth.credentials.Credentials) + configuration.add(prefix="idx-sync-", project="proj-idx", credentials=mock_creds, client=None, async_client=None, admin_client=admin_client) + + try: + # call setup function (may require admin permission; emulator may accept) + project = configuration.get_config().project + if not USE_EMULATOR: + # real GCP -> try to set up indexes + set_up_composite_indexes_and_ttl_policies(gcloud_project=project, models=get_all_subclasses(Model)) + info("sync index/ttl setup called (no exception)") + else: + # emulator -> skip (or log) + print("Skipping index/TTL setup when using emulator.") + + except Exception as e: + info(f"index/ttl setup raised (will continue): {e}") + +async def test_indexes_and_ttl_async(): + info("\n=== composite indexes and TTL (async) ===") + admin_client = await get_async_admin_client_if_possible() + if not admin_client: + info("Skipping async index/TTL setup (async admin client not available).") + return + + class ExpiringAsync(AsyncModel): + __collection__ = "expiringModel" + __ttl_field__ = "expire" + __composite_indexes__ = [ + collection_index(IndexField("content", Query.ASCENDING), IndexField("expire", Query.DESCENDING)), + ] + content: str + expire: datetime + + mock_creds = Mock(spec=google.auth.credentials.Credentials) + configuration.add(prefix="idx-async-", project="proj-idx", credentials=mock_creds) + + try: + project = configuration.get_config().project + if not USE_EMULATOR: + # real GCP -> try to set up indexes + await async_set_up_composite_indexes_and_ttl_policies(gcloud_project=project, models=get_all_subclasses(AsyncModel)) + info("async index/ttl setup called (no exception)") + else: + # emulator -> skip (or log) + print("Skipping index/TTL setup when using emulator.") + + except Exception as e: + info(f"async index/ttl setup raised (will continue): {e}") + +# ----------------------- +# Runner +# ----------------------- +def main(): + info("Beginning full integration test following README examples") + if not USE_EMULATOR: + info("WARNING: FIRESTORE_EMULATOR_HOST not set — this script is designed for emulator runs; continue with caution.") + + # run sync new config flow + test_new_config_sync_flow() + + # legacy sync flow + test_legacy_config_sync_flow() + + # run async new config flow + asyncio.run(test_new_config_async_flow()) ## Currently deletion fails! + + # legacy async flow + asyncio.run(test_legacy_config_async_flow()) + + # subcollections + test_subcollections_and_model_for() + + # transactions + test_transactions_sync() + asyncio.run(test_transactions_async()) + + # multi-config usage + test_multi_config_usage() + + # indexes & ttl (sync & async) (best-effort) + test_indexes_and_ttl_sync() + asyncio.run(test_indexes_and_ttl_async()) + + info("\nIntegration README full test finished successfully.") + +if __name__ == "__main__": + main() From d3e8e1a3f5de1886de13c39b70fc00c7055d30a2 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Fri, 5 Dec 2025 16:27:57 -0300 Subject: [PATCH 27/38] fixing configure clients integ test --- integration_tests/README.md | 4 +++- ..._db_client.py => configure_firestore_db_clients.py} | 10 +++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) rename integration_tests/{configure_firestore_db_client.py => configure_firestore_db_clients.py} (93%) diff --git a/integration_tests/README.md b/integration_tests/README.md index edca292..fb164f7 100644 --- a/integration_tests/README.md +++ b/integration_tests/README.md @@ -1,7 +1,7 @@ # integration_tests/README.md ## Overview -This folder contains three integration test files. This README explains how to run them and what each file is for. +This folder contains integration test files. This README explains how to run them and what each file is for. ## Prerequisites - Python 3.13 @@ -16,6 +16,7 @@ cd firedantic/integration_tests - `integration_tests/configure_firestore_db_clients.py` — Purpose: shows how to create and connect to various db clients. - `integration_tests/full_sync_flow.py` — Purpose: configures clients, saves data to db, finds the data, and deletes all data in a sync fashion. - `integration_tests/full_async_flow.py` — Purpose: configures async clients, saves data to db, finds the data, and deletes all data in an async fashion. +- `integration_tests/test_integration_all.py` - Purpose: intended to run all examples shown in the readme, new and legacy. ## How to run @@ -23,6 +24,7 @@ Run each individual test file: - `python configure_firestore_db_clients.py` - `python full_sync_flow.py` - `python full_async_flow.py` + - `python test_integration_all.py` ## Environment and configuration diff --git a/integration_tests/configure_firestore_db_client.py b/integration_tests/configure_firestore_db_clients.py similarity index 93% rename from integration_tests/configure_firestore_db_client.py rename to integration_tests/configure_firestore_db_clients.py index 7830764..8819044 100644 --- a/integration_tests/configure_firestore_db_client.py +++ b/integration_tests/configure_firestore_db_clients.py @@ -91,9 +91,9 @@ def configure_multiple_clients(): assert default.prefix == "firedantic-test-" assert default.project == "firedantic-test" - # Client objects exist - assert default.client is not None - assert default.async_client is not None + # Client objects exist when called + assert config.get_client() is not None + assert config.get_async_client() is not None # Types are correct assert isinstance(default.client, Client) @@ -116,6 +116,10 @@ def configure_multiple_clients(): assert billing.prefix == "test-billing-" assert billing.project == "test-billing" + # Client objects exist when called + assert config.get_client("billing") is not None + assert config.get_async_client("billing") is not None + # Correct client types assert isinstance(billing.client, Client) assert isinstance(billing.async_client, AsyncClient) From 6effb6ad526c72658c5e19687add4e4abe2efe83 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Mon, 8 Dec 2025 13:20:15 -0300 Subject: [PATCH 28/38] returning some readme aspects to how it was --- README.md | 32 +------------------------------- start_emulator.sh | 2 -- 2 files changed, 1 insertion(+), 33 deletions(-) diff --git a/README.md b/README.md index ed53dfa..4e50322 100644 --- a/README.md +++ b/README.md @@ -678,7 +678,7 @@ if __name__ == "__main__": PRs are welcome! -### Prequisites: +### Prerequisites: - [pipx](https://pipx.pypa.io/stable/installation/) (to install poetry, if needed) - `pipx install poetry` @@ -711,41 +711,11 @@ start_emulator To run tests locally, you should first: -Set the host: - -```bash -export FIRESTORE_EMULATOR_HOST= -``` - -Setup a virtual environment(not required but will be easier to run poetry like so): - -```bash -python -m venv firedantic -source firedantic/bin/activate -cd firedantic/tests -``` - -Run pytests within `tests` directory: -```bash -pip install pytest -pip install pytest-asyncio -pytest . -``` - -Poetry is configured to run on the published version of firedantic and will overwrite changes. -Not recommended for local development. For public use, poetry is recommended: -Run poetry within `tests` directory: - ```bash poetry install poetry run invoke test ``` -To run a singular test: -```bash -poetry run pytest tests_sync/.py:: -``` - \*Note, the emulator must be set and running for all tests to pass. ### About sync and async versions of library diff --git a/start_emulator.sh b/start_emulator.sh index 0913753..c559ecc 100755 --- a/start_emulator.sh +++ b/start_emulator.sh @@ -1,4 +1,2 @@ #!/usr/bin/env sh -export GOOGLE_CLOUD_PROJECT=test-project - exec firebase -P firedantic-test emulators:start --only firestore From dc5b588caf1174a8bafd2ff48df927104fa10d28 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Mon, 8 Dec 2025 13:23:09 -0300 Subject: [PATCH 29/38] fixing prereqs --- README.md | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 4e50322..6e5093e 100644 --- a/README.md +++ b/README.md @@ -680,12 +680,11 @@ PRs are welcome! ### Prerequisites: -- [pipx](https://pipx.pypa.io/stable/installation/) (to install poetry, if needed) -- `pipx install poetry` -- `pip install pre-commit` -- [nvm](https://github.com/nvm-sh/nvm) -- `npm` (for firebase CLI installation) -- [java](http://www.java.com) via `brew install openjdk@21` +- [Node 22](https://nodejs.org/en/download) +- [Pre-commit](https://pre-commit.com/#install) +- [Python 3.10+](https://wiki.python.org/moin/BeginnersGuide/Download) +- [Poetry](https://python-poetry.org/docs/#installation) +- [Google Firebase Emulator Suite](https://firebase.google.com/docs/emulator-suite/install_and_configure#install_the_local_emulator_suite) ### Running Firestore emulator @@ -707,7 +706,7 @@ Run the Firestore emulator with a predictable port: start_emulator ``` -### Tests +### Running Tests To run tests locally, you should first: From bdab1e095101c35e192d4a924b108fecf60c86d6 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Mon, 8 Dec 2025 14:48:13 -0300 Subject: [PATCH 30/38] updated unasync.py and instructions --- README.md | 62 +++++-- firedantic/__init__.py | 2 +- firedantic/_async/indexes.py | 16 +- firedantic/_async/model.py | 36 ++-- firedantic/_sync/indexes.py | 22 ++- firedantic/_sync/model.py | 111 ++++++------ firedantic/configurations.py | 70 +++++--- firedantic/tests/tests_async/conftest.py | 3 +- firedantic/tests/tests_async/test_indexes.py | 37 +++- firedantic/tests/tests_async/test_model.py | 1 + firedantic/tests/tests_sync/conftest.py | 2 +- firedantic/tests/tests_sync/test_indexes.py | 56 +++--- firedantic/tests/tests_sync/test_model.py | 24 +-- integration_tests/README.md | 42 +++-- .../configure_firestore_db_clients.py | 13 +- integration_tests/full_async_flow.py | 93 +++++++--- integration_tests/full_sync_flow.py | 86 ++++++--- integration_tests/test_integration_all.py | 163 +++++++++++++----- unasync.py | 3 + 19 files changed, 537 insertions(+), 305 deletions(-) diff --git a/README.md b/README.md index 6e5093e..acf3d1b 100644 --- a/README.md +++ b/README.md @@ -18,16 +18,21 @@ pip install firedantic ## Quick overview -Firedantic provides simple Pydantic-based models for Firestore, with both sync and async model classes, helpers for composite indexes and TTL policies, and a new configuration system that supports multiple named Firestore connections. +Firedantic provides simple Pydantic-based models for Firestore, with both sync and async +model classes, helpers for composite indexes and TTL policies, and a new configuration +system that supports multiple named Firestore connections. ## Usage ### Migration Guide: Legacy `configure()` → New `Configuration` -We introduced a new, more flexible `Configuration` registry to support multiple Firestore clients, lazy client creation, and admin clients. The legacy `configure()` helper is still supported for backwards compatibility, but it is now considered **deprecated**. Scroll below for legacy instructions. - +We introduced a new, more flexible `Configuration` registry to support multiple +Firestore clients, lazy client creation, and admin clients. The legacy `configure()` +helper is still supported for backwards compatibility, but it is now considered +**deprecated**. Scroll below for legacy instructions. ### New (Recommended) Usage + ```python from firedantic.configurations import configuration @@ -45,9 +50,12 @@ billing_client = configuration.get_client("billing") # sync client for "billing" Notes: -- `configuration.add(...)` accepts either client/async_client (pre-built) or project+credentials and will lazily create clients. -- Models can declare `__db_config__ = "custom-name"` to use a named configuration or omit it to use the `"(default)"` config. -- Backwards-compatible helpers (`configure()`, `CONFIGURATIONS`) will still populate the old surface but are deprecated. +- `configuration.add(...)` accepts either client/async_client (pre-built) or + project+credentials and will lazily create clients. +- Models can declare `__db_config__ = "custom-name"` to use a named configuration or + omit it to use the `"(default)"` config. +- Backwards-compatible helpers (`configure()`, `CONFIGURATIONS`) will still populate the + old surface but are deprecated. ### Old (Legacy – Still Works, Deprecated) Usage @@ -75,6 +83,7 @@ configure(client, prefix="firedantic-test-") ``` You may also still use: + ```python from firedantic.configurations import CONFIGURATIONS @@ -84,9 +93,11 @@ prefix = CONFIGURATIONS["prefix"] ### Defining Models -Once that is done, you can start defining your Pydantic models. Models are Pydantic classes that extend Firedantic’s sync Model or async AsyncModel: +Once that is done, you can start defining your Pydantic models. Models are Pydantic +classes that extend Firedantic’s sync Model or async AsyncModel: #### Sync Model Example + ```python from pydantic import BaseModel from firedantic import Model @@ -207,7 +218,8 @@ async def get_user_purchases(user_id: str, period: str = "2021") -> int: ## Composite Indexes and TTL Policies -Firedantic supports defining and automatically creating Composite Indexes and TTL Policies for your Firestore models. These can be created using either: +Firedantic supports defining and automatically creating Composite Indexes and TTL +Policies for your Firestore models. These can be created using either: - The new `Configuration` class (recommended) - The legacy `configure()` method (deprecated but still supported) @@ -221,9 +233,11 @@ It must be a list of composite index definitions created using: - `collection_index(...)` – for single-collection queries - `collection_group_index(...)` – for collection group queries -Each index definition takes an arbitrary number of (field_name, order) tuples. Order must be `Query.ASCENDING` or `Query.DESCENDING`. +Each index definition takes an arbitrary number of (field_name, order) tuples. Order +must be `Query.ASCENDING` or `Query.DESCENDING`. #### Example: + ```python __composite_indexes__ = [ collection_index(("content", Query.ASCENDING), ("expire", Query.DESCENDING)), @@ -236,11 +250,13 @@ __composite_indexes__ = [ TTL (Time-To-Live) policies are defined using the `__ttl_field__` attribute. Rules: + - The field must be a datetime object - The field name must be assigned to `__ttl_field__` - TTL policies cannot be created in the Firestore Emulator #### Example: + ```python __ttl_field__ = "expire" ``` @@ -248,14 +264,16 @@ __ttl_field__ = "expire" #### Recommended Usage (New Configuration API) All index and TTL setup functions now automatically resolve: + - The correct project - The correct database -- The correct admin client -…based on each model’s `__db_config__` +- The correct admin client …based on each model’s `__db_config__` -This means you no longer need to pass projects, databases, or admin clients manually-- And further, you can maintain multiple of each within your app. +This means you no longer need to pass projects, databases, or admin clients manually-- +And further, you can maintain multiple of each within your app. #### Sync Example (Recommended) + ```python from datetime import datetime from google.cloud.firestore import Query @@ -339,7 +357,8 @@ if __name__ == "__main__": #### Legacy Usage (Still Supported, Deprecated) -The old method using `configure()` and manually passing Firestore Admin clients is still supported but deprecated. +The old method using `configure()` and manually passing Firestore Admin clients is still +supported but deprecated. ##### Composite indexes @@ -374,8 +393,8 @@ The examples use `async_set_up_composite_indexes_and_ttl_policies` and and TTL policies. However, you can use separate functions to set up only either one of them. - ###### Legacy Sync Example + ```python from datetime import datetime from google.cloud.firestore import Client, Query @@ -417,6 +436,7 @@ if __name__ == "__main__": ``` ###### Legacy Async Example + ```python import asyncio from datetime import datetime @@ -461,7 +481,8 @@ if __name__ == "__main__": Firedantic has basic support for [Firestore Transactions](https://firebase.google.com/docs/firestore/manage-data/transactions). -The following methods can be used in a transaction for both **sync** and **async** models: +The following methods can be used in a transaction for both **sync** and **async** +models: - `Model.delete(transaction=transaction)` - `Model.find_one(transaction=transaction)` @@ -477,15 +498,15 @@ When using transactions, note that read operations must come before write operat ### Recommended Usage (New Configuration Class) With the new `Configuration` system, transactions automatically use the configured: + - Project - Database -- Client -based on the active configuration. - +- Client based on the active configuration. ### Transaction examples #### Sync Transaction Example (Recommended) + ```python from firedantic import Model, get_transaction from firedantic.configurations import configuration @@ -540,7 +561,8 @@ if __name__ == "__main__": ### Legacy Usage (Still Supported, Deprecated) -The original configure() method and manual client wiring still works, but is now deprecated in favor of `Configuration`. +The original configure() method and manual client wiring still works, but is now +deprecated in favor of `Configuration`. #### Legacy Async Transaction Example @@ -727,6 +749,8 @@ version is generated automatically by invoke task: poetry run invoke unasync ``` +If new functions, comments, variables etc. are added that will span across sync and async directories, be sure to first declare the replacements in `unasync.py` before running poetry. I.e. a replacement of 'async_client' with 'client' for the sync directory. + We decided to go this way in order to: - make sure both versions have the same API diff --git a/firedantic/__init__.py b/firedantic/__init__.py index cc67a70..f79f867 100644 --- a/firedantic/__init__.py +++ b/firedantic/__init__.py @@ -34,9 +34,9 @@ from firedantic.common import collection_group_index, collection_index from firedantic.configurations import ( CONFIGURATIONS, + ConfigItem, Configuration, configure, - ConfigItem, get_async_transaction, get_transaction, ) diff --git a/firedantic/_async/indexes.py b/firedantic/_async/indexes.py index 51e6b67..c930da4 100644 --- a/firedantic/_async/indexes.py +++ b/firedantic/_async/indexes.py @@ -111,7 +111,9 @@ async def set_up_composite_indexes( project = gcloud_project or configuration.get_config(config_name).project # Build collection group path using configuration helper (includes prefix) - collection_group = configuration.get_collection_name(model, config_name=config_name) + collection_group = configuration.get_collection_name( + model, config_name=config_name + ) path = f"projects/{project}/databases/{database}/collectionGroups/{collection_group}" indexes_in_db = await get_existing_indexes(client, path=path) @@ -120,10 +122,18 @@ async def set_up_composite_indexes( new_indexes = model_indexes.difference(indexes_in_db) for index in existing_indexes: - logger.debug("Composite index already exists in DB: %s, collection: %s", index, collection_group) + logger.debug( + "Composite index already exists in DB: %s, collection: %s", + index, + collection_group, + ) for index in new_indexes: - logger.info("Creating new composite index: %s, collection: %s", index, collection_group) + logger.info( + "Creating new composite index: %s, collection: %s", + index, + collection_group, + ) operation = await create_composite_index(client, index, path) operations.append(operation) diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index c095ba5..0cdd9b8 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -59,7 +59,7 @@ def get_collection_name(cls, collection_name: Optional[str] = None) -> str: cfg = configuration.get_config(config_name) prefix = cfg.prefix or "" return f"{prefix}{collection_name}" - + if getattr(cls, "__collection__", None): cfg = configuration.get_config(config_name) return f"{cfg.prefix or ''}{cls.__collection__}" @@ -67,8 +67,9 @@ def get_collection_name(cls, collection_name: Optional[str] = None) -> str: raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") - -def _get_col_ref(cls, collection_name: Optional[str] = None) -> AsyncCollectionReference: +def _get_col_ref( + cls, collection_name: Optional[str] = None +) -> AsyncCollectionReference: """ Return an AsyncCollectionReference for the model class using the configured async client. @@ -88,11 +89,12 @@ def _get_col_ref(cls, collection_name: Optional[str] = None) -> AsyncCollectionR # Ensure we got the right object back if not hasattr(col_ref, "document"): - raise RuntimeError(f"_get_col_ref returned unexpected object for {cls}: {type(col_ref)!r}") + raise RuntimeError( + f"_get_col_ref returned unexpected object for {cls}: {type(col_ref)!r}" + ) return col_ref - class AsyncBareModel(pydantic.BaseModel, ABC): """ Base model class. @@ -106,7 +108,6 @@ class AsyncBareModel(pydantic.BaseModel, ABC): __composite_indexes__: Optional[Iterable[IndexDefinition]] = None __db_config__: str = "(default)" # override in subclasses when needed - async def save( self, *, @@ -126,7 +127,7 @@ async def save( # Resolve config to use (explicit -> instance -> class -> default) if config_name is not None: - resolved = config_name + resolved = config_name else: resolved = getattr(self, "__db_config__", None) if not resolved: @@ -134,9 +135,7 @@ async def save( # Build payload data = self.model_dump( - by_alias=True, - exclude_unset=exclude_unset, - exclude_none=exclude_none + by_alias=True, exclude_unset=exclude_unset, exclude_none=exclude_none ) if self.__document_id__ in data: del data[self.__document_id__] @@ -144,7 +143,7 @@ async def save( async_client = configuration.get_async_client(resolved) if async_client is None: raise RuntimeError(f"No async client configured for config '{resolved}'") - + # Get collection reference from async_client with the collection_name collection_name = self.get_collection_name() col_ref = async_client.collection(collection_name) @@ -152,7 +151,7 @@ async def save( # Build doc ref (use provided id if set, otherwise let server generate) doc_id = self.get_document_id() if doc_id: - doc_ref = col_ref.document(doc_id) + doc_ref = col_ref.document(doc_id) else: doc_ref = col_ref.document() @@ -212,7 +211,7 @@ def get_document_id(self) -> Optional[str]: @classmethod async def delete_all_for_model(cls, config_name: Optional[str] = None) -> None: - # Resolve config to use (explicit -> instance -> class -> default) + # Resolve config to use (explicit -> instance -> class -> default) config_name = cls.__db_config__ client = configuration.get_async_client(config_name) @@ -222,7 +221,6 @@ async def delete_all_for_model(cls, config_name: Optional[str] = None) -> None: async for doc in col_ref.stream(): await doc.reference.delete() - @classmethod async def find( # pylint: disable=too-many-arguments cls: Type[TAsyncBareModel], @@ -379,7 +377,9 @@ async def truncate_collection(cls, batch_size: int = 128) -> int: ) @classmethod - def _get_col_ref(cls, collection_name: Optional[str] = None) -> AsyncCollectionReference: + def _get_col_ref( + cls, collection_name: Optional[str] = None + ) -> AsyncCollectionReference: """ Returns the collection reference. """ @@ -392,7 +392,9 @@ def get_collection_name(cls) -> str: """ return get_collection_name(cls, cls.__collection__) - def _get_doc_ref(self, config_name: Optional[str] = "(default)") -> AsyncDocumentReference: + def _get_doc_ref( + self, config_name: Optional[str] = "(default)" + ) -> AsyncDocumentReference: """ Returns the document reference. @@ -529,4 +531,4 @@ async def get_by_id( class AsyncSubCollection(AsyncBareSubCollection, ABC): __document_id__ = "id" - __model_cls__: Type[AsyncSubModel] \ No newline at end of file + __model_cls__: Type[AsyncSubModel] diff --git a/firedantic/_sync/indexes.py b/firedantic/_sync/indexes.py index 5ec5c66..e828233 100644 --- a/firedantic/_sync/indexes.py +++ b/firedantic/_sync/indexes.py @@ -82,7 +82,7 @@ def create_composite_index( def set_up_composite_indexes( - gcloud_project: str, + gcloud_project: Optional[str], models: Iterable[Type[BareModel]], database: str = "(default)", client: Optional[FirestoreAdminClient] = None, @@ -103,7 +103,7 @@ def set_up_composite_indexes( for model in models: if not getattr(model, "__composite_indexes__", None): continue - + # Resolve config name: prefer model __db_config__ if present; else default config_name = getattr(model, "__db_config__", "(default)") @@ -111,7 +111,9 @@ def set_up_composite_indexes( project = gcloud_project or configuration.get_config(config_name).project # Build collection group path using configuration helper (includes prefix) - collection_group = configuration.get_collection_name(model, config_name=config_name) + collection_group = configuration.get_collection_name( + model, config_name=config_name + ) path = f"projects/{project}/databases/{database}/collectionGroups/{collection_group}" indexes_in_db = get_existing_indexes(client, path=path) @@ -120,10 +122,18 @@ def set_up_composite_indexes( new_indexes = model_indexes.difference(indexes_in_db) for index in existing_indexes: - logger.debug("Composite index already exists in DB: %s, collection: %s", index, collection_group) + logger.debug( + "Composite index already exists in DB: %s, collection: %s", + index, + collection_group, + ) for index in new_indexes: - logger.info("Creating new composite index: %s, collection: %s", index, collection_group) + logger.info( + "Creating new composite index: %s, collection: %s", + index, + collection_group, + ) operation = create_composite_index(client, index, path) operations.append(operation) @@ -148,4 +158,4 @@ def set_up_composite_indexes_and_ttl_policies( models = list(models) ops = set_up_composite_indexes(gcloud_project, models, database, client) ops.extend(set_up_ttl_policies(gcloud_project, models, database, client)) - return ops \ No newline at end of file + return ops diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index 99156fa..62273cd 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -8,7 +8,6 @@ DocumentReference, DocumentSnapshot, FieldFilter, - Client ) from google.cloud.firestore_v1.base_query import BaseQuery from google.cloud.firestore_v1.transaction import Transaction @@ -44,7 +43,7 @@ def get_collection_name(cls, collection_name: Optional[str] = None) -> str: """ - Returns the collection name for `cls`. + Return the collection name for `cls`. - If `collection_name` is provided, treat it as an explicit collection name and prefix it using the configured prefix for the class (via __db_config__). @@ -60,7 +59,7 @@ def get_collection_name(cls, collection_name: Optional[str] = None) -> str: cfg = configuration.get_config(config_name) prefix = cfg.prefix or "" return f"{prefix}{collection_name}" - + if getattr(cls, "__collection__", None): cfg = configuration.get_config(config_name) return f"{cfg.prefix or ''}{cls.__collection__}" @@ -68,9 +67,9 @@ def get_collection_name(cls, collection_name: Optional[str] = None) -> str: raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") -def _get_col_ref(cls, collection_name: Optional[str]) -> CollectionReference: +def _get_col_ref(cls, collection_name: Optional[str] = None) -> CollectionReference: """ - Return a CollectionReference for the model class using the configured client. + Return an CollectionReference for the model class using the configured client. :param cls: model class :param collection_name: optional explicit collection name override @@ -83,16 +82,17 @@ def _get_col_ref(cls, collection_name: Optional[str]) -> CollectionReference: config_name = getattr(cls, "__db_config__", "(default)") client = configuration.get_client(config_name) - # Return the AsyncCollectionReference (this is a real client call) + # Return the CollectionReference (this is a real client call) col_ref = client.collection(col_name) # Ensure we got the right object back if not hasattr(col_ref, "document"): - raise RuntimeError(f"_get_col_ref returned unexpected object for {cls}: {type(col_ref)!r}") + raise RuntimeError( + f"_get_col_ref returned unexpected object for {cls}: {type(col_ref)!r}" + ) return col_ref - class BareModel(pydantic.BaseModel, ABC): """ Base model class. @@ -104,49 +104,12 @@ class BareModel(pydantic.BaseModel, ABC): __document_id__: str __ttl_field__: Optional[str] = None __composite_indexes__: Optional[Iterable[IndexDefinition]] = None - __db_config__: str = "(default)" # can be overridden in subclasses - - - @property - def _resolve_config(self) -> str: - """Return the config name this class is bound to""" - return getattr(self.__class__, "__db_config__", "(default)") - - @property - def _client(self) -> Client: - """Return the Firestore client this class is bound to""" - return configuration.get_client(self.__db_config__) - - def _get_collection(self, config_name: Optional[str] = None) -> CollectionReference: - """ - Return a CollectionReference for this model based on the resolved config and - a collection name derived from the class (can override with a different naming scheme) - """ - resolved = config_name if config_name is not None else self._resolve_config() - client: Client = configuration.get_client(name=resolved) - - # Use the config item's prefix + class-name-based collection name. - cfg_item = configuration.get_config(resolved) - collection_name = f"{cfg_item.prefix}{self.__class__.__name__.lower()}s" - return client.collection(collection_name) - - def _get_doc_ref(self, config_name: Optional[str] = None) -> DocumentReference: - """ - Return a DocumentReference for this instance. - If the instance has an id (self.__document_id__), use it; otherwise create a new doc ref - (client.collection().document() without an id will generate one). - """ - collection = self._get_collection(config_name) - doc_id = getattr(self, self.__document_id__, None) - if doc_id: - return collection.document(doc_id) - # No id: create a new document reference (Firestore client will generate ID) - return collection.document() - + __db_config__: str = "(default)" # override in subclasses when needed def save( self, *, + config_name: Optional[str] = None, exclude_unset: bool = False, exclude_none: bool = False, transaction: Optional[Transaction] = None, @@ -159,22 +122,44 @@ def save( :param transaction: Optional transaction to use. :raise DocumentIDError: If the document ID is not valid. """ + + # Resolve config to use (explicit -> instance -> class -> default) + if config_name is not None: + resolved = config_name + else: + resolved = getattr(self, "__db_config__", None) + if not resolved: + resolved = getattr(self.__class__, "__db_config__", "(default)") + + # Build payload data = self.model_dump( by_alias=True, exclude_unset=exclude_unset, exclude_none=exclude_none ) if self.__document_id__ in data: del data[self.__document_id__] - # get the document ref bound to the models configuration (db_config) - doc_ref = self._get_doc_ref() - - # Write data to firestore db using transaction or directly. + client = configuration.get_client(resolved) + if client is None: + raise RuntimeError(f"No client configured for config '{resolved}'") + + # Get collection reference from client with the collection_name + collection_name = self.get_collection_name() + col_ref = client.collection(collection_name) + + # Build doc ref (use provided id if set, otherwise let server generate) + doc_id = self.get_document_id() + if doc_id: + doc_ref = col_ref.document(doc_id) + else: + doc_ref = col_ref.document() + + # Use transaction if provided (assume it's compatible) otherwise do direct set if transaction is not None: + # Transaction.delete/set expects DocumentReference from the same client. transaction.set(doc_ref, data) else: doc_ref.set(data) - # Ensure the instance has the document id set. setattr(self, self.__document_id__, doc_ref.id) def delete(self, transaction: Optional[Transaction] = None) -> None: @@ -183,10 +168,13 @@ def delete(self, transaction: Optional[Transaction] = None) -> None: :raise DocumentIDError: If the ID is not valid. """ + doc_ref = self._get_doc_ref() + if transaction is not None: - transaction.delete(self._get_doc_ref()) + transaction.delete(doc_ref) else: - self._get_doc_ref().delete() + # print(f"\nTransaction not provided, deleting document: {self._get_doc_ref().path}") + doc_ref.delete() def reload(self, transaction: Optional[Transaction] = None) -> None: """ @@ -221,7 +209,7 @@ def get_document_id(self) -> Optional[str]: @classmethod def delete_all_for_model(cls, config_name: Optional[str] = None) -> None: - # Resolve config to use (explicit -> instance -> class -> default) + # Resolve config to use (explicit -> instance -> class -> default) config_name = cls.__db_config__ client = configuration.get_client(config_name) @@ -385,11 +373,11 @@ def truncate_collection(cls, batch_size: int = 128) -> int: ) @classmethod - def _get_col_ref(cls) -> CollectionReference: + def _get_col_ref(cls, collection_name: Optional[str] = None) -> CollectionReference: """ Returns the collection reference. """ - return _get_col_ref(cls, cls.__collection__) + return _get_col_ref(cls, collection_name) @classmethod def get_collection_name(cls) -> str: @@ -398,14 +386,15 @@ def get_collection_name(cls) -> str: """ return get_collection_name(cls, cls.__collection__) - def _get_doc_ref(self) -> DocumentReference: + def _get_doc_ref( + self, config_name: Optional[str] = "(default)" + ) -> DocumentReference: """ Returns the document reference. :raise DocumentIDError: If the ID is not valid. """ - doc_id = self.get_document_id() - return self._get_col_ref().document(doc_id) # type: ignore + return self._get_col_ref(config_name).document(self.get_document_id()) # type: ignore @staticmethod def _validate_document_id(document_id: str): @@ -536,4 +525,4 @@ def get_by_id( class SubCollection(BareSubCollection, ABC): __document_id__ = "id" - __model_cls__: Type[SubModel] \ No newline at end of file + __model_cls__: Type[SubModel] diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 330aa06..28738fa 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -1,25 +1,30 @@ +import warnings from os import environ from typing import Any, Dict, Optional, Type, Union -import warnings -from google.auth.credentials import Credentials from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.client_info import ClientInfo +from google.auth.credentials import Credentials +from google.cloud.firestore_admin_v1 import FirestoreAdminClient +from google.cloud.firestore_admin_v1.services.firestore_admin import ( + FirestoreAdminAsyncClient, +) +from google.cloud.firestore_admin_v1.services.firestore_admin.transports.base import ( + DEFAULT_CLIENT_INFO, +) from google.cloud.firestore_v1 import ( - AsyncClient, - AsyncTransaction, - Client, + AsyncClient, + AsyncTransaction, + Client, CollectionReference, - Transaction + Transaction, ) -from google.cloud.firestore_admin_v1 import FirestoreAdminClient -from google.cloud.firestore_admin_v1.services.firestore_admin import FirestoreAdminAsyncClient -from google.cloud.firestore_admin_v1.services.firestore_admin.transports.base import DEFAULT_CLIENT_INFO from pydantic import BaseModel # --- Old compatibility surface (kept for backwards compatibility) --- CONFIGURATIONS: Dict[str, Any] = {} + def configure(client: Union[Client, AsyncClient], prefix: str = "") -> None: """ Legacy helper: updates the module-level `configuration` default entry and preserves @@ -42,7 +47,7 @@ def configure(client: Union[Client, AsyncClient], prefix: str = "") -> None: else: # treat as sync client configuration.add(name="(default)", prefix=prefix, client=client) - + CONFIGURATIONS["db"] = client CONFIGURATIONS["prefix"] = prefix @@ -63,6 +68,7 @@ class ConfigItem(BaseModel): Holds configuration for a named Firestore connection. Clients may be provided directly, or created lazily from the stored params. """ + name: str prefix: str = "" project: Optional[str] = None @@ -82,6 +88,7 @@ class ConfigItem(BaseModel): model_config = {"arbitrary_types_allowed": True} + class Configuration: """ Registry for named Firestore configurations. @@ -98,7 +105,7 @@ def __init__(self) -> None: # Create a sensible default config entry, but avoid passing empty string as project. default_project = environ.get("GOOGLE_CLOUD_PROJECT") or None - + self.config["(default)"] = ConfigItem( name="(default)", prefix="", @@ -128,7 +135,7 @@ def add( You may either pass a pre-built client and/or an async_client, or provide only project/credentials so clients will be constructed lazily. """ - + normalized_project = self._normalize_project(project) item = ConfigItem( @@ -147,14 +154,14 @@ def add( ) self.config[name] = item return item - + # dict-like accessors def __getitem__(self, name: str) -> ConfigItem: return self.get_config(name) def __contains__(self, name: str) -> bool: return name in self.config - + def get(self, name: str, default=None): return self.config.get(name, default) @@ -165,10 +172,10 @@ def get_config(self, name: str = "(default)") -> ConfigItem: raise KeyError( f"Configuration '{name}' not found. Available: {list(self.config.keys())}" ) from err - + def get_config_names(self): return list(self.config.keys()) - + def _normalize_project(self, project: Optional[str]) -> Optional[str]: """ Convert empty-string project to None so Client(...) doesn't get empty string for project (i.e. project="") @@ -186,7 +193,7 @@ def get_client(self, name: Optional[str] = None) -> Client: if cfg.project is None: raise RuntimeError( f"No sync client configured for '{resolved}' and no project available; " - "call configuration.add(..., client=..., project=...) or set " \ + "call configuration.add(..., client=..., project=...) or set " "GOOGLE_CLOUD_PROJECT and let add() build the client." ) cfg.client = Client( @@ -197,7 +204,7 @@ def get_client(self, name: Optional[str] = None) -> Client: # NOTE: modern firestore clients may accept database param in constructor; keep for future. ) return cfg.client - + # async client accessor (lazy-create) def get_async_client(self, name: Optional[str] = None) -> AsyncClient: resolved = name if name is not None else "(default)" @@ -207,7 +214,7 @@ def get_async_client(self, name: Optional[str] = None) -> AsyncClient: if cfg.project is None: raise RuntimeError( f"No async client configured for '{resolved}' and no project available; " - "call configuration.add(..., async_client=..., project=...) or set " \ + "call configuration.add(..., async_client=..., project=...) or set " "GOOGLE_CLOUD_PROJECT and let add() build the client." ) cfg.async_client = AsyncClient( @@ -217,12 +224,12 @@ def get_async_client(self, name: Optional[str] = None) -> AsyncClient: client_options=cfg.client_options, # type: ignore[arg-type] ) return cfg.async_client - + # admin client accessor (lazy-create) def get_admin_client(self, name: Optional[str] = None) -> FirestoreAdminClient: resolved = name if name is not None else "(default)" cfg = self.get_config(resolved) - + if cfg.admin_client is None: cfg.admin_client = FirestoreAdminClient( credentials=cfg.credentials, @@ -233,10 +240,12 @@ def get_admin_client(self, name: Optional[str] = None) -> FirestoreAdminClient: return cfg.admin_client # async admin client accessor (lazy-create) - def get_async_admin_client(self, name: Optional[str] = None) -> FirestoreAdminAsyncClient: + def get_async_admin_client( + self, name: Optional[str] = None + ) -> FirestoreAdminAsyncClient: resolved = name if name is not None else "(default)" cfg = self.get_config(resolved) - + if cfg.async_admin_client is None: cfg.async_admin_client = FirestoreAdminAsyncClient( credentials=cfg.credentials, @@ -252,9 +261,11 @@ def get_transaction(self, name: Optional[str] = None) -> Transaction: def get_async_transaction(self, name: Optional[str] = None) -> AsyncTransaction: return self.get_async_client(name=name).transaction() - + # helpers for models to derive collection name / reference - def get_collection_name(self, model_class: Type, config_name: Optional[str] = None) -> str: + def get_collection_name( + self, model_class: Type, config_name: Optional[str] = None + ) -> str: """ Return the collection name string (prefix + model name). """ @@ -268,7 +279,9 @@ def get_collection_name(self, model_class: Type, config_name: Optional[str] = No model_name = model_class.__name__ return f"{prefix}{model_name[0].lower()}{model_name[1:]}" # (lower case first letter of model name) - def get_collection_ref(self, model_class: Type, name: Optional[str] = None) -> CollectionReference: + def get_collection_ref( + self, model_class: Type, name: Optional[str] = None + ) -> CollectionReference: """ Return a CollectionReference for the given model_class using the sync client. """ @@ -289,11 +302,12 @@ def get_async_collection_ref(self, model_class: Type, name: Optional[str] = None async_client = cfg.async_client if async_client is None: raise RuntimeError(f"No async client configured for config '{resolved}'") - + collection_name = self.get_collection_name(model_class, resolved) return async_client.collection(collection_name) - + + # make the module-level singleton available to models/tests # other modules should: from firedantic.configurations import configuration configuration = Configuration() diff --git a/firedantic/tests/tests_async/conftest.py b/firedantic/tests/tests_async/conftest.py index 6eabbfb..09a4710 100644 --- a/firedantic/tests/tests_async/conftest.py +++ b/firedantic/tests/tests_async/conftest.py @@ -21,6 +21,7 @@ from unittest.mock import AsyncMock, Mock # noqa isort: skip + class CustomIDModel(AsyncBareModel): __collection__ = "custom" __document_id__ = "foo" @@ -306,4 +307,4 @@ async def update_field(self, data) -> MockOperation: @pytest.fixture() def mock_admin_client(): - return AsyncMockFirestoreAdminClient() \ No newline at end of file + return AsyncMockFirestoreAdminClient() diff --git a/firedantic/tests/tests_async/test_indexes.py b/firedantic/tests/tests_async/test_indexes.py index ee9caf2..f6d93fe 100644 --- a/firedantic/tests/tests_async/test_indexes.py +++ b/firedantic/tests/tests_async/test_indexes.py @@ -13,8 +13,8 @@ collection_index, ) from firedantic.common import IndexField -from firedantic.tests.tests_async.conftest import MockListIndexOperation from firedantic.configurations import configuration +from firedantic.tests.tests_async.conftest import MockListIndexOperation import pytest # noqa isort: skip @@ -29,7 +29,10 @@ class BaseModelWithIndexes(AsyncModel): @pytest.mark.asyncio async def test_set_up_composite_index(mock_admin_client) -> None: - configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -64,7 +67,10 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_set_up_collection_group_index(mock_admin_client) -> None: - configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_group_index( @@ -96,7 +102,10 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_set_up_composite_indexes_and_policies(mock_admin_client) -> None: - configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -121,7 +130,10 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_set_up_many_composite_indexes(mock_admin_client) -> None: - configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -149,7 +161,10 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_set_up_indexes_model_without_indexes(mock_admin_client) -> None: - configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) + class ModelWithoutIndexes(AsyncModel): __collection__ = "modelWithoutIndexes" @@ -168,7 +183,9 @@ class ModelWithoutIndexes(AsyncModel): @pytest.mark.asyncio async def test_existing_indexes_are_skipped(mock_admin_client) -> None: - configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) expected_prefix = configuration.get_config("(default)").prefix resp = ListIndexesResponse( @@ -227,7 +244,9 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_same_fields_in_another_collection(mock_admin_client) -> None: - configuration.add(name="(default)", prefix="test_", project="proj", async_client=mock_admin_client) + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) expected_prefix = configuration.get_config("(default)").prefix # Test that when another collection has an index with exactly the same fields, @@ -267,4 +286,4 @@ class ModelWithIndexes(BaseModelWithIndexes): models=[ModelWithIndexes], client=mock_admin_client, ) - assert len(result) == 1 \ No newline at end of file + assert len(result) == 1 diff --git a/firedantic/tests/tests_async/test_model.py b/firedantic/tests/tests_async/test_model.py index 8e40652..78523e7 100644 --- a/firedantic/tests/tests_async/test_model.py +++ b/firedantic/tests/tests_async/test_model.py @@ -44,6 +44,7 @@ async def test_save_model(create_company) -> None: assert company.owner.first_name == "John" assert company.owner.last_name == "Doe" + @pytest.mark.asyncio async def test_delete_all_for_model(create_company) -> None: company: Company = await create_company( diff --git a/firedantic/tests/tests_sync/conftest.py b/firedantic/tests/tests_sync/conftest.py index fc59651..8debfc4 100644 --- a/firedantic/tests/tests_sync/conftest.py +++ b/firedantic/tests/tests_sync/conftest.py @@ -305,4 +305,4 @@ def update_field(self, data) -> MockOperation: @pytest.fixture() def mock_admin_client(): - return MockFirestoreAdminClient() \ No newline at end of file + return MockFirestoreAdminClient() diff --git a/firedantic/tests/tests_sync/test_indexes.py b/firedantic/tests/tests_sync/test_indexes.py index 7be99d9..46d4b5c 100644 --- a/firedantic/tests/tests_sync/test_indexes.py +++ b/firedantic/tests/tests_sync/test_indexes.py @@ -13,8 +13,8 @@ set_up_composite_indexes_and_ttl_policies, ) from firedantic.common import IndexField -from firedantic.tests.tests_sync.conftest import MockListIndexOperation from firedantic.configurations import configuration +from firedantic.tests.tests_sync.conftest import MockListIndexOperation import pytest # noqa isort: skip @@ -27,9 +27,11 @@ class BaseModelWithIndexes(Model): age: int - def test_set_up_composite_index(mock_admin_client) -> None: - configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -49,7 +51,6 @@ class ModelWithIndexes(BaseModelWithIndexes): # index is a protobuf structure sent via Google Cloud Admin path = call_list[0][1]["request"].parent expected_prefix = configuration.get_config("(default)").prefix - assert ( path == f"projects/proj/databases/(default)/collectionGroups/{expected_prefix}modelWithIndexes" @@ -63,9 +64,11 @@ class ModelWithIndexes(BaseModelWithIndexes): assert index.fields[1].order.name == Query.DESCENDING - def test_set_up_collection_group_index(mock_admin_client) -> None: - configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_group_index( @@ -85,6 +88,7 @@ class ModelWithIndexes(BaseModelWithIndexes): # index is a protobuf structure sent via Google Cloud Admin path = call_list[0][1]["request"].parent expected_prefix = configuration.get_config("(default)").prefix + assert ( path == f"projects/proj/databases/(default)/collectionGroups/{expected_prefix}modelWithIndexes" @@ -94,9 +98,11 @@ class ModelWithIndexes(BaseModelWithIndexes): assert len(index.fields) == 2 - def test_set_up_composite_indexes_and_policies(mock_admin_client) -> None: - configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -119,9 +125,11 @@ class ModelWithIndexes(BaseModelWithIndexes): assert len(call_list) == 1 - def test_set_up_many_composite_indexes(mock_admin_client) -> None: - configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -147,9 +155,11 @@ class ModelWithIndexes(BaseModelWithIndexes): assert len(result) == 3 - def test_set_up_indexes_model_without_indexes(mock_admin_client) -> None: - configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) + class ModelWithoutIndexes(Model): __collection__ = "modelWithoutIndexes" @@ -166,10 +176,12 @@ class ModelWithoutIndexes(Model): assert len(call_list) == 0 - def test_existing_indexes_are_skipped(mock_admin_client) -> None: - configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) expected_prefix = configuration.get_config("(default)").prefix + resp = ListIndexesResponse( { "indexes": [ @@ -200,9 +212,7 @@ def test_existing_indexes_are_skipped(mock_admin_client) -> None: ] } ) - mock_admin_client.list_indexes = Mock( - return_value=MockListIndexOperation([resp]) - ) + mock_admin_client.list_indexes = Mock(return_value=MockListIndexOperation([resp])) class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( @@ -224,12 +234,14 @@ class ModelWithIndexes(BaseModelWithIndexes): assert len(result) == 0 - def test_same_fields_in_another_collection(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) + expected_prefix = configuration.get_config("(default)").prefix + # Test that when another collection has an index with exactly the same fields, # it won't affect creating an index in the target collection - configuration.add(name="(default)", prefix="test_", project="proj", client=mock_admin_client) - expected_prefix = configuration.get_config("(default)").prefix resp = ListIndexesResponse( { "indexes": [ @@ -248,9 +260,7 @@ def test_same_fields_in_another_collection(mock_admin_client) -> None: ] } ) - mock_admin_client.list_indexes = Mock( - return_value=MockListIndexOperation([resp]) - ) + mock_admin_client.list_indexes = Mock(return_value=MockListIndexOperation([resp])) class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( diff --git a/firedantic/tests/tests_sync/test_model.py b/firedantic/tests/tests_sync/test_model.py index 321e51c..abdffce 100644 --- a/firedantic/tests/tests_sync/test_model.py +++ b/firedantic/tests/tests_sync/test_model.py @@ -98,6 +98,12 @@ def test_find(create_company, create_product) -> None: d = Company.find({"owner.first_name": "John"}) assert len(d) == 4 + d = Company.find({"owner.first_name": {op.EQ: "John"}}) + assert len(d) == 4 + + d = Company.find({"owner.first_name": {"==": "John"}}) + assert len(d) == 4 + for p in TEST_PRODUCTS: create_product(**p) @@ -251,6 +257,7 @@ def test_get_by_empty_str_id() -> None: def test_missing_collection() -> None: class User(Model): name: str + # normally __collection__ would be defined here with pytest.raises(CollectionNotDefined): User(name="John").save() @@ -300,7 +307,7 @@ class User(Model): "!:&+-*'()", ], ) -def test_models_with_valid_custom_id( model_id) -> None: +def test_models_with_valid_custom_id(model_id) -> None: product_id = str(uuid4()) product = Product(product_id=product_id, price=123.45, stock=2) @@ -395,7 +402,7 @@ def test_bare_model_document_id_persistency() -> None: assert len(CustomIDModel.find({})) == 1 -def test_bare_model_get_by_empty_doc_id(configure_db) -> None: +def test_bare_model_get_by_empty_doc_id() -> None: with pytest.raises(ModelNotFoundError): CustomIDModel.get_by_doc_id("") @@ -508,8 +515,6 @@ def test_save_with_exclude_unset() -> None: def test_update_city_in_transaction() -> None: """ Test updating a model in a transaction. Test case from README. - - :param: configure_db: pytest fixture """ @transactional @@ -531,8 +536,6 @@ def decrement_population(transaction: Transaction, city: City, decrement: int = def test_delete_in_transaction() -> None: """ Test deleting a model in a transaction. - - :param: configure_db: pytest fixture """ @transactional @@ -546,7 +549,8 @@ def delete_in_transaction(transaction: Transaction, profile_id: str) -> None: assert p.id t = get_transaction() - delete_in_transaction(t, p.id) + with t: + delete_in_transaction(t, p.id) with pytest.raises(ModelNotFoundError): Profile.get_by_id(p.id) @@ -555,8 +559,6 @@ def delete_in_transaction(transaction: Transaction, profile_id: str) -> None: def test_update_model_in_transaction() -> None: """ Test updating a model in a transaction. - - :param: configure_db: pytest fixture """ @transactional @@ -581,8 +583,6 @@ def update_in_transaction( def test_update_submodel_in_transaction() -> None: """ Test Updating a submodel in a transaction. - - :param: configure_db: pytest fixture """ @transactional @@ -607,4 +607,4 @@ def update_submodel_in_transaction( user_stats = update_submodel_in_transaction(t, u.id, "2021") assert isinstance(user_stats, UserStats) assert user_stats.purchases == 43 - assert get_user_purchases(u.id) == 43 \ No newline at end of file + assert get_user_purchases(u.id) == 43 diff --git a/integration_tests/README.md b/integration_tests/README.md index fb164f7..dfb4bc1 100644 --- a/integration_tests/README.md +++ b/integration_tests/README.md @@ -1,11 +1,15 @@ # integration_tests/README.md ## Overview -This folder contains integration test files. This README explains how to run them and what each file is for. + +This folder contains integration test files. This README explains how to run them and +what each file is for. ## Prerequisites + - Python 3.13 -- Run: +- Run: + ```bash python -m venv firedantic source firedantic/bin/activate @@ -13,30 +17,36 @@ cd firedantic/integration_tests ``` ## Files and purpose (replace placeholders with real filenames) -- `integration_tests/configure_firestore_db_clients.py` — Purpose: shows how to create and connect to various db clients. -- `integration_tests/full_sync_flow.py` — Purpose: configures clients, saves data to db, finds the data, and deletes all data in a sync fashion. -- `integration_tests/full_async_flow.py` — Purpose: configures async clients, saves data to db, finds the data, and deletes all data in an async fashion. -- `integration_tests/test_integration_all.py` - Purpose: intended to run all examples shown in the readme, new and legacy. +- `integration_tests/configure_firestore_db_clients.py` — Purpose: shows how to create + and connect to various db clients. +- `integration_tests/full_sync_flow.py` — Purpose: configures clients, saves data to db, + finds the data, and deletes all data in a sync fashion. +- `integration_tests/full_async_flow.py` — Purpose: configures async clients, saves data + to db, finds the data, and deletes all data in an async fashion. +- `integration_tests/test_integration_all.py` - Purpose: intended to run all examples + shown in the readme, new and legacy. ## How to run -Run each individual test file: - - `python configure_firestore_db_clients.py` - - `python full_sync_flow.py` - - `python full_async_flow.py` - - `python test_integration_all.py` +Run each individual test file: - `python configure_firestore_db_clients.py` - +`python full_sync_flow.py` - `python full_async_flow.py` - +`python test_integration_all.py` ## Environment and configuration + - Ensure firestore emulator is running in another terminal window: - - `./start_emulator.sh` + - `./start_emulator.sh` ## What to expect: + - For the configure_firestore_db_clients test, you should expect to see the following: -`All configure_firestore_db_client tests passed!` + `All configure_firestore_db_client tests passed!` + +- For the `full_sync_flow` and `full_async_flow`, you should expect to see the following + output: \ + You can readily play around with the models to update the data as desired. -- For the `full_sync_flow` and `full_async_flow`, you should expect to see the following output: \ -You can readily play around with the models to update the data as desired. ``` Number of company owners with first name: 'Bill': 1 @@ -51,4 +61,4 @@ Number of company owners with first name: 'Alice': 1 Number of billing companies with id: '1234567-8c': 1 Number of billing accounts with billing_id: 801048: 1 -``` \ No newline at end of file +``` diff --git a/integration_tests/configure_firestore_db_clients.py b/integration_tests/configure_firestore_db_clients.py index 8819044..50ee349 100644 --- a/integration_tests/configure_firestore_db_clients.py +++ b/integration_tests/configure_firestore_db_clients.py @@ -4,7 +4,7 @@ import google.auth.credentials from google.cloud.firestore import AsyncClient, Client -from firedantic.configurations import configuration, configure, CONFIGURATIONS +from firedantic.configurations import CONFIGURATIONS, configuration, configure ## OLD WAY: @@ -26,7 +26,7 @@ def configure_sync_client(): # project check assert CONFIGURATIONS["db"].project == "firedantic-test" - + # emulator expectations assert isinstance(CONFIGURATIONS["db"]._credentials, Mock) @@ -34,7 +34,6 @@ def configure_sync_client(): assert not isinstance(CONFIGURATIONS["db"], AsyncClient) - def configure_async_client(): # Firestore emulator must be running if using locally. if environ.get("FIRESTORE_EMULATOR_HOST"): @@ -53,7 +52,7 @@ def configure_async_client(): # project check assert CONFIGURATIONS["db"].project == "firedantic-test" - + # emulator expectations assert isinstance(CONFIGURATIONS["db"]._credentials, Mock) @@ -61,7 +60,6 @@ def configure_async_client(): assert not isinstance(CONFIGURATIONS["db"], Client) - ## NEW WAY: def configure_multiple_clients(): config = configuration @@ -81,7 +79,6 @@ def configure_multiple_clients(): credentials=Mock(spec=google.auth.credentials.Credentials), ) - # =========================== # VALIDATE DEFAULT CONFIG # =========================== @@ -106,7 +103,6 @@ def configure_multiple_clients(): assert isinstance(config.get_client(), Client) assert isinstance(config.get_async_client(), AsyncClient) - # =========================== # VALIDATE BILLING CONFIG # =========================== @@ -135,7 +131,6 @@ def configure_multiple_clients(): assert billing.project != default.project - try: ### ---- Running OLD way ---- configure_sync_client() @@ -143,7 +138,7 @@ def configure_multiple_clients(): ### ---- Running NEW way ---- configure_multiple_clients() - + print("\nAll configure_firestore_db_client tests passed!\n") except AssertionError as e: diff --git a/integration_tests/full_async_flow.py b/integration_tests/full_async_flow.py index 81c8787..e15a533 100644 --- a/integration_tests/full_async_flow.py +++ b/integration_tests/full_async_flow.py @@ -1,11 +1,12 @@ +import sys +from os import environ from unittest.mock import Mock +import google.auth.credentials + from firedantic import AsyncModel -from firedantic.configurations import configuration, AsyncClient, configure +from firedantic.configurations import AsyncClient, configuration, configure -import google.auth.credentials -import sys -from os import environ async def test_old_way(): # This is the old way of doing things, kept for backwards compatibility. @@ -13,7 +14,7 @@ async def test_old_way(): if environ.get("FIRESTORE_EMULATOR_HOST"): client = AsyncClient( project="firedantic-test", - credentials=Mock(spec=google.auth.credentials.Credentials) + credentials=Mock(spec=google.auth.credentials.Credentials), ) else: client = AsyncClient() @@ -22,12 +23,13 @@ async def test_old_way(): class Owner(AsyncModel): """Dummy owner Pydantic model.""" + first_name: str last_name: str - class Company(AsyncModel): """Dummy company Firedantic model.""" + __collection__ = "companies" company_id: str owner: Owner @@ -41,8 +43,12 @@ class Company(AsyncModel): await company.reload() # Finding data from DB - print(f"\nNumber of company owners with first name: 'Bill': {len(await Company.find({"owner.first_name": "Bill"}))}") - print(f"\nNumber of companies with id: '1234567-7': {len(await Company.find({"company_id": "1234567-7"}))}") + print( + f"\nNumber of company owners with first name: 'Bill': {len(await Company.find({"owner.first_name": "Bill"}))}" + ) + print( + f"\nNumber of companies with id: '1234567-7': {len(await Company.find({"company_id": "1234567-7"}))}" + ) # Delete everything from the database await company.delete_all_for_model() @@ -50,22 +56,24 @@ class Company(AsyncModel): if not deletion_success: print(f"\nDeletion of (default) DB failed\n") + ## With single async client async def test_with_default(): - + class Owner(AsyncModel): """Dummy owner Pydantic model.""" + __collection__ = "owners" first_name: str last_name: str - class Company(AsyncModel): """Dummy company Firedantic model.""" + __collection__ = "companies" company_id: str owner: Owner - + # Firestore emulator must be running if using locally, name defaults to "(default)" configuration.add( prefix="async-default-test-", @@ -77,7 +85,7 @@ class Company(AsyncModel): owner = Owner(first_name="John", last_name="Doe") company = Company(company_id="1234567-8a", owner=owner) - # Save the company/owner info to the DB + # Save the company/owner info to the DB await company.save() # only need to include config_name when not using default # Reloads model data from the database to ensure most-up-to-date info @@ -85,16 +93,26 @@ class Company(AsyncModel): # Assert that async client exists and configuration is correct assert isinstance(configuration.get_async_client(), AsyncClient) - assert configuration.get_collection_name(Owner) == configuration.get_config().prefix + "owners" - assert configuration.get_collection_name(Company) == configuration.get_config().prefix + "companies" + assert ( + configuration.get_collection_name(Owner) + == configuration.get_config().prefix + "owners" + ) + assert ( + configuration.get_collection_name(Company) + == configuration.get_config().prefix + "companies" + ) assert configuration.get_config().prefix == "async-default-test-" assert configuration.get_config().project == "async-default-test" assert configuration.get_config().name == "(default)" # Finding data from DB - print(f"\nNumber of company owners with first name: 'John': {len(await Company.find({"owner.first_name": "John"}))}") + print( + f"\nNumber of company owners with first name: 'John': {len(await Company.find({"owner.first_name": "John"}))}" + ) - print(f"\nNumber of companies with id: '1234567-8a': {len(await Company.find({"company_id": "1234567-8a"}))}") + print( + f"\nNumber of companies with id: '1234567-8a': {len(await Company.find({"company_id": "1234567-8a"}))}" + ) # Delete everything from the database await company.delete_all_for_model() @@ -110,14 +128,15 @@ async def test_with_multiple(): class Owner(AsyncModel): """Dummy owner Pydantic model.""" + __db_config__ = config_name __collection__ = "owners" first_name: str last_name: str - class Company(AsyncModel): """Dummy company Firedantic model.""" + __db_config__ = config_name __collection__ = config_name company_id: str @@ -135,7 +154,7 @@ class Company(AsyncModel): owner = Owner(first_name="Alice", last_name="Begone") company = Company(company_id="1234567-9", owner=owner) - # Save the company/owner info to the DB + # Save the company/owner info to the DB await company.save() # only need to include config_name when not using default # Reloads model data from the database to ensure most-up-to-date info @@ -148,15 +167,16 @@ class Company(AsyncModel): class BillingAccount(AsyncModel): """Dummy billing account Pydantic model.""" + __db_config__ = config_name __collection__ = "accounts" name: str billing_id: int owner: str - class BillingCompany(AsyncModel): """Dummy company Firedantic model.""" + __db_config__ = config_name __collection__ = "companies" company_id: str @@ -172,8 +192,14 @@ class BillingCompany(AsyncModel): # Assert that async clients exists and added configurations are correct assert isinstance(configuration.get_async_client("billing"), AsyncClient) - assert configuration.get_collection_name(BillingAccount, "billing") == configuration.get_config("billing").prefix + "accounts" - assert configuration.get_collection_name(BillingCompany, "billing") == configuration.get_config("billing").prefix + "companies" + assert ( + configuration.get_collection_name(BillingAccount, "billing") + == configuration.get_config("billing").prefix + "accounts" + ) + assert ( + configuration.get_collection_name(BillingCompany, "billing") + == configuration.get_config("billing").prefix + "companies" + ) assert configuration.get_config("billing").prefix == "async-billing-test-" assert configuration.get_config("billing").project == "async-billing-test" @@ -181,7 +207,10 @@ class BillingCompany(AsyncModel): assert isinstance(configuration.get_async_client("companies"), AsyncClient) - assert configuration.get_collection_name(Company, "companies") == configuration.get_config("companies").prefix + "companies" + assert ( + configuration.get_collection_name(Company, "companies") + == configuration.get_config("companies").prefix + "companies" + ) assert configuration.get_config("companies").prefix == "async-companies-test-" assert configuration.get_config("companies").project == "async-companies-test" @@ -199,11 +228,17 @@ class BillingCompany(AsyncModel): await bc.reload() # 3. Finding data - print(f"\nNumber of company owners with first name: 'Alice': {len(await Company.find({"owner.first_name": "Alice"}))}") + print( + f"\nNumber of company owners with first name: 'Alice': {len(await Company.find({"owner.first_name": "Alice"}))}" + ) - print(f"\nNumber of billing companies with id: '1234567-8c': {len(await BillingCompany.find({"company_id": "1234567-8c"}))}") + print( + f"\nNumber of billing companies with id: '1234567-8c': {len(await BillingCompany.find({"company_id": "1234567-8c"}))}" + ) - print(f"\nNumber of billing accounts with billing_id: 801048: {len(await BillingCompany.find({"billing_account.billing_id": 801048}))}") + print( + f"\nNumber of billing accounts with billing_id: 801048: {len(await BillingCompany.find({"billing_account.billing_id": 801048}))}" + ) # Delete everything from the database await company.delete_all_for_model() @@ -213,24 +248,28 @@ class BillingCompany(AsyncModel): if not deletion_success: print(f"\nDeletion of Company DB failed\n") - deletion_success = [] == await BillingCompany.find({"billing_account.billing_id": 801048}) + deletion_success = [] == await BillingCompany.find( + {"billing_account.billing_id": 801048} + ) if not deletion_success: print(f"\nDeletion of BillingCompany DB failed\n") - # suppress silent error msg from asyncio threads def dbg_hook(exctype, value, tb): print("=== Uncaught exception ===") print(exctype, value) import traceback + traceback.print_tb(tb) + sys.excepthook = dbg_hook # Run the tests if __name__ == "__main__": import asyncio + asyncio.run(test_old_way()) asyncio.run(test_with_default()) asyncio.run(test_with_multiple()) diff --git a/integration_tests/full_sync_flow.py b/integration_tests/full_sync_flow.py index 02316f8..6410b0b 100644 --- a/integration_tests/full_sync_flow.py +++ b/integration_tests/full_sync_flow.py @@ -1,11 +1,11 @@ +from os import environ from unittest.mock import Mock +import google.auth.credentials from pydantic import BaseModel -from firedantic import Model -from firedantic.configurations import configuration, Client, configure -import google.auth.credentials -from os import environ +from firedantic import Model +from firedantic.configurations import Client, configuration, configure def test_old_way(): @@ -14,7 +14,7 @@ def test_old_way(): if environ.get("FIRESTORE_EMULATOR_HOST"): client = Client( project="firedantic-test", - credentials=Mock(spec=google.auth.credentials.Credentials) + credentials=Mock(spec=google.auth.credentials.Credentials), ) else: client = Client() @@ -23,12 +23,13 @@ def test_old_way(): class Owner(BaseModel): """Dummy owner Pydantic model.""" + first_name: str last_name: str - class Company(Model): """Dummy company Firedantic model.""" + __collection__ = "companies" company_id: str owner: Owner @@ -42,8 +43,12 @@ class Company(Model): company.reload() # Finding data from DB - print(f"\nNumber of company owners with first name: 'Bill': {len(Company.find({"owner.first_name": "Bill"}))}") - print(f"\nNumber of companies with id: '1234567-7': {len(Company.find({"company_id": "1234567-7"}))}") + print( + f"\nNumber of company owners with first name: 'Bill': {len(Company.find({"owner.first_name": "Bill"}))}" + ) + print( + f"\nNumber of companies with id: '1234567-7': {len(Company.find({"company_id": "1234567-7"}))}" + ) # Delete everything from the database company.delete_all_for_model() @@ -51,20 +56,21 @@ class Company(Model): ## With single sync client def test_with_default(): - + class Owner(Model): """Dummy owner Pydantic model.""" + __collection__ = "owners" first_name: str last_name: str - class Company(Model): """Dummy company Firedantic model.""" + __collection__ = "companies" company_id: str owner: Owner - + # Firestore emulator must be running if using locally, name defaults to "(default)" configuration.add( prefix="sync-default-test-", @@ -76,7 +82,7 @@ class Company(Model): owner = Owner(first_name="John", last_name="Doe") company = Company(company_id="1234567-8", owner=owner) - # Save the company/owner info to the DB + # Save the company/owner info to the DB company.save() # Reloads model data from the database to ensure most-up-to-date info @@ -84,19 +90,30 @@ class Company(Model): # Assert that sync client exists and configuration is correct assert isinstance(configuration.get_client(), Client) - assert configuration.get_collection_name(Owner) == configuration.get_config().prefix + "owners" - assert configuration.get_collection_name(Company) == configuration.get_config().prefix + "companies" + assert ( + configuration.get_collection_name(Owner) + == configuration.get_config().prefix + "owners" + ) + assert ( + configuration.get_collection_name(Company) + == configuration.get_config().prefix + "companies" + ) assert configuration.get_config().prefix == "sync-default-test-" assert configuration.get_config().project == "sync-default-test" assert configuration.get_config().name == "(default)" - + # Finding data from DB - print(f"\nNumber of company owners with first name: 'John': {len(Company.find({"owner.first_name": "John"}))}") - print(f"\nNumber of companies with id: '1234567-8': {len(Company.find({"company_id": "1234567-8"}))}") + print( + f"\nNumber of company owners with first name: 'John': {len(Company.find({"owner.first_name": "John"}))}" + ) + print( + f"\nNumber of companies with id: '1234567-8': {len(Company.find({"company_id": "1234567-8"}))}" + ) # Delete everything from the database company.delete_all_for_model() + # Now with multiple SYNC clients/dbs: def test_with_multiple(): @@ -104,14 +121,15 @@ def test_with_multiple(): class Owner(Model): """Dummy owner Pydantic model.""" + __db_config__ = config_name __collection__ = "owners" first_name: str last_name: str - class Company(Model): """Dummy company Firedantic model.""" + __db_config__ = config_name __collection__ = config_name company_id: str @@ -141,15 +159,16 @@ class Company(Model): class BillingAccount(Model): """Dummy billing account Pydantic model.""" + __db_config__ = config_name __collection__ = "accounts" name: str billing_id: int owner: str - class BillingCompany(Model): """Dummy company Firedantic model.""" + __db_config__ = config_name __collection__ = "companies" company_id: str @@ -165,8 +184,14 @@ class BillingCompany(Model): # Assert that sync clients exists and added configurations are correct assert isinstance(configuration.get_client("billing"), Client) - assert configuration.get_collection_name(BillingAccount, "billing") == configuration.get_config("billing").prefix + "accounts" - assert configuration.get_collection_name(BillingCompany, "billing") == configuration.get_config("billing").prefix + "companies" + assert ( + configuration.get_collection_name(BillingAccount, "billing") + == configuration.get_config("billing").prefix + "accounts" + ) + assert ( + configuration.get_collection_name(BillingCompany, "billing") + == configuration.get_config("billing").prefix + "companies" + ) assert configuration.get_config("billing").prefix == "sync-billing-test-" assert configuration.get_config("billing").project == "sync-billing-test" @@ -174,13 +199,15 @@ class BillingCompany(Model): assert isinstance(configuration.get_client("companies"), Client) - assert configuration.get_collection_name(Company, "companies") == configuration.get_config("companies").prefix + "companies" + assert ( + configuration.get_collection_name(Company, "companies") + == configuration.get_config("companies").prefix + "companies" + ) assert configuration.get_config("companies").prefix == "sync-companies-test-" assert configuration.get_config("companies").project == "sync-companies-test" assert configuration.get_config("companies").name == "companies" - # Create and save billing account and company account = BillingAccount(name="ABC Billing", billing_id=801048, owner="MFisher") bc = BillingCompany(company_id="1234567-8c", billing_account=account) @@ -193,14 +220,21 @@ class BillingCompany(Model): bc.reload() # 3. Finding data - print(f"\nNumber of company owners with first name: 'Alice': {len(Company.find({"owner.first_name": "Alice"}))}") - print(f"\nNumber of billing companies with id: '1234567-8c': {len(BillingCompany.find({"company_id": "1234567-8c"}))}") - print(f"\nNumber of billing accounts with billing_id: 801048: {len(BillingCompany.find({"billing_account.billing_id": 801048}))}") + print( + f"\nNumber of company owners with first name: 'Alice': {len(Company.find({"owner.first_name": "Alice"}))}" + ) + print( + f"\nNumber of billing companies with id: '1234567-8c': {len(BillingCompany.find({"company_id": "1234567-8c"}))}" + ) + print( + f"\nNumber of billing accounts with billing_id: 801048: {len(BillingCompany.find({"billing_account.billing_id": 801048}))}" + ) # now delete everything from the DBs: company.delete_all_for_model() bc.delete_all_for_model() + ## Ensure existing configure function works for backwards compatibility test_old_way() diff --git a/integration_tests/test_integration_all.py b/integration_tests/test_integration_all.py index c132818..2ecae47 100644 --- a/integration_tests/test_integration_all.py +++ b/integration_tests/test_integration_all.py @@ -29,41 +29,54 @@ # require firedantic package (adjust path or install locally) try: # recommended imports - from firedantic.common import IndexField - from firedantic.configurations import configuration, CONFIGURATIONS, configure - from firedantic import ( - Model, + from google.cloud.firestore import Client, Query + from google.cloud.firestore_admin_v1.services.firestore_admin import ( + FirestoreAdminAsyncClient, + FirestoreAdminClient, + ) + + # from google.cloud.firestore import AsyncClient as AsyncClientSyncName # alias to avoid name clash + from google.cloud.firestore_v1 import ( + AsyncClient, + Transaction, + async_transactional, + transactional, + ) + + from firedantic import ( # # IndexField, AsyncModel, - collection_index, - collection_group_index, - ## IndexField, - set_up_composite_indexes_and_ttl_policies, + Model, async_set_up_composite_indexes_and_ttl_policies, + collection_group_index, + collection_index, get_all_subclasses, - get_transaction, get_async_transaction, + get_transaction, + set_up_composite_indexes_and_ttl_policies, ) - from google.cloud.firestore import Client, Query - # from google.cloud.firestore import AsyncClient as AsyncClientSyncName # alias to avoid name clash - from google.cloud.firestore_v1 import AsyncClient, transactional, Transaction, async_transactional - from google.cloud.firestore_admin_v1.services.firestore_admin import FirestoreAdminClient - from google.cloud.firestore_admin_v1.services.firestore_admin import FirestoreAdminAsyncClient + from firedantic.common import IndexField + from firedantic.configurations import CONFIGURATIONS, configuration, configure except Exception as e: - print("Failed to import required modules. Ensure `firedantic` is importable and google-cloud-* libs are installed.") + print( + "Failed to import required modules. Ensure `firedantic` is importable and google-cloud-* libs are installed." + ) raise EMULATOR = os.environ.get("FIRESTORE_EMULATOR_HOST") USE_EMULATOR = bool(EMULATOR) + def must(msg: str): print("ERROR:", msg) sys.exit(2) + def info(msg: str): print(msg) + # ----------------------- # Helpers # ----------------------- @@ -72,6 +85,7 @@ def assert_or_exit(cond: bool, message: str): print("ASSERTION FAILED:", message) raise AssertionError(message) + def cleanup_sync_collection_by_model(model_cls: type[Model]) -> None: """Delete everything from a model's collection (sync).""" info(f"Cleaning collection for sync model {model_cls.__name__}") @@ -79,12 +93,14 @@ def cleanup_sync_collection_by_model(model_cls: type[Model]) -> None: for obj in model_cls.find({}): obj.delete() + async def cleanup_async_collection_by_model(model_cls: type[AsyncModel]) -> None: info(f"Cleaning collection for async model {model_cls.__name__}") items = await model_cls.find({}) for it in items: await it.delete() + # admin helpers def get_admin_client_if_possible(): """Return admin client or None (if emulator/admin not available).""" @@ -100,6 +116,7 @@ def get_admin_client_if_possible(): info(f"FirestoreAdminClient unavailable/skipping admin steps: {e}") return None + async def get_async_admin_client_if_possible(): try: if USE_EMULATOR: @@ -112,6 +129,7 @@ async def get_async_admin_client_if_possible(): info(f"FirestoreAdminAsyncClient unavailable/skipping admin steps: {e}") return None + # ----------------------- # Test cases # ----------------------- @@ -127,9 +145,9 @@ def test_new_config_sync_flow(): configuration.add(name="billing", prefix="billing-", project="billing-project") # get clients - client = configuration.get_client() # sync client for "(default)" + client = configuration.get_client() # sync client for "(default)" # async_client = configuration.get_async_client() # async client for "(default)" - billing_client = configuration.get_client("billing") # sync client for "billing" + billing_client = configuration.get_client("billing") # sync client for "billing" class Billing(Model): __db_config__ = "billing" @@ -153,7 +171,10 @@ class Billing(Model): # find results = Billing.find({"name": "Firedantic Billing"}) assert_or_exit(isinstance(results, list), "find returned not a list") - assert_or_exit(any(r.billing_id == "9901981" for r in results), "find did not return saved billing account") + assert_or_exit( + any(r.billing_id == "9901981" for r in results), + "find did not return saved billing account", + ) # delete b.delete() @@ -175,7 +196,10 @@ def test_legacy_config_sync_flow(): client = Client() configure(client, prefix="legacy-sync-") - assert_or_exit(CONFIGURATIONS["prefix"].startswith("legacy-sync-"), "legacy configure prefix mismatch") + assert_or_exit( + CONFIGURATIONS["prefix"].startswith("legacy-sync-"), + "legacy configure prefix mismatch", + ) # use same model pattern as earlier class LegacyOwner(Model): @@ -201,7 +225,9 @@ class LegacyCompany(Model): async def test_new_config_async_flow(): info("\n=== new Configuration API: async flow ===") mock_creds = Mock(spec=google.auth.credentials.Credentials) - configuration.add(prefix="integ-async-", project="test-project", credentials=mock_creds) + configuration.add( + prefix="integ-async-", project="test-project", credentials=mock_creds + ) class Owner(AsyncModel): first_name: str @@ -225,7 +251,9 @@ class Company(AsyncModel): assert_or_exit(c.owner.first_name == "Alice", "async owner first name mismatch") found = await Company.find({"company_id": "A-1"}) - assert_or_exit(len(found) >= 1 and found[0].company_id == "A-1", "async find failed") + assert_or_exit( + len(found) >= 1 and found[0].company_id == "A-1", "async find failed" + ) await c.delete() remains = await Company.find({"company_id": "A-1"}) @@ -234,7 +262,6 @@ class Company(AsyncModel): print("new config async flow OK") - async def test_legacy_config_async_flow(): info("\n=== legacy configure() async flow ===") mock_creds = Mock(spec=google.auth.credentials.Credentials) @@ -245,7 +272,10 @@ async def test_legacy_config_async_flow(): aclient = AsyncClient() # legacy configure supports AsyncClient and sets CONFIGURATIONS configure(aclient, prefix="legacy-async-") - assert_or_exit(CONFIGURATIONS["prefix"].startswith("legacy-async-"), "legacy async configure failed") + assert_or_exit( + CONFIGURATIONS["prefix"].startswith("legacy-async-"), + "legacy async configure failed", + ) class LOwner(AsyncModel): first_name: str @@ -265,6 +295,7 @@ class LCompany(AsyncModel): await c.delete() info("legacy async flow OK") + def test_subcollections_and_model_for(): info("\n=== subcollections (model_for) ===") # ensure configuration exists @@ -299,7 +330,9 @@ class Collection: s.save() # find back found = StatsCollection.find({"purchases": 3}) - assert_or_exit(any(x.purchases == 3 for x in found), "subcollection save/find failed") + assert_or_exit( + any(x.purchases == 3 for x in found), "subcollection save/find failed" + ) # cleanup for x in StatsCollection.find({}): x.delete() @@ -307,16 +340,17 @@ class Collection: x.delete() info("subcollections OK") + def test_transactions_sync(): info("\n=== sync transactions ===") # mock_creds = Mock(spec=google.auth.credentials.Credentials) - + # Configure once configuration.add( project="firedantic-test", prefix="firedantic-test-", ) - + class City(Model): __collection__ = "cities" population: int = 0 @@ -335,9 +369,7 @@ def _increment_population(transaction: Transaction) -> None: cleanup_sync_collection_by_model(City) @transactional - def decrement_population( - transaction: Transaction, city: City, decrement: int = 1 - ): + def decrement_population(transaction: Transaction, city: City, decrement: int = 1): city.reload(transaction=transaction) city.population = max(0, city.population - decrement) city.save(transaction=transaction) @@ -356,14 +388,15 @@ def decrement_population( # reload c.reload() assert_or_exit(c.population == 2, "sync transaction increment failed") - + # delete c.delete() print("sync transactions OK") + async def test_transactions_async(): info("\n=== async transactions ===") - + # Configure once configuration.add( project="async-tx-test", @@ -398,16 +431,24 @@ async def increment_async(tx, city_id: str, inc: int = 1): await c.delete() print("async transactions OK") + def test_multi_config_usage(): info("\n=== multi-config usage ===") mock_creds = Mock(spec=google.auth.credentials.Credentials) - + # default config - configuration.add(prefix="multi-default-", project="proj-default", credentials=mock_creds) - + configuration.add( + prefix="multi-default-", project="proj-default", credentials=mock_creds + ) + # billing config - configuration.add(name="billing", prefix="multi-billing-", project="proj-billing", credentials=mock_creds) + configuration.add( + name="billing", + prefix="multi-billing-", + project="proj-billing", + credentials=mock_creds, + ) class CompanyDefault(Model): __collection__ = "companies" @@ -438,14 +479,19 @@ class BillingAccount(Model): # find calls (default vs billing) found_default = CompanyDefault.find({"company_id": "MD-1"}) found_billing = BillingAccount.find({"billing_id": "B-1"}) - assert_or_exit(any(x.company_id == "MD-1" for x in found_default), "find default failed") - assert_or_exit(any(x.billing_id == "B-1" for x in found_billing), "find billing failed") + assert_or_exit( + any(x.company_id == "MD-1" for x in found_default), "find default failed" + ) + assert_or_exit( + any(x.billing_id == "B-1" for x in found_billing), "find billing failed" + ) # cleanup cleanup_sync_collection_by_model(CompanyDefault) cleanup_sync_collection_by_model(BillingAccount) info("multi-config usage OK") + def test_indexes_and_ttl_sync(): info("\n=== composite indexes and TTL (sync) ===") @@ -458,22 +504,37 @@ class ExpiringModel(Model): ## uses "(default)" config name __collection__ = "expiringModel" __ttl_field__ = "expire" __composite_indexes__ = [ - collection_index(IndexField("content", Query.ASCENDING), IndexField("expire", Query.DESCENDING)), - collection_group_index(IndexField("content", Query.DESCENDING), IndexField("expire", Query.ASCENDING)), + collection_index( + IndexField("content", Query.ASCENDING), + IndexField("expire", Query.DESCENDING), + ), + collection_group_index( + IndexField("content", Query.DESCENDING), + IndexField("expire", Query.ASCENDING), + ), ] content: str expire: datetime # register config with admin client (best-effort) mock_creds = Mock(spec=google.auth.credentials.Credentials) - configuration.add(prefix="idx-sync-", project="proj-idx", credentials=mock_creds, client=None, async_client=None, admin_client=admin_client) + configuration.add( + prefix="idx-sync-", + project="proj-idx", + credentials=mock_creds, + client=None, + async_client=None, + admin_client=admin_client, + ) try: # call setup function (may require admin permission; emulator may accept) project = configuration.get_config().project if not USE_EMULATOR: # real GCP -> try to set up indexes - set_up_composite_indexes_and_ttl_policies(gcloud_project=project, models=get_all_subclasses(Model)) + set_up_composite_indexes_and_ttl_policies( + gcloud_project=project, models=get_all_subclasses(Model) + ) info("sync index/ttl setup called (no exception)") else: # emulator -> skip (or log) @@ -482,6 +543,7 @@ class ExpiringModel(Model): ## uses "(default)" config name except Exception as e: info(f"index/ttl setup raised (will continue): {e}") + async def test_indexes_and_ttl_async(): info("\n=== composite indexes and TTL (async) ===") admin_client = await get_async_admin_client_if_possible() @@ -493,7 +555,10 @@ class ExpiringAsync(AsyncModel): __collection__ = "expiringModel" __ttl_field__ = "expire" __composite_indexes__ = [ - collection_index(IndexField("content", Query.ASCENDING), IndexField("expire", Query.DESCENDING)), + collection_index( + IndexField("content", Query.ASCENDING), + IndexField("expire", Query.DESCENDING), + ), ] content: str expire: datetime @@ -505,22 +570,27 @@ class ExpiringAsync(AsyncModel): project = configuration.get_config().project if not USE_EMULATOR: # real GCP -> try to set up indexes - await async_set_up_composite_indexes_and_ttl_policies(gcloud_project=project, models=get_all_subclasses(AsyncModel)) + await async_set_up_composite_indexes_and_ttl_policies( + gcloud_project=project, models=get_all_subclasses(AsyncModel) + ) info("async index/ttl setup called (no exception)") else: # emulator -> skip (or log) print("Skipping index/TTL setup when using emulator.") - + except Exception as e: info(f"async index/ttl setup raised (will continue): {e}") + # ----------------------- # Runner # ----------------------- def main(): info("Beginning full integration test following README examples") if not USE_EMULATOR: - info("WARNING: FIRESTORE_EMULATOR_HOST not set — this script is designed for emulator runs; continue with caution.") + info( + "WARNING: FIRESTORE_EMULATOR_HOST not set — this script is designed for emulator runs; continue with caution." + ) # run sync new config flow test_new_config_sync_flow() @@ -529,7 +599,7 @@ def main(): test_legacy_config_sync_flow() # run async new config flow - asyncio.run(test_new_config_async_flow()) ## Currently deletion fails! + asyncio.run(test_new_config_async_flow()) ## Currently deletion fails! # legacy async flow asyncio.run(test_legacy_config_async_flow()) @@ -550,5 +620,6 @@ def main(): info("\nIntegration README full test finished successfully.") + if __name__ == "__main__": main() diff --git a/unasync.py b/unasync.py index ffa7887..d13cabf 100644 --- a/unasync.py +++ b/unasync.py @@ -5,6 +5,8 @@ from pathlib import Path SUBS = [ + ("get_async_client", "get_client"), + ("async_client", "client"), ("get_async_transaction", "get_transaction"), ( "from google.cloud.firestore_v1.async_transaction", @@ -25,6 +27,7 @@ ("async def", "def"), ("async for", "for"), ("async with", "with"), + ("async client", "client"), ("StopAsyncIteration", "StopIteration"), ("__anext__", "__next__"), ("async_truncate_collection", "truncate_collection"), From fa9fb84c95e7dcf6c2a4059998c078198ac067d6 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Mon, 8 Dec 2025 17:30:26 -0300 Subject: [PATCH 31/38] putting suggestion further up. --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index acf3d1b..871efd7 100644 --- a/README.md +++ b/README.md @@ -737,6 +737,10 @@ poetry install poetry run invoke test ``` + +\*Note, when new functions, comments, variables etc. are added that will span across sync and async directories, be sure to first declare the replacements in `unasync.py` before running `poetry run invoke test`. I.e. indicating a replacement of 'async_client' with 'client' text across both directories. + + \*Note, the emulator must be set and running for all tests to pass. ### About sync and async versions of library @@ -749,8 +753,6 @@ version is generated automatically by invoke task: poetry run invoke unasync ``` -If new functions, comments, variables etc. are added that will span across sync and async directories, be sure to first declare the replacements in `unasync.py` before running poetry. I.e. a replacement of 'async_client' with 'client' for the sync directory. - We decided to go this way in order to: - make sure both versions have the same API From 781fb772d2b07f8b6cac02a5ba072c3d3cdfc75e Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Tue, 9 Dec 2025 15:16:28 -0300 Subject: [PATCH 32/38] making the delete fail safely, more tests --- README.md | 7 ++- firedantic/_async/model.py | 71 ++++++++++++++++++++-- firedantic/_sync/model.py | 71 ++++++++++++++++++++-- firedantic/tests/tests_async/test_model.py | 60 +++++++++--------- firedantic/tests/tests_sync/test_model.py | 56 +++++++++-------- 5 files changed, 197 insertions(+), 68 deletions(-) diff --git a/README.md b/README.md index 871efd7..e08a2b7 100644 --- a/README.md +++ b/README.md @@ -737,9 +737,10 @@ poetry install poetry run invoke test ``` - -\*Note, when new functions, comments, variables etc. are added that will span across sync and async directories, be sure to first declare the replacements in `unasync.py` before running `poetry run invoke test`. I.e. indicating a replacement of 'async_client' with 'client' text across both directories. - +\*Note, when new functions, comments, variables etc. are added that will span across +sync and async directories, be sure to first declare the replacements in `unasync.py` +before running `poetry run invoke test`. I.e. indicating a replacement of 'async_client' +with 'client' text across both directories. \*Note, the emulator must be set and running for all tests to pass. diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index 0cdd9b8..994d2b1 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -172,11 +172,74 @@ async def delete(self, transaction: Optional[AsyncTransaction] = None) -> None: """ doc_ref = self._get_doc_ref() + # try to extract client-like objects + doc_client = getattr(doc_ref, "_client", None) or getattr( + doc_ref, "client", None + ) + tx_client = ( + getattr(transaction, "_client", None) + or getattr(transaction, "_client_async", None) + if transaction is not None + else None + ) + if transaction is not None: + # Defensive check: make sure the doc_ref is built from same client as the transaction. + tx_client = getattr(transaction, "_client", None) or getattr( + transaction, "_client_async", None + ) + doc_client = getattr(doc_ref, "_client", None) or getattr( + doc_ref, "client", None + ) + + # If both sides expose client objects, ensure they are same identity. + if ( + tx_client is not None + and doc_client is not None + and tx_client is not doc_client + ): + # Try to rebuild a document reference from the transaction's client using the same path + path = getattr(doc_ref, "path", None) + if path is None: + raise RuntimeError( + "Cannot resolve document path to rebuild doc_ref for transaction." + ) + + # For most firestores clients, client.document(path) works for sync client; + # for async, we try client.document(path) as well (it usually exists). + try: + # prefer a method that accepts full path + alt_doc_ref = None + if hasattr(tx_client, "document"): + alt_doc_ref = tx_client.document(path) + elif hasattr(tx_client, "collection"): + # fallback: split path to collection and doc id + parts = path.split("/") + if len(parts) >= 2: + collection_path = "/".join(parts[:-1]) + doc_id = parts[-1] + alt_doc_ref = tx_client.collection( + collection_path + ).document(doc_id) + if alt_doc_ref is None: + raise RuntimeError( + "Could not rebuild document reference from transaction client." + ) + doc_ref = alt_doc_ref + + except Exception as exc: + raise RuntimeError( + "Document reference was created from a different Firestore client than " + "the provided transaction. Recreate doc_ref from the transaction's client " + "or call delete() without a transaction." + ) from exc + + # schedule delete on the transaction (this will be committed on transaction commit) transaction.delete(doc_ref) - else: - # print(f"\nTransaction not provided, deleting document: {self._get_doc_ref().path}") - await doc_ref.delete() + return + + # no transaction: do a direct delete + await doc_ref.delete() async def reload(self, transaction: Optional[AsyncTransaction] = None) -> None: """ @@ -209,7 +272,7 @@ def get_document_id(self) -> Optional[str]: _OrderBy = List[Tuple[str, OrderDirection]] @classmethod - async def delete_all_for_model(cls, config_name: Optional[str] = None) -> None: + async def delete(cls, config_name: Optional[str] = None) -> None: # Resolve config to use (explicit -> instance -> class -> default) config_name = cls.__db_config__ diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index 62273cd..616138a 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -170,11 +170,74 @@ def delete(self, transaction: Optional[Transaction] = None) -> None: """ doc_ref = self._get_doc_ref() + # try to extract client-like objects + doc_client = getattr(doc_ref, "_client", None) or getattr( + doc_ref, "client", None + ) + tx_client = ( + getattr(transaction, "_client", None) + or getattr(transaction, "_client_async", None) + if transaction is not None + else None + ) + if transaction is not None: + # Defensive check: make sure the doc_ref is built from same client as the transaction. + tx_client = getattr(transaction, "_client", None) or getattr( + transaction, "_client_async", None + ) + doc_client = getattr(doc_ref, "_client", None) or getattr( + doc_ref, "client", None + ) + + # If both sides expose client objects, ensure they are same identity. + if ( + tx_client is not None + and doc_client is not None + and tx_client is not doc_client + ): + # Try to rebuild a document reference from the transaction's client using the same path + path = getattr(doc_ref, "path", None) + if path is None: + raise RuntimeError( + "Cannot resolve document path to rebuild doc_ref for transaction." + ) + + # For most firestores clients, client.document(path) works for sync client; + # for async, we try client.document(path) as well (it usually exists). + try: + # prefer a method that accepts full path + alt_doc_ref = None + if hasattr(tx_client, "document"): + alt_doc_ref = tx_client.document(path) + elif hasattr(tx_client, "collection"): + # fallback: split path to collection and doc id + parts = path.split("/") + if len(parts) >= 2: + collection_path = "/".join(parts[:-1]) + doc_id = parts[-1] + alt_doc_ref = tx_client.collection( + collection_path + ).document(doc_id) + if alt_doc_ref is None: + raise RuntimeError( + "Could not rebuild document reference from transaction client." + ) + doc_ref = alt_doc_ref + + except Exception as exc: + raise RuntimeError( + "Document reference was created from a different Firestore client than " + "the provided transaction. Recreate doc_ref from the transaction's client " + "or call delete() without a transaction." + ) from exc + + # schedule delete on the transaction (this will be committed on transaction commit) transaction.delete(doc_ref) - else: - # print(f"\nTransaction not provided, deleting document: {self._get_doc_ref().path}") - doc_ref.delete() + return + + # no transaction: do a direct delete + doc_ref.delete() def reload(self, transaction: Optional[Transaction] = None) -> None: """ @@ -207,7 +270,7 @@ def get_document_id(self) -> Optional[str]: _OrderBy = List[Tuple[str, OrderDirection]] @classmethod - def delete_all_for_model(cls, config_name: Optional[str] = None) -> None: + def delete(cls, config_name: Optional[str] = None) -> None: # Resolve config to use (explicit -> instance -> class -> default) config_name = cls.__db_config__ diff --git a/firedantic/tests/tests_async/test_model.py b/firedantic/tests/tests_async/test_model.py index 78523e7..83fdba5 100644 --- a/firedantic/tests/tests_async/test_model.py +++ b/firedantic/tests/tests_async/test_model.py @@ -8,6 +8,7 @@ import firedantic.operators as op from firedantic import AsyncModel, get_async_transaction +from firedantic.configurations import configuration from firedantic.exceptions import ( CollectionNotDefined, InvalidDocumentID, @@ -45,21 +46,6 @@ async def test_save_model(create_company) -> None: assert company.owner.last_name == "Doe" -@pytest.mark.asyncio -async def test_delete_all_for_model(create_company) -> None: - company: Company = await create_company( - company_id="11223344-5", first_name="Jane", last_name="Doe" - ) - - _id = company.id - assert _id - - await company.delete_all_for_model() - - with pytest.raises(ModelNotFoundError): - await Company.get_by_id(_id) - - @pytest.mark.asyncio async def test_find_one(create_company) -> None: with pytest.raises(ModelNotFoundError): @@ -572,29 +558,43 @@ async def decrement_population( @pytest.mark.asyncio -async def test_delete_in_transaction() -> None: +async def test_delete_in_transaction(create_company): """ - Test deleting a model in a transaction. + Test deleting a Company model within a Firestore transaction. """ + # Create a company + company: Company = await create_company( + company_id="11223344-4", first_name="Joe", last_name="Day" + ) + _id = company.id + assert _id @async_transactional - async def delete_in_transaction( - transaction: AsyncTransaction, profile_id: str - ) -> None: - """Deletes a Profile in a transaction.""" - profile = await Profile.get_by_id(profile_id, transaction=transaction) - await profile.delete(transaction=transaction) - - p = Profile(name="Foo") - await p.save() - assert p.id + async def delete_company(transaction: AsyncTransaction) -> None: + await company.delete() + # Call the transactional function t = get_async_transaction() - async with t: - await delete_in_transaction(t, p.id) + await delete_company(t) + # Outside the transaction, the deletion should now be committed with pytest.raises(ModelNotFoundError): - await Profile.get_by_id(p.id) + await Company.get_by_id(_id) + + +@pytest.mark.asyncio +async def test_delete_model(create_company) -> None: + company: Company = await create_company( + company_id="11223344-5", first_name="Jane", last_name="Doe" + ) + + _id = company.id + assert _id + + await company.delete() + + with pytest.raises(ModelNotFoundError): + await Company.get_by_id(_id) @pytest.mark.asyncio diff --git a/firedantic/tests/tests_sync/test_model.py b/firedantic/tests/tests_sync/test_model.py index abdffce..77031e2 100644 --- a/firedantic/tests/tests_sync/test_model.py +++ b/firedantic/tests/tests_sync/test_model.py @@ -8,6 +8,7 @@ import firedantic.operators as op from firedantic import Model, get_transaction +from firedantic.configurations import configuration from firedantic.exceptions import ( CollectionNotDefined, InvalidDocumentID, @@ -44,20 +45,6 @@ def test_save_model(create_company) -> None: assert company.owner.last_name == "Doe" -def test_delete_all_for_model(create_company) -> None: - company: Company = create_company( - company_id="11223344-5", first_name="Jane", last_name="Doe" - ) - - _id = company.id - assert _id - - company.delete_all_for_model() - - with pytest.raises(ModelNotFoundError): - Company.get_by_id(_id) - - def test_find_one(create_company) -> None: with pytest.raises(ModelNotFoundError): Company.find_one() @@ -533,27 +520,42 @@ def decrement_population(transaction: Transaction, city: City, decrement: int = assert c.population == 0 -def test_delete_in_transaction() -> None: +def test_delete_in_transaction(create_company): """ - Test deleting a model in a transaction. + Test deleting a Company model within a Firestore transaction. """ + # Create a company + company: Company = create_company( + company_id="11223344-4", first_name="Joe", last_name="Day" + ) + _id = company.id + assert _id @transactional - def delete_in_transaction(transaction: Transaction, profile_id: str) -> None: - """Deletes a Profile in a transaction.""" - profile = Profile.get_by_id(profile_id, transaction=transaction) - profile.delete(transaction=transaction) - - p = Profile(name="Foo") - p.save() - assert p.id + def delete_company(transaction: Transaction) -> None: + company.delete() + # Call the transactional function t = get_transaction() - with t: - delete_in_transaction(t, p.id) + delete_company(t) + + # Outside the transaction, the deletion should now be committed + with pytest.raises(ModelNotFoundError): + Company.get_by_id(_id) + + +def test_delete_model(create_company) -> None: + company: Company = create_company( + company_id="11223344-5", first_name="Jane", last_name="Doe" + ) + + _id = company.id + assert _id + + company.delete() with pytest.raises(ModelNotFoundError): - Profile.get_by_id(p.id) + Company.get_by_id(_id) def test_update_model_in_transaction() -> None: From 00dae137969571b25d25e13e84e52b53245d0702 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Tue, 9 Dec 2025 15:29:07 -0300 Subject: [PATCH 33/38] clean up integ tests --- integration_tests/full_async_flow.py | 8 ++++---- integration_tests/full_sync_flow.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/integration_tests/full_async_flow.py b/integration_tests/full_async_flow.py index e15a533..8618897 100644 --- a/integration_tests/full_async_flow.py +++ b/integration_tests/full_async_flow.py @@ -51,7 +51,7 @@ class Company(AsyncModel): ) # Delete everything from the database - await company.delete_all_for_model() + await company.delete() deletion_success = [] == await Company.find({"company_id": "1234567-7"}) if not deletion_success: print(f"\nDeletion of (default) DB failed\n") @@ -115,7 +115,7 @@ class Company(AsyncModel): ) # Delete everything from the database - await company.delete_all_for_model() + await company.delete() deletion_success = [] == await Company.find({"company_id": "1234567-8a"}) if not deletion_success: print(f"\nDeletion of (default) DB failed\n") @@ -241,8 +241,8 @@ class BillingCompany(AsyncModel): ) # Delete everything from the database - await company.delete_all_for_model() - await bc.delete_all_for_model() + await company.delete() + await bc.delete() deletion_success = [] == await Company.find({"company_id": "1234567-8a"}) if not deletion_success: diff --git a/integration_tests/full_sync_flow.py b/integration_tests/full_sync_flow.py index 6410b0b..b62ab1f 100644 --- a/integration_tests/full_sync_flow.py +++ b/integration_tests/full_sync_flow.py @@ -51,7 +51,7 @@ class Company(Model): ) # Delete everything from the database - company.delete_all_for_model() + company.delete() ## With single sync client @@ -111,7 +111,7 @@ class Company(Model): ) # Delete everything from the database - company.delete_all_for_model() + company.delete() # Now with multiple SYNC clients/dbs: @@ -231,8 +231,8 @@ class BillingCompany(Model): ) # now delete everything from the DBs: - company.delete_all_for_model() - bc.delete_all_for_model() + company.delete() + bc.delete() ## Ensure existing configure function works for backwards compatibility From 25735a4413ad7f6f1098d6fab0789df441ab559b Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Fri, 12 Dec 2025 18:26:14 -0300 Subject: [PATCH 34/38] update version to 0.13.0 --- firedantic/configurations.py | 12 ++++++------ pyproject.toml | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 28738fa..9663858 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -25,7 +25,7 @@ CONFIGURATIONS: Dict[str, Any] = {} -def configure(client: Union[Client, AsyncClient], prefix: str = "") -> None: +def configure(db: Union[Client, AsyncClient], prefix: str = "") -> None: """ Legacy helper: updates the module-level `configuration` default entry and preserves the old CONFIGURATIONS mapping so old callers continue to work. @@ -36,19 +36,19 @@ def configure(client: Union[Client, AsyncClient], prefix: str = "") -> None: # soft deprecation notice for users (no stacktrace) warnings.warn( - "firedantic.configure(client, ...) is deprecated and will be removed in a " + "firedantic.configure(db, ...) is deprecated and will be removed in a " "future release. Use firedantic.configurations.configuration.add(...) instead.", DeprecationWarning, stacklevel=2, ) - if isinstance(client, AsyncClient): - configuration.add(name="(default)", prefix=prefix, async_client=client) + if isinstance(db, AsyncClient): + configuration.add(name="(default)", prefix=prefix, async_client=db) else: # treat as sync client - configuration.add(name="(default)", prefix=prefix, client=client) + configuration.add(name="(default)", prefix=prefix, client=db) - CONFIGURATIONS["db"] = client + CONFIGURATIONS["db"] = db CONFIGURATIONS["prefix"] = prefix diff --git a/pyproject.toml b/pyproject.toml index a16227a..a1fb56e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "firedantic" -version = "0.12.0" +version = "0.13.0" description = "Pydantic base models for Firestore" authors = ["IOXIO Ltd"] license = "BSD-3-Clause" From f548cc0a6429cc85344237713ebe55cc95d023d7 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Fri, 12 Dec 2025 18:26:50 -0300 Subject: [PATCH 35/38] changelog update --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa177d0..8d50958 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## [Unreleased] +## [0.13.0] - 2025-12-12 + +### Added + +- Added support for multiple firestore clients. + ## [0.12.0] - 2025-11-20 ### Changed From 47ecdbcc3903799567fd86a0a59ed331f0d47ef1 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Thu, 18 Dec 2025 16:44:04 -0300 Subject: [PATCH 36/38] found a bug in delete process, adding more unit tests and codecov --- .gitignore | 2 + firedantic/_async/model.py | 11 +- firedantic/_sync/model.py | 11 +- firedantic/tests/test_clients.py | 424 ++++++++++++++++++++++++++ firedantic/tests/test_mocked_model.py | 154 ++++++++++ poetry.lock | 132 +++++++- pyproject.toml | 1 + 7 files changed, 727 insertions(+), 8 deletions(-) create mode 100644 firedantic/tests/test_clients.py create mode 100644 firedantic/tests/test_mocked_model.py diff --git a/.gitignore b/.gitignore index 1f28392..f013835 100644 --- a/.gitignore +++ b/.gitignore @@ -83,6 +83,8 @@ ipython_config.py # pyenv .python-version +pyvenv.cfg +/bin # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index 994d2b1..be48d6d 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -166,7 +166,7 @@ async def save( async def delete(self, transaction: Optional[AsyncTransaction] = None) -> None: """ - Deletes this model from the database. + Deletes this specific model instance from the database. :raise DocumentIDError: If the ID is not valid. """ @@ -272,7 +272,10 @@ def get_document_id(self) -> Optional[str]: _OrderBy = List[Tuple[str, OrderDirection]] @classmethod - async def delete(cls, config_name: Optional[str] = None) -> None: + async def delete_all(cls, config_name: Optional[str] = None) -> None: + """ + Deletes all models of this type from the database. + """ # Resolve config to use (explicit -> instance -> class -> default) config_name = cls.__db_config__ @@ -463,7 +466,9 @@ def _get_doc_ref( :raise DocumentIDError: If the ID is not valid. """ - return self._get_col_ref(config_name).document(self.get_document_id()) # type: ignore + # _get_col_ref takes collection_name, not config_name. + # Any specific config usage should likely be handled by context or passed properly if supported. + return self._get_col_ref().document(self.get_document_id()) # type: ignore @staticmethod def _validate_document_id(document_id: str): diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index 616138a..90c995a 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -164,7 +164,7 @@ def save( def delete(self, transaction: Optional[Transaction] = None) -> None: """ - Deletes this model from the database. + Deletes this specific model instance from the database. :raise DocumentIDError: If the ID is not valid. """ @@ -270,7 +270,10 @@ def get_document_id(self) -> Optional[str]: _OrderBy = List[Tuple[str, OrderDirection]] @classmethod - def delete(cls, config_name: Optional[str] = None) -> None: + def delete_all(cls, config_name: Optional[str] = None) -> None: + """ + Deletes all models of this type from the database. + """ # Resolve config to use (explicit -> instance -> class -> default) config_name = cls.__db_config__ @@ -457,7 +460,9 @@ def _get_doc_ref( :raise DocumentIDError: If the ID is not valid. """ - return self._get_col_ref(config_name).document(self.get_document_id()) # type: ignore + # _get_col_ref takes collection_name, not config_name. + # Any specific config usage should likely be handled by context or passed properly if supported. + return self._get_col_ref().document(self.get_document_id()) # type: ignore @staticmethod def _validate_document_id(document_id: str): diff --git a/firedantic/tests/test_clients.py b/firedantic/tests/test_clients.py new file mode 100644 index 0000000..e5ac91b --- /dev/null +++ b/firedantic/tests/test_clients.py @@ -0,0 +1,424 @@ +from os import environ +from unittest.mock import Mock + +import google.auth.credentials +import pytest +from google.cloud.firestore import AsyncClient, Client + +import firedantic.configurations as cfg_module +from firedantic import Configuration, configure + + +# --- tiny fake client classes to capture construction args --- +class FakeClient: + def __init__(self, *args, **kwargs): + self._constructed_args = args + self._constructed_kwargs = kwargs + + +class FakeAsyncClient: + def __init__(self, *args, **kwargs): + self._constructed_args = args + self._constructed_kwargs = kwargs + + +# --- Tests --------------------------------------------------------- +def test_get_client_lazy_creation_and_caching(monkeypatch): + """ + get_client() should create a Client lazily (on first call) and then cache it. + """ + monkeypatch.setattr(cfg_module, "Client", FakeClient) + + cfg = Configuration() + # register a config with a project so that lazy creation is allowed + item = cfg.add(name="x", project="proj-x", prefix="px-") + + # no client yet + assert item.client is None + + # first call -> construct and cache + client1 = cfg.get_client("x") + assert isinstance(client1, FakeClient) + assert cfg.get_config("x").client is client1 + + # subsequent call -> same instance + client2 = cfg.get_client("x") + assert client2 is client1 + + # verify constructor received our project in kwargs (if implementation provides it) + # accept either 'project' in kwargs or stored on item.project + assert cfg.get_config("x").project == "proj-x" + + +def test_get_async_client_lazy_creation_and_caching(monkeypatch): + """ + get_async_client() should create an AsyncClient lazily and cache it. + """ + monkeypatch.setattr(cfg_module, "AsyncClient", FakeAsyncClient) + + cfg = Configuration() + item = cfg.add(name="y", project="proj-y", prefix="py-") + + assert item.async_client is None + + async_client1 = cfg.get_async_client("y") + assert isinstance(async_client1, FakeAsyncClient) + assert cfg.get_config("y").async_client is async_client1 + + async_client2 = cfg.get_async_client("y") + assert async_client2 is async_client1 + + +def test_preserve_prebuilt_clients(): + """ + If user supplies a prebuilt client / async_client to add(), the registry uses exactly them. + """ + cfg = Configuration() + + prebuilt_sync = FakeClient() + prebuilt_async = FakeAsyncClient() + + cfg.add( + name="prebuilt", + project="proj-pre", + prefix="pp-", + client=prebuilt_sync, + async_client=prebuilt_async, + ) + + # get_client/get_async_client should return the exact instances given + got_sync = cfg.get_client("prebuilt") + got_async = cfg.get_async_client("prebuilt") + assert got_sync is prebuilt_sync + assert got_async is prebuilt_async + assert cfg.get_config("prebuilt").client is prebuilt_sync + assert cfg.get_config("prebuilt").async_client is prebuilt_async + + +def test_get_client_no_project_raises(monkeypatch): + """ + If no project was supplied and no client was given, get_client() should raise a clear error. + """ + # ensure environment does NOT provide a project fallback + monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False) + + # add a config without project and without client + cfg = Configuration() + cfg.add(name="noproj", project=None, prefix="np-") + + with pytest.raises(RuntimeError): + cfg.get_client("noproj") + + with pytest.raises(RuntimeError): + cfg.get_async_client("noproj") + + +def test_multiple_clients_are_independent(monkeypatch): + """ + Ensure default and billing configs return distinct client objects and preserve stored metadata. + """ + monkeypatch.setattr(cfg_module, "Client", FakeClient) + monkeypatch.setattr(cfg_module, "AsyncClient", FakeAsyncClient) + creds = Mock(spec=google.auth.credentials.Credentials) + + cfg = Configuration() + cfg.add(prefix="firedantic-test-", project="proj-default", credentials=creds) + cfg.add( + name="billing", prefix="billing-", project="proj-billing", credentials=creds + ) + + # clients for "(default)"" config + default_sync = cfg.get_client() + default_async = cfg.get_async_client() + assert isinstance(default_sync, FakeClient) + assert isinstance(default_async, FakeAsyncClient) + + # clients for billing config + billing_sync = cfg.get_client("billing") + billing_async = cfg.get_async_client("billing") + assert isinstance(billing_sync, FakeClient) + assert isinstance(billing_async, FakeAsyncClient) + + # Ensure the four objects are distinct (sync vs async and default vs billing) + assert default_sync is not billing_sync + assert default_async is not billing_async + assert default_sync is not default_async + assert billing_sync is not billing_async + + # verify metadata preserved + assert cfg.get_config().prefix == "firedantic-test-" + assert cfg.get_config().project == "proj-default" + assert cfg.get_config("billing").prefix == "billing-" + assert cfg.get_config("billing").project == "proj-billing" + + +def test___getitem___and_contains_behavior(monkeypatch): + """ + Ensure dict-like accessors work: __getitem__ and __contains__. + """ + monkeypatch.setattr(cfg_module, "Client", FakeClient) + cfg = Configuration() + cfg.add(name="abc", project="p-abc", prefix="pa-") + assert "abc" in cfg + assert "(default)" in cfg + assert cfg["abc"].project == "p-abc" + assert cfg["(default)"].prefix == "" + + +def test_configure_client(): + + project = "firedantic-test" + prefix = "firedantic-test-" + creds = Mock(spec=google.auth.credentials.Credentials) + + if not environ.get("FIRESTORE_EMULATOR_HOST"): + raise "Firestore emulator must be running" + + config = Configuration() + config.add(prefix=prefix, project=project, credentials=creds) + + assert config["(default)"].project == project + assert config["(default)"].prefix == prefix + client = config.get_client() + assert isinstance(client, Client) + + +### Test sync/async client classes ### + + +def test_configure_async_client(): + # Firestore emulator must be running if using locally. + project = "firedantic-test" + prefix = "firedantic-test-" + creds = Mock(spec=google.auth.credentials.Credentials) + + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = AsyncClient( + project=project, + credentials=creds, + ) + else: + raise "Firestore emulator must be running" + + config = Configuration() + config.add(prefix=prefix, project=project, credentials=creds) + + assert config["(default)"].project == project + assert config["(default)"].prefix == prefix + asyncClient = config.get_async_client() + assert isinstance(asyncClient, AsyncClient) + + +def test_configure_multiple_clients(): + config = Configuration() + mock_creds = Mock(spec=google.auth.credentials.Credentials) + + # name = (default) + config.add( + prefix="firedantic-test-", project="firedantic-test", credentials=mock_creds + ) + + # name = billing + config.add( + name="billing", + prefix="test-billing-", + project="test-billing", + credentials=mock_creds, + ) + + # assert that all 4 clients are unique + client = config.get_client() + assert isinstance(client, Client) + + asyncClient = config.get_async_client() + assert isinstance(asyncClient, AsyncClient) + + billingClient = config.get_client("billing") + assert isinstance(client, Client) + + billingAsyncClient = config.get_async_client("billing") + assert isinstance(asyncClient, AsyncClient) + + assert client != asyncClient != billingClient != billingAsyncClient + + # assert that CONFIGURATIONS holds two different client configs + assert config["(default)"].prefix == "firedantic-test-" + assert config["(default)"].project == "firedantic-test" + assert config["(default)"].credentials == mock_creds + + assert config["billing"].prefix == "test-billing-" + assert config["billing"].project == "test-billing" + assert config["billing"].credentials == mock_creds + + +### Test admin client classes ### + + +# helper fake admin client classes to capture construction arguments +class FakeAdminClient: + def __init__(self, *args, **kwargs): + self._constructed_args = args + self._constructed_kwargs = kwargs + + +class FakeAsyncAdminClient: + def __init__(self, *args, **kwargs): + self._constructed_args = args + self._constructed_kwargs = kwargs + + +def test_get_admin_client_lazy_creation(monkeypatch): + """ + get_admin_client() should create a FirestoreAdminClient lazily and cache it. + """ + cfg = Configuration() + + # monkeypatch the admin client class used inside the module + monkeypatch.setattr( + "firedantic.configurations.FirestoreAdminClient", FakeAdminClient + ) + + # add a config that does not supply admin_client upfront + cfg.add(name="x", project="proj-a", database="(default)", prefix="p-") + + # initially there should be no admin client stored on the ConfigItem + item = cfg.get_config("x") + assert item.admin_client is None + + # first call should create it + admin = cfg.get_admin_client("x") + assert isinstance(admin, FakeAdminClient) + assert cfg.get_config("x").admin_client is admin # cached + + # subsequent call returns same instance (cached) + admin2 = cfg.get_admin_client("x") + assert admin2 is admin + + +def test_get_async_admin_client_lazy_creation(monkeypatch): + """ + get_async_admin_client() should create an async admin client lazily and cache it. + """ + cfg = Configuration() + + monkeypatch.setattr( + "firedantic.configurations.FirestoreAdminAsyncClient", FakeAsyncAdminClient + ) + + cfg.add(name="y", project="proj-b", database="billing", prefix="pb-") + + item = cfg.get_config("y") + assert item.async_admin_client is None + + async_admin = cfg.get_async_admin_client("y") + assert isinstance(async_admin, FakeAsyncAdminClient) + assert cfg.get_config("y").async_admin_client is async_admin + + # cached + assert cfg.get_async_admin_client("y") is async_admin + + +def test_preserve_prebuilt_admin_client(): + """ + If the user supplies an admin_client in add(), the registry should use it verbatim. + """ + cfg = Configuration() + + # create a prebuilt fake admin client instance + prebuilt = FakeAdminClient() + cfg.add( + name="prebuilt", + project="proj-pre", + database="(default)", + prefix="pp-", + admin_client=prebuilt, + ) + + # get_admin_client should return the exact object we passed + got = cfg.get_admin_client("prebuilt") + assert got is prebuilt + assert cfg.get_config("prebuilt").admin_client is prebuilt + + +def test_admin_client_creation_receives_transport_and_client_options(monkeypatch): + """ + Ensure the admin-client constructor receives client_options and transport passed into add(). + """ + cfg = Configuration() + + captured = {} + + # define a fake constructor that records kwargs + def fake_constructor(*args, **kwargs): + inst = FakeAdminClient(*args, **kwargs) + captured["kwargs"] = kwargs + captured["args"] = args + return inst + + monkeypatch.setattr( + "firedantic.configurations.FirestoreAdminClient", fake_constructor + ) + + # pass some admin_transport and client_options into add() + fake_transport = object() + fake_client_options = {"api_endpoint": "test-endpoint"} + fake_client_info = Mock() + + cfg.add( + name="z", + project="proj-z", + database="(default)", + prefix="pz-", + credentials=None, + admin_transport=fake_transport, + client_options=fake_client_options, + client_info=fake_client_info, + ) + + # creating admin client should call fake_constructor with the supplied values (as kwargs) + admin = cfg.get_admin_client("z") + assert isinstance(admin, FakeAdminClient) + assert "kwargs" in captured + # check the captured kwargs include our transport + client_options and a client_info (or default) + assert captured["kwargs"].get("transport") is fake_transport + assert captured["kwargs"].get("client_options") == fake_client_options + assert "client_info" in captured["kwargs"] + + +def test_get_admin_client_unknown_config_raises(): + cfg = Configuration() + with pytest.raises(KeyError): + cfg.get_admin_client("does-not-exist") + + +def test_legacy_configure_shim_creates_config(monkeypatch): + from firedantic.configurations import CONFIGURATIONS + from firedantic.configurations import Client as RealClient # monkeypatch below + from firedantic.configurations import configuration, configure + + # fake client class to avoid network + class FakeClient: + pass + + monkeypatch.setattr("firedantic.configurations.Client", FakeClient) + + fake = FakeClient() + configure(fake, prefix="legacy-") + cfg = configuration.get_config("(default)") + assert cfg.prefix == "legacy-" + assert CONFIGURATIONS["prefix"] == "legacy-" + assert CONFIGURATIONS["db"] is fake + + +def test_legacy_configure_warns(monkeypatch): + import warnings + + from firedantic.configurations import configure + + class FakeClient: + pass + + with warnings.catch_warnings(record=True) as rec: + warnings.simplefilter("always") + configure(FakeClient(), prefix="x") + assert any(issubclass(w.category, DeprecationWarning) for w in rec) diff --git a/firedantic/tests/test_mocked_model.py b/firedantic/tests/test_mocked_model.py new file mode 100644 index 0000000..64bd1b3 --- /dev/null +++ b/firedantic/tests/test_mocked_model.py @@ -0,0 +1,154 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +from google.cloud.firestore import AsyncClient + +from firedantic import AsyncModel, configure +from firedantic.configurations import CONFIGURATIONS + + +@pytest.fixture +def MockModelClass(): + class TestModel(AsyncModel): + __collection__ = "mock_models" + name: str + age: int = 30 + + return TestModel + + +@pytest.fixture +def mock_client(): + client = AsyncMock(spec=AsyncClient) + # Reset configurations to ensure isolation + CONFIGURATIONS.clear() + configure(client, prefix="test_") + return client + + +@pytest.mark.asyncio +async def test_save_new_model(mock_client, MockModelClass): + collection_mock = MagicMock() + document_mock = MagicMock() + + mock_client.collection.return_value = collection_mock + collection_mock.document.return_value = document_mock + + document_mock.id = "generated-id" + # Ensure set/update are async + document_mock.set = AsyncMock() + document_mock.update = AsyncMock() + + model = MockModelClass(name="Alice") + await model.save() + + mock_client.collection.assert_called_with("test_mock_models") + assert collection_mock.document.called + document_mock.set.assert_called_once() + assert model.id == "generated-id" + + +@pytest.mark.asyncio +async def test_save_update_model(mock_client, MockModelClass): + collection_mock = MagicMock() + document_mock = MagicMock() + + mock_client.collection.return_value = collection_mock + collection_mock.document.return_value = document_mock + document_mock.set = AsyncMock() + document_mock.update = AsyncMock() + + model = MockModelClass(id="existing-id", name="Bob") + await model.save() + + mock_client.collection.assert_called_with("test_mock_models") + collection_mock.document.assert_called_with("existing-id") + # It might use update or set. + assert document_mock.update.called or document_mock.set.called + + +@pytest.mark.asyncio +async def test_delete_model(mock_client, MockModelClass): + collection_mock = MagicMock() + document_mock = MagicMock() + + mock_client.collection.return_value = collection_mock + collection_mock.document.return_value = document_mock + document_mock.delete = AsyncMock() + + # Verify we DO NOT call stream (which would indicate bulk delete) + collection_mock.stream = MagicMock( + side_effect=Exception("Should not stream collection for single delete!") + ) + + model = MockModelClass(id="to-delete", name="Charlie") + await model.delete() + + mock_client.collection.assert_called_with("test_mock_models") + collection_mock.document.assert_called_with("to-delete") + document_mock.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_delete_all(mock_client, MockModelClass): + collection_mock = MagicMock() + document_mock = MagicMock() + + mock_client.collection.return_value = collection_mock + + mock_doc = MagicMock() + mock_doc.reference = document_mock + document_mock.delete = AsyncMock() + + async def async_stream(): + yield mock_doc + yield mock_doc + + collection_mock.stream.side_effect = async_stream + + await MockModelClass.delete_all() + + assert collection_mock.stream.called + assert document_mock.delete.call_count == 2 + + +@pytest.mark.asyncio +async def test_find(mock_client, MockModelClass): + collection_mock = MagicMock() + + mock_client.collection.return_value = collection_mock + + # Mocking the query chain: .where().where().stream() + # If find() calls where(), we need to return the request (or new query object) to allow chaining. + # The final object calls stream(). + + # We can use a recursive magic mock for chaining or set return_value to self. + query_mock = MagicMock() + collection_mock.where.return_value = query_mock + query_mock.where.return_value = query_mock # For multiple wheres + query_mock.order_by.return_value = query_mock + query_mock.limit.return_value = query_mock + query_mock.offset.return_value = query_mock + + # Mock stream to return an async generator + mock_snapshot = MagicMock() + mock_snapshot.id = "doc-1" + mock_snapshot.to_dict.return_value = {"name": "Dave", "age": 40} + + async def async_stream(): + yield mock_snapshot + + query_mock.stream.return_value = async_stream() + + # In case it calls stream() directly on collection (no filters) + collection_mock.stream.side_effect = async_stream + + # Case 1: find({}) -> no filters, might call stream on collection mock directly or via "select" etc. + # Assuming firedantic calls collection(...).stream() for empty query, + # or collection(...).where(..).stream() + + results = await MockModelClass.find({"name": "Dave"}) + + assert len(results) == 1 + assert results[0].name == "Dave" + assert results[0].id == "doc-1" diff --git a/poetry.lock b/poetry.lock index 3b6922f..9127c69 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. [[package]] name = "annotated-types" @@ -126,6 +126,114 @@ files = [ {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] +[[package]] +name = "coverage" +version = "7.13.0" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "coverage-7.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:02d9fb9eccd48f6843c98a37bd6817462f130b86da8660461e8f5e54d4c06070"}, + {file = "coverage-7.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:367449cf07d33dc216c083f2036bb7d976c6e4903ab31be400ad74ad9f85ce98"}, + {file = "coverage-7.13.0-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:cdb3c9f8fef0a954c632f64328a3935988d33a6604ce4bf67ec3e39670f12ae5"}, + {file = "coverage-7.13.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d10fd186aac2316f9bbb46ef91977f9d394ded67050ad6d84d94ed6ea2e8e54e"}, + {file = "coverage-7.13.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7f88ae3e69df2ab62fb0bc5219a597cb890ba5c438190ffa87490b315190bb33"}, + {file = "coverage-7.13.0-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c4be718e51e86f553bcf515305a158a1cd180d23b72f07ae76d6017c3cc5d791"}, + {file = "coverage-7.13.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a00d3a393207ae12f7c49bb1c113190883b500f48979abb118d8b72b8c95c032"}, + {file = "coverage-7.13.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a7b1cd820e1b6116f92c6128f1188e7afe421c7e1b35fa9836b11444e53ebd9"}, + {file = "coverage-7.13.0-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:37eee4e552a65866f15dedd917d5e5f3d59805994260720821e2c1b51ac3248f"}, + {file = "coverage-7.13.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:62d7c4f13102148c78d7353c6052af6d899a7f6df66a32bddcc0c0eb7c5326f8"}, + {file = "coverage-7.13.0-cp310-cp310-win32.whl", hash = "sha256:24e4e56304fdb56f96f80eabf840eab043b3afea9348b88be680ec5986780a0f"}, + {file = "coverage-7.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:74c136e4093627cf04b26a35dab8cbfc9b37c647f0502fc313376e11726ba303"}, + {file = "coverage-7.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0dfa3855031070058add1a59fdfda0192fd3e8f97e7c81de0596c145dea51820"}, + {file = "coverage-7.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fdb6f54f38e334db97f72fa0c701e66d8479af0bc3f9bfb5b90f1c30f54500f"}, + {file = "coverage-7.13.0-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:7e442c013447d1d8d195be62852270b78b6e255b79b8675bad8479641e21fd96"}, + {file = "coverage-7.13.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1ed5630d946859de835a85e9a43b721123a8a44ec26e2830b296d478c7fd4259"}, + {file = "coverage-7.13.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7f15a931a668e58087bc39d05d2b4bf4b14ff2875b49c994bbdb1c2217a8daeb"}, + {file = "coverage-7.13.0-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:30a3a201a127ea57f7e14ba43c93c9c4be8b7d17a26e03bb49e6966d019eede9"}, + {file = "coverage-7.13.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7a485ff48fbd231efa32d58f479befce52dcb6bfb2a88bb7bf9a0b89b1bc8030"}, + {file = "coverage-7.13.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:22486cdafba4f9e471c816a2a5745337742a617fef68e890d8baf9f3036d7833"}, + {file = "coverage-7.13.0-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:263c3dbccc78e2e331e59e90115941b5f53e85cfcc6b3b2fbff1fd4e3d2c6ea8"}, + {file = "coverage-7.13.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e5330fa0cc1f5c3c4c3bb8e101b742025933e7848989370a1d4c8c5e401ea753"}, + {file = "coverage-7.13.0-cp311-cp311-win32.whl", hash = "sha256:0f4872f5d6c54419c94c25dd6ae1d015deeb337d06e448cd890a1e89a8ee7f3b"}, + {file = "coverage-7.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51a202e0f80f241ccb68e3e26e19ab5b3bf0f813314f2c967642f13ebcf1ddfe"}, + {file = "coverage-7.13.0-cp311-cp311-win_arm64.whl", hash = "sha256:d2a9d7f1c11487b1c69367ab3ac2d81b9b3721f097aa409a3191c3e90f8f3dd7"}, + {file = "coverage-7.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:0b3d67d31383c4c68e19a88e28fc4c2e29517580f1b0ebec4a069d502ce1e0bf"}, + {file = "coverage-7.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:581f086833d24a22c89ae0fe2142cfaa1c92c930adf637ddf122d55083fb5a0f"}, + {file = "coverage-7.13.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0a3a30f0e257df382f5f9534d4ce3d4cf06eafaf5192beb1a7bd066cb10e78fb"}, + {file = "coverage-7.13.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:583221913fbc8f53b88c42e8dbb8fca1d0f2e597cb190ce45916662b8b9d9621"}, + {file = "coverage-7.13.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5f5d9bd30756fff3e7216491a0d6d520c448d5124d3d8e8f56446d6412499e74"}, + {file = "coverage-7.13.0-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a23e5a1f8b982d56fa64f8e442e037f6ce29322f1f9e6c2344cd9e9f4407ee57"}, + {file = "coverage-7.13.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9b01c22bc74a7fb44066aaf765224c0d933ddf1f5047d6cdfe4795504a4493f8"}, + {file = "coverage-7.13.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:898cce66d0836973f48dda4e3514d863d70142bdf6dfab932b9b6a90ea5b222d"}, + {file = "coverage-7.13.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:3ab483ea0e251b5790c2aac03acde31bff0c736bf8a86829b89382b407cd1c3b"}, + {file = "coverage-7.13.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1d84e91521c5e4cb6602fe11ece3e1de03b2760e14ae4fcf1a4b56fa3c801fcd"}, + {file = "coverage-7.13.0-cp312-cp312-win32.whl", hash = "sha256:193c3887285eec1dbdb3f2bd7fbc351d570ca9c02ca756c3afbc71b3c98af6ef"}, + {file = "coverage-7.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:4f3e223b2b2db5e0db0c2b97286aba0036ca000f06aca9b12112eaa9af3d92ae"}, + {file = "coverage-7.13.0-cp312-cp312-win_arm64.whl", hash = "sha256:086cede306d96202e15a4b77ace8472e39d9f4e5f9fd92dd4fecdfb2313b2080"}, + {file = "coverage-7.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:28ee1c96109974af104028a8ef57cec21447d42d0e937c0275329272e370ebcf"}, + {file = "coverage-7.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d1e97353dcc5587b85986cda4ff3ec98081d7e84dd95e8b2a6d59820f0545f8a"}, + {file = "coverage-7.13.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:99acd4dfdfeb58e1937629eb1ab6ab0899b131f183ee5f23e0b5da5cba2fec74"}, + {file = "coverage-7.13.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:ff45e0cd8451e293b63ced93161e189780baf444119391b3e7d25315060368a6"}, + {file = "coverage-7.13.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f4f72a85316d8e13234cafe0a9f81b40418ad7a082792fa4165bd7d45d96066b"}, + {file = "coverage-7.13.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:11c21557d0e0a5a38632cbbaca5f008723b26a89d70db6315523df6df77d6232"}, + {file = "coverage-7.13.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:76541dc8d53715fb4f7a3a06b34b0dc6846e3c69bc6204c55653a85dd6220971"}, + {file = "coverage-7.13.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:6e9e451dee940a86789134b6b0ffbe31c454ade3b849bb8a9d2cca2541a8e91d"}, + {file = "coverage-7.13.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:5c67dace46f361125e6b9cace8fe0b729ed8479f47e70c89b838d319375c8137"}, + {file = "coverage-7.13.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f59883c643cb19630500f57016f76cfdcd6845ca8c5b5ea1f6e17f74c8e5f511"}, + {file = "coverage-7.13.0-cp313-cp313-win32.whl", hash = "sha256:58632b187be6f0be500f553be41e277712baa278147ecb7559983c6d9faf7ae1"}, + {file = "coverage-7.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:73419b89f812f498aca53f757dd834919b48ce4799f9d5cad33ca0ae442bdb1a"}, + {file = "coverage-7.13.0-cp313-cp313-win_arm64.whl", hash = "sha256:eb76670874fdd6091eedcc856128ee48c41a9bbbb9c3f1c7c3cf169290e3ffd6"}, + {file = "coverage-7.13.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:6e63ccc6e0ad8986386461c3c4b737540f20426e7ec932f42e030320896c311a"}, + {file = "coverage-7.13.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:494f5459ffa1bd45e18558cd98710c36c0b8fbfa82a5eabcbe671d80ecffbfe8"}, + {file = "coverage-7.13.0-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:06cac81bf10f74034e055e903f5f946e3e26fc51c09fc9f584e4a1605d977053"}, + {file = "coverage-7.13.0-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f2ffc92b46ed6e6760f1d47a71e56b5664781bc68986dbd1836b2b70c0ce2071"}, + {file = "coverage-7.13.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0602f701057c6823e5db1b74530ce85f17c3c5be5c85fc042ac939cbd909426e"}, + {file = "coverage-7.13.0-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:25dc33618d45456ccb1d37bce44bc78cf269909aa14c4db2e03d63146a8a1493"}, + {file = "coverage-7.13.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:71936a8b3b977ddd0b694c28c6a34f4fff2e9dd201969a4ff5d5fc7742d614b0"}, + {file = "coverage-7.13.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:936bc20503ce24770c71938d1369461f0c5320830800933bc3956e2a4ded930e"}, + {file = "coverage-7.13.0-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:af0a583efaacc52ae2521f8d7910aff65cdb093091d76291ac5820d5e947fc1c"}, + {file = "coverage-7.13.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f1c23e24a7000da892a312fb17e33c5f94f8b001de44b7cf8ba2e36fbd15859e"}, + {file = "coverage-7.13.0-cp313-cp313t-win32.whl", hash = "sha256:5f8a0297355e652001015e93be345ee54393e45dc3050af4a0475c5a2b767d46"}, + {file = "coverage-7.13.0-cp313-cp313t-win_amd64.whl", hash = "sha256:6abb3a4c52f05e08460bd9acf04fec027f8718ecaa0d09c40ffbc3fbd70ecc39"}, + {file = "coverage-7.13.0-cp313-cp313t-win_arm64.whl", hash = "sha256:3ad968d1e3aa6ce5be295ab5fe3ae1bf5bb4769d0f98a80a0252d543a2ef2e9e"}, + {file = "coverage-7.13.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:453b7ec753cf5e4356e14fe858064e5520c460d3bbbcb9c35e55c0d21155c256"}, + {file = "coverage-7.13.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:af827b7cbb303e1befa6c4f94fd2bf72f108089cfa0f8abab8f4ca553cf5ca5a"}, + {file = "coverage-7.13.0-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:9987a9e4f8197a1000280f7cc089e3ea2c8b3c0a64d750537809879a7b4ceaf9"}, + {file = "coverage-7.13.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3188936845cd0cb114fa6a51842a304cdbac2958145d03be2377ec41eb285d19"}, + {file = "coverage-7.13.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a2bdb3babb74079f021696cb46b8bb5f5661165c385d3a238712b031a12355be"}, + {file = "coverage-7.13.0-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7464663eaca6adba4175f6c19354feea61ebbdd735563a03d1e472c7072d27bb"}, + {file = "coverage-7.13.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8069e831f205d2ff1f3d355e82f511eb7c5522d7d413f5db5756b772ec8697f8"}, + {file = "coverage-7.13.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:6fb2d5d272341565f08e962cce14cdf843a08ac43bd621783527adb06b089c4b"}, + {file = "coverage-7.13.0-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:5e70f92ef89bac1ac8a99b3324923b4749f008fdbd7aa9cb35e01d7a284a04f9"}, + {file = "coverage-7.13.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:4b5de7d4583e60d5fd246dd57fcd3a8aa23c6e118a8c72b38adf666ba8e7e927"}, + {file = "coverage-7.13.0-cp314-cp314-win32.whl", hash = "sha256:a6c6e16b663be828a8f0b6c5027d36471d4a9f90d28444aa4ced4d48d7d6ae8f"}, + {file = "coverage-7.13.0-cp314-cp314-win_amd64.whl", hash = "sha256:0900872f2fdb3ee5646b557918d02279dc3af3dfb39029ac4e945458b13f73bc"}, + {file = "coverage-7.13.0-cp314-cp314-win_arm64.whl", hash = "sha256:3a10260e6a152e5f03f26db4a407c4c62d3830b9af9b7c0450b183615f05d43b"}, + {file = "coverage-7.13.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:9097818b6cc1cfb5f174e3263eba4a62a17683bcfe5c4b5d07f4c97fa51fbf28"}, + {file = "coverage-7.13.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:0018f73dfb4301a89292c73be6ba5f58722ff79f51593352759c1790ded1cabe"}, + {file = "coverage-7.13.0-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:166ad2a22ee770f5656e1257703139d3533b4a0b6909af67c6b4a3adc1c98657"}, + {file = "coverage-7.13.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f6aaef16d65d1787280943f1c8718dc32e9cf141014e4634d64446702d26e0ff"}, + {file = "coverage-7.13.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e999e2dcc094002d6e2c7bbc1fb85b58ba4f465a760a8014d97619330cdbbbf3"}, + {file = "coverage-7.13.0-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:00c3d22cf6fb1cf3bf662aaaa4e563be8243a5ed2630339069799835a9cc7f9b"}, + {file = "coverage-7.13.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:22ccfe8d9bb0d6134892cbe1262493a8c70d736b9df930f3f3afae0fe3ac924d"}, + {file = "coverage-7.13.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:9372dff5ea15930fea0445eaf37bbbafbc771a49e70c0aeed8b4e2c2614cc00e"}, + {file = "coverage-7.13.0-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:69ac2c492918c2461bc6ace42d0479638e60719f2a4ef3f0815fa2df88e9f940"}, + {file = "coverage-7.13.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:739c6c051a7540608d097b8e13c76cfa85263ced467168dc6b477bae3df7d0e2"}, + {file = "coverage-7.13.0-cp314-cp314t-win32.whl", hash = "sha256:fe81055d8c6c9de76d60c94ddea73c290b416e061d40d542b24a5871bad498b7"}, + {file = "coverage-7.13.0-cp314-cp314t-win_amd64.whl", hash = "sha256:445badb539005283825959ac9fa4a28f712c214b65af3a2c464f1adc90f5fcbc"}, + {file = "coverage-7.13.0-cp314-cp314t-win_arm64.whl", hash = "sha256:de7f6748b890708578fc4b7bb967d810aeb6fcc9bff4bb77dbca77dab2f9df6a"}, + {file = "coverage-7.13.0-py3-none-any.whl", hash = "sha256:850d2998f380b1e266459ca5b47bc9e7daf9af1d070f66317972f382d46f1904"}, + {file = "coverage-7.13.0.tar.gz", hash = "sha256:a394aa27f2d7ff9bc04cf703817773a59ad6dfbd577032e690f961d2460ee936"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli ; python_full_version <= \"3.11.0a6\""] + [[package]] name = "exceptiongroup" version = "1.1.3" @@ -851,6 +959,26 @@ pytest = ">=8.2,<9" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] +[[package]] +name = "pytest-cov" +version = "7.0.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861"}, + {file = "pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1"}, +] + +[package.dependencies] +coverage = {version = ">=7.10.6", extras = ["toml"]} +pluggy = ">=1.2" +pytest = ">=7" + +[package.extras] +testing = ["process-tests", "pytest-xdist", "virtualenv"] + [[package]] name = "requests" version = "2.32.5" @@ -992,4 +1120,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.1" python-versions = "^3.10" -content-hash = "12593d2241ebd57ce8ddc9c696596f1d4f434b620275149c63d7dcabb5601f6f" +content-hash = "1effa7f4c6ddc8981fe8ee168e3aa6514ddc1777f6701d7f8b96fc88c9f96f58" diff --git a/pyproject.toml b/pyproject.toml index a1fb56e..fc4ef3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ pytest = "^8.3.4" pytest-asyncio = "^0.25.3" black = "^25.1.0" watchdog = "^6.0.0" +pytest-cov = "^7.0.0" [build-system] requires = ["poetry-core>=1.0.0"] From bb458a191518fd5867957d5a697966e2afb89261 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Thu, 18 Dec 2025 16:50:49 -0300 Subject: [PATCH 37/38] couple minor edits --- firedantic/_async/model.py | 2 -- firedantic/tests/test_clients.py | 4 ++-- poetry.lock | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index be48d6d..11ab7f6 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -466,8 +466,6 @@ def _get_doc_ref( :raise DocumentIDError: If the ID is not valid. """ - # _get_col_ref takes collection_name, not config_name. - # Any specific config usage should likely be handled by context or passed properly if supported. return self._get_col_ref().document(self.get_document_id()) # type: ignore @staticmethod diff --git a/firedantic/tests/test_clients.py b/firedantic/tests/test_clients.py index e5ac91b..197c980 100644 --- a/firedantic/tests/test_clients.py +++ b/firedantic/tests/test_clients.py @@ -71,7 +71,7 @@ def test_get_async_client_lazy_creation_and_caching(monkeypatch): def test_preserve_prebuilt_clients(): """ - If user supplies a prebuilt client / async_client to add(), the registry uses exactly them. + If user supplies a prebuilt client / async_client to add(), ensure it uses those. """ cfg = Configuration() @@ -379,7 +379,7 @@ def fake_constructor(*args, **kwargs): admin = cfg.get_admin_client("z") assert isinstance(admin, FakeAdminClient) assert "kwargs" in captured - # check the captured kwargs include our transport + client_options and a client_info (or default) + # check the captured kwargs including transport + client_options and a client_info (or default) assert captured["kwargs"].get("transport") is fake_transport assert captured["kwargs"].get("client_options") == fake_client_options assert "client_info" in captured["kwargs"] diff --git a/poetry.lock b/poetry.lock index 9127c69..50d55d3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "annotated-types" From 08078fccc102221911c3c09ae455627dbb49ee34 Mon Sep 17 00:00:00 2001 From: Marissa Fisher Date: Tue, 6 Jan 2026 19:54:47 -0300 Subject: [PATCH 38/38] fixing up way to run integ tests smoothly --- README.md | 92 +++++++++++++++---- integration_tests/README.md | 64 ------------- .../configure_firestore_db_clients.py | 3 + integration_tests/full_async_flow.py | 3 + ...gration_all.py => full_readme_examples.py} | 3 + integration_tests/full_sync_flow.py | 3 + tasks.py | 15 +++ 7 files changed, 103 insertions(+), 80 deletions(-) delete mode 100644 integration_tests/README.md rename integration_tests/{test_integration_all.py => full_readme_examples.py} (99%) diff --git a/README.md b/README.md index e08a2b7..d7a16b8 100644 --- a/README.md +++ b/README.md @@ -728,22 +728,6 @@ Run the Firestore emulator with a predictable port: start_emulator ``` -### Running Tests - -To run tests locally, you should first: - -```bash -poetry install -poetry run invoke test -``` - -\*Note, when new functions, comments, variables etc. are added that will span across -sync and async directories, be sure to first declare the replacements in `unasync.py` -before running `poetry run invoke test`. I.e. indicating a replacement of 'async_client' -with 'client' text across both directories. - -\*Note, the emulator must be set and running for all tests to pass. - ### About sync and async versions of library Although this library provides both sync and async versions of models, please keep in @@ -776,6 +760,82 @@ information about the release in [CHANGELOG.md](CHANGELOG.md): poetry run invoke make-changelog ``` + +### Running Tests + +To run tests locally, you should first: + +```bash +poetry install +poetry run invoke test +``` + +\*Note, when new functions, comments, variables etc. are added that will span across +sync and async directories, be sure to first declare the replacements in `unasync.py` +before running `poetry run invoke test`. I.e. indicating a replacement of 'async_client' +with 'client' text across both directories. + +\*Note, the emulator must be set and running for all tests to pass. + +### Running Integration Tests + + +#### Environment and configuration + +- Ensure firestore emulator is running in another terminal window: + - `./start_emulator.sh` + +#### Files and purpose (replace placeholders with real filenames) + +- `integration_tests/configure_firestore_db_clients.py` — Purpose: shows how to create + and connect to various db clients. +- `integration_tests/full_sync_flow.py` — Purpose: configures clients, saves data to db, + finds the data, and deletes all data in a sync fashion. +- `integration_tests/full_async_flow.py` — Purpose: configures async clients, saves data + to db, finds the data, and deletes all data in an async fashion. +- `integration_tests/full_readme_examples.py` - Purpose: intended to run all examples + shown in the readme, new and legacy. + +#### How to run + +Run each individual test file: + +- `poetry run python integration_tests/configure_firestore_db_clients.py` +- `poetry run python integration_tests/full_sync_flow.py` +- `poetry run python integration_tests/full_async_flow.py` +- `poetry run python integration_tests/full_readme_examples.py` + +or run all: + +```bash +poetry run invoke integration +``` + +#### What to expect + +- For the configure_firestore_db_clients test, you should expect to see the following: + `All configure_firestore_db_client tests passed!` + +- For the `full_sync_flow` and `full_async_flow`, you should expect to see the following + output: \ + You can readily play around with the models to update the data as desired. + +``` +Number of company owners with first name: 'Bill': 1 + +Number of companies with id: '1234567-7': 1 + +Number of company owners with first name: 'John': 1 + +Number of companies with id: '1234567-8a': 1 + +Number of company owners with first name: 'Alice': 1 + +Number of billing companies with id: '1234567-8c': 1 + +Number of billing accounts with billing_id: 801048: 1 +``` + ## License This code is released under the BSD 3-Clause license. Details in the diff --git a/integration_tests/README.md b/integration_tests/README.md deleted file mode 100644 index dfb4bc1..0000000 --- a/integration_tests/README.md +++ /dev/null @@ -1,64 +0,0 @@ -# integration_tests/README.md - -## Overview - -This folder contains integration test files. This README explains how to run them and -what each file is for. - -## Prerequisites - -- Python 3.13 -- Run: - -```bash -python -m venv firedantic -source firedantic/bin/activate -cd firedantic/integration_tests -``` - -## Files and purpose (replace placeholders with real filenames) - -- `integration_tests/configure_firestore_db_clients.py` — Purpose: shows how to create - and connect to various db clients. -- `integration_tests/full_sync_flow.py` — Purpose: configures clients, saves data to db, - finds the data, and deletes all data in a sync fashion. -- `integration_tests/full_async_flow.py` — Purpose: configures async clients, saves data - to db, finds the data, and deletes all data in an async fashion. -- `integration_tests/test_integration_all.py` - Purpose: intended to run all examples - shown in the readme, new and legacy. - -## How to run - -Run each individual test file: - `python configure_firestore_db_clients.py` - -`python full_sync_flow.py` - `python full_async_flow.py` - -`python test_integration_all.py` - -## Environment and configuration - -- Ensure firestore emulator is running in another terminal window: - - `./start_emulator.sh` - -## What to expect: - -- For the configure_firestore_db_clients test, you should expect to see the following: - `All configure_firestore_db_client tests passed!` - -- For the `full_sync_flow` and `full_async_flow`, you should expect to see the following - output: \ - You can readily play around with the models to update the data as desired. - -``` -Number of company owners with first name: 'Bill': 1 - -Number of companies with id: '1234567-7': 1 - -Number of company owners with first name: 'John': 1 - -Number of companies with id: '1234567-8a': 1 - -Number of company owners with first name: 'Alice': 1 - -Number of billing companies with id: '1234567-8c': 1 - -Number of billing accounts with billing_id: 801048: 1 -``` diff --git a/integration_tests/configure_firestore_db_clients.py b/integration_tests/configure_firestore_db_clients.py index 50ee349..216b379 100644 --- a/integration_tests/configure_firestore_db_clients.py +++ b/integration_tests/configure_firestore_db_clients.py @@ -1,6 +1,9 @@ from os import environ from unittest.mock import Mock +if not environ.get("FIRESTORE_EMULATOR_HOST"): + environ["FIRESTORE_EMULATOR_HOST"] = "127.0.0.1:8686" + import google.auth.credentials from google.cloud.firestore import AsyncClient, Client diff --git a/integration_tests/full_async_flow.py b/integration_tests/full_async_flow.py index 8618897..76f4b18 100644 --- a/integration_tests/full_async_flow.py +++ b/integration_tests/full_async_flow.py @@ -2,6 +2,9 @@ from os import environ from unittest.mock import Mock +if not environ.get("FIRESTORE_EMULATOR_HOST"): + environ["FIRESTORE_EMULATOR_HOST"] = "127.0.0.1:8686" + import google.auth.credentials from firedantic import AsyncModel diff --git a/integration_tests/test_integration_all.py b/integration_tests/full_readme_examples.py similarity index 99% rename from integration_tests/test_integration_all.py rename to integration_tests/full_readme_examples.py index 2ecae47..033ac8b 100644 --- a/integration_tests/test_integration_all.py +++ b/integration_tests/full_readme_examples.py @@ -26,6 +26,9 @@ import google.auth.credentials +if not os.environ.get("FIRESTORE_EMULATOR_HOST"): + os.environ["FIRESTORE_EMULATOR_HOST"] = "127.0.0.1:8686" + # require firedantic package (adjust path or install locally) try: # recommended imports diff --git a/integration_tests/full_sync_flow.py b/integration_tests/full_sync_flow.py index b62ab1f..16dd6fa 100644 --- a/integration_tests/full_sync_flow.py +++ b/integration_tests/full_sync_flow.py @@ -1,6 +1,9 @@ from os import environ from unittest.mock import Mock +if not environ.get("FIRESTORE_EMULATOR_HOST"): + environ["FIRESTORE_EMULATOR_HOST"] = "127.0.0.1:8686" + import google.auth.credentials from pydantic import BaseModel diff --git a/tasks.py b/tasks.py index 5888fa9..0a1777d 100644 --- a/tasks.py +++ b/tasks.py @@ -158,3 +158,18 @@ def make_changelog(ctx): changelog_path.write_text(new_changelog) print(f"{changelog_path} was updated, please fill in release information") + + +@task +def integration(ctx): + """ + Run all integration tests + """ + files = [ + "integration_tests/configure_firestore_db_clients.py", + "integration_tests/full_sync_flow.py", + "integration_tests/full_async_flow.py", + "integration_tests/full_readme_examples.py", + ] + for f in files: + run_test_cmd(ctx, f"python {f}", env=DEV_ENV)