diff --git a/.gitignore b/.gitignore index c02de79..f013835 100644 --- a/.gitignore +++ b/.gitignore @@ -83,6 +83,8 @@ ipython_config.py # pyenv .python-version +pyvenv.cfg +/bin # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. @@ -135,3 +137,4 @@ ui-debug.log .idea .DS_Store +.vscode/ diff --git a/CHANGELOG.md b/CHANGELOG.md index aa177d0..8d50958 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## [Unreleased] +## [0.13.0] - 2025-12-12 + +### Added + +- Added support for multiple firestore clients. + ## [0.12.0] - 2025-11-20 ### Changed diff --git a/README.md b/README.md index e80d025..d7a16b8 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,49 @@ The package is available on PyPI: pip install firedantic ``` +## Quick overview + +Firedantic provides simple Pydantic-based models for Firestore, with both sync and async +model classes, helpers for composite indexes and TTL policies, and a new configuration +system that supports multiple named Firestore connections. + ## Usage +### Migration Guide: Legacy `configure()` → New `Configuration` + +We introduced a new, more flexible `Configuration` registry to support multiple +Firestore clients, lazy client creation, and admin clients. The legacy `configure()` +helper is still supported for backwards compatibility, but it is now considered +**deprecated**. Scroll below for legacy instructions. + +### New (Recommended) Usage + +```python +from firedantic.configurations import configuration + +# default config +configuration.add(prefix="app-", project="my-project") + +# extra config +configuration.add(name="billing", prefix="billing-", project="billing-project") + +# get clients +client = configuration.get_client() # sync client for "(default)" +async_client = configuration.get_async_client() # async client for "(default)" +billing_client = configuration.get_client("billing") # sync client for "billing" +``` + +Notes: + +- `configuration.add(...)` accepts either client/async_client (pre-built) or + project+credentials and will lazily create clients. +- Models can declare `__db_config__ = "custom-name"` to use a named configuration or + omit it to use the `"(default)"` config. +- Backwards-compatible helpers (`configure()`, `CONFIGURATIONS`) will still populate the + old surface but are deprecated. + +### Old (Legacy – Still Works, Deprecated) Usage + In your application you will need to configure the firestore db client and optionally the collection prefix, which by default is empty. @@ -41,21 +82,32 @@ else: configure(client, prefix="firedantic-test-") ``` -Once that is done, you can start defining your Pydantic models, e.g: +You may also still use: ```python -from pydantic import BaseModel +from firedantic.configurations import CONFIGURATIONS + +db = CONFIGURATIONS["db"] +prefix = CONFIGURATIONS["prefix"] +``` + +### Defining Models + +Once that is done, you can start defining your Pydantic models. Models are Pydantic +classes that extend Firedantic’s sync Model or async AsyncModel: + +#### Sync Model Example +```python +from pydantic import BaseModel from firedantic import Model class Owner(BaseModel): - """Dummy owner Pydantic model.""" first_name: str last_name: str class Company(Model): - """Dummy company Firedantic model.""" __collection__ = "companies" company_id: str owner: Owner @@ -65,7 +117,7 @@ owner = Owner(first_name="John", last_name="Doe") company = Company(company_id="1234567-8", owner=owner) company.save() -# Prints out the firestore ID of the Company model +# Access the company id print(company.id) # Reloads model data from the database @@ -97,44 +149,22 @@ Product.find(order_by=[('unit_value', Query.ASCENDING), ('stock', Query.DESCENDI The query operators are found at [https://firebase.google.com/docs/firestore/query-data/queries#query_operators](https://firebase.google.com/docs/firestore/query-data/queries#query_operators). -### Async usage +### Async Usage Firedantic can also be used in an async way, like this: -```python -import asyncio -from os import environ -from unittest.mock import Mock - -import google.auth.credentials -from google.cloud.firestore import AsyncClient - -from firedantic import AsyncModel, configure - -# Firestore emulator must be running if using locally. -if environ.get("FIRESTORE_EMULATOR_HOST"): - client = AsyncClient( - project="firedantic-test", - credentials=Mock(spec=google.auth.credentials.Credentials), - ) -else: - client = AsyncClient() - -configure(client, prefix="firedantic-test-") - +#### Async Model Example +```python class Person(AsyncModel): __collection__ = "persons" name: str - async def main(): alice = Person(name="Alice") await alice.save() - print(f"Saved Alice as {alice.id}") bob = Person(name="Bob") await bob.save() - print(f"Saved Bob as {bob.id}") found_alice = await Person.find_one({"name": "Alice"}) print(f"Found Alice: {found_alice.id}") @@ -145,10 +175,7 @@ async def main(): print(f"Found Bob: {found_bob.id}") await alice.delete() - print("Deleted Alice") await bob.delete() - print("Deleted Bob") - if __name__ == "__main__": asyncio.run(main()) @@ -164,10 +191,8 @@ reference to using the `model_for()` method. ```python from typing import Optional, Type - from firedantic import AsyncModel, AsyncSubCollection, AsyncSubModel, ModelNotFoundError - class UserStats(AsyncSubModel): id: Optional[str] = None purchases: int = 0 @@ -176,12 +201,10 @@ class UserStats(AsyncSubModel): # Can use any properties of the "parent" model __collection_tpl__ = "users/{id}/stats" - class User(AsyncModel): __collection__ = "users" name: str - async def get_user_purchases(user_id: str, period: str = "2021") -> int: user = await User.get_by_id(user_id) stats_model: Type[UserStats] = UserStats.model_for(user) @@ -195,10 +218,149 @@ async def get_user_purchases(user_id: str, period: str = "2021") -> int: ## Composite Indexes and TTL Policies -Firedantic has support for defining composite indexes and TTL policies as well as -creating them. +Firedantic supports defining and automatically creating Composite Indexes and TTL +Policies for your Firestore models. These can be created using either: + +- The new `Configuration` class (recommended) +- The legacy `configure()` method (deprecated but still supported) + +### Defining Composite Indexes + +Composite indexes are defined on your model using the `__composite_indexes__` attribute. + +It must be a list of composite index definitions created using: + +- `collection_index(...)` – for single-collection queries +- `collection_group_index(...)` – for collection group queries + +Each index definition takes an arbitrary number of (field_name, order) tuples. Order +must be `Query.ASCENDING` or `Query.DESCENDING`. + +#### Example: + +```python +__composite_indexes__ = [ + collection_index(("content", Query.ASCENDING), ("expire", Query.DESCENDING)), + collection_group_index(("content", Query.DESCENDING), ("expire", Query.ASCENDING)), +] +``` + +### Defining TTL Policies + +TTL (Time-To-Live) policies are defined using the `__ttl_field__` attribute. + +Rules: + +- The field must be a datetime object +- The field name must be assigned to `__ttl_field__` +- TTL policies cannot be created in the Firestore Emulator -### Composite indexes +#### Example: + +```python +__ttl_field__ = "expire" +``` + +#### Recommended Usage (New Configuration API) + +All index and TTL setup functions now automatically resolve: + +- The correct project +- The correct database +- The correct admin client …based on each model’s `__db_config__` + +This means you no longer need to pass projects, databases, or admin clients manually-- +And further, you can maintain multiple of each within your app. + +#### Sync Example (Recommended) + +```python +from datetime import datetime +from google.cloud.firestore import Query + +from firedantic import ( + Model, + collection_index, + collection_group_index, + get_all_subclasses, + set_up_composite_indexes_and_ttl_policies, +) +from firedantic.configurations import configuration + +class ExpiringModel(Model): + __collection__ = "expiringModel" + __ttl_field__ = "expire" + __composite_indexes__ = [ + collection_index(("content", Query.ASCENDING), ("expire", Query.DESCENDING)), + collection_group_index(("content", Query.DESCENDING), ("expire", Query.ASCENDING)), + ] + + expire: datetime + content: str + +def main(): + configuration.add( + name="(default)", + project="my-project", + prefix="firedantic-test-", + # credentials=... optional + ) + + set_up_composite_indexes_and_ttl_policies( + models=get_all_subclasses(Model), + ) + +if __name__ == "__main__": + main() +``` + +#### Async Example (Recommended) + +```python +import asyncio +from datetime import datetime +from google.cloud.firestore import Query + +from firedantic import ( + AsyncModel, + async_set_up_composite_indexes_and_ttl_policies, + collection_index, + collection_group_index, + get_all_subclasses, +) +from firedantic.configurations import configuration + +class ExpiringModel(AsyncModel): + __collection__ = "expiringModel" + __ttl_field__ = "expire" + __composite_indexes__ = [ + collection_index(("content", Query.ASCENDING), ("expire", Query.DESCENDING)), + collection_group_index(("content", Query.DESCENDING), ("expire", Query.ASCENDING)), + ] + + expire: datetime + content: str + +async def main(): + configuration.add( + project="my-project", + prefix="firedantic-test-", + ) + + await async_set_up_composite_indexes_and_ttl_policies( + models=get_all_subclasses(AsyncModel), + ) + +if __name__ == "__main__": + asyncio.run(main()) +``` + +#### Legacy Usage (Still Supported, Deprecated) + +The old method using `configure()` and manually passing Firestore Admin clients is still +supported but deprecated. + +##### Composite indexes Composite indexes of a collection are defined in `__composite_indexes__`, which is a list of all indexes to be created. @@ -213,7 +375,7 @@ to create indexes. For more details, see the example further down. -### TTL Policies +##### TTL Policies The field used for the TTL policy should be a datetime field and the name of the field should be defined in `__ttl_field__`. The `set_up_ttl_policies` and @@ -221,31 +383,32 @@ should be defined in `__ttl_field__`. The `set_up_ttl_policies` and Note: The TTL policies can not be set up in the Firestore emulator. -### Examples +##### Examples Below are examples (both sync and async) to show how to use Firedantic to set up -composite indexes and TTL policies. +composite indexes and TTL policies with the legacy format. The examples use `async_set_up_composite_indexes_and_ttl_policies` and `set_up_composite_indexes_and_ttl_policies` functions to set up both composite indexes and TTL policies. However, you can use separate functions to set up only either one of them. -#### Composite Index and TTL Policy Example (sync) +###### Legacy Sync Example ```python from datetime import datetime +from google.cloud.firestore import Client, Query +from google.cloud.firestore_admin_v1 import FirestoreAdminClient + from firedantic import ( + Model, collection_index, collection_group_index, configure, get_all_subclasses, - Model, set_up_composite_indexes_and_ttl_policies, ) -from google.cloud.firestore import Client, Query -from google.cloud.firestore_admin_v1 import FirestoreAdminClient class ExpiringModel(Model): @@ -267,18 +430,18 @@ def main(): models=get_all_subclasses(Model), client=FirestoreAdminClient(), ) - # or use set_up_composite_indexes / set_up_ttl_policies functions separately - if __name__ == "__main__": main() ``` -#### Composite Index and TTL Policy Example (async) +###### Legacy Async Example ```python import asyncio from datetime import datetime +from google.cloud.firestore import AsyncClient, Query +from google.cloud.firestore_admin_v1.services.firestore_admin import FirestoreAdminAsyncClient from firedantic import ( AsyncModel, @@ -288,10 +451,6 @@ from firedantic import ( configure, get_all_subclasses, ) -from google.cloud.firestore import AsyncClient, Query -from google.cloud.firestore_admin_v1.services.firestore_admin import ( - FirestoreAdminAsyncClient, -) class ExpiringModel(AsyncModel): @@ -313,8 +472,6 @@ async def main(): models=get_all_subclasses(AsyncModel), client=FirestoreAdminAsyncClient(), ) - # or await async_set_up_composite_indexes / async_set_up_ttl_policies separately - if __name__ == "__main__": asyncio.run(main()) @@ -324,7 +481,8 @@ if __name__ == "__main__": Firedantic has basic support for [Firestore Transactions](https://firebase.google.com/docs/firestore/manage-data/transactions). -The following methods can be used in a transaction: +The following methods can be used in a transaction for both **sync** and **async** +models: - `Model.delete(transaction=transaction)` - `Model.find_one(transaction=transaction)` @@ -337,8 +495,77 @@ The following methods can be used in a transaction: When using transactions, note that read operations must come before write operations. +### Recommended Usage (New Configuration Class) + +With the new `Configuration` system, transactions automatically use the configured: + +- Project +- Database +- Client based on the active configuration. + ### Transaction examples +#### Sync Transaction Example (Recommended) + +```python +from firedantic import Model, get_transaction +from firedantic.configurations import configuration +from google.cloud.firestore_v1 import transactional, Transaction + + +# Configure once +configuration.add( + project="firedantic-test", + prefix="firedantic-test-", +) + + +class City(Model): + __collection__ = "cities" + population: int + + def increment_population(self, increment: int = 1): + @transactional + def _increment_population(transaction: Transaction) -> None: + self.reload(transaction=transaction) + self.population += increment + self.save(transaction=transaction) + + t = get_transaction() + _increment_population(transaction=t) + + +def main(): + @transactional + def decrement_population( + transaction: Transaction, city: City, decrement: int = 1 + ): + city.reload(transaction=transaction) + city.population = max(0, city.population - decrement) + city.save(transaction=transaction) + + c = City(id="SF", population=1) + c.save() + + c.increment_population(increment=1) + assert c.population == 2 + + t = get_transaction() + decrement_population(transaction=t, city=c, decrement=5) + assert c.population == 0 + + +if __name__ == "__main__": + main() +``` + +### Legacy Usage (Still Supported, Deprecated) + +The original configure() method and manual client wiring still works, but is now +deprecated in favor of `Configuration`. + +#### Legacy Async Transaction Example + In this example (async and sync version of it below), we are updating a `City` to increment and decrement the population of it, both using an instance method and a standalone function. Please note that the `@async_transactional` and `@transactional` @@ -346,8 +573,6 @@ decorators always expect the first argument of the wrapped function to be `trans i.e. you can not directly wrap an instance method that has `self` as the first argument or a class method that has `cls` as the first argument. -#### Transaction example (async) - ```python import asyncio from os import environ @@ -409,7 +634,7 @@ if __name__ == "__main__": asyncio.run(main()) ``` -#### Transaction example (sync) +#### Legacy Sync Transaction Example ```python from os import environ @@ -475,12 +700,13 @@ if __name__ == "__main__": PRs are welcome! -To run tests locally, you should run: +### Prerequisites: -```bash -poetry install -poetry run invoke test -``` +- [Node 22](https://nodejs.org/en/download) +- [Pre-commit](https://pre-commit.com/#install) +- [Python 3.10+](https://wiki.python.org/moin/BeginnersGuide/Download) +- [Poetry](https://python-poetry.org/docs/#installation) +- [Google Firebase Emulator Suite](https://firebase.google.com/docs/emulator-suite/install_and_configure#install_the_local_emulator_suite) ### Running Firestore emulator @@ -534,6 +760,82 @@ information about the release in [CHANGELOG.md](CHANGELOG.md): poetry run invoke make-changelog ``` + +### Running Tests + +To run tests locally, you should first: + +```bash +poetry install +poetry run invoke test +``` + +\*Note, when new functions, comments, variables etc. are added that will span across +sync and async directories, be sure to first declare the replacements in `unasync.py` +before running `poetry run invoke test`. I.e. indicating a replacement of 'async_client' +with 'client' text across both directories. + +\*Note, the emulator must be set and running for all tests to pass. + +### Running Integration Tests + + +#### Environment and configuration + +- Ensure firestore emulator is running in another terminal window: + - `./start_emulator.sh` + +#### Files and purpose (replace placeholders with real filenames) + +- `integration_tests/configure_firestore_db_clients.py` — Purpose: shows how to create + and connect to various db clients. +- `integration_tests/full_sync_flow.py` — Purpose: configures clients, saves data to db, + finds the data, and deletes all data in a sync fashion. +- `integration_tests/full_async_flow.py` — Purpose: configures async clients, saves data + to db, finds the data, and deletes all data in an async fashion. +- `integration_tests/full_readme_examples.py` - Purpose: intended to run all examples + shown in the readme, new and legacy. + +#### How to run + +Run each individual test file: + +- `poetry run python integration_tests/configure_firestore_db_clients.py` +- `poetry run python integration_tests/full_sync_flow.py` +- `poetry run python integration_tests/full_async_flow.py` +- `poetry run python integration_tests/full_readme_examples.py` + +or run all: + +```bash +poetry run invoke integration +``` + +#### What to expect + +- For the configure_firestore_db_clients test, you should expect to see the following: + `All configure_firestore_db_client tests passed!` + +- For the `full_sync_flow` and `full_async_flow`, you should expect to see the following + output: \ + You can readily play around with the models to update the data as desired. + +``` +Number of company owners with first name: 'Bill': 1 + +Number of companies with id: '1234567-7': 1 + +Number of company owners with first name: 'John': 1 + +Number of companies with id: '1234567-8a': 1 + +Number of company owners with first name: 'Alice': 1 + +Number of billing companies with id: '1234567-8c': 1 + +Number of billing accounts with billing_id: 801048: 1 +``` + ## License This code is released under the BSD 3-Clause license. Details in the diff --git a/firedantic/__init__.py b/firedantic/__init__.py index e4f79d2..f79f867 100644 --- a/firedantic/__init__.py +++ b/firedantic/__init__.py @@ -34,6 +34,8 @@ from firedantic.common import collection_group_index, collection_index from firedantic.configurations import ( CONFIGURATIONS, + ConfigItem, + Configuration, configure, get_async_transaction, get_transaction, diff --git a/firedantic/_async/indexes.py b/firedantic/_async/indexes.py index b7d0d52..c930da4 100644 --- a/firedantic/_async/indexes.py +++ b/firedantic/_async/indexes.py @@ -14,6 +14,7 @@ from firedantic._async.model import AsyncBareModel from firedantic._async.ttl_policy import set_up_ttl_policies from firedantic.common import IndexDefinition, IndexField +from firedantic.configurations import configuration logger = getLogger("firedantic") @@ -81,7 +82,7 @@ async def create_composite_index( async def set_up_composite_indexes( - gcloud_project: str, + gcloud_project: Optional[str], models: Iterable[Type[AsyncBareModel]], database: str = "(default)", client: Optional[FirestoreAdminAsyncClient] = None, @@ -100,24 +101,39 @@ async def set_up_composite_indexes( operations = [] for model in models: - if not model.__composite_indexes__: + if not getattr(model, "__composite_indexes__", None): continue - path = ( - f"projects/{gcloud_project}/databases/{database}/" - f"collectionGroups/{model.get_collection_name()}" + + # Resolve config name: prefer model __db_config__ if present; else default + config_name = getattr(model, "__db_config__", "(default)") + + # If caller did not pass gcloud_project, try to get it from config + project = gcloud_project or configuration.get_config(config_name).project + + # Build collection group path using configuration helper (includes prefix) + collection_group = configuration.get_collection_name( + model, config_name=config_name ) + path = f"projects/{project}/databases/{database}/collectionGroups/{collection_group}" + indexes_in_db = await get_existing_indexes(client, path=path) model_indexes = set(model.__composite_indexes__) existing_indexes = indexes_in_db.intersection(model_indexes) new_indexes = model_indexes.difference(indexes_in_db) for index in existing_indexes: - log_str = "Composite index already exists in DB: %s, collection: %s" - logger.debug(log_str, index, model.get_collection_name()) + logger.debug( + "Composite index already exists in DB: %s, collection: %s", + index, + collection_group, + ) for index in new_indexes: - log_str = "Creating new composite index: %s, collection: %s" - logger.info(log_str, index, model.get_collection_name()) + logger.info( + "Creating new composite index: %s, collection: %s", + index, + collection_group, + ) operation = await create_composite_index(client, index, path) operations.append(operation) diff --git a/firedantic/_async/model.py b/firedantic/_async/model.py index 84251df..11ab7f6 100644 --- a/firedantic/_async/model.py +++ b/firedantic/_async/model.py @@ -15,7 +15,7 @@ import firedantic.operators as op from firedantic import async_truncate_collection from firedantic.common import IndexDefinition, OrderDirection -from firedantic.configurations import CONFIGURATIONS +from firedantic.configurations import configuration from firedantic.exceptions import ( CollectionNotDefined, InvalidDocumentID, @@ -41,23 +41,58 @@ } -def get_collection_name(cls, name: Optional[str]) -> str: +def get_collection_name(cls, collection_name: Optional[str] = None) -> str: """ - Returns the collection name. + Return the collection name for `cls`. - :raise CollectionNotDefined: If the collection name is not defined. + - If `collection_name` is provided, treat it as an explicit collection name + and prefix it using the configured prefix for the class (via __db_config__). + - Otherwise, use the class name and the configured prefix. + + :raises CollectionNotDefined: If neither a class collection nor derived name is available. + """ + # Resolve the config name from the model class (default to "(default)") + config_name = getattr(cls, "__db_config__", "(default)") + + # If caller provided an explicit collection string, apply prefix from config + if collection_name: + cfg = configuration.get_config(config_name) + prefix = cfg.prefix or "" + return f"{prefix}{collection_name}" + + if getattr(cls, "__collection__", None): + cfg = configuration.get_config(config_name) + return f"{cfg.prefix or ''}{cls.__collection__}" + + raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") + + +def _get_col_ref( + cls, collection_name: Optional[str] = None +) -> AsyncCollectionReference: """ - if not name: - raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") + Return an AsyncCollectionReference for the model class using the configured async client. - return f"{CONFIGURATIONS['prefix']}{name}" + :param cls: model class + :param collection_name: optional explicit collection name override + :raises CollectionNotDefined: when collection cannot be resolved + """ + # Build the prefixed collection name + col_name = get_collection_name(cls, collection_name) + # Resolve config name from class and fetch async client + config_name = getattr(cls, "__db_config__", "(default)") + async_client = configuration.get_async_client(config_name) -def _get_col_ref(cls, name: Optional[str]) -> AsyncCollectionReference: - collection: AsyncCollectionReference = CONFIGURATIONS["db"].collection( - get_collection_name(cls, name) - ) - return collection + # Return the AsyncCollectionReference (this is a real client call) + col_ref = async_client.collection(col_name) + + # Ensure we got the right object back + if not hasattr(col_ref, "document"): + raise RuntimeError( + f"_get_col_ref returned unexpected object for {cls}: {type(col_ref)!r}" + ) + return col_ref class AsyncBareModel(pydantic.BaseModel, ABC): @@ -71,10 +106,12 @@ class AsyncBareModel(pydantic.BaseModel, ABC): __document_id__: str __ttl_field__: Optional[str] = None __composite_indexes__: Optional[Iterable[IndexDefinition]] = None + __db_config__: str = "(default)" # override in subclasses when needed async def save( self, *, + config_name: Optional[str] = None, exclude_unset: bool = False, exclude_none: bool = False, transaction: Optional[AsyncTransaction] = None, @@ -87,29 +124,122 @@ async def save( :param transaction: Optional transaction to use. :raise DocumentIDError: If the document ID is not valid. """ + + # Resolve config to use (explicit -> instance -> class -> default) + if config_name is not None: + resolved = config_name + else: + resolved = getattr(self, "__db_config__", None) + if not resolved: + resolved = getattr(self.__class__, "__db_config__", "(default)") + + # Build payload data = self.model_dump( by_alias=True, exclude_unset=exclude_unset, exclude_none=exclude_none ) if self.__document_id__ in data: del data[self.__document_id__] - doc_ref = self._get_doc_ref() + async_client = configuration.get_async_client(resolved) + if async_client is None: + raise RuntimeError(f"No async client configured for config '{resolved}'") + + # Get collection reference from async_client with the collection_name + collection_name = self.get_collection_name() + col_ref = async_client.collection(collection_name) + + # Build doc ref (use provided id if set, otherwise let server generate) + doc_id = self.get_document_id() + if doc_id: + doc_ref = col_ref.document(doc_id) + else: + doc_ref = col_ref.document() + + # Use transaction if provided (assume it's compatible) otherwise do direct await set if transaction is not None: + # Transaction.delete/set expects DocumentReference from the same client. transaction.set(doc_ref, data) else: await doc_ref.set(data) + setattr(self, self.__document_id__, doc_ref.id) async def delete(self, transaction: Optional[AsyncTransaction] = None) -> None: """ - Deletes this model from the database. + Deletes this specific model instance from the database. :raise DocumentIDError: If the ID is not valid. """ + doc_ref = self._get_doc_ref() + + # try to extract client-like objects + doc_client = getattr(doc_ref, "_client", None) or getattr( + doc_ref, "client", None + ) + tx_client = ( + getattr(transaction, "_client", None) + or getattr(transaction, "_client_async", None) + if transaction is not None + else None + ) + if transaction is not None: - transaction.delete(self._get_doc_ref()) - else: - await self._get_doc_ref().delete() + # Defensive check: make sure the doc_ref is built from same client as the transaction. + tx_client = getattr(transaction, "_client", None) or getattr( + transaction, "_client_async", None + ) + doc_client = getattr(doc_ref, "_client", None) or getattr( + doc_ref, "client", None + ) + + # If both sides expose client objects, ensure they are same identity. + if ( + tx_client is not None + and doc_client is not None + and tx_client is not doc_client + ): + # Try to rebuild a document reference from the transaction's client using the same path + path = getattr(doc_ref, "path", None) + if path is None: + raise RuntimeError( + "Cannot resolve document path to rebuild doc_ref for transaction." + ) + + # For most firestores clients, client.document(path) works for sync client; + # for async, we try client.document(path) as well (it usually exists). + try: + # prefer a method that accepts full path + alt_doc_ref = None + if hasattr(tx_client, "document"): + alt_doc_ref = tx_client.document(path) + elif hasattr(tx_client, "collection"): + # fallback: split path to collection and doc id + parts = path.split("/") + if len(parts) >= 2: + collection_path = "/".join(parts[:-1]) + doc_id = parts[-1] + alt_doc_ref = tx_client.collection( + collection_path + ).document(doc_id) + if alt_doc_ref is None: + raise RuntimeError( + "Could not rebuild document reference from transaction client." + ) + doc_ref = alt_doc_ref + + except Exception as exc: + raise RuntimeError( + "Document reference was created from a different Firestore client than " + "the provided transaction. Recreate doc_ref from the transaction's client " + "or call delete() without a transaction." + ) from exc + + # schedule delete on the transaction (this will be committed on transaction commit) + transaction.delete(doc_ref) + return + + # no transaction: do a direct delete + await doc_ref.delete() async def reload(self, transaction: Optional[AsyncTransaction] = None) -> None: """ @@ -128,7 +258,7 @@ async def reload(self, transaction: Optional[AsyncTransaction] = None) -> None: self.__dict__.update(updated_model.__dict__) - def get_document_id(self): + def get_document_id(self) -> Optional[str]: """ Returns the document ID for this model instance. @@ -141,6 +271,22 @@ def get_document_id(self): _OrderBy = List[Tuple[str, OrderDirection]] + @classmethod + async def delete_all(cls, config_name: Optional[str] = None) -> None: + """ + Deletes all models of this type from the database. + """ + + # Resolve config to use (explicit -> instance -> class -> default) + config_name = cls.__db_config__ + + client = configuration.get_async_client(config_name) + col_name = get_collection_name(cls) + col_ref = client.collection(col_name) + + async for doc in col_ref.stream(): + await doc.reference.delete() + @classmethod async def find( # pylint: disable=too-many-arguments cls: Type[TAsyncBareModel], @@ -173,8 +319,7 @@ async def find( # pylint: disable=too-many-arguments query = cls._add_filter(query, key, value) if order_by is not None: - for order_by_item in order_by: - field, direction = order_by_item + for field, direction in order_by: query = query.order_by(field, direction=direction) # type: ignore if limit is not None: query = query.limit(limit) # type: ignore @@ -298,11 +443,13 @@ async def truncate_collection(cls, batch_size: int = 128) -> int: ) @classmethod - def _get_col_ref(cls) -> AsyncCollectionReference: + def _get_col_ref( + cls, collection_name: Optional[str] = None + ) -> AsyncCollectionReference: """ Returns the collection reference. """ - return _get_col_ref(cls, cls.__collection__) + return _get_col_ref(cls, collection_name) @classmethod def get_collection_name(cls) -> str: @@ -311,7 +458,9 @@ def get_collection_name(cls) -> str: """ return get_collection_name(cls, cls.__collection__) - def _get_doc_ref(self) -> AsyncDocumentReference: + def _get_doc_ref( + self, config_name: Optional[str] = "(default)" + ) -> AsyncDocumentReference: """ Returns the document reference. diff --git a/firedantic/_sync/indexes.py b/firedantic/_sync/indexes.py index a1eb7c5..e828233 100644 --- a/firedantic/_sync/indexes.py +++ b/firedantic/_sync/indexes.py @@ -14,6 +14,7 @@ from firedantic._sync.model import BareModel from firedantic._sync.ttl_policy import set_up_ttl_policies from firedantic.common import IndexDefinition, IndexField +from firedantic.configurations import configuration logger = getLogger("firedantic") @@ -81,7 +82,7 @@ def create_composite_index( def set_up_composite_indexes( - gcloud_project: str, + gcloud_project: Optional[str], models: Iterable[Type[BareModel]], database: str = "(default)", client: Optional[FirestoreAdminClient] = None, @@ -100,24 +101,39 @@ def set_up_composite_indexes( operations = [] for model in models: - if not model.__composite_indexes__: + if not getattr(model, "__composite_indexes__", None): continue - path = ( - f"projects/{gcloud_project}/databases/{database}/" - f"collectionGroups/{model.get_collection_name()}" + + # Resolve config name: prefer model __db_config__ if present; else default + config_name = getattr(model, "__db_config__", "(default)") + + # If caller did not pass gcloud_project, try to get it from config + project = gcloud_project or configuration.get_config(config_name).project + + # Build collection group path using configuration helper (includes prefix) + collection_group = configuration.get_collection_name( + model, config_name=config_name ) + path = f"projects/{project}/databases/{database}/collectionGroups/{collection_group}" + indexes_in_db = get_existing_indexes(client, path=path) model_indexes = set(model.__composite_indexes__) existing_indexes = indexes_in_db.intersection(model_indexes) new_indexes = model_indexes.difference(indexes_in_db) for index in existing_indexes: - log_str = "Composite index already exists in DB: %s, collection: %s" - logger.debug(log_str, index, model.get_collection_name()) + logger.debug( + "Composite index already exists in DB: %s, collection: %s", + index, + collection_group, + ) for index in new_indexes: - log_str = "Creating new composite index: %s, collection: %s" - logger.info(log_str, index, model.get_collection_name()) + logger.info( + "Creating new composite index: %s, collection: %s", + index, + collection_group, + ) operation = create_composite_index(client, index, path) operations.append(operation) diff --git a/firedantic/_sync/model.py b/firedantic/_sync/model.py index 78ec372..90c995a 100644 --- a/firedantic/_sync/model.py +++ b/firedantic/_sync/model.py @@ -15,7 +15,7 @@ import firedantic.operators as op from firedantic import truncate_collection from firedantic.common import IndexDefinition, OrderDirection -from firedantic.configurations import CONFIGURATIONS +from firedantic.configurations import configuration from firedantic.exceptions import ( CollectionNotDefined, InvalidDocumentID, @@ -41,23 +41,56 @@ } -def get_collection_name(cls, name: Optional[str]) -> str: +def get_collection_name(cls, collection_name: Optional[str] = None) -> str: """ - Returns the collection name. + Return the collection name for `cls`. - :raise CollectionNotDefined: If the collection name is not defined. + - If `collection_name` is provided, treat it as an explicit collection name + and prefix it using the configured prefix for the class (via __db_config__). + - Otherwise, use the class name and the configured prefix. + + :raises CollectionNotDefined: If neither a class collection nor derived name is available. + """ + # Resolve the config name from the model class (default to "(default)") + config_name = getattr(cls, "__db_config__", "(default)") + + # If caller provided an explicit collection string, apply prefix from config + if collection_name: + cfg = configuration.get_config(config_name) + prefix = cfg.prefix or "" + return f"{prefix}{collection_name}" + + if getattr(cls, "__collection__", None): + cfg = configuration.get_config(config_name) + return f"{cfg.prefix or ''}{cls.__collection__}" + + raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") + + +def _get_col_ref(cls, collection_name: Optional[str] = None) -> CollectionReference: """ - if not name: - raise CollectionNotDefined(f"Missing collection name for {cls.__name__}") + Return an CollectionReference for the model class using the configured client. - return f"{CONFIGURATIONS['prefix']}{name}" + :param cls: model class + :param collection_name: optional explicit collection name override + :raises CollectionNotDefined: when collection cannot be resolved + """ + # Build the prefixed collection name + col_name = get_collection_name(cls, collection_name) + # Resolve config name from class and fetch client + config_name = getattr(cls, "__db_config__", "(default)") + client = configuration.get_client(config_name) -def _get_col_ref(cls, name: Optional[str]) -> CollectionReference: - collection: CollectionReference = CONFIGURATIONS["db"].collection( - get_collection_name(cls, name) - ) - return collection + # Return the CollectionReference (this is a real client call) + col_ref = client.collection(col_name) + + # Ensure we got the right object back + if not hasattr(col_ref, "document"): + raise RuntimeError( + f"_get_col_ref returned unexpected object for {cls}: {type(col_ref)!r}" + ) + return col_ref class BareModel(pydantic.BaseModel, ABC): @@ -71,10 +104,12 @@ class BareModel(pydantic.BaseModel, ABC): __document_id__: str __ttl_field__: Optional[str] = None __composite_indexes__: Optional[Iterable[IndexDefinition]] = None + __db_config__: str = "(default)" # override in subclasses when needed def save( self, *, + config_name: Optional[str] = None, exclude_unset: bool = False, exclude_none: bool = False, transaction: Optional[Transaction] = None, @@ -87,29 +122,122 @@ def save( :param transaction: Optional transaction to use. :raise DocumentIDError: If the document ID is not valid. """ + + # Resolve config to use (explicit -> instance -> class -> default) + if config_name is not None: + resolved = config_name + else: + resolved = getattr(self, "__db_config__", None) + if not resolved: + resolved = getattr(self.__class__, "__db_config__", "(default)") + + # Build payload data = self.model_dump( by_alias=True, exclude_unset=exclude_unset, exclude_none=exclude_none ) if self.__document_id__ in data: del data[self.__document_id__] - doc_ref = self._get_doc_ref() + client = configuration.get_client(resolved) + if client is None: + raise RuntimeError(f"No client configured for config '{resolved}'") + + # Get collection reference from client with the collection_name + collection_name = self.get_collection_name() + col_ref = client.collection(collection_name) + + # Build doc ref (use provided id if set, otherwise let server generate) + doc_id = self.get_document_id() + if doc_id: + doc_ref = col_ref.document(doc_id) + else: + doc_ref = col_ref.document() + + # Use transaction if provided (assume it's compatible) otherwise do direct set if transaction is not None: + # Transaction.delete/set expects DocumentReference from the same client. transaction.set(doc_ref, data) else: doc_ref.set(data) + setattr(self, self.__document_id__, doc_ref.id) def delete(self, transaction: Optional[Transaction] = None) -> None: """ - Deletes this model from the database. + Deletes this specific model instance from the database. :raise DocumentIDError: If the ID is not valid. """ + doc_ref = self._get_doc_ref() + + # try to extract client-like objects + doc_client = getattr(doc_ref, "_client", None) or getattr( + doc_ref, "client", None + ) + tx_client = ( + getattr(transaction, "_client", None) + or getattr(transaction, "_client_async", None) + if transaction is not None + else None + ) + if transaction is not None: - transaction.delete(self._get_doc_ref()) - else: - self._get_doc_ref().delete() + # Defensive check: make sure the doc_ref is built from same client as the transaction. + tx_client = getattr(transaction, "_client", None) or getattr( + transaction, "_client_async", None + ) + doc_client = getattr(doc_ref, "_client", None) or getattr( + doc_ref, "client", None + ) + + # If both sides expose client objects, ensure they are same identity. + if ( + tx_client is not None + and doc_client is not None + and tx_client is not doc_client + ): + # Try to rebuild a document reference from the transaction's client using the same path + path = getattr(doc_ref, "path", None) + if path is None: + raise RuntimeError( + "Cannot resolve document path to rebuild doc_ref for transaction." + ) + + # For most firestores clients, client.document(path) works for sync client; + # for async, we try client.document(path) as well (it usually exists). + try: + # prefer a method that accepts full path + alt_doc_ref = None + if hasattr(tx_client, "document"): + alt_doc_ref = tx_client.document(path) + elif hasattr(tx_client, "collection"): + # fallback: split path to collection and doc id + parts = path.split("/") + if len(parts) >= 2: + collection_path = "/".join(parts[:-1]) + doc_id = parts[-1] + alt_doc_ref = tx_client.collection( + collection_path + ).document(doc_id) + if alt_doc_ref is None: + raise RuntimeError( + "Could not rebuild document reference from transaction client." + ) + doc_ref = alt_doc_ref + + except Exception as exc: + raise RuntimeError( + "Document reference was created from a different Firestore client than " + "the provided transaction. Recreate doc_ref from the transaction's client " + "or call delete() without a transaction." + ) from exc + + # schedule delete on the transaction (this will be committed on transaction commit) + transaction.delete(doc_ref) + return + + # no transaction: do a direct delete + doc_ref.delete() def reload(self, transaction: Optional[Transaction] = None) -> None: """ @@ -128,7 +256,7 @@ def reload(self, transaction: Optional[Transaction] = None) -> None: self.__dict__.update(updated_model.__dict__) - def get_document_id(self): + def get_document_id(self) -> Optional[str]: """ Returns the document ID for this model instance. @@ -141,6 +269,22 @@ def get_document_id(self): _OrderBy = List[Tuple[str, OrderDirection]] + @classmethod + def delete_all(cls, config_name: Optional[str] = None) -> None: + """ + Deletes all models of this type from the database. + """ + + # Resolve config to use (explicit -> instance -> class -> default) + config_name = cls.__db_config__ + + client = configuration.get_client(config_name) + col_name = get_collection_name(cls) + col_ref = client.collection(col_name) + + for doc in col_ref.stream(): + doc.reference.delete() + @classmethod def find( # pylint: disable=too-many-arguments cls: Type[TBareModel], @@ -173,8 +317,7 @@ def find( # pylint: disable=too-many-arguments query = cls._add_filter(query, key, value) if order_by is not None: - for order_by_item in order_by: - field, direction = order_by_item + for field, direction in order_by: query = query.order_by(field, direction=direction) # type: ignore if limit is not None: query = query.limit(limit) # type: ignore @@ -296,11 +439,11 @@ def truncate_collection(cls, batch_size: int = 128) -> int: ) @classmethod - def _get_col_ref(cls) -> CollectionReference: + def _get_col_ref(cls, collection_name: Optional[str] = None) -> CollectionReference: """ Returns the collection reference. """ - return _get_col_ref(cls, cls.__collection__) + return _get_col_ref(cls, collection_name) @classmethod def get_collection_name(cls) -> str: @@ -309,12 +452,16 @@ def get_collection_name(cls) -> str: """ return get_collection_name(cls, cls.__collection__) - def _get_doc_ref(self) -> DocumentReference: + def _get_doc_ref( + self, config_name: Optional[str] = "(default)" + ) -> DocumentReference: """ Returns the document reference. :raise DocumentIDError: If the ID is not valid. """ + # _get_col_ref takes collection_name, not config_name. + # Any specific config usage should likely be handled by context or passed properly if supported. return self._get_col_ref().document(self.get_document_id()) # type: ignore @staticmethod diff --git a/firedantic/configurations.py b/firedantic/configurations.py index 2f9e277..9663858 100644 --- a/firedantic/configurations.py +++ b/firedantic/configurations.py @@ -1,35 +1,313 @@ -from typing import Any, Dict, Union +import warnings +from os import environ +from typing import Any, Dict, Optional, Type, Union -from google.cloud.firestore_v1 import AsyncClient, AsyncTransaction, Client, Transaction +from google.api_core.client_options import ClientOptions +from google.api_core.gapic_v1.client_info import ClientInfo +from google.auth.credentials import Credentials +from google.cloud.firestore_admin_v1 import FirestoreAdminClient +from google.cloud.firestore_admin_v1.services.firestore_admin import ( + FirestoreAdminAsyncClient, +) +from google.cloud.firestore_admin_v1.services.firestore_admin.transports.base import ( + DEFAULT_CLIENT_INFO, +) +from google.cloud.firestore_v1 import ( + AsyncClient, + AsyncTransaction, + Client, + CollectionReference, + Transaction, +) +from pydantic import BaseModel +# --- Old compatibility surface (kept for backwards compatibility) --- CONFIGURATIONS: Dict[str, Any] = {} def configure(db: Union[Client, AsyncClient], prefix: str = "") -> None: - """Configures the prefix and DB. - + """ + Legacy helper: updates the module-level `configuration` default entry and preserves + the old CONFIGURATIONS mapping so old callers continue to work. + Configures the prefix and DB. :param db: The firestore client instance. :param prefix: The prefix to use for collection names. """ - global CONFIGURATIONS + + # soft deprecation notice for users (no stacktrace) + warnings.warn( + "firedantic.configure(db, ...) is deprecated and will be removed in a " + "future release. Use firedantic.configurations.configuration.add(...) instead.", + DeprecationWarning, + stacklevel=2, + ) + + if isinstance(db, AsyncClient): + configuration.add(name="(default)", prefix=prefix, async_client=db) + else: + # treat as sync client + configuration.add(name="(default)", prefix=prefix, client=db) CONFIGURATIONS["db"] = db CONFIGURATIONS["prefix"] = prefix def get_transaction() -> Transaction: + """Backward-compatible sync transaction getter, forwards to new implementation.""" + return configuration.get_transaction() + + +def get_async_transaction() -> AsyncTransaction: + """Backward-compatible async transaction getter, forwards to new implementation.""" + return configuration.get_async_transaction() + + +# --- New configuration system --- +class ConfigItem(BaseModel): """ - Get a new Firestore transaction for the configured DB. + Holds configuration for a named Firestore connection. + Clients may be provided directly, or created lazily from the stored params. """ - transaction = CONFIGURATIONS["db"].transaction() - assert isinstance(transaction, Transaction) - return transaction + name: str + prefix: str = "" + project: Optional[str] = None + database: str = "(default)" -def get_async_transaction() -> AsyncTransaction: + # client objects (may be None; created lazily) + client: Optional[Any] = None + async_client: Optional[Any] = None + admin_client: Optional[Any] = None + async_admin_client: Optional[Any] = None + + # creation params (kept so we can lazily instantiate clients) + credentials: Optional[Credentials] = None + client_info: Optional[Any] = None + client_options: Optional[Union[Dict[str, Any], Any]] = None + admin_transport: Optional[Any] = None + + model_config = {"arbitrary_types_allowed": True} + + +class Configuration: """ - Get a new Firestore transaction for the configured DB. + Registry for named Firestore configurations. + + Usage Example: + configuration.add(name="billing", project="myproj", credentials=creds) + client = configuration.get_client("billing") """ - transaction = CONFIGURATIONS["db"].transaction() - assert isinstance(transaction, AsyncTransaction) - return transaction + + def __init__(self) -> None: + + # mapping name -> ConfigItem + self.config: Dict[str, ConfigItem] = {} + + # Create a sensible default config entry, but avoid passing empty string as project. + default_project = environ.get("GOOGLE_CLOUD_PROJECT") or None + + self.config["(default)"] = ConfigItem( + name="(default)", + prefix="", + project=default_project, + database="(default)", + ) + + def add( + self, + name: str = "(default)", # adding a config without a name results in overriding the default + *, + project: Optional[str] = None, + database: str = "(default)", + prefix: str = "", + client: Optional[Any] = None, + async_client: Optional[Any] = None, + admin_client: Optional[Any] = None, + async_admin_client: Optional[Any] = None, + credentials: Optional[Credentials] = None, + client_info: Optional[Any] = None, + client_options: Optional[Union[Dict[str, Any], Any]] = None, + admin_transport: Optional[Any] = None, + ) -> ConfigItem: + """ + Add a named configuration. + + You may either pass a pre-built client and/or an async_client, + or provide only project/credentials so clients will be constructed lazily. + """ + + normalized_project = self._normalize_project(project) + + item = ConfigItem( + name=name, + project=normalized_project, + database=database, + prefix=prefix, + client=client, + async_client=async_client, + admin_client=admin_client, + async_admin_client=async_admin_client, + credentials=credentials, + client_info=client_info, + client_options=client_options, + admin_transport=admin_transport, + ) + self.config[name] = item + return item + + # dict-like accessors + def __getitem__(self, name: str) -> ConfigItem: + return self.get_config(name) + + def __contains__(self, name: str) -> bool: + return name in self.config + + def get(self, name: str, default=None): + return self.config.get(name, default) + + def get_config(self, name: str = "(default)") -> ConfigItem: + try: + return self.config[name] + except KeyError as err: + raise KeyError( + f"Configuration '{name}' not found. Available: {list(self.config.keys())}" + ) from err + + def get_config_names(self): + return list(self.config.keys()) + + def _normalize_project(self, project: Optional[str]) -> Optional[str]: + """ + Convert empty-string project to None so Client(...) doesn't get empty string for project (i.e. project="") + """ + if project: + return project + return environ.get("GOOGLE_CLOUD_PROJECT") or None + + # sync client accessor (lazy-create) + def get_client(self, name: Optional[str] = None) -> Client: + resolved = name if name is not None else "(default)" + cfg = self.get_config(resolved) + if cfg.client is None: + # don't try to create a client if we lack project info + if cfg.project is None: + raise RuntimeError( + f"No sync client configured for '{resolved}' and no project available; " + "call configuration.add(..., client=..., project=...) or set " + "GOOGLE_CLOUD_PROJECT and let add() build the client." + ) + cfg.client = Client( + project=cfg.project, + credentials=cfg.credentials, + client_info=cfg.client_info, + client_options=cfg.client_options, # type: ignore[arg-type] + # NOTE: modern firestore clients may accept database param in constructor; keep for future. + ) + return cfg.client + + # async client accessor (lazy-create) + def get_async_client(self, name: Optional[str] = None) -> AsyncClient: + resolved = name if name is not None else "(default)" + cfg = self.get_config(resolved) + + if cfg.async_client is None: + if cfg.project is None: + raise RuntimeError( + f"No async client configured for '{resolved}' and no project available; " + "call configuration.add(..., async_client=..., project=...) or set " + "GOOGLE_CLOUD_PROJECT and let add() build the client." + ) + cfg.async_client = AsyncClient( + project=cfg.project, + credentials=cfg.credentials, + client_info=cfg.client_info, + client_options=cfg.client_options, # type: ignore[arg-type] + ) + return cfg.async_client + + # admin client accessor (lazy-create) + def get_admin_client(self, name: Optional[str] = None) -> FirestoreAdminClient: + resolved = name if name is not None else "(default)" + cfg = self.get_config(resolved) + + if cfg.admin_client is None: + cfg.admin_client = FirestoreAdminClient( + credentials=cfg.credentials, + transport=cfg.admin_transport, + client_options=cfg.client_options, # type: ignore[arg-type] + client_info=cfg.client_info or DEFAULT_CLIENT_INFO, + ) + return cfg.admin_client + + # async admin client accessor (lazy-create) + def get_async_admin_client( + self, name: Optional[str] = None + ) -> FirestoreAdminAsyncClient: + resolved = name if name is not None else "(default)" + cfg = self.get_config(resolved) + + if cfg.async_admin_client is None: + cfg.async_admin_client = FirestoreAdminAsyncClient( + credentials=cfg.credentials, + transport=cfg.admin_transport, + client_options=cfg.client_options, # type: ignore[arg-type] + client_info=cfg.client_info or DEFAULT_CLIENT_INFO, + ) + return cfg.async_admin_client + + # transactions + def get_transaction(self, name: Optional[str] = None) -> Transaction: + return self.get_client(name=name).transaction() + + def get_async_transaction(self, name: Optional[str] = None) -> AsyncTransaction: + return self.get_async_client(name=name).transaction() + + # helpers for models to derive collection name / reference + def get_collection_name( + self, model_class: Type, config_name: Optional[str] = None + ) -> str: + """ + Return the collection name string (prefix + model name). + """ + resolved = config_name if config_name is not None else "(default)" + cfg = self.get_config(resolved) + prefix = cfg.prefix or "" + + if hasattr(model_class, "__collection__"): + return prefix + model_class.__collection__ + else: + model_name = model_class.__name__ + return f"{prefix}{model_name[0].lower()}{model_name[1:]}" # (lower case first letter of model name) + + def get_collection_ref( + self, model_class: Type, name: Optional[str] = None + ) -> CollectionReference: + """ + Return a CollectionReference for the given model_class using the sync client. + """ + resolved = name if name is not None else "(default)" + cfg = self.get_config(resolved) + client = cfg.client + if client is None: + raise RuntimeError(f"No sync client configured for config '{resolved}'") + collection_name = self.get_collection_name(model_class, resolved) + return client.collection(collection_name) + + def get_async_collection_ref(self, model_class: Type, name: Optional[str] = None): + """ + Return an AsyncCollectionReference using the async client. + """ + resolved = name if name is not None else "(default)" + cfg = self.get_config(resolved) + async_client = cfg.async_client + if async_client is None: + raise RuntimeError(f"No async client configured for config '{resolved}'") + + collection_name = self.get_collection_name(model_class, resolved) + + return async_client.collection(collection_name) + + +# make the module-level singleton available to models/tests +# other modules should: from firedantic.configurations import configuration +configuration = Configuration() diff --git a/firedantic/tests/test_clients.py b/firedantic/tests/test_clients.py new file mode 100644 index 0000000..197c980 --- /dev/null +++ b/firedantic/tests/test_clients.py @@ -0,0 +1,424 @@ +from os import environ +from unittest.mock import Mock + +import google.auth.credentials +import pytest +from google.cloud.firestore import AsyncClient, Client + +import firedantic.configurations as cfg_module +from firedantic import Configuration, configure + + +# --- tiny fake client classes to capture construction args --- +class FakeClient: + def __init__(self, *args, **kwargs): + self._constructed_args = args + self._constructed_kwargs = kwargs + + +class FakeAsyncClient: + def __init__(self, *args, **kwargs): + self._constructed_args = args + self._constructed_kwargs = kwargs + + +# --- Tests --------------------------------------------------------- +def test_get_client_lazy_creation_and_caching(monkeypatch): + """ + get_client() should create a Client lazily (on first call) and then cache it. + """ + monkeypatch.setattr(cfg_module, "Client", FakeClient) + + cfg = Configuration() + # register a config with a project so that lazy creation is allowed + item = cfg.add(name="x", project="proj-x", prefix="px-") + + # no client yet + assert item.client is None + + # first call -> construct and cache + client1 = cfg.get_client("x") + assert isinstance(client1, FakeClient) + assert cfg.get_config("x").client is client1 + + # subsequent call -> same instance + client2 = cfg.get_client("x") + assert client2 is client1 + + # verify constructor received our project in kwargs (if implementation provides it) + # accept either 'project' in kwargs or stored on item.project + assert cfg.get_config("x").project == "proj-x" + + +def test_get_async_client_lazy_creation_and_caching(monkeypatch): + """ + get_async_client() should create an AsyncClient lazily and cache it. + """ + monkeypatch.setattr(cfg_module, "AsyncClient", FakeAsyncClient) + + cfg = Configuration() + item = cfg.add(name="y", project="proj-y", prefix="py-") + + assert item.async_client is None + + async_client1 = cfg.get_async_client("y") + assert isinstance(async_client1, FakeAsyncClient) + assert cfg.get_config("y").async_client is async_client1 + + async_client2 = cfg.get_async_client("y") + assert async_client2 is async_client1 + + +def test_preserve_prebuilt_clients(): + """ + If user supplies a prebuilt client / async_client to add(), ensure it uses those. + """ + cfg = Configuration() + + prebuilt_sync = FakeClient() + prebuilt_async = FakeAsyncClient() + + cfg.add( + name="prebuilt", + project="proj-pre", + prefix="pp-", + client=prebuilt_sync, + async_client=prebuilt_async, + ) + + # get_client/get_async_client should return the exact instances given + got_sync = cfg.get_client("prebuilt") + got_async = cfg.get_async_client("prebuilt") + assert got_sync is prebuilt_sync + assert got_async is prebuilt_async + assert cfg.get_config("prebuilt").client is prebuilt_sync + assert cfg.get_config("prebuilt").async_client is prebuilt_async + + +def test_get_client_no_project_raises(monkeypatch): + """ + If no project was supplied and no client was given, get_client() should raise a clear error. + """ + # ensure environment does NOT provide a project fallback + monkeypatch.delenv("GOOGLE_CLOUD_PROJECT", raising=False) + + # add a config without project and without client + cfg = Configuration() + cfg.add(name="noproj", project=None, prefix="np-") + + with pytest.raises(RuntimeError): + cfg.get_client("noproj") + + with pytest.raises(RuntimeError): + cfg.get_async_client("noproj") + + +def test_multiple_clients_are_independent(monkeypatch): + """ + Ensure default and billing configs return distinct client objects and preserve stored metadata. + """ + monkeypatch.setattr(cfg_module, "Client", FakeClient) + monkeypatch.setattr(cfg_module, "AsyncClient", FakeAsyncClient) + creds = Mock(spec=google.auth.credentials.Credentials) + + cfg = Configuration() + cfg.add(prefix="firedantic-test-", project="proj-default", credentials=creds) + cfg.add( + name="billing", prefix="billing-", project="proj-billing", credentials=creds + ) + + # clients for "(default)"" config + default_sync = cfg.get_client() + default_async = cfg.get_async_client() + assert isinstance(default_sync, FakeClient) + assert isinstance(default_async, FakeAsyncClient) + + # clients for billing config + billing_sync = cfg.get_client("billing") + billing_async = cfg.get_async_client("billing") + assert isinstance(billing_sync, FakeClient) + assert isinstance(billing_async, FakeAsyncClient) + + # Ensure the four objects are distinct (sync vs async and default vs billing) + assert default_sync is not billing_sync + assert default_async is not billing_async + assert default_sync is not default_async + assert billing_sync is not billing_async + + # verify metadata preserved + assert cfg.get_config().prefix == "firedantic-test-" + assert cfg.get_config().project == "proj-default" + assert cfg.get_config("billing").prefix == "billing-" + assert cfg.get_config("billing").project == "proj-billing" + + +def test___getitem___and_contains_behavior(monkeypatch): + """ + Ensure dict-like accessors work: __getitem__ and __contains__. + """ + monkeypatch.setattr(cfg_module, "Client", FakeClient) + cfg = Configuration() + cfg.add(name="abc", project="p-abc", prefix="pa-") + assert "abc" in cfg + assert "(default)" in cfg + assert cfg["abc"].project == "p-abc" + assert cfg["(default)"].prefix == "" + + +def test_configure_client(): + + project = "firedantic-test" + prefix = "firedantic-test-" + creds = Mock(spec=google.auth.credentials.Credentials) + + if not environ.get("FIRESTORE_EMULATOR_HOST"): + raise "Firestore emulator must be running" + + config = Configuration() + config.add(prefix=prefix, project=project, credentials=creds) + + assert config["(default)"].project == project + assert config["(default)"].prefix == prefix + client = config.get_client() + assert isinstance(client, Client) + + +### Test sync/async client classes ### + + +def test_configure_async_client(): + # Firestore emulator must be running if using locally. + project = "firedantic-test" + prefix = "firedantic-test-" + creds = Mock(spec=google.auth.credentials.Credentials) + + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = AsyncClient( + project=project, + credentials=creds, + ) + else: + raise "Firestore emulator must be running" + + config = Configuration() + config.add(prefix=prefix, project=project, credentials=creds) + + assert config["(default)"].project == project + assert config["(default)"].prefix == prefix + asyncClient = config.get_async_client() + assert isinstance(asyncClient, AsyncClient) + + +def test_configure_multiple_clients(): + config = Configuration() + mock_creds = Mock(spec=google.auth.credentials.Credentials) + + # name = (default) + config.add( + prefix="firedantic-test-", project="firedantic-test", credentials=mock_creds + ) + + # name = billing + config.add( + name="billing", + prefix="test-billing-", + project="test-billing", + credentials=mock_creds, + ) + + # assert that all 4 clients are unique + client = config.get_client() + assert isinstance(client, Client) + + asyncClient = config.get_async_client() + assert isinstance(asyncClient, AsyncClient) + + billingClient = config.get_client("billing") + assert isinstance(client, Client) + + billingAsyncClient = config.get_async_client("billing") + assert isinstance(asyncClient, AsyncClient) + + assert client != asyncClient != billingClient != billingAsyncClient + + # assert that CONFIGURATIONS holds two different client configs + assert config["(default)"].prefix == "firedantic-test-" + assert config["(default)"].project == "firedantic-test" + assert config["(default)"].credentials == mock_creds + + assert config["billing"].prefix == "test-billing-" + assert config["billing"].project == "test-billing" + assert config["billing"].credentials == mock_creds + + +### Test admin client classes ### + + +# helper fake admin client classes to capture construction arguments +class FakeAdminClient: + def __init__(self, *args, **kwargs): + self._constructed_args = args + self._constructed_kwargs = kwargs + + +class FakeAsyncAdminClient: + def __init__(self, *args, **kwargs): + self._constructed_args = args + self._constructed_kwargs = kwargs + + +def test_get_admin_client_lazy_creation(monkeypatch): + """ + get_admin_client() should create a FirestoreAdminClient lazily and cache it. + """ + cfg = Configuration() + + # monkeypatch the admin client class used inside the module + monkeypatch.setattr( + "firedantic.configurations.FirestoreAdminClient", FakeAdminClient + ) + + # add a config that does not supply admin_client upfront + cfg.add(name="x", project="proj-a", database="(default)", prefix="p-") + + # initially there should be no admin client stored on the ConfigItem + item = cfg.get_config("x") + assert item.admin_client is None + + # first call should create it + admin = cfg.get_admin_client("x") + assert isinstance(admin, FakeAdminClient) + assert cfg.get_config("x").admin_client is admin # cached + + # subsequent call returns same instance (cached) + admin2 = cfg.get_admin_client("x") + assert admin2 is admin + + +def test_get_async_admin_client_lazy_creation(monkeypatch): + """ + get_async_admin_client() should create an async admin client lazily and cache it. + """ + cfg = Configuration() + + monkeypatch.setattr( + "firedantic.configurations.FirestoreAdminAsyncClient", FakeAsyncAdminClient + ) + + cfg.add(name="y", project="proj-b", database="billing", prefix="pb-") + + item = cfg.get_config("y") + assert item.async_admin_client is None + + async_admin = cfg.get_async_admin_client("y") + assert isinstance(async_admin, FakeAsyncAdminClient) + assert cfg.get_config("y").async_admin_client is async_admin + + # cached + assert cfg.get_async_admin_client("y") is async_admin + + +def test_preserve_prebuilt_admin_client(): + """ + If the user supplies an admin_client in add(), the registry should use it verbatim. + """ + cfg = Configuration() + + # create a prebuilt fake admin client instance + prebuilt = FakeAdminClient() + cfg.add( + name="prebuilt", + project="proj-pre", + database="(default)", + prefix="pp-", + admin_client=prebuilt, + ) + + # get_admin_client should return the exact object we passed + got = cfg.get_admin_client("prebuilt") + assert got is prebuilt + assert cfg.get_config("prebuilt").admin_client is prebuilt + + +def test_admin_client_creation_receives_transport_and_client_options(monkeypatch): + """ + Ensure the admin-client constructor receives client_options and transport passed into add(). + """ + cfg = Configuration() + + captured = {} + + # define a fake constructor that records kwargs + def fake_constructor(*args, **kwargs): + inst = FakeAdminClient(*args, **kwargs) + captured["kwargs"] = kwargs + captured["args"] = args + return inst + + monkeypatch.setattr( + "firedantic.configurations.FirestoreAdminClient", fake_constructor + ) + + # pass some admin_transport and client_options into add() + fake_transport = object() + fake_client_options = {"api_endpoint": "test-endpoint"} + fake_client_info = Mock() + + cfg.add( + name="z", + project="proj-z", + database="(default)", + prefix="pz-", + credentials=None, + admin_transport=fake_transport, + client_options=fake_client_options, + client_info=fake_client_info, + ) + + # creating admin client should call fake_constructor with the supplied values (as kwargs) + admin = cfg.get_admin_client("z") + assert isinstance(admin, FakeAdminClient) + assert "kwargs" in captured + # check the captured kwargs including transport + client_options and a client_info (or default) + assert captured["kwargs"].get("transport") is fake_transport + assert captured["kwargs"].get("client_options") == fake_client_options + assert "client_info" in captured["kwargs"] + + +def test_get_admin_client_unknown_config_raises(): + cfg = Configuration() + with pytest.raises(KeyError): + cfg.get_admin_client("does-not-exist") + + +def test_legacy_configure_shim_creates_config(monkeypatch): + from firedantic.configurations import CONFIGURATIONS + from firedantic.configurations import Client as RealClient # monkeypatch below + from firedantic.configurations import configuration, configure + + # fake client class to avoid network + class FakeClient: + pass + + monkeypatch.setattr("firedantic.configurations.Client", FakeClient) + + fake = FakeClient() + configure(fake, prefix="legacy-") + cfg = configuration.get_config("(default)") + assert cfg.prefix == "legacy-" + assert CONFIGURATIONS["prefix"] == "legacy-" + assert CONFIGURATIONS["db"] is fake + + +def test_legacy_configure_warns(monkeypatch): + import warnings + + from firedantic.configurations import configure + + class FakeClient: + pass + + with warnings.catch_warnings(record=True) as rec: + warnings.simplefilter("always") + configure(FakeClient(), prefix="x") + assert any(issubclass(w.category, DeprecationWarning) for w in rec) diff --git a/firedantic/tests/test_mocked_model.py b/firedantic/tests/test_mocked_model.py new file mode 100644 index 0000000..64bd1b3 --- /dev/null +++ b/firedantic/tests/test_mocked_model.py @@ -0,0 +1,154 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +from google.cloud.firestore import AsyncClient + +from firedantic import AsyncModel, configure +from firedantic.configurations import CONFIGURATIONS + + +@pytest.fixture +def MockModelClass(): + class TestModel(AsyncModel): + __collection__ = "mock_models" + name: str + age: int = 30 + + return TestModel + + +@pytest.fixture +def mock_client(): + client = AsyncMock(spec=AsyncClient) + # Reset configurations to ensure isolation + CONFIGURATIONS.clear() + configure(client, prefix="test_") + return client + + +@pytest.mark.asyncio +async def test_save_new_model(mock_client, MockModelClass): + collection_mock = MagicMock() + document_mock = MagicMock() + + mock_client.collection.return_value = collection_mock + collection_mock.document.return_value = document_mock + + document_mock.id = "generated-id" + # Ensure set/update are async + document_mock.set = AsyncMock() + document_mock.update = AsyncMock() + + model = MockModelClass(name="Alice") + await model.save() + + mock_client.collection.assert_called_with("test_mock_models") + assert collection_mock.document.called + document_mock.set.assert_called_once() + assert model.id == "generated-id" + + +@pytest.mark.asyncio +async def test_save_update_model(mock_client, MockModelClass): + collection_mock = MagicMock() + document_mock = MagicMock() + + mock_client.collection.return_value = collection_mock + collection_mock.document.return_value = document_mock + document_mock.set = AsyncMock() + document_mock.update = AsyncMock() + + model = MockModelClass(id="existing-id", name="Bob") + await model.save() + + mock_client.collection.assert_called_with("test_mock_models") + collection_mock.document.assert_called_with("existing-id") + # It might use update or set. + assert document_mock.update.called or document_mock.set.called + + +@pytest.mark.asyncio +async def test_delete_model(mock_client, MockModelClass): + collection_mock = MagicMock() + document_mock = MagicMock() + + mock_client.collection.return_value = collection_mock + collection_mock.document.return_value = document_mock + document_mock.delete = AsyncMock() + + # Verify we DO NOT call stream (which would indicate bulk delete) + collection_mock.stream = MagicMock( + side_effect=Exception("Should not stream collection for single delete!") + ) + + model = MockModelClass(id="to-delete", name="Charlie") + await model.delete() + + mock_client.collection.assert_called_with("test_mock_models") + collection_mock.document.assert_called_with("to-delete") + document_mock.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_delete_all(mock_client, MockModelClass): + collection_mock = MagicMock() + document_mock = MagicMock() + + mock_client.collection.return_value = collection_mock + + mock_doc = MagicMock() + mock_doc.reference = document_mock + document_mock.delete = AsyncMock() + + async def async_stream(): + yield mock_doc + yield mock_doc + + collection_mock.stream.side_effect = async_stream + + await MockModelClass.delete_all() + + assert collection_mock.stream.called + assert document_mock.delete.call_count == 2 + + +@pytest.mark.asyncio +async def test_find(mock_client, MockModelClass): + collection_mock = MagicMock() + + mock_client.collection.return_value = collection_mock + + # Mocking the query chain: .where().where().stream() + # If find() calls where(), we need to return the request (or new query object) to allow chaining. + # The final object calls stream(). + + # We can use a recursive magic mock for chaining or set return_value to self. + query_mock = MagicMock() + collection_mock.where.return_value = query_mock + query_mock.where.return_value = query_mock # For multiple wheres + query_mock.order_by.return_value = query_mock + query_mock.limit.return_value = query_mock + query_mock.offset.return_value = query_mock + + # Mock stream to return an async generator + mock_snapshot = MagicMock() + mock_snapshot.id = "doc-1" + mock_snapshot.to_dict.return_value = {"name": "Dave", "age": 40} + + async def async_stream(): + yield mock_snapshot + + query_mock.stream.return_value = async_stream() + + # In case it calls stream() directly on collection (no filters) + collection_mock.stream.side_effect = async_stream + + # Case 1: find({}) -> no filters, might call stream on collection mock directly or via "select" etc. + # Assuming firedantic calls collection(...).stream() for empty query, + # or collection(...).where(..).stream() + + results = await MockModelClass.find({"name": "Dave"}) + + assert len(results) == 1 + assert results[0].name == "Dave" + assert results[0].id == "doc-1" diff --git a/firedantic/tests/tests_async/test_indexes.py b/firedantic/tests/tests_async/test_indexes.py index 5c0d3e9..f6d93fe 100644 --- a/firedantic/tests/tests_async/test_indexes.py +++ b/firedantic/tests/tests_async/test_indexes.py @@ -13,6 +13,7 @@ collection_index, ) from firedantic.common import IndexField +from firedantic.configurations import configuration from firedantic.tests.tests_async.conftest import MockListIndexOperation import pytest # noqa isort: skip @@ -28,6 +29,10 @@ class BaseModelWithIndexes(AsyncModel): @pytest.mark.asyncio async def test_set_up_composite_index(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -46,9 +51,10 @@ class ModelWithIndexes(BaseModelWithIndexes): call_list = mock_admin_client.create_index.call_args_list # index is a protobuf structure sent via Google Cloud Admin path = call_list[0][1]["request"].parent + expected_prefix = configuration.get_config("(default)").prefix assert ( path - == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS['prefix']}modelWithIndexes" + == f"projects/proj/databases/(default)/collectionGroups/{expected_prefix}modelWithIndexes" ) index = call_list[0][1]["request"].index assert index.query_scope.name == "COLLECTION" @@ -61,6 +67,10 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_set_up_collection_group_index(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_group_index( @@ -79,9 +89,11 @@ class ModelWithIndexes(BaseModelWithIndexes): call_list = mock_admin_client.create_index.call_args_list # index is a protobuf structure sent via Google Cloud Admin path = call_list[0][1]["request"].parent + expected_prefix = configuration.get_config("(default)").prefix + assert ( path - == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS['prefix']}modelWithIndexes" + == f"projects/proj/databases/(default)/collectionGroups/{expected_prefix}modelWithIndexes" ) index = call_list[0][1]["request"].index assert index.query_scope.name == "COLLECTION_GROUP" @@ -90,6 +102,10 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_set_up_composite_indexes_and_policies(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -114,6 +130,10 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_set_up_many_composite_indexes(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -141,6 +161,10 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_set_up_indexes_model_without_indexes(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) + class ModelWithoutIndexes(AsyncModel): __collection__ = "modelWithoutIndexes" @@ -159,13 +183,18 @@ class ModelWithoutIndexes(AsyncModel): @pytest.mark.asyncio async def test_existing_indexes_are_skipped(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) + expected_prefix = configuration.get_config("(default)").prefix + resp = ListIndexesResponse( { "indexes": [ { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS['prefix']}modelWithIndexes/123456" + f"{expected_prefix}modelWithIndexes/123456" ), "query_scope": "COLLECTION", "fields": [ @@ -177,7 +206,7 @@ async def test_existing_indexes_are_skipped(mock_admin_client) -> None: { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS['prefix']}modelWithIndexes/67889" + f"{expected_prefix}modelWithIndexes/67889" ), "query_scope": "COLLECTION", "fields": [ @@ -215,6 +244,11 @@ class ModelWithIndexes(BaseModelWithIndexes): @pytest.mark.asyncio async def test_same_fields_in_another_collection(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", async_client=mock_admin_client + ) + expected_prefix = configuration.get_config("(default)").prefix + # Test that when another collection has an index with exactly the same fields, # it won't affect creating an index in the target collection resp = ListIndexesResponse( @@ -223,7 +257,7 @@ async def test_same_fields_in_another_collection(mock_admin_client) -> None: { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS['prefix']}anotherModel/123456" + f"{expected_prefix}anotherModel/123456" ), "query_scope": "COLLECTION", "fields": [ diff --git a/firedantic/tests/tests_async/test_model.py b/firedantic/tests/tests_async/test_model.py index ed27c0d..83fdba5 100644 --- a/firedantic/tests/tests_async/test_model.py +++ b/firedantic/tests/tests_async/test_model.py @@ -8,6 +8,7 @@ import firedantic.operators as op from firedantic import AsyncModel, get_async_transaction +from firedantic.configurations import configuration from firedantic.exceptions import ( CollectionNotDefined, InvalidDocumentID, @@ -19,6 +20,7 @@ CustomIDConflictModel, CustomIDModel, CustomIDModelExtra, + Owner, Product, Profile, TodoList, @@ -36,7 +38,7 @@ @pytest.mark.asyncio -async def test_save_model(configure_db, create_company) -> None: +async def test_save_model(create_company) -> None: company = await create_company() assert company.id is not None @@ -45,22 +47,7 @@ async def test_save_model(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_delete_model(configure_db, create_company) -> None: - company: Company = await create_company( - company_id="11223344-5", first_name="Jane", last_name="Doe" - ) - - _id = company.id - assert _id - - await company.delete() - - with pytest.raises(ModelNotFoundError): - await Company.get_by_id(_id) - - -@pytest.mark.asyncio -async def test_find_one(configure_db, create_company) -> None: +async def test_find_one(create_company) -> None: with pytest.raises(ModelNotFoundError): await Company.find_one() @@ -91,7 +78,7 @@ async def test_find_one(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_find(configure_db, create_company, create_product) -> None: +async def test_find(create_company, create_product) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: await create_company(company_id=company_id) @@ -103,6 +90,12 @@ async def test_find(configure_db, create_company, create_product) -> None: d = await Company.find({"owner.first_name": "John"}) assert len(d) == 4 + d = await Company.find({"owner.first_name": {op.EQ: "John"}}) + assert len(d) == 4 + + d = await Company.find({"owner.first_name": {"==": "John"}}) + assert len(d) == 4 + for p in TEST_PRODUCTS: await create_product(**p) @@ -122,7 +115,7 @@ async def test_find(configure_db, create_company, create_product) -> None: @pytest.mark.asyncio -async def test_find_not_in(configure_db, create_company) -> None: +async def test_find_not_in(create_company) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: await create_company(company_id=company_id) @@ -143,7 +136,7 @@ async def test_find_not_in(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_find_array_contains(configure_db, create_todolist) -> None: +async def test_find_array_contains(create_todolist) -> None: list_1 = await create_todolist("list_1", ["Work", "Eat", "Sleep"]) await create_todolist("list_2", ["Learn Python", "Walk the dog"]) @@ -153,7 +146,7 @@ async def test_find_array_contains(configure_db, create_todolist) -> None: @pytest.mark.asyncio -async def test_find_array_contains_any(configure_db, create_todolist) -> None: +async def test_find_array_contains_any(create_todolist) -> None: list_1 = await create_todolist("list_1", ["Work", "Eat"]) list_2 = await create_todolist("list_2", ["Relax", "Chill", "Sleep"]) await create_todolist("list_3", ["Learn Python", "Walk the dog"]) @@ -165,7 +158,7 @@ async def test_find_array_contains_any(configure_db, create_todolist) -> None: @pytest.mark.asyncio -async def test_find_limit(configure_db, create_company) -> None: +async def test_find_limit(create_company) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: await create_company(company_id=company_id) @@ -178,7 +171,7 @@ async def test_find_limit(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_find_order_by(configure_db, create_company) -> None: +async def test_find_order_by(create_company) -> None: companies_and_owners = [ {"company_id": "1234555-1", "last_name": "A", "first_name": "A"}, {"company_id": "1234555-2", "last_name": "A", "first_name": "B"}, @@ -227,7 +220,7 @@ async def test_find_order_by(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_find_offset(configure_db, create_company) -> None: +async def test_find_offset(create_company) -> None: ids_and_lastnames = ( ("1234555-1", "A"), ("1234567-8", "B"), @@ -245,7 +238,7 @@ async def test_find_offset(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_get_by_id(configure_db, create_company) -> None: +async def test_get_by_id(create_company) -> None: c: Company = await create_company(company_id="1234567-8") assert c.id is not None @@ -260,22 +253,23 @@ async def test_get_by_id(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_get_by_empty_str_id(configure_db) -> None: +async def test_get_by_empty_str_id() -> None: with pytest.raises(ModelNotFoundError): await Company.get_by_id("") @pytest.mark.asyncio -async def test_missing_collection(configure_db) -> None: +async def test_missing_collection() -> None: class User(AsyncModel): name: str + # normally __collection__ would be defined here with pytest.raises(CollectionNotDefined): await User(name="John").save() @pytest.mark.asyncio -async def test_model_aliases(configure_db) -> None: +async def test_model_aliases() -> None: class User(AsyncModel): __collection__ = "User" @@ -320,7 +314,7 @@ class User(AsyncModel): "!:&+-*'()", ], ) -async def test_models_with_valid_custom_id(configure_db, model_id) -> None: +async def test_models_with_valid_custom_id(model_id) -> None: product_id = str(uuid4()) product = Product(product_id=product_id, price=123.45, stock=2) @@ -349,7 +343,7 @@ async def test_models_with_valid_custom_id(configure_db, model_id) -> None: "foo/bar/baz", ], ) -async def test_models_with_invalid_custom_id(configure_db, model_id: str) -> None: +async def test_models_with_invalid_custom_id(model_id: str) -> None: product = Product(product_id="product 123", price=123.45, stock=2) product.id = model_id with pytest.raises(InvalidDocumentID): @@ -360,7 +354,7 @@ async def test_models_with_invalid_custom_id(configure_db, model_id: str) -> Non @pytest.mark.asyncio -async def test_truncate_collection(configure_db, create_company) -> None: +async def test_truncate_collection(create_company) -> None: await create_company(company_id="1234567-8") await create_company(company_id="1234567-9") @@ -373,7 +367,7 @@ async def test_truncate_collection(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_custom_id_model(configure_db) -> None: +async def test_custom_id_model() -> None: c = CustomIDModel(bar="bar") # type: ignore await c.save() @@ -386,7 +380,7 @@ async def test_custom_id_model(configure_db) -> None: @pytest.mark.asyncio -async def test_custom_id_conflict(configure_db) -> None: +async def test_custom_id_conflict() -> None: await CustomIDConflictModel(foo="foo", bar="bar").save() models = await CustomIDModel.find({}) @@ -398,7 +392,7 @@ async def test_custom_id_conflict(configure_db) -> None: @pytest.mark.asyncio -async def test_model_id_persistency(configure_db) -> None: +async def test_model_id_persistency() -> None: c = CustomIDConflictModel(foo="foo", bar="bar") await c.save() assert c.id @@ -410,7 +404,7 @@ async def test_model_id_persistency(configure_db) -> None: @pytest.mark.asyncio -async def test_bare_model_document_id_persistency(configure_db) -> None: +async def test_bare_model_document_id_persistency() -> None: c = CustomIDModel(bar="bar") # type: ignore await c.save() assert c.foo @@ -422,20 +416,20 @@ async def test_bare_model_document_id_persistency(configure_db) -> None: @pytest.mark.asyncio -async def test_bare_model_get_by_empty_doc_id(configure_db) -> None: +async def test_bare_model_get_by_empty_doc_id() -> None: with pytest.raises(ModelNotFoundError): await CustomIDModel.get_by_doc_id("") @pytest.mark.asyncio -async def test_extra_fields(configure_db) -> None: +async def test_extra_fields() -> None: await CustomIDModelExtra(foo="foo", bar="bar", baz="baz").save() # type: ignore with pytest.raises(ValidationError): await CustomIDModel.find({}) @pytest.mark.asyncio -async def test_company_stats(configure_db, create_company) -> None: +async def test_company_stats(create_company) -> None: company: Company = await create_company(company_id="1234567-8") company_stats = company.stats() @@ -456,7 +450,7 @@ async def test_company_stats(configure_db, create_company) -> None: @pytest.mark.asyncio -async def test_subcollection_model_safety(configure_db) -> None: +async def test_subcollection_model_safety() -> None: """ Ensure you shouldn't be able to use unprepared subcollection models accidentally """ @@ -465,7 +459,7 @@ async def test_subcollection_model_safety(configure_db) -> None: @pytest.mark.asyncio -async def test_get_user_purchases(configure_db) -> None: +async def test_get_user_purchases() -> None: u = User(name="Foo") await u.save() assert u.id @@ -477,7 +471,7 @@ async def test_get_user_purchases(configure_db) -> None: @pytest.mark.asyncio -async def test_reload(configure_db) -> None: +async def test_reload() -> None: u = User(name="Foo") await u.save() @@ -496,7 +490,7 @@ async def test_reload(configure_db) -> None: @pytest.mark.asyncio -async def test_save_with_exclude_none(configure_db) -> None: +async def test_save_with_exclude_none() -> None: p = Profile(name="Foo") await p.save(exclude_none=True) @@ -518,7 +512,7 @@ async def test_save_with_exclude_none(configure_db) -> None: @pytest.mark.asyncio -async def test_save_with_exclude_unset(configure_db) -> None: +async def test_save_with_exclude_unset() -> None: p = Profile(photo_url=None) await p.save(exclude_unset=True) @@ -540,11 +534,9 @@ async def test_save_with_exclude_unset(configure_db) -> None: @pytest.mark.asyncio -async def test_update_city_in_transaction(configure_db) -> None: +async def test_update_city_in_transaction() -> None: """ Test updating a model in a transaction. Test case from README. - - :param: configure_db: pytest fixture """ @async_transactional @@ -566,38 +558,49 @@ async def decrement_population( @pytest.mark.asyncio -async def test_delete_in_transaction(configure_db) -> None: +async def test_delete_in_transaction(create_company): """ - Test deleting a model in a transaction. - - :param: configure_db: pytest fixture + Test deleting a Company model within a Firestore transaction. """ + # Create a company + company: Company = await create_company( + company_id="11223344-4", first_name="Joe", last_name="Day" + ) + _id = company.id + assert _id @async_transactional - async def delete_in_transaction( - transaction: AsyncTransaction, profile_id: str - ) -> None: - """Deletes a Profile in a transaction.""" - profile = await Profile.get_by_id(profile_id, transaction=transaction) - await profile.delete(transaction=transaction) - - p = Profile(name="Foo") - await p.save() - assert p.id + async def delete_company(transaction: AsyncTransaction) -> None: + await company.delete() + # Call the transactional function t = get_async_transaction() - await delete_in_transaction(t, p.id) + await delete_company(t) + + # Outside the transaction, the deletion should now be committed + with pytest.raises(ModelNotFoundError): + await Company.get_by_id(_id) + + +@pytest.mark.asyncio +async def test_delete_model(create_company) -> None: + company: Company = await create_company( + company_id="11223344-5", first_name="Jane", last_name="Doe" + ) + + _id = company.id + assert _id + + await company.delete() with pytest.raises(ModelNotFoundError): - await Profile.get_by_id(p.id) + await Company.get_by_id(_id) @pytest.mark.asyncio -async def test_update_model_in_transaction(configure_db) -> None: +async def test_update_model_in_transaction() -> None: """ Test updating a model in a transaction. - - :param: configure_db: pytest fixture """ @async_transactional @@ -620,11 +623,9 @@ async def update_in_transaction( @pytest.mark.asyncio -async def test_update_submodel_in_transaction(configure_db) -> None: +async def test_update_submodel_in_transaction() -> None: """ Test Updating a submodel in a transaction. - - :param: configure_db: pytest fixture """ @async_transactional diff --git a/firedantic/tests/tests_sync/test_indexes.py b/firedantic/tests/tests_sync/test_indexes.py index c4fb399..46d4b5c 100644 --- a/firedantic/tests/tests_sync/test_indexes.py +++ b/firedantic/tests/tests_sync/test_indexes.py @@ -13,6 +13,7 @@ set_up_composite_indexes_and_ttl_policies, ) from firedantic.common import IndexField +from firedantic.configurations import configuration from firedantic.tests.tests_sync.conftest import MockListIndexOperation import pytest # noqa isort: skip @@ -27,6 +28,10 @@ class BaseModelWithIndexes(Model): def test_set_up_composite_index(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -45,9 +50,10 @@ class ModelWithIndexes(BaseModelWithIndexes): call_list = mock_admin_client.create_index.call_args_list # index is a protobuf structure sent via Google Cloud Admin path = call_list[0][1]["request"].parent + expected_prefix = configuration.get_config("(default)").prefix assert ( path - == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS['prefix']}modelWithIndexes" + == f"projects/proj/databases/(default)/collectionGroups/{expected_prefix}modelWithIndexes" ) index = call_list[0][1]["request"].index assert index.query_scope.name == "COLLECTION" @@ -59,6 +65,10 @@ class ModelWithIndexes(BaseModelWithIndexes): def test_set_up_collection_group_index(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_group_index( @@ -77,9 +87,11 @@ class ModelWithIndexes(BaseModelWithIndexes): call_list = mock_admin_client.create_index.call_args_list # index is a protobuf structure sent via Google Cloud Admin path = call_list[0][1]["request"].parent + expected_prefix = configuration.get_config("(default)").prefix + assert ( path - == f"projects/proj/databases/(default)/collectionGroups/{CONFIGURATIONS['prefix']}modelWithIndexes" + == f"projects/proj/databases/(default)/collectionGroups/{expected_prefix}modelWithIndexes" ) index = call_list[0][1]["request"].index assert index.query_scope.name == "COLLECTION_GROUP" @@ -87,6 +99,10 @@ class ModelWithIndexes(BaseModelWithIndexes): def test_set_up_composite_indexes_and_policies(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -110,6 +126,10 @@ class ModelWithIndexes(BaseModelWithIndexes): def test_set_up_many_composite_indexes(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) + class ModelWithIndexes(BaseModelWithIndexes): __composite_indexes__ = ( collection_index( @@ -136,6 +156,10 @@ class ModelWithIndexes(BaseModelWithIndexes): def test_set_up_indexes_model_without_indexes(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) + class ModelWithoutIndexes(Model): __collection__ = "modelWithoutIndexes" @@ -153,13 +177,18 @@ class ModelWithoutIndexes(Model): def test_existing_indexes_are_skipped(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) + expected_prefix = configuration.get_config("(default)").prefix + resp = ListIndexesResponse( { "indexes": [ { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS['prefix']}modelWithIndexes/123456" + f"{expected_prefix}modelWithIndexes/123456" ), "query_scope": "COLLECTION", "fields": [ @@ -171,7 +200,7 @@ def test_existing_indexes_are_skipped(mock_admin_client) -> None: { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS['prefix']}modelWithIndexes/67889" + f"{expected_prefix}modelWithIndexes/67889" ), "query_scope": "COLLECTION", "fields": [ @@ -206,6 +235,11 @@ class ModelWithIndexes(BaseModelWithIndexes): def test_same_fields_in_another_collection(mock_admin_client) -> None: + configuration.add( + name="(default)", prefix="test_", project="proj", client=mock_admin_client + ) + expected_prefix = configuration.get_config("(default)").prefix + # Test that when another collection has an index with exactly the same fields, # it won't affect creating an index in the target collection resp = ListIndexesResponse( @@ -214,7 +248,7 @@ def test_same_fields_in_another_collection(mock_admin_client) -> None: { "name": ( "projects/fake-project/databases/(default)/collectionGroups/" - f"{CONFIGURATIONS['prefix']}anotherModel/123456" + f"{expected_prefix}anotherModel/123456" ), "query_scope": "COLLECTION", "fields": [ diff --git a/firedantic/tests/tests_sync/test_model.py b/firedantic/tests/tests_sync/test_model.py index 37e27a4..77031e2 100644 --- a/firedantic/tests/tests_sync/test_model.py +++ b/firedantic/tests/tests_sync/test_model.py @@ -8,6 +8,7 @@ import firedantic.operators as op from firedantic import Model, get_transaction +from firedantic.configurations import configuration from firedantic.exceptions import ( CollectionNotDefined, InvalidDocumentID, @@ -19,6 +20,7 @@ CustomIDConflictModel, CustomIDModel, CustomIDModelExtra, + Owner, Product, Profile, TodoList, @@ -35,7 +37,7 @@ ] -def test_save_model(configure_db, create_company) -> None: +def test_save_model(create_company) -> None: company = create_company() assert company.id is not None @@ -43,21 +45,7 @@ def test_save_model(configure_db, create_company) -> None: assert company.owner.last_name == "Doe" -def test_delete_model(configure_db, create_company) -> None: - company: Company = create_company( - company_id="11223344-5", first_name="Jane", last_name="Doe" - ) - - _id = company.id - assert _id - - company.delete() - - with pytest.raises(ModelNotFoundError): - Company.get_by_id(_id) - - -def test_find_one(configure_db, create_company) -> None: +def test_find_one(create_company) -> None: with pytest.raises(ModelNotFoundError): Company.find_one() @@ -85,7 +73,7 @@ def test_find_one(configure_db, create_company) -> None: assert first_desc.owner.first_name == "Foo" -def test_find(configure_db, create_company, create_product) -> None: +def test_find(create_company, create_product) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: create_company(company_id=company_id) @@ -97,6 +85,12 @@ def test_find(configure_db, create_company, create_product) -> None: d = Company.find({"owner.first_name": "John"}) assert len(d) == 4 + d = Company.find({"owner.first_name": {op.EQ: "John"}}) + assert len(d) == 4 + + d = Company.find({"owner.first_name": {"==": "John"}}) + assert len(d) == 4 + for p in TEST_PRODUCTS: create_product(**p) @@ -115,7 +109,7 @@ def test_find(configure_db, create_company, create_product) -> None: Product.find({"product_id": {"<>": "a"}}) -def test_find_not_in(configure_db, create_company) -> None: +def test_find_not_in(create_company) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: create_company(company_id=company_id) @@ -135,7 +129,7 @@ def test_find_not_in(configure_db, create_company) -> None: assert company.company_id in ("2131232-4", "4124432-4") -def test_find_array_contains(configure_db, create_todolist) -> None: +def test_find_array_contains(create_todolist) -> None: list_1 = create_todolist("list_1", ["Work", "Eat", "Sleep"]) create_todolist("list_2", ["Learn Python", "Walk the dog"]) @@ -144,7 +138,7 @@ def test_find_array_contains(configure_db, create_todolist) -> None: assert found[0].name == list_1.name -def test_find_array_contains_any(configure_db, create_todolist) -> None: +def test_find_array_contains_any(create_todolist) -> None: list_1 = create_todolist("list_1", ["Work", "Eat"]) list_2 = create_todolist("list_2", ["Relax", "Chill", "Sleep"]) create_todolist("list_3", ["Learn Python", "Walk the dog"]) @@ -155,7 +149,7 @@ def test_find_array_contains_any(configure_db, create_todolist) -> None: assert lst.name in (list_1.name, list_2.name) -def test_find_limit(configure_db, create_company) -> None: +def test_find_limit(create_company) -> None: ids = ["1234555-1", "1234567-8", "2131232-4", "4124432-4"] for company_id in ids: create_company(company_id=company_id) @@ -167,7 +161,7 @@ def test_find_limit(configure_db, create_company) -> None: assert len(companies_2) == 2 -def test_find_order_by(configure_db, create_company) -> None: +def test_find_order_by(create_company) -> None: companies_and_owners = [ {"company_id": "1234555-1", "last_name": "A", "first_name": "A"}, {"company_id": "1234555-2", "last_name": "A", "first_name": "B"}, @@ -211,7 +205,7 @@ def test_find_order_by(configure_db, create_company) -> None: assert companies_and_owners == lastname_ascending_firstname_ascending -def test_find_offset(configure_db, create_company) -> None: +def test_find_offset(create_company) -> None: ids_and_lastnames = ( ("1234555-1", "A"), ("1234567-8", "B"), @@ -228,7 +222,7 @@ def test_find_offset(configure_db, create_company) -> None: assert len(companies_ascending) == 2 -def test_get_by_id(configure_db, create_company) -> None: +def test_get_by_id(create_company) -> None: c: Company = create_company(company_id="1234567-8") assert c.id is not None @@ -242,20 +236,21 @@ def test_get_by_id(configure_db, create_company) -> None: assert c_2.owner.first_name == "John" -def test_get_by_empty_str_id(configure_db) -> None: +def test_get_by_empty_str_id() -> None: with pytest.raises(ModelNotFoundError): Company.get_by_id("") -def test_missing_collection(configure_db) -> None: +def test_missing_collection() -> None: class User(Model): name: str + # normally __collection__ would be defined here with pytest.raises(CollectionNotDefined): User(name="John").save() -def test_model_aliases(configure_db) -> None: +def test_model_aliases() -> None: class User(Model): __collection__ = "User" @@ -299,7 +294,7 @@ class User(Model): "!:&+-*'()", ], ) -def test_models_with_valid_custom_id(configure_db, model_id) -> None: +def test_models_with_valid_custom_id(model_id) -> None: product_id = str(uuid4()) product = Product(product_id=product_id, price=123.45, stock=2) @@ -327,7 +322,7 @@ def test_models_with_valid_custom_id(configure_db, model_id) -> None: "foo/bar/baz", ], ) -def test_models_with_invalid_custom_id(configure_db, model_id: str) -> None: +def test_models_with_invalid_custom_id(model_id: str) -> None: product = Product(product_id="product 123", price=123.45, stock=2) product.id = model_id with pytest.raises(InvalidDocumentID): @@ -337,7 +332,7 @@ def test_models_with_invalid_custom_id(configure_db, model_id: str) -> None: Product.get_by_id(model_id) -def test_truncate_collection(configure_db, create_company) -> None: +def test_truncate_collection(create_company) -> None: create_company(company_id="1234567-8") create_company(company_id="1234567-9") @@ -349,7 +344,7 @@ def test_truncate_collection(configure_db, create_company) -> None: assert len(new_companies) == 0 -def test_custom_id_model(configure_db) -> None: +def test_custom_id_model() -> None: c = CustomIDModel(bar="bar") # type: ignore c.save() @@ -361,7 +356,7 @@ def test_custom_id_model(configure_db) -> None: assert m.bar == "bar" -def test_custom_id_conflict(configure_db) -> None: +def test_custom_id_conflict() -> None: CustomIDConflictModel(foo="foo", bar="bar").save() models = CustomIDModel.find({}) @@ -372,7 +367,7 @@ def test_custom_id_conflict(configure_db) -> None: assert m.bar == "bar" -def test_model_id_persistency(configure_db) -> None: +def test_model_id_persistency() -> None: c = CustomIDConflictModel(foo="foo", bar="bar") c.save() assert c.id @@ -383,7 +378,7 @@ def test_model_id_persistency(configure_db) -> None: assert len(CustomIDConflictModel.find({})) == 1 -def test_bare_model_document_id_persistency(configure_db) -> None: +def test_bare_model_document_id_persistency() -> None: c = CustomIDModel(bar="bar") # type: ignore c.save() assert c.foo @@ -394,18 +389,18 @@ def test_bare_model_document_id_persistency(configure_db) -> None: assert len(CustomIDModel.find({})) == 1 -def test_bare_model_get_by_empty_doc_id(configure_db) -> None: +def test_bare_model_get_by_empty_doc_id() -> None: with pytest.raises(ModelNotFoundError): CustomIDModel.get_by_doc_id("") -def test_extra_fields(configure_db) -> None: +def test_extra_fields() -> None: CustomIDModelExtra(foo="foo", bar="bar", baz="baz").save() # type: ignore with pytest.raises(ValidationError): CustomIDModel.find({}) -def test_company_stats(configure_db, create_company) -> None: +def test_company_stats(create_company) -> None: company: Company = create_company(company_id="1234567-8") company_stats = company.stats() @@ -425,7 +420,7 @@ def test_company_stats(configure_db, create_company) -> None: assert stats.sales == 101 -def test_subcollection_model_safety(configure_db) -> None: +def test_subcollection_model_safety() -> None: """ Ensure you shouldn't be able to use unprepared subcollection models accidentally """ @@ -433,7 +428,7 @@ def test_subcollection_model_safety(configure_db) -> None: UserStats.find({}) -def test_get_user_purchases(configure_db) -> None: +def test_get_user_purchases() -> None: u = User(name="Foo") u.save() assert u.id @@ -444,7 +439,7 @@ def test_get_user_purchases(configure_db) -> None: assert get_user_purchases(u.id) == 42 -def test_reload(configure_db) -> None: +def test_reload() -> None: u = User(name="Foo") u.save() @@ -462,7 +457,7 @@ def test_reload(configure_db) -> None: another_user.reload() -def test_save_with_exclude_none(configure_db) -> None: +def test_save_with_exclude_none() -> None: p = Profile(name="Foo") p.save(exclude_none=True) @@ -483,7 +478,7 @@ def test_save_with_exclude_none(configure_db) -> None: assert data == {"name": "Foo", "photo_url": None} -def test_save_with_exclude_unset(configure_db) -> None: +def test_save_with_exclude_unset() -> None: p = Profile(photo_url=None) p.save(exclude_unset=True) @@ -504,11 +499,9 @@ def test_save_with_exclude_unset(configure_db) -> None: assert data == {"name": "", "photo_url": None} -def test_update_city_in_transaction(configure_db) -> None: +def test_update_city_in_transaction() -> None: """ Test updating a model in a transaction. Test case from README. - - :param: configure_db: pytest fixture """ @transactional @@ -527,35 +520,47 @@ def decrement_population(transaction: Transaction, city: City, decrement: int = assert c.population == 0 -def test_delete_in_transaction(configure_db) -> None: +def test_delete_in_transaction(create_company): """ - Test deleting a model in a transaction. - - :param: configure_db: pytest fixture + Test deleting a Company model within a Firestore transaction. """ + # Create a company + company: Company = create_company( + company_id="11223344-4", first_name="Joe", last_name="Day" + ) + _id = company.id + assert _id @transactional - def delete_in_transaction(transaction: Transaction, profile_id: str) -> None: - """Deletes a Profile in a transaction.""" - profile = Profile.get_by_id(profile_id, transaction=transaction) - profile.delete(transaction=transaction) - - p = Profile(name="Foo") - p.save() - assert p.id + def delete_company(transaction: Transaction) -> None: + company.delete() + # Call the transactional function t = get_transaction() - delete_in_transaction(t, p.id) + delete_company(t) + # Outside the transaction, the deletion should now be committed with pytest.raises(ModelNotFoundError): - Profile.get_by_id(p.id) + Company.get_by_id(_id) -def test_update_model_in_transaction(configure_db) -> None: +def test_delete_model(create_company) -> None: + company: Company = create_company( + company_id="11223344-5", first_name="Jane", last_name="Doe" + ) + + _id = company.id + assert _id + + company.delete() + + with pytest.raises(ModelNotFoundError): + Company.get_by_id(_id) + + +def test_update_model_in_transaction() -> None: """ Test updating a model in a transaction. - - :param: configure_db: pytest fixture """ @transactional @@ -577,11 +582,9 @@ def update_in_transaction( assert p.name == "Bar" -def test_update_submodel_in_transaction(configure_db) -> None: +def test_update_submodel_in_transaction() -> None: """ Test Updating a submodel in a transaction. - - :param: configure_db: pytest fixture """ @transactional diff --git a/integration_tests/configure_firestore_db_clients.py b/integration_tests/configure_firestore_db_clients.py new file mode 100644 index 0000000..216b379 --- /dev/null +++ b/integration_tests/configure_firestore_db_clients.py @@ -0,0 +1,149 @@ +from os import environ +from unittest.mock import Mock + +if not environ.get("FIRESTORE_EMULATOR_HOST"): + environ["FIRESTORE_EMULATOR_HOST"] = "127.0.0.1:8686" + +import google.auth.credentials +from google.cloud.firestore import AsyncClient, Client + +from firedantic.configurations import CONFIGURATIONS, configuration, configure + + +## OLD WAY: +def configure_sync_client(): + # Firestore emulator must be running if using locally. + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = Client( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + else: + client = Client() + + configure(client, prefix="firedantic-sync-") + + # ---- Assertions ---- + assert CONFIGURATIONS["prefix"] == "firedantic-sync-" + assert isinstance(CONFIGURATIONS["db"], Client) + + # project check + assert CONFIGURATIONS["db"].project == "firedantic-test" + + # emulator expectations + assert isinstance(CONFIGURATIONS["db"]._credentials, Mock) + + # ensure no accidental async setting + assert not isinstance(CONFIGURATIONS["db"], AsyncClient) + + +def configure_async_client(): + # Firestore emulator must be running if using locally. + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = AsyncClient( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + else: + client = AsyncClient() + + configure(client, prefix="firedantic-async-") + + # ---- Assertions ---- + assert CONFIGURATIONS["prefix"] == "firedantic-async-" + assert isinstance(CONFIGURATIONS["db"], AsyncClient) + + # project check + assert CONFIGURATIONS["db"].project == "firedantic-test" + + # emulator expectations + assert isinstance(CONFIGURATIONS["db"]._credentials, Mock) + + # ensure no accidental async setting + assert not isinstance(CONFIGURATIONS["db"], Client) + + +## NEW WAY: +def configure_multiple_clients(): + config = configuration + + # name = (default) + config.add( + prefix="firedantic-test-", + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # name = billing + config.add( + name="billing", + prefix="test-billing-", + project="test-billing", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # =========================== + # VALIDATE DEFAULT CONFIG + # =========================== + default = config.get_config() # loads "(default)" + + # Basic prefix & project + assert default.prefix == "firedantic-test-" + assert default.project == "firedantic-test" + + # Client objects exist when called + assert config.get_client() is not None + assert config.get_async_client() is not None + + # Types are correct + assert isinstance(default.client, Client) + assert isinstance(default.async_client, AsyncClient) + + # Credentials are what we passed in + assert isinstance(default.credentials, Mock) + + # Config can return correct client API + assert isinstance(config.get_client(), Client) + assert isinstance(config.get_async_client(), AsyncClient) + + # =========================== + # VALIDATE BILLING CONFIG + # =========================== + billing = config.get_config("billing") + + # Prefix & project + assert billing.prefix == "test-billing-" + assert billing.project == "test-billing" + + # Client objects exist when called + assert config.get_client("billing") is not None + assert config.get_async_client("billing") is not None + + # Correct client types + assert isinstance(billing.client, Client) + assert isinstance(billing.async_client, AsyncClient) + + # Correct credentials object + assert isinstance(billing.credentials, Mock) + + # Ensure default and billing configs are distinct objects + assert billing is not default + + # Make sure nothing leaked between configs + assert billing.prefix != default.prefix + assert billing.project != default.project + + +try: + ### ---- Running OLD way ---- + configure_sync_client() + configure_async_client() + + ### ---- Running NEW way ---- + configure_multiple_clients() + + print("\nAll configure_firestore_db_client tests passed!\n") + +except AssertionError as e: + print(f"\nconfigure_firestore_db_client tests failed: {e}\n") + raise e diff --git a/integration_tests/full_async_flow.py b/integration_tests/full_async_flow.py new file mode 100644 index 0000000..76f4b18 --- /dev/null +++ b/integration_tests/full_async_flow.py @@ -0,0 +1,278 @@ +import sys +from os import environ +from unittest.mock import Mock + +if not environ.get("FIRESTORE_EMULATOR_HOST"): + environ["FIRESTORE_EMULATOR_HOST"] = "127.0.0.1:8686" + +import google.auth.credentials + +from firedantic import AsyncModel +from firedantic.configurations import AsyncClient, configuration, configure + + +async def test_old_way(): + # This is the old way of doing things, kept for backwards compatibility. + # Firestore emulator must be running if using locally. + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = AsyncClient( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + else: + client = AsyncClient() + + configure(client, prefix="firedantic-test-") + + class Owner(AsyncModel): + """Dummy owner Pydantic model.""" + + first_name: str + last_name: str + + class Company(AsyncModel): + """Dummy company Firedantic model.""" + + __collection__ = "companies" + company_id: str + owner: Owner + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="Bill", last_name="Smith") + company = Company(company_id="1234567-7", owner=owner) + await company.save() + + # Reloads model data from the database + await company.reload() + + # Finding data from DB + print( + f"\nNumber of company owners with first name: 'Bill': {len(await Company.find({"owner.first_name": "Bill"}))}" + ) + print( + f"\nNumber of companies with id: '1234567-7': {len(await Company.find({"company_id": "1234567-7"}))}" + ) + + # Delete everything from the database + await company.delete() + deletion_success = [] == await Company.find({"company_id": "1234567-7"}) + if not deletion_success: + print(f"\nDeletion of (default) DB failed\n") + + +## With single async client +async def test_with_default(): + + class Owner(AsyncModel): + """Dummy owner Pydantic model.""" + + __collection__ = "owners" + first_name: str + last_name: str + + class Company(AsyncModel): + """Dummy company Firedantic model.""" + + __collection__ = "companies" + company_id: str + owner: Owner + + # Firestore emulator must be running if using locally, name defaults to "(default)" + configuration.add( + prefix="async-default-test-", + project="async-default-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="John", last_name="Doe") + company = Company(company_id="1234567-8a", owner=owner) + + # Save the company/owner info to the DB + await company.save() # only need to include config_name when not using default + + # Reloads model data from the database to ensure most-up-to-date info + await company.reload() + + # Assert that async client exists and configuration is correct + assert isinstance(configuration.get_async_client(), AsyncClient) + assert ( + configuration.get_collection_name(Owner) + == configuration.get_config().prefix + "owners" + ) + assert ( + configuration.get_collection_name(Company) + == configuration.get_config().prefix + "companies" + ) + assert configuration.get_config().prefix == "async-default-test-" + assert configuration.get_config().project == "async-default-test" + assert configuration.get_config().name == "(default)" + + # Finding data from DB + print( + f"\nNumber of company owners with first name: 'John': {len(await Company.find({"owner.first_name": "John"}))}" + ) + + print( + f"\nNumber of companies with id: '1234567-8a': {len(await Company.find({"company_id": "1234567-8a"}))}" + ) + + # Delete everything from the database + await company.delete() + deletion_success = [] == await Company.find({"company_id": "1234567-8a"}) + if not deletion_success: + print(f"\nDeletion of (default) DB failed\n") + + +# Now with multiple ASYNC clients/dbs: +async def test_with_multiple(): + + config_name = "companies" + + class Owner(AsyncModel): + """Dummy owner Pydantic model.""" + + __db_config__ = config_name + __collection__ = "owners" + first_name: str + last_name: str + + class Company(AsyncModel): + """Dummy company Firedantic model.""" + + __db_config__ = config_name + __collection__ = config_name + company_id: str + owner: Owner + + # 1. Add first config with config_name = "companies" + configuration.add( + name=config_name, + prefix="async-companies-test-", + project="async-companies-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="Alice", last_name="Begone") + company = Company(company_id="1234567-9", owner=owner) + + # Save the company/owner info to the DB + await company.save() # only need to include config_name when not using default + + # Reloads model data from the database to ensure most-up-to-date info + await company.reload() + + # ##################################################### + + # 2. Add the second config with config_name = "billing" + config_name = "billing" + + class BillingAccount(AsyncModel): + """Dummy billing account Pydantic model.""" + + __db_config__ = config_name + __collection__ = "accounts" + name: str + billing_id: int + owner: str + + class BillingCompany(AsyncModel): + """Dummy company Firedantic model.""" + + __db_config__ = config_name + __collection__ = "companies" + company_id: str + billing_account: BillingAccount + + configuration.add( + name=config_name, + prefix="async-billing-test-", + project="async-billing-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Assert that async clients exists and added configurations are correct + assert isinstance(configuration.get_async_client("billing"), AsyncClient) + + assert ( + configuration.get_collection_name(BillingAccount, "billing") + == configuration.get_config("billing").prefix + "accounts" + ) + assert ( + configuration.get_collection_name(BillingCompany, "billing") + == configuration.get_config("billing").prefix + "companies" + ) + + assert configuration.get_config("billing").prefix == "async-billing-test-" + assert configuration.get_config("billing").project == "async-billing-test" + assert configuration.get_config("billing").name == "billing" + + assert isinstance(configuration.get_async_client("companies"), AsyncClient) + + assert ( + configuration.get_collection_name(Company, "companies") + == configuration.get_config("companies").prefix + "companies" + ) + + assert configuration.get_config("companies").prefix == "async-companies-test-" + assert configuration.get_config("companies").project == "async-companies-test" + assert configuration.get_config("companies").name == "companies" + + # Create and save billing account and billing company + account = BillingAccount(name="ABC Billing", billing_id=801048, owner="MFisher") + bc = BillingCompany(company_id="1234567-8c", billing_account=account) + + await account.save() + await bc.save() + + # Reloads model data from the database + await account.reload() + await bc.reload() + + # 3. Finding data + print( + f"\nNumber of company owners with first name: 'Alice': {len(await Company.find({"owner.first_name": "Alice"}))}" + ) + + print( + f"\nNumber of billing companies with id: '1234567-8c': {len(await BillingCompany.find({"company_id": "1234567-8c"}))}" + ) + + print( + f"\nNumber of billing accounts with billing_id: 801048: {len(await BillingCompany.find({"billing_account.billing_id": 801048}))}" + ) + + # Delete everything from the database + await company.delete() + await bc.delete() + + deletion_success = [] == await Company.find({"company_id": "1234567-8a"}) + if not deletion_success: + print(f"\nDeletion of Company DB failed\n") + + deletion_success = [] == await BillingCompany.find( + {"billing_account.billing_id": 801048} + ) + if not deletion_success: + print(f"\nDeletion of BillingCompany DB failed\n") + + +# suppress silent error msg from asyncio threads +def dbg_hook(exctype, value, tb): + print("=== Uncaught exception ===") + print(exctype, value) + import traceback + + traceback.print_tb(tb) + + +sys.excepthook = dbg_hook + +# Run the tests +if __name__ == "__main__": + import asyncio + + asyncio.run(test_old_way()) + asyncio.run(test_with_default()) + asyncio.run(test_with_multiple()) diff --git a/integration_tests/full_readme_examples.py b/integration_tests/full_readme_examples.py new file mode 100644 index 0000000..033ac8b --- /dev/null +++ b/integration_tests/full_readme_examples.py @@ -0,0 +1,628 @@ +#!/usr/bin/env python3 +""" +integration_test_readme_full.py + +Integration-style test script covering README examples: +- new Configuration API and legacy configure() +- sync & async model operations: save / reload / find / delete +- transactions (sync & async) +- subcollections +- multiple named configurations / per-model __db_config__ +- composite indexes + TTL policies (attempted, skipped if admin client not available) + +Run with a running Firestore emulator: +export FIRESTORE_EMULATOR_HOST="127.0.0.1:8686" +python test_integration_all.py +""" +from __future__ import annotations + +import asyncio +import os +import sys +import time +from datetime import datetime, timedelta +from typing import Optional +from unittest.mock import Mock + +import google.auth.credentials + +if not os.environ.get("FIRESTORE_EMULATOR_HOST"): + os.environ["FIRESTORE_EMULATOR_HOST"] = "127.0.0.1:8686" + +# require firedantic package (adjust path or install locally) +try: + # recommended imports + from google.cloud.firestore import Client, Query + from google.cloud.firestore_admin_v1.services.firestore_admin import ( + FirestoreAdminAsyncClient, + FirestoreAdminClient, + ) + + # from google.cloud.firestore import AsyncClient as AsyncClientSyncName # alias to avoid name clash + from google.cloud.firestore_v1 import ( + AsyncClient, + Transaction, + async_transactional, + transactional, + ) + + from firedantic import ( # # IndexField, + AsyncModel, + Model, + async_set_up_composite_indexes_and_ttl_policies, + collection_group_index, + collection_index, + get_all_subclasses, + get_async_transaction, + get_transaction, + set_up_composite_indexes_and_ttl_policies, + ) + from firedantic.common import IndexField + from firedantic.configurations import CONFIGURATIONS, configuration, configure + +except Exception as e: + print( + "Failed to import required modules. Ensure `firedantic` is importable and google-cloud-* libs are installed." + ) + raise + + +EMULATOR = os.environ.get("FIRESTORE_EMULATOR_HOST") +USE_EMULATOR = bool(EMULATOR) + + +def must(msg: str): + print("ERROR:", msg) + sys.exit(2) + + +def info(msg: str): + print(msg) + + +# ----------------------- +# Helpers +# ----------------------- +def assert_or_exit(cond: bool, message: str): + if not cond: + print("ASSERTION FAILED:", message) + raise AssertionError(message) + + +def cleanup_sync_collection_by_model(model_cls: type[Model]) -> None: + """Delete everything from a model's collection (sync).""" + info(f"Cleaning collection for sync model {model_cls.__name__}") + # find all docs and delete + for obj in model_cls.find({}): + obj.delete() + + +async def cleanup_async_collection_by_model(model_cls: type[AsyncModel]) -> None: + info(f"Cleaning collection for async model {model_cls.__name__}") + items = await model_cls.find({}) + for it in items: + await it.delete() + + +# admin helpers +def get_admin_client_if_possible(): + """Return admin client or None (if emulator/admin not available).""" + try: + if USE_EMULATOR: + # Emulator usually accepts the same transport; credentials can be a mock + creds = Mock(spec=google.auth.credentials.Credentials) + client = FirestoreAdminClient(credentials=creds) + else: + client = FirestoreAdminClient() + return client + except Exception as e: + info(f"FirestoreAdminClient unavailable/skipping admin steps: {e}") + return None + + +async def get_async_admin_client_if_possible(): + try: + if USE_EMULATOR: + creds = Mock(spec=google.auth.credentials.Credentials) + client = FirestoreAdminAsyncClient(credentials=creds) + else: + client = FirestoreAdminAsyncClient() + return client + except Exception as e: + info(f"FirestoreAdminAsyncClient unavailable/skipping admin steps: {e}") + return None + + +# ----------------------- +# Test cases +# ----------------------- +def test_new_config_sync_flow(): + info("\n=== new Configuration API: sync flow ===") + # register a default configuration (use mock creds for emulator) + mock_creds = Mock(spec=google.auth.credentials.Credentials) + + # default config + configuration.add(prefix="app-", project="my-project") + + # extra config + configuration.add(name="billing", prefix="billing-", project="billing-project") + + # get clients + client = configuration.get_client() # sync client for "(default)" + # async_client = configuration.get_async_client() # async client for "(default)" + billing_client = configuration.get_client("billing") # sync client for "billing" + + class Billing(Model): + __db_config__ = "billing" + __collection__ = "billing_accounts" + billing_id: str + name: str + + # cleanup + cleanup_sync_collection_by_model(Billing) + + # create and save + b = Billing(billing_id="9901981", name="Firedantic Billing") + b.save() + print(f"billing.id after save: {b.id}") + assert_or_exit(b.id is not None, "Billing save did not set id") + + # reload + b.reload() + assert_or_exit(b.billing_id == "9901981", "billing_id mismatch after reload") + + # find + results = Billing.find({"name": "Firedantic Billing"}) + assert_or_exit(isinstance(results, list), "find returned not a list") + assert_or_exit( + any(r.billing_id == "9901981" for r in results), + "find did not return saved billing account", + ) + + # delete + b.delete() + time.sleep(0.1) # allow emulator a bit of time to delete + remaining = Billing.find({"billing_id": "9901981"}) + assert_or_exit(all(r.company_id != "9901981" for r in remaining), "delete failed") + + print("new config sync flow OK") + + +def test_legacy_config_sync_flow(): + info("\n=== legacy configure() sync flow ===") + # legacy configure() should still populate CONFIGURATIONS and work + mock_creds = Mock(spec=google.auth.credentials.Credentials) + # create a sync Client for legacy configure + if USE_EMULATOR: + client = Client(project="legacy-project", credentials=mock_creds) + else: + client = Client() + + configure(client, prefix="legacy-sync-") + assert_or_exit( + CONFIGURATIONS["prefix"].startswith("legacy-sync-"), + "legacy configure prefix mismatch", + ) + + # use same model pattern as earlier + class LegacyOwner(Model): + first_name: str + last_name: str + + class LegacyCompany(Model): + __collection__ = "companies" + company_id: str + owner: LegacyOwner + + # cleanup and run + cleanup_sync_collection_by_model(LegacyCompany) + o = LegacyOwner(first_name="L", last_name="User") + comp = LegacyCompany(company_id="L-1", owner=o) + comp.save() + comp.reload() + assert_or_exit(comp.company_id == "L-1", "legacy reload failed") + comp.delete() + info("legacy sync flow OK") + + +async def test_new_config_async_flow(): + info("\n=== new Configuration API: async flow ===") + mock_creds = Mock(spec=google.auth.credentials.Credentials) + configuration.add( + prefix="integ-async-", project="test-project", credentials=mock_creds + ) + + class Owner(AsyncModel): + first_name: str + last_name: str + + class Company(AsyncModel): + __collection__ = "companies" + company_id: str + owner: Owner + + await cleanup_async_collection_by_model(Company) + + owner = Owner(first_name="Alice", last_name="Smith") + c = Company(company_id="A-1", owner=owner) + await c.save() + print(f"company.id after save: {c.id}") + assert_or_exit(c.id is not None, "async save didn't set id") + + await c.reload() + assert_or_exit(c.company_id == "A-1", "async reload company_id mismatch") + assert_or_exit(c.owner.first_name == "Alice", "async owner first name mismatch") + + found = await Company.find({"company_id": "A-1"}) + assert_or_exit( + len(found) >= 1 and found[0].company_id == "A-1", "async find failed" + ) + + await c.delete() + remains = await Company.find({"company_id": "A-1"}) + assert_or_exit(all(x.company_id != "A-1" for x in remains), "async delete failed") + + print("new config async flow OK") + + +async def test_legacy_config_async_flow(): + info("\n=== legacy configure() async flow ===") + mock_creds = Mock(spec=google.auth.credentials.Credentials) + # create async client and call legacy configure + if USE_EMULATOR: + aclient = AsyncClient(project="legacy-project", credentials=mock_creds) + else: + aclient = AsyncClient() + # legacy configure supports AsyncClient and sets CONFIGURATIONS + configure(aclient, prefix="legacy-async-") + assert_or_exit( + CONFIGURATIONS["prefix"].startswith("legacy-async-"), + "legacy async configure failed", + ) + + class LOwner(AsyncModel): + first_name: str + last_name: str + + class LCompany(AsyncModel): + __collection__ = "companies" + company_id: str + owner: LOwner + + await cleanup_async_collection_by_model(LCompany) + o = LOwner(first_name="LA", last_name="User") + c = LCompany(company_id="LA-1", owner=o) + await c.save() + await c.reload() + assert_or_exit(c.company_id == "LA-1", "legacy async reload failed") + await c.delete() + info("legacy async flow OK") + + +def test_subcollections_and_model_for(): + info("\n=== subcollections (model_for) ===") + # ensure configuration exists + mock_creds = Mock(spec=google.auth.credentials.Credentials) + configuration.add(prefix="sc-", project="test-project", credentials=mock_creds) + + class Parent(Model): + __collection__ = "parents" + name: str + + class Stats(Model): + __collection__ = None # will be supplied via Collection + __document_id__ = "id" + purchases: int = 0 + + class Collection: + __collection_tpl__ = "parents/{id}/stats" + + # create parent and get submodel + cleanup_sync_collection_by_model(Parent) + p = Parent(name="P1") + p.save() + # build subcollection model class for parent + StatsCollection = type("StatsForParent", (Stats,), {}) + # attach collection tpl like library would; simpler: use provided pattern via utility in README (model_for pattern) + # Here we emulate model_for behavior: + StatsCollection.__collection_cls__ = Stats.Collection + StatsCollection.__collection__ = Stats.Collection.__collection_tpl__.format(id=p.id) + StatsCollection.__document_id__ = "id" + + s = StatsCollection(purchases=3) + s.save() + # find back + found = StatsCollection.find({"purchases": 3}) + assert_or_exit( + any(x.purchases == 3 for x in found), "subcollection save/find failed" + ) + # cleanup + for x in StatsCollection.find({}): + x.delete() + for x in Parent.find({}): + x.delete() + info("subcollections OK") + + +def test_transactions_sync(): + info("\n=== sync transactions ===") + # mock_creds = Mock(spec=google.auth.credentials.Credentials) + + # Configure once + configuration.add( + project="firedantic-test", + prefix="firedantic-test-", + ) + + class City(Model): + __collection__ = "cities" + population: int = 0 + + def increment_population(self, increment: int = 1): + @transactional + def _increment_population(transaction: Transaction) -> None: + self.reload(transaction=transaction) + self.population += increment + self.save(transaction=transaction) + + t = get_transaction() + _increment_population(transaction=t) + + # cleanup + cleanup_sync_collection_by_model(City) + + @transactional + def decrement_population(transaction: Transaction, city: City, decrement: int = 1): + city.reload(transaction=transaction) + city.population = max(0, city.population - decrement) + city.save(transaction=transaction) + + # create object + c = City(id="SF", population=1) + c.save() + + c.increment_population(increment=2) + assert c.population == 3 + + t = get_transaction() + decrement_population(transaction=t, city=c, decrement=1) + assert c.population == 2 + + # reload + c.reload() + assert_or_exit(c.population == 2, "sync transaction increment failed") + + # delete + c.delete() + print("sync transactions OK") + + +async def test_transactions_async(): + info("\n=== async transactions ===") + + # Configure once + configuration.add( + project="async-tx-test", + prefix="async-tx-test-", + ) + + class CityA(AsyncModel): + __collection__ = "cities" + population: int = 0 + + await cleanup_async_collection_by_model(CityA) + + @async_transactional + async def increment_async(tx, city_id: str, inc: int = 1): + c = await CityA.get_by_id(city_id, transaction=tx) + c.population += inc + await c.save(transaction=tx) + + # create ans save object + c = CityA(id="AC1", population=2) + await c.save() + + # increment with transaction + t = get_async_transaction() + await increment_async(transaction=t, city_id=c.id, inc=3) + + # reload + await c.reload() + assert_or_exit(c.population == 5, "async transaction failed") + + # delete + await c.delete() + print("async transactions OK") + + +def test_multi_config_usage(): + info("\n=== multi-config usage ===") + + mock_creds = Mock(spec=google.auth.credentials.Credentials) + + # default config + configuration.add( + prefix="multi-default-", project="proj-default", credentials=mock_creds + ) + + # billing config + configuration.add( + name="billing", + prefix="multi-billing-", + project="proj-billing", + credentials=mock_creds, + ) + + class CompanyDefault(Model): + __collection__ = "companies" + company_id: str + + class BillingAccount(Model): + __db_config__ = "billing" + __collection__ = "billing_accounts" + billing_id: str + name: str + + # cleanup + cleanup_sync_collection_by_model(CompanyDefault) + cleanup_sync_collection_by_model(BillingAccount) + + # save to default + cd = CompanyDefault(company_id="MD-1") + cd.save() + cd.reload() + assert_or_exit(cd.company_id == "MD-1", "default config save/reload failed") + + # save to billing (model defines __db_config__) + ba = BillingAccount(billing_id="B-1", name="Bill Co") + ba.save() # should use billing config + ba.reload() + assert_or_exit(ba.billing_id == "B-1", "billing model save/reload failed") + + # find calls (default vs billing) + found_default = CompanyDefault.find({"company_id": "MD-1"}) + found_billing = BillingAccount.find({"billing_id": "B-1"}) + assert_or_exit( + any(x.company_id == "MD-1" for x in found_default), "find default failed" + ) + assert_or_exit( + any(x.billing_id == "B-1" for x in found_billing), "find billing failed" + ) + + # cleanup + cleanup_sync_collection_by_model(CompanyDefault) + cleanup_sync_collection_by_model(BillingAccount) + info("multi-config usage OK") + + +def test_indexes_and_ttl_sync(): + info("\n=== composite indexes and TTL (sync) ===") + + admin_client = get_admin_client_if_possible() + if not admin_client: + info("Skipping sync index/TTL setup (admin client not available).") + return + + class ExpiringModel(Model): ## uses "(default)" config name + __collection__ = "expiringModel" + __ttl_field__ = "expire" + __composite_indexes__ = [ + collection_index( + IndexField("content", Query.ASCENDING), + IndexField("expire", Query.DESCENDING), + ), + collection_group_index( + IndexField("content", Query.DESCENDING), + IndexField("expire", Query.ASCENDING), + ), + ] + content: str + expire: datetime + + # register config with admin client (best-effort) + mock_creds = Mock(spec=google.auth.credentials.Credentials) + configuration.add( + prefix="idx-sync-", + project="proj-idx", + credentials=mock_creds, + client=None, + async_client=None, + admin_client=admin_client, + ) + + try: + # call setup function (may require admin permission; emulator may accept) + project = configuration.get_config().project + if not USE_EMULATOR: + # real GCP -> try to set up indexes + set_up_composite_indexes_and_ttl_policies( + gcloud_project=project, models=get_all_subclasses(Model) + ) + info("sync index/ttl setup called (no exception)") + else: + # emulator -> skip (or log) + print("Skipping index/TTL setup when using emulator.") + + except Exception as e: + info(f"index/ttl setup raised (will continue): {e}") + + +async def test_indexes_and_ttl_async(): + info("\n=== composite indexes and TTL (async) ===") + admin_client = await get_async_admin_client_if_possible() + if not admin_client: + info("Skipping async index/TTL setup (async admin client not available).") + return + + class ExpiringAsync(AsyncModel): + __collection__ = "expiringModel" + __ttl_field__ = "expire" + __composite_indexes__ = [ + collection_index( + IndexField("content", Query.ASCENDING), + IndexField("expire", Query.DESCENDING), + ), + ] + content: str + expire: datetime + + mock_creds = Mock(spec=google.auth.credentials.Credentials) + configuration.add(prefix="idx-async-", project="proj-idx", credentials=mock_creds) + + try: + project = configuration.get_config().project + if not USE_EMULATOR: + # real GCP -> try to set up indexes + await async_set_up_composite_indexes_and_ttl_policies( + gcloud_project=project, models=get_all_subclasses(AsyncModel) + ) + info("async index/ttl setup called (no exception)") + else: + # emulator -> skip (or log) + print("Skipping index/TTL setup when using emulator.") + + except Exception as e: + info(f"async index/ttl setup raised (will continue): {e}") + + +# ----------------------- +# Runner +# ----------------------- +def main(): + info("Beginning full integration test following README examples") + if not USE_EMULATOR: + info( + "WARNING: FIRESTORE_EMULATOR_HOST not set — this script is designed for emulator runs; continue with caution." + ) + + # run sync new config flow + test_new_config_sync_flow() + + # legacy sync flow + test_legacy_config_sync_flow() + + # run async new config flow + asyncio.run(test_new_config_async_flow()) ## Currently deletion fails! + + # legacy async flow + asyncio.run(test_legacy_config_async_flow()) + + # subcollections + test_subcollections_and_model_for() + + # transactions + test_transactions_sync() + asyncio.run(test_transactions_async()) + + # multi-config usage + test_multi_config_usage() + + # indexes & ttl (sync & async) (best-effort) + test_indexes_and_ttl_sync() + asyncio.run(test_indexes_and_ttl_async()) + + info("\nIntegration README full test finished successfully.") + + +if __name__ == "__main__": + main() diff --git a/integration_tests/full_sync_flow.py b/integration_tests/full_sync_flow.py new file mode 100644 index 0000000..16dd6fa --- /dev/null +++ b/integration_tests/full_sync_flow.py @@ -0,0 +1,248 @@ +from os import environ +from unittest.mock import Mock + +if not environ.get("FIRESTORE_EMULATOR_HOST"): + environ["FIRESTORE_EMULATOR_HOST"] = "127.0.0.1:8686" + +import google.auth.credentials +from pydantic import BaseModel + +from firedantic import Model +from firedantic.configurations import Client, configuration, configure + + +def test_old_way(): + # This is the old way of doing things, kept for backwards compatibility. + # Firestore emulator must be running if using locally. + if environ.get("FIRESTORE_EMULATOR_HOST"): + client = Client( + project="firedantic-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + else: + client = Client() + + configure(client, prefix="firedantic-test-") + + class Owner(BaseModel): + """Dummy owner Pydantic model.""" + + first_name: str + last_name: str + + class Company(Model): + """Dummy company Firedantic model.""" + + __collection__ = "companies" + company_id: str + owner: Owner + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="Bill", last_name="Smith") + company = Company(company_id="1234567-7", owner=owner) + company.save() + + # Reloads model data from the database + company.reload() + + # Finding data from DB + print( + f"\nNumber of company owners with first name: 'Bill': {len(Company.find({"owner.first_name": "Bill"}))}" + ) + print( + f"\nNumber of companies with id: '1234567-7': {len(Company.find({"company_id": "1234567-7"}))}" + ) + + # Delete everything from the database + company.delete() + + +## With single sync client +def test_with_default(): + + class Owner(Model): + """Dummy owner Pydantic model.""" + + __collection__ = "owners" + first_name: str + last_name: str + + class Company(Model): + """Dummy company Firedantic model.""" + + __collection__ = "companies" + company_id: str + owner: Owner + + # Firestore emulator must be running if using locally, name defaults to "(default)" + configuration.add( + prefix="sync-default-test-", + project="sync-default-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="John", last_name="Doe") + company = Company(company_id="1234567-8", owner=owner) + + # Save the company/owner info to the DB + company.save() + + # Reloads model data from the database to ensure most-up-to-date info + company.reload() + + # Assert that sync client exists and configuration is correct + assert isinstance(configuration.get_client(), Client) + assert ( + configuration.get_collection_name(Owner) + == configuration.get_config().prefix + "owners" + ) + assert ( + configuration.get_collection_name(Company) + == configuration.get_config().prefix + "companies" + ) + assert configuration.get_config().prefix == "sync-default-test-" + assert configuration.get_config().project == "sync-default-test" + assert configuration.get_config().name == "(default)" + + # Finding data from DB + print( + f"\nNumber of company owners with first name: 'John': {len(Company.find({"owner.first_name": "John"}))}" + ) + print( + f"\nNumber of companies with id: '1234567-8': {len(Company.find({"company_id": "1234567-8"}))}" + ) + + # Delete everything from the database + company.delete() + + +# Now with multiple SYNC clients/dbs: +def test_with_multiple(): + + config_name = "companies" + + class Owner(Model): + """Dummy owner Pydantic model.""" + + __db_config__ = config_name + __collection__ = "owners" + first_name: str + last_name: str + + class Company(Model): + """Dummy company Firedantic model.""" + + __db_config__ = config_name + __collection__ = config_name + company_id: str + owner: Owner + + # 1. Add first config with config_name = "companies" + configuration.add( + name=config_name, + prefix="sync-companies-test-", + project="sync-companies-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Now you can use the model to save it to Firestore + owner = Owner(first_name="Alice", last_name="Begone") + company = Company(company_id="1234567-9", owner=owner) + + company.save() + + # Reloads model data from the database + company.reload() + + # ##################################################### + + # 2. Add the second config with config_name = "billing" + config_name = "billing" + + class BillingAccount(Model): + """Dummy billing account Pydantic model.""" + + __db_config__ = config_name + __collection__ = "accounts" + name: str + billing_id: int + owner: str + + class BillingCompany(Model): + """Dummy company Firedantic model.""" + + __db_config__ = config_name + __collection__ = "companies" + company_id: str + billing_account: BillingAccount + + configuration.add( + name=config_name, + prefix="sync-billing-test-", + project="sync-billing-test", + credentials=Mock(spec=google.auth.credentials.Credentials), + ) + + # Assert that sync clients exists and added configurations are correct + assert isinstance(configuration.get_client("billing"), Client) + + assert ( + configuration.get_collection_name(BillingAccount, "billing") + == configuration.get_config("billing").prefix + "accounts" + ) + assert ( + configuration.get_collection_name(BillingCompany, "billing") + == configuration.get_config("billing").prefix + "companies" + ) + + assert configuration.get_config("billing").prefix == "sync-billing-test-" + assert configuration.get_config("billing").project == "sync-billing-test" + assert configuration.get_config("billing").name == "billing" + + assert isinstance(configuration.get_client("companies"), Client) + + assert ( + configuration.get_collection_name(Company, "companies") + == configuration.get_config("companies").prefix + "companies" + ) + + assert configuration.get_config("companies").prefix == "sync-companies-test-" + assert configuration.get_config("companies").project == "sync-companies-test" + assert configuration.get_config("companies").name == "companies" + + # Create and save billing account and company + account = BillingAccount(name="ABC Billing", billing_id=801048, owner="MFisher") + bc = BillingCompany(company_id="1234567-8c", billing_account=account) + + account.save() + bc.save() + + # Reloads model data from the database + account.reload() + bc.reload() + + # 3. Finding data + print( + f"\nNumber of company owners with first name: 'Alice': {len(Company.find({"owner.first_name": "Alice"}))}" + ) + print( + f"\nNumber of billing companies with id: '1234567-8c': {len(BillingCompany.find({"company_id": "1234567-8c"}))}" + ) + print( + f"\nNumber of billing accounts with billing_id: 801048: {len(BillingCompany.find({"billing_account.billing_id": 801048}))}" + ) + + # now delete everything from the DBs: + company.delete() + bc.delete() + + +## Ensure existing configure function works for backwards compatibility +test_old_way() + +## Ensure new way with single default client works +test_with_default() + +## Ensure new way with multiple clients works +test_with_multiple() diff --git a/poetry.lock b/poetry.lock index 3b6922f..50d55d3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -126,6 +126,114 @@ files = [ {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] +[[package]] +name = "coverage" +version = "7.13.0" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "coverage-7.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:02d9fb9eccd48f6843c98a37bd6817462f130b86da8660461e8f5e54d4c06070"}, + {file = "coverage-7.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:367449cf07d33dc216c083f2036bb7d976c6e4903ab31be400ad74ad9f85ce98"}, + {file = "coverage-7.13.0-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:cdb3c9f8fef0a954c632f64328a3935988d33a6604ce4bf67ec3e39670f12ae5"}, + {file = "coverage-7.13.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d10fd186aac2316f9bbb46ef91977f9d394ded67050ad6d84d94ed6ea2e8e54e"}, + {file = "coverage-7.13.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7f88ae3e69df2ab62fb0bc5219a597cb890ba5c438190ffa87490b315190bb33"}, + {file = "coverage-7.13.0-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c4be718e51e86f553bcf515305a158a1cd180d23b72f07ae76d6017c3cc5d791"}, + {file = "coverage-7.13.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a00d3a393207ae12f7c49bb1c113190883b500f48979abb118d8b72b8c95c032"}, + {file = "coverage-7.13.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a7b1cd820e1b6116f92c6128f1188e7afe421c7e1b35fa9836b11444e53ebd9"}, + {file = "coverage-7.13.0-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:37eee4e552a65866f15dedd917d5e5f3d59805994260720821e2c1b51ac3248f"}, + {file = "coverage-7.13.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:62d7c4f13102148c78d7353c6052af6d899a7f6df66a32bddcc0c0eb7c5326f8"}, + {file = "coverage-7.13.0-cp310-cp310-win32.whl", hash = "sha256:24e4e56304fdb56f96f80eabf840eab043b3afea9348b88be680ec5986780a0f"}, + {file = "coverage-7.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:74c136e4093627cf04b26a35dab8cbfc9b37c647f0502fc313376e11726ba303"}, + {file = "coverage-7.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0dfa3855031070058add1a59fdfda0192fd3e8f97e7c81de0596c145dea51820"}, + {file = "coverage-7.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fdb6f54f38e334db97f72fa0c701e66d8479af0bc3f9bfb5b90f1c30f54500f"}, + {file = "coverage-7.13.0-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:7e442c013447d1d8d195be62852270b78b6e255b79b8675bad8479641e21fd96"}, + {file = "coverage-7.13.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1ed5630d946859de835a85e9a43b721123a8a44ec26e2830b296d478c7fd4259"}, + {file = "coverage-7.13.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7f15a931a668e58087bc39d05d2b4bf4b14ff2875b49c994bbdb1c2217a8daeb"}, + {file = "coverage-7.13.0-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:30a3a201a127ea57f7e14ba43c93c9c4be8b7d17a26e03bb49e6966d019eede9"}, + {file = "coverage-7.13.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7a485ff48fbd231efa32d58f479befce52dcb6bfb2a88bb7bf9a0b89b1bc8030"}, + {file = "coverage-7.13.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:22486cdafba4f9e471c816a2a5745337742a617fef68e890d8baf9f3036d7833"}, + {file = "coverage-7.13.0-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:263c3dbccc78e2e331e59e90115941b5f53e85cfcc6b3b2fbff1fd4e3d2c6ea8"}, + {file = "coverage-7.13.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e5330fa0cc1f5c3c4c3bb8e101b742025933e7848989370a1d4c8c5e401ea753"}, + {file = "coverage-7.13.0-cp311-cp311-win32.whl", hash = "sha256:0f4872f5d6c54419c94c25dd6ae1d015deeb337d06e448cd890a1e89a8ee7f3b"}, + {file = "coverage-7.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51a202e0f80f241ccb68e3e26e19ab5b3bf0f813314f2c967642f13ebcf1ddfe"}, + {file = "coverage-7.13.0-cp311-cp311-win_arm64.whl", hash = "sha256:d2a9d7f1c11487b1c69367ab3ac2d81b9b3721f097aa409a3191c3e90f8f3dd7"}, + {file = "coverage-7.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:0b3d67d31383c4c68e19a88e28fc4c2e29517580f1b0ebec4a069d502ce1e0bf"}, + {file = "coverage-7.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:581f086833d24a22c89ae0fe2142cfaa1c92c930adf637ddf122d55083fb5a0f"}, + {file = "coverage-7.13.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0a3a30f0e257df382f5f9534d4ce3d4cf06eafaf5192beb1a7bd066cb10e78fb"}, + {file = "coverage-7.13.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:583221913fbc8f53b88c42e8dbb8fca1d0f2e597cb190ce45916662b8b9d9621"}, + {file = "coverage-7.13.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5f5d9bd30756fff3e7216491a0d6d520c448d5124d3d8e8f56446d6412499e74"}, + {file = "coverage-7.13.0-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a23e5a1f8b982d56fa64f8e442e037f6ce29322f1f9e6c2344cd9e9f4407ee57"}, + {file = "coverage-7.13.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9b01c22bc74a7fb44066aaf765224c0d933ddf1f5047d6cdfe4795504a4493f8"}, + {file = "coverage-7.13.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:898cce66d0836973f48dda4e3514d863d70142bdf6dfab932b9b6a90ea5b222d"}, + {file = "coverage-7.13.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:3ab483ea0e251b5790c2aac03acde31bff0c736bf8a86829b89382b407cd1c3b"}, + {file = "coverage-7.13.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1d84e91521c5e4cb6602fe11ece3e1de03b2760e14ae4fcf1a4b56fa3c801fcd"}, + {file = "coverage-7.13.0-cp312-cp312-win32.whl", hash = "sha256:193c3887285eec1dbdb3f2bd7fbc351d570ca9c02ca756c3afbc71b3c98af6ef"}, + {file = "coverage-7.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:4f3e223b2b2db5e0db0c2b97286aba0036ca000f06aca9b12112eaa9af3d92ae"}, + {file = "coverage-7.13.0-cp312-cp312-win_arm64.whl", hash = "sha256:086cede306d96202e15a4b77ace8472e39d9f4e5f9fd92dd4fecdfb2313b2080"}, + {file = "coverage-7.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:28ee1c96109974af104028a8ef57cec21447d42d0e937c0275329272e370ebcf"}, + {file = "coverage-7.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d1e97353dcc5587b85986cda4ff3ec98081d7e84dd95e8b2a6d59820f0545f8a"}, + {file = "coverage-7.13.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:99acd4dfdfeb58e1937629eb1ab6ab0899b131f183ee5f23e0b5da5cba2fec74"}, + {file = "coverage-7.13.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:ff45e0cd8451e293b63ced93161e189780baf444119391b3e7d25315060368a6"}, + {file = "coverage-7.13.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f4f72a85316d8e13234cafe0a9f81b40418ad7a082792fa4165bd7d45d96066b"}, + {file = "coverage-7.13.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:11c21557d0e0a5a38632cbbaca5f008723b26a89d70db6315523df6df77d6232"}, + {file = "coverage-7.13.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:76541dc8d53715fb4f7a3a06b34b0dc6846e3c69bc6204c55653a85dd6220971"}, + {file = "coverage-7.13.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:6e9e451dee940a86789134b6b0ffbe31c454ade3b849bb8a9d2cca2541a8e91d"}, + {file = "coverage-7.13.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:5c67dace46f361125e6b9cace8fe0b729ed8479f47e70c89b838d319375c8137"}, + {file = "coverage-7.13.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f59883c643cb19630500f57016f76cfdcd6845ca8c5b5ea1f6e17f74c8e5f511"}, + {file = "coverage-7.13.0-cp313-cp313-win32.whl", hash = "sha256:58632b187be6f0be500f553be41e277712baa278147ecb7559983c6d9faf7ae1"}, + {file = "coverage-7.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:73419b89f812f498aca53f757dd834919b48ce4799f9d5cad33ca0ae442bdb1a"}, + {file = "coverage-7.13.0-cp313-cp313-win_arm64.whl", hash = "sha256:eb76670874fdd6091eedcc856128ee48c41a9bbbb9c3f1c7c3cf169290e3ffd6"}, + {file = "coverage-7.13.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:6e63ccc6e0ad8986386461c3c4b737540f20426e7ec932f42e030320896c311a"}, + {file = "coverage-7.13.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:494f5459ffa1bd45e18558cd98710c36c0b8fbfa82a5eabcbe671d80ecffbfe8"}, + {file = "coverage-7.13.0-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:06cac81bf10f74034e055e903f5f946e3e26fc51c09fc9f584e4a1605d977053"}, + {file = "coverage-7.13.0-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f2ffc92b46ed6e6760f1d47a71e56b5664781bc68986dbd1836b2b70c0ce2071"}, + {file = "coverage-7.13.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0602f701057c6823e5db1b74530ce85f17c3c5be5c85fc042ac939cbd909426e"}, + {file = "coverage-7.13.0-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:25dc33618d45456ccb1d37bce44bc78cf269909aa14c4db2e03d63146a8a1493"}, + {file = "coverage-7.13.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:71936a8b3b977ddd0b694c28c6a34f4fff2e9dd201969a4ff5d5fc7742d614b0"}, + {file = "coverage-7.13.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:936bc20503ce24770c71938d1369461f0c5320830800933bc3956e2a4ded930e"}, + {file = "coverage-7.13.0-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:af0a583efaacc52ae2521f8d7910aff65cdb093091d76291ac5820d5e947fc1c"}, + {file = "coverage-7.13.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f1c23e24a7000da892a312fb17e33c5f94f8b001de44b7cf8ba2e36fbd15859e"}, + {file = "coverage-7.13.0-cp313-cp313t-win32.whl", hash = "sha256:5f8a0297355e652001015e93be345ee54393e45dc3050af4a0475c5a2b767d46"}, + {file = "coverage-7.13.0-cp313-cp313t-win_amd64.whl", hash = "sha256:6abb3a4c52f05e08460bd9acf04fec027f8718ecaa0d09c40ffbc3fbd70ecc39"}, + {file = "coverage-7.13.0-cp313-cp313t-win_arm64.whl", hash = "sha256:3ad968d1e3aa6ce5be295ab5fe3ae1bf5bb4769d0f98a80a0252d543a2ef2e9e"}, + {file = "coverage-7.13.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:453b7ec753cf5e4356e14fe858064e5520c460d3bbbcb9c35e55c0d21155c256"}, + {file = "coverage-7.13.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:af827b7cbb303e1befa6c4f94fd2bf72f108089cfa0f8abab8f4ca553cf5ca5a"}, + {file = "coverage-7.13.0-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:9987a9e4f8197a1000280f7cc089e3ea2c8b3c0a64d750537809879a7b4ceaf9"}, + {file = "coverage-7.13.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3188936845cd0cb114fa6a51842a304cdbac2958145d03be2377ec41eb285d19"}, + {file = "coverage-7.13.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a2bdb3babb74079f021696cb46b8bb5f5661165c385d3a238712b031a12355be"}, + {file = "coverage-7.13.0-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7464663eaca6adba4175f6c19354feea61ebbdd735563a03d1e472c7072d27bb"}, + {file = "coverage-7.13.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8069e831f205d2ff1f3d355e82f511eb7c5522d7d413f5db5756b772ec8697f8"}, + {file = "coverage-7.13.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:6fb2d5d272341565f08e962cce14cdf843a08ac43bd621783527adb06b089c4b"}, + {file = "coverage-7.13.0-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:5e70f92ef89bac1ac8a99b3324923b4749f008fdbd7aa9cb35e01d7a284a04f9"}, + {file = "coverage-7.13.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:4b5de7d4583e60d5fd246dd57fcd3a8aa23c6e118a8c72b38adf666ba8e7e927"}, + {file = "coverage-7.13.0-cp314-cp314-win32.whl", hash = "sha256:a6c6e16b663be828a8f0b6c5027d36471d4a9f90d28444aa4ced4d48d7d6ae8f"}, + {file = "coverage-7.13.0-cp314-cp314-win_amd64.whl", hash = "sha256:0900872f2fdb3ee5646b557918d02279dc3af3dfb39029ac4e945458b13f73bc"}, + {file = "coverage-7.13.0-cp314-cp314-win_arm64.whl", hash = "sha256:3a10260e6a152e5f03f26db4a407c4c62d3830b9af9b7c0450b183615f05d43b"}, + {file = "coverage-7.13.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:9097818b6cc1cfb5f174e3263eba4a62a17683bcfe5c4b5d07f4c97fa51fbf28"}, + {file = "coverage-7.13.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:0018f73dfb4301a89292c73be6ba5f58722ff79f51593352759c1790ded1cabe"}, + {file = "coverage-7.13.0-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:166ad2a22ee770f5656e1257703139d3533b4a0b6909af67c6b4a3adc1c98657"}, + {file = "coverage-7.13.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f6aaef16d65d1787280943f1c8718dc32e9cf141014e4634d64446702d26e0ff"}, + {file = "coverage-7.13.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e999e2dcc094002d6e2c7bbc1fb85b58ba4f465a760a8014d97619330cdbbbf3"}, + {file = "coverage-7.13.0-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:00c3d22cf6fb1cf3bf662aaaa4e563be8243a5ed2630339069799835a9cc7f9b"}, + {file = "coverage-7.13.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:22ccfe8d9bb0d6134892cbe1262493a8c70d736b9df930f3f3afae0fe3ac924d"}, + {file = "coverage-7.13.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:9372dff5ea15930fea0445eaf37bbbafbc771a49e70c0aeed8b4e2c2614cc00e"}, + {file = "coverage-7.13.0-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:69ac2c492918c2461bc6ace42d0479638e60719f2a4ef3f0815fa2df88e9f940"}, + {file = "coverage-7.13.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:739c6c051a7540608d097b8e13c76cfa85263ced467168dc6b477bae3df7d0e2"}, + {file = "coverage-7.13.0-cp314-cp314t-win32.whl", hash = "sha256:fe81055d8c6c9de76d60c94ddea73c290b416e061d40d542b24a5871bad498b7"}, + {file = "coverage-7.13.0-cp314-cp314t-win_amd64.whl", hash = "sha256:445badb539005283825959ac9fa4a28f712c214b65af3a2c464f1adc90f5fcbc"}, + {file = "coverage-7.13.0-cp314-cp314t-win_arm64.whl", hash = "sha256:de7f6748b890708578fc4b7bb967d810aeb6fcc9bff4bb77dbca77dab2f9df6a"}, + {file = "coverage-7.13.0-py3-none-any.whl", hash = "sha256:850d2998f380b1e266459ca5b47bc9e7daf9af1d070f66317972f382d46f1904"}, + {file = "coverage-7.13.0.tar.gz", hash = "sha256:a394aa27f2d7ff9bc04cf703817773a59ad6dfbd577032e690f961d2460ee936"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli ; python_full_version <= \"3.11.0a6\""] + [[package]] name = "exceptiongroup" version = "1.1.3" @@ -851,6 +959,26 @@ pytest = ">=8.2,<9" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] +[[package]] +name = "pytest-cov" +version = "7.0.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861"}, + {file = "pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1"}, +] + +[package.dependencies] +coverage = {version = ">=7.10.6", extras = ["toml"]} +pluggy = ">=1.2" +pytest = ">=7" + +[package.extras] +testing = ["process-tests", "pytest-xdist", "virtualenv"] + [[package]] name = "requests" version = "2.32.5" @@ -992,4 +1120,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.1" python-versions = "^3.10" -content-hash = "12593d2241ebd57ce8ddc9c696596f1d4f434b620275149c63d7dcabb5601f6f" +content-hash = "1effa7f4c6ddc8981fe8ee168e3aa6514ddc1777f6701d7f8b96fc88c9f96f58" diff --git a/pyproject.toml b/pyproject.toml index a16227a..fc4ef3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "firedantic" -version = "0.12.0" +version = "0.13.0" description = "Pydantic base models for Firestore" authors = ["IOXIO Ltd"] license = "BSD-3-Clause" @@ -29,6 +29,7 @@ pytest = "^8.3.4" pytest-asyncio = "^0.25.3" black = "^25.1.0" watchdog = "^6.0.0" +pytest-cov = "^7.0.0" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tasks.py b/tasks.py index 5888fa9..0a1777d 100644 --- a/tasks.py +++ b/tasks.py @@ -158,3 +158,18 @@ def make_changelog(ctx): changelog_path.write_text(new_changelog) print(f"{changelog_path} was updated, please fill in release information") + + +@task +def integration(ctx): + """ + Run all integration tests + """ + files = [ + "integration_tests/configure_firestore_db_clients.py", + "integration_tests/full_sync_flow.py", + "integration_tests/full_async_flow.py", + "integration_tests/full_readme_examples.py", + ] + for f in files: + run_test_cmd(ctx, f"python {f}", env=DEV_ENV) diff --git a/unasync.py b/unasync.py index ffa7887..d13cabf 100644 --- a/unasync.py +++ b/unasync.py @@ -5,6 +5,8 @@ from pathlib import Path SUBS = [ + ("get_async_client", "get_client"), + ("async_client", "client"), ("get_async_transaction", "get_transaction"), ( "from google.cloud.firestore_v1.async_transaction", @@ -25,6 +27,7 @@ ("async def", "def"), ("async for", "for"), ("async with", "with"), + ("async client", "client"), ("StopAsyncIteration", "StopIteration"), ("__anext__", "__next__"), ("async_truncate_collection", "truncate_collection"),