From e8667894df1b98aaecd3103ecd8969a383a8cf47 Mon Sep 17 00:00:00 2001 From: Joe Black Date: Thu, 30 Mar 2023 18:24:25 -0400 Subject: [PATCH 1/9] simulation --- src/quart_sqlalchemy/model/mixins.py | 2 +- src/quart_sqlalchemy/model/model.py | 3 +- src/quart_sqlalchemy/sim/__init__.py | 2 + src/quart_sqlalchemy/sim/app.py | 145 +++ src/quart_sqlalchemy/sim/builder.py | 96 ++ src/quart_sqlalchemy/sim/handle.py | 594 ++++++++++++ src/quart_sqlalchemy/sim/legacy.py | 404 ++++++++ src/quart_sqlalchemy/sim/logic.py | 883 ++++++++++++++++++ src/quart_sqlalchemy/sim/main.py | 9 + src/quart_sqlalchemy/sim/model.py | 163 ++++ src/quart_sqlalchemy/sim/repo.py | 285 ++++++ src/quart_sqlalchemy/sim/repo_adapter.py | 309 ++++++ src/quart_sqlalchemy/sim/schema.py | 14 + src/quart_sqlalchemy/sim/signals.py | 33 + src/quart_sqlalchemy/sim/util.py | 101 ++ src/quart_sqlalchemy/sim/views/__init__.py | 18 + src/quart_sqlalchemy/sim/views/auth_user.py | 7 + src/quart_sqlalchemy/sim/views/auth_wallet.py | 62 ++ src/quart_sqlalchemy/sim/views/decorator.py | 131 +++ .../sim/views/magic_client.py | 7 + 20 files changed, 3266 insertions(+), 2 deletions(-) create mode 100644 src/quart_sqlalchemy/sim/__init__.py create mode 100644 src/quart_sqlalchemy/sim/app.py create mode 100644 src/quart_sqlalchemy/sim/builder.py create mode 100644 src/quart_sqlalchemy/sim/handle.py create mode 100644 src/quart_sqlalchemy/sim/legacy.py create mode 100644 src/quart_sqlalchemy/sim/logic.py create mode 100644 src/quart_sqlalchemy/sim/main.py create mode 100644 src/quart_sqlalchemy/sim/model.py create mode 100644 src/quart_sqlalchemy/sim/repo.py create mode 100644 src/quart_sqlalchemy/sim/repo_adapter.py create mode 100644 src/quart_sqlalchemy/sim/schema.py create mode 100644 src/quart_sqlalchemy/sim/signals.py create mode 100644 src/quart_sqlalchemy/sim/util.py create mode 100644 src/quart_sqlalchemy/sim/views/__init__.py create mode 100644 src/quart_sqlalchemy/sim/views/auth_user.py create mode 100644 src/quart_sqlalchemy/sim/views/auth_wallet.py create mode 100644 src/quart_sqlalchemy/sim/views/decorator.py create mode 100644 src/quart_sqlalchemy/sim/views/magic_client.py diff --git a/src/quart_sqlalchemy/model/mixins.py b/src/quart_sqlalchemy/model/mixins.py index 55c0c3f..3dc3834 100644 --- a/src/quart_sqlalchemy/model/mixins.py +++ b/src/quart_sqlalchemy/model/mixins.py @@ -79,7 +79,7 @@ def to_dict(self): class RecursiveDictMixin: __abstract__ = True - def model_to_dict( + def to_dict( self, obj: t.Optional[t.Any] = None, max_depth: int = 3, diff --git a/src/quart_sqlalchemy/model/model.py b/src/quart_sqlalchemy/model/model.py index cd82457..bbcdc21 100644 --- a/src/quart_sqlalchemy/model/model.py +++ b/src/quart_sqlalchemy/model/model.py @@ -14,13 +14,14 @@ from .mixins import ComparableMixin from .mixins import DynamicArgsMixin from .mixins import ReprMixin +from .mixins import SimpleDictMixin from .mixins import TableNameMixin sa = sqlalchemy -class Base(DynamicArgsMixin, ReprMixin, ComparableMixin, TableNameMixin): +class Base(DynamicArgsMixin, ReprMixin, SimpleDictMixin, ComparableMixin, TableNameMixin): __abstract__ = True type_annotation_map = { diff --git a/src/quart_sqlalchemy/sim/__init__.py b/src/quart_sqlalchemy/sim/__init__.py new file mode 100644 index 0000000..6a3f395 --- /dev/null +++ b/src/quart_sqlalchemy/sim/__init__.py @@ -0,0 +1,2 @@ +from . import app +from . import model diff --git a/src/quart_sqlalchemy/sim/app.py b/src/quart_sqlalchemy/sim/app.py new file mode 100644 index 0000000..91b0ec8 --- /dev/null +++ b/src/quart_sqlalchemy/sim/app.py @@ -0,0 +1,145 @@ +import json +import logging +import re +import typing as t + +import sqlalchemy as sa +from pydantic import BaseModel +from quart import g +from quart import Quart +from quart import request +from quart import Request +from quart import Response +from quart_schema import QuartSchema + +from .. import Base +from .. import SQLAlchemyConfig +from ..framework import QuartSQLAlchemy +from .util import ObjectID + + +AUTHORIZATION_PATTERN = re.compile(r"Bearer (?P.+)") +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class MyBase(Base): + type_annotation_map = {ObjectID: sa.Integer} + + +app = Quart(__name__) +db = QuartSQLAlchemy( + SQLAlchemyConfig.parse_obj( + { + "model_class": MyBase, + "binds": { + "default": { + "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + "read-replica": { + "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + "read_only": True, + }, + "async": { + "engine": { + "url": "sqlite+aiosqlite:///file:mem.db?mode=memory&cache=shared&uri=true" + }, + "session": {"expire_on_commit": False}, + }, + }, + } + ) +) +openapi = QuartSchema(app) + + +class RequestAuth(BaseModel): + client: t.Optional[t.Any] = None + user: t.Optional[t.Any] = None + + @property + def has_client(self): + return self.client is not None + + @property + def has_user(self): + return self.user is not None + + @property + def is_anonymous(self): + return all([self.has_client is False, self.has_user is False]) + + +def get_request_client(request: Request): + api_key = request.headers.get("X-Public-API-Key") + if not api_key: + return + + with g.bind.Session() as session: + try: + magic_client = g.h.MagicClient(session).get_by_public_api_key(api_key) + except ValueError: + return + else: + return magic_client + + +def get_request_user(request: Request): + auth_header = request.headers.get("Authorization") + + if not auth_header: + return + m = AUTHORIZATION_PATTERN.match(auth_header) + if m is None: + raise RuntimeError("invalid authorization header") + + auth_token = m.group("auth_token") + + with g.bind.Session() as session: + try: + auth_user = g.h.AuthUser(session).get_by_session_token(auth_token) + except ValueError: + return + else: + return auth_user + + +@app.before_request +def set_ethereum_network(): + g.request_network = request.headers.get("X-Fortmatic-Network", "GOERLI").upper() + + +@app.before_request +def set_bind_handlers_for_request(): + from quart_sqlalchemy.sim.handle import Handlers + + g.db = db + + method = request.method + if method in ["GET", "OPTIONS", "TRACE", "HEAD"]: + bind = "read-replica" + else: + bind = "default" + + g.bind = db.get_bind(bind) + g.h = Handlers(g.bind) + + +@app.before_request +def set_request_auth(): + g.auth = RequestAuth( + client=get_request_client(request), + user=get_request_user(request), + ) + + +@app.after_request +async def add_json_response_envelope(response: Response) -> Response: + if response.mimetype != "application/json": + return response + data = await response.get_json() + payload = dict(status="ok", message="", data=data) + response.set_data(json.dumps(payload)) + return response diff --git a/src/quart_sqlalchemy/sim/builder.py b/src/quart_sqlalchemy/sim/builder.py new file mode 100644 index 0000000..ccaffeb --- /dev/null +++ b/src/quart_sqlalchemy/sim/builder.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import typing as t + +import sqlalchemy +import sqlalchemy.event +import sqlalchemy.exc +import sqlalchemy.orm +import sqlalchemy.sql +from sqlalchemy.orm.interfaces import ORMOption + +from quart_sqlalchemy.types import ColumnExpr +from quart_sqlalchemy.types import DMLTable +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import Selectable + + +sa = sqlalchemy + + +class StatementBuilder(t.Generic[EntityT]): + model: t.Type[EntityT] + + def __init__(self, model: t.Type[EntityT]): + self.model = model + + def complex_select( + self, + selectables: t.Sequence[Selectable] = (), + conditions: t.Sequence[ColumnExpr] = (), + group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + ) -> sa.Select: + statement = sa.select(*selectables or self.model).where(*conditions) + + if for_update: + statement = statement.with_for_update() + if offset: + statement = statement.offset(offset) + if limit: + statement = statement.limit(limit) + if group_by: + statement = statement.group_by(*group_by) + if order_by: + statement = statement.order_by(*order_by) + + for option in options: + for context in option.context: + for strategy in context.strategy: + if "joined" in strategy: + distinct = True + + statement = statement.options(option) + + if distinct: + statement = statement.distinct() + + if execution_options: + statement = statement.execution_options(**execution_options) + + return statement + + def insert( + self, + target: t.Optional[DMLTable] = None, + values: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Insert: + return sa.insert(target or self.model).values(**values or {}) + + def bulk_insert( + self, + target: t.Optional[DMLTable] = None, + values: t.Sequence[t.Dict[str, t.Any]] = (), + ) -> sa.Insert: + return sa.insert(target or self.model).values(*values) + + def bulk_update( + self, + target: t.Optional[DMLTable] = None, + conditions: t.Sequence[ColumnExpr] = (), + values: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Update: + return sa.update(target or self.model).where(*conditions).values(**values or {}) + + def bulk_delete( + self, + target: t.Optional[DMLTable] = None, + conditions: t.Sequence[ColumnExpr] = (), + ) -> sa.Delete: + return sa.delete(target or self.model).where(*conditions) diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py new file mode 100644 index 0000000..7076309 --- /dev/null +++ b/src/quart_sqlalchemy/sim/handle.py @@ -0,0 +1,594 @@ +import logging +import typing as t +from datetime import datetime + +from sqlalchemy.orm import Session + +from quart_sqlalchemy import Bind + +from . import signals +from .logic import LogicComponent as Logic +from .model import AuthUser +from .model import AuthWallet +from .model import EntityType +from .model import WalletType +from .util import ObjectID + + +logger = logging.getLogger(__name__) + +CLIENTS_PER_API_USER_LIMIT = 50 + + +class MaxClientsExceeded(Exception): + pass + + +class AuthUserBaseError(Exception): + pass + + +class InvalidSubstringError(AuthUserBaseError): + pass + + +class APIKeySet(t.NamedTuple): + public_key: str + secret_key: str + + +def get_session_proxy(): + from .app import db + + return db.bind.Session() + + +class HandlerBase: + logic: Logic + session: Session + """The base class for all handler classes. It provides handler with access + to our logic object. + """ + + def __init__(self, session: t.Optional[Session], logic: t.Optional[Logic] = None): + self.session = session or get_session_proxy() + self.logic = logic or Logic() + + +def get_product_type_by_client_id(client_id): + return EntityType.MAGIC.value + + +class MagicClientHandler(HandlerBase): + def add( + self, + magic_api_user_id, + magic_team_id, + app_name=None, + is_magic_connect_enabled=False, + ): + """Registers a new client. + + Args: + is_magic_connect_enabled (boolean): if True, it will create a Magic Connect app. + + Returns: + A ``MagicClient``. + """ + magic_clients_count = self.logic.MagicClientAPIUser.count_by_magic_api_user_id( + magic_api_user_id, + ) + + if magic_clients_count >= CLIENTS_PER_API_USER_LIMIT: + raise MaxClientsExceeded() + + return self.add_client( + magic_api_user_id, + magic_team_id, + app_name, + is_magic_connect_enabled, + ) + + def get_by_public_api_key(self, public_api_key): + return self.logic.MagicClientAPIKey.get_by_public_api_key(public_api_key) + + def add_client( + self, + magic_api_user_id, + magic_team_id, + app_name=None, + is_magic_connect_enabled=False, + ): + live_api_key = APIKeySet(public_key="xxx", secret_key="yyy") + + with self.logic.begin(ro=False) as session: + magic_client = self.logic.MagicClient._add( + session, + app_name=app_name, + ) + # self.logic.MagicClientAPIKey._add( + # session, + # magic_client.id, + # live_api_key_pair=live_api_key, + # ) + # self.logic.MagicClientAPIUser._add( + # session, + # magic_api_user_id, + # magic_client.id, + # ) + + # self.logic.MagicClientAuthMethods._add( + # session, + # magic_client_id=magic_client.id, + # is_magic_connect_enabled=is_magic_connect_enabled, + # is_metamask_wallet_enabled=(True if is_magic_connect_enabled else False), + # is_wallet_connect_enabled=(True if is_magic_connect_enabled else False), + # is_coinbase_wallet_enabled=(True if is_magic_connect_enabled else False), + # ) + + # self.logic.MagicClientTeam._add(session, magic_client.id, magic_team_id) + + return magic_client, live_api_key + + def get_magic_api_user_id_by_client_id(self, magic_client_id): + return self.logic.MagicClient.get_magic_api_user_id_by_client_id(magic_client_id) + + def get_by_id(self, magic_client_id): + return self.logic.MagicClient.get_by_id(magic_client_id) + + def update_app_name_by_id(self, magic_client_id, app_name): + """ + Args: + magic_client_id (ObjectID|int|str): self explanatory. + app_name (str): Desired application name. + + Returns: + None if `magic_client_id` doesn't exist in the db + app_name if update was successful + """ + client = self.logic.MagicClient.update_by_id(magic_client_id, app_name=app_name) + + if not client: + return None + + return client.app_name + + def update_by_id(self, magic_client_id, **kwargs): + client = self.logic.MagicClient.update_by_id(magic_client_id, **kwargs) + + return client + + def set_inactive_by_id(self, magic_client_id): + """ + Args: + magic_client_id (ObjectID|int|str): self explanatory. + + Returns: + None + """ + self.logic.MagicClient.update_by_id(magic_client_id, is_active=False) + + def get_users_for_client( + self, + magic_client_id, + offset=None, + limit=None, + include_count=False, + ): + """ + Returns emails and signup timestamps for all auth users belonging to a given client + """ + auth_user_handler = AuthUserHandler() + product_type = get_product_type_by_client_id(magic_client_id) + auth_users = auth_user_handler.get_by_client_id_and_user_type( + magic_client_id, + product_type, + offset=offset, + limit=limit, + ) + + # Here we blindly load from oauth users table because we only provide + # two login methods right now. If not email link then it is oauth. + # TODO(ajen#ch22926|2020-08-14): rely on the `login_method` column to + # deterministically load from correct source (oauth, webauthn, etc.). + # emails_from_oauth = OAuthUserHandler().get_emails_by_auth_user_ids( + # [auth_user.id for auth_user in auth_users if auth_user.email is None], + # ) + + data = { + "users": [ + dict(email=u.email or "none", signup_ts=int(datetime.timestamp(u.time_created))) + for u in auth_users + ] + } + + if include_count: + data["count"] = auth_user_handler.get_user_count_by_client_id_and_user_type( + magic_client_id, + product_type, + ) + + return data + + def get_users_for_client_v2( + self, + magic_client_id, + offset=None, + limit=None, + include_count=False, + ): + """ + Returns emails, signup timestamps, provenance and MFA enablement for all auth users + belonging to a given client. + """ + auth_user_handler = AuthUserHandler() + product_type = get_product_type_by_client_id(magic_client_id) + auth_users = auth_user_handler.get_by_client_id_and_user_type( + magic_client_id, + product_type, + offset=offset, + limit=limit, + ) + + data = { + "users": [ + dict(email=u.email or "none", signup_ts=int(datetime.timestamp(u.time_created))) + for u in auth_users + ] + } + + if include_count: + data["count"] = auth_user_handler.get_user_count_by_client_id_and_user_type( + magic_client_id, + product_type, + ) + + return data + + # def get_user_logins_for_client(self, magic_client_id, limit=None): + # logins = AuthUserLoginHandler().get_logins_by_magic_client_id( + # magic_client_id, + # limit=limit or 20, + # ) + # user_logins = get_user_logins_response(logins) + + # return sorted( + # user_logins, + # key=lambda x: x["login_ts"], + # reverse=True, + # )[:limit] + + +class AuthUserHandler(HandlerBase): + # auth_user_mfa_handler: AuthUserMfaHandler + + def __init__(self, *args, auth_user_mfa_handler=None, **kwargs): + super().__init__(*args, **kwargs) + # self.auth_user_mfa_handler = auth_user_mfa_handler or AuthUserMfaHandler() + + def get_by_session_token(self, session_token): + return self.logic.AuthUser.get_by_session_token(session_token) + + def get_or_create_by_email_and_client_id( + self, + email, + client_id, + user_type=EntityType.MAGIC.value, + ): + auth_user = self.logic.AuthUser.get_by_email_and_client_id( + email, + client_id, + user_type=user_type, + ) + if not auth_user: + # try: + # email = enhanced_email_validation( + # email, + # source=MAGIC, + # # So we don't affect sign-up. + # silence_network_error=True, + # ) + # except ( + # EnhanceEmailValidationError, + # EnhanceEmailSuggestionError, + # ) as e: + # logger.warning( + # "Email Start Attempt.", + # exc_info=True, + # ) + # raise EnhancedEmailValidation(error_message=str(e)) from e + + auth_user = self.logic.AuthUser.add_by_email_and_client_id( + client_id, + email=email, + user_type=user_type, + ) + return auth_user + + def get_by_id_and_validate_exists(self, auth_user_id): + """This function helps formalize how a non-existent auth user should be handled.""" + auth_user = self.logic.AuthUser.get_by_id(auth_user_id) + if auth_user is None: + raise RuntimeError('resource_name="auth_user"') + return auth_user + + # This function is reserved for consolidating into a canonical user. Do not + # call this function under other circumstances as it will automatically set + # the user as verified. See ch-25343 for additional details. + def create_verified_user( + self, + client_id, + email, + user_type=EntityType.FORTMATIC.value, + ): + with self.logic.begin(ro=False) as session: + auid = self.logic.AuthUser._add_by_email_and_client_id( + session, + client_id, + email, + user_type=user_type, + ).id + auth_user = self.logic.AuthUser._update_by_id( + session, + auid, + date_verified=datetime.utcnow(), + ) + + return auth_user + + # def get_auth_user_from_public_address(self, public_address): + # wallet = self.logic.AuthWallet.get_by_public_address(public_address) + + # if not wallet: + # return None + + # return self.logic.AuthUser.get_by_id(wallet.auth_user_id) + + def get_by_id(self, auth_user_id, load_mfa_methods=False) -> AuthUser: + # join_list = ["mfa_methods"] if load_mfa_methods else None + return self.logic.AuthUser.get_by_id(auth_user_id) + + def get_by_client_id_and_user_type( + self, + client_id, + user_type, + offset=None, + limit=None, + ): + if user_type == EntityType.CONNECT.value: + return self.logic.AuthUser.get_by_client_id_for_connect( + client_id, + offset=offset, + limit=limit, + ) + else: + return self.logic.AuthUser.get_by_client_id_and_user_type( + client_id, + user_type, + offset=offset, + limit=limit, + ) + + def get_by_client_ids_and_user_type( + self, + client_ids, + user_type, + offset=None, + limit=None, + ): + return self.logic.AuthUser.get_by_client_ids_and_user_type( + client_ids, + user_type, + offset=offset, + limit=limit, + ) + + def get_user_count_by_client_id_and_user_type(self, client_id, user_type): + if user_type == EntityType.CONNECT.value: + return self.logic.AuthUser.get_user_count_by_client_id_for_connect( + client_id, + ) + else: + return self.logic.AuthUser.get_user_count_by_client_id_and_user_type( + client_id, + user_type, + ) + + def exist_by_email_client_id_and_user_type(self, email, client_id, user_type): + return self.logic.AuthUser.exist_by_email_and_client_id( + email, + client_id, + user_type=user_type, + ) + + def update_email_by_id(self, model_id, email): + return self.logic.AuthUser.update_by_id(model_id, email=email) + + def update_phone_number_by_id(self, model_id, phone_number): + return self.logic.AuthUser.update_by_id(model_id, phone_number=phone_number) + + def get_by_email_client_id_and_user_type(self, email, client_id, user_type): + return self.logic.AuthUser.get_by_email_and_client_id( + email, + client_id, + user_type, + ) + + def mark_date_verified_by_id(self, model_id): + return self.logic.AuthUser.update_by_id( + model_id, + date_verified=datetime.utcnow(), + ) + + def set_role_by_email_magic_client_id(self, email, magic_client_id, role): + auth_user = self.logic.AuthUser.get_by_email_and_client_id( + email, + magic_client_id, + EntityType.MAGIC.value, + ) + + if not auth_user: + auth_user = self.logic.AuthUser.add_by_email_and_client_id( + magic_client_id, + email, + user_type=EntityType.MAGIC.value, + ) + + return self.logic.AuthUser.update_by_id(auth_user.id, **{role: True}) + + def search_by_client_id_and_substring( + self, + client_id, + substring, + offset=None, + limit=10, + load_mfa_methods=False, + ): + # join_list = ["mfa_methods"] if load_mfa_methods is True else None + + if not isinstance(substring, str) or len(substring) < 3: + raise InvalidSubstringError() + + auth_users = self.logic.AuthUser.get_by_client_id_with_substring_search( + client_id, + substring, + offset=offset, + limit=limit, + # join_list=join_list, + ) + + # mfa_enablements = self.auth_user_mfa_handler.is_active_batch( + # [auth_user.id for auth_user in auth_users], + # ) + # for auth_user in auth_users: + # if mfa_enablements[auth_user.id] is False: + # auth_user.mfa_methods = [] + + return auth_users + + def is_magic_connect_enabled(self, auth_user_id=None, auth_user=None): + if auth_user is None and auth_user_id is None: + raise Exception("At least one argument needed: auth_user_id or auth_user.") + + if auth_user is None: + auth_user = self.get_by_id(auth_user_id) + + return auth_user.user_type == EntityType.CONNECT.value + + def mark_as_inactive(self, auth_user_id): + self.logic.AuthUser.update_by_id(auth_user_id, is_active=False) + + def get_by_email_and_wallet_type_for_interop(self, email, wallet_type, network): + """ + Opinionated method for fetching AuthWallets by email address, wallet_type and network. + """ + return self.logic.AuthUser.get_by_email_for_interop( + email=email, + wallet_type=wallet_type, + network=network, + ) + + def get_magic_connect_auth_user(self, auth_user_id): + auth_user = self.get_by_id_and_validate_exists(auth_user_id) + if not auth_user.is_magic_connect_user: + raise RuntimeError("RequestForbidden") + return auth_user + + +@signals.auth_user_duplicate.connect +def handle_duplicate_auth_users( + current_app, + original_auth_user_id, + duplicate_auth_user_ids, + auth_user_handler: t.Optional[AuthUserHandler] = None, +) -> None: + logger.info(f"{len(duplicate_auth_user_ids)} dupe(s) found for {original_auth_user_id}") + + auth_user_handler = auth_user_handler or AuthUserHandler() + + for dupe_id in duplicate_auth_user_ids: + logger.info( + f"marking auth_user_id {dupe_id} as inactive, in favor of original {original_auth_user_id}", + ) + auth_user_handler.mark_as_inactive(dupe_id) + + +class AuthWalletHandler(HandlerBase): + # account_linking_feature = LDFeatureFlag("is-account-linking-enabled", anonymous_user=True) + + def __init__(self, network, *args, wallet_type=WalletType.ETH, **kwargs): + super().__init__(*args, **kwargs) + self.wallet_network = network + self.wallet_type = wallet_type + + def get_by_id(self, model_id): + return self.logic.AuthWallet.get_by_id(model_id) + + def get_by_public_address(self, public_address): + return self.logic.AuthWallet.get_by_public_address(public_address) + + def get_by_auth_user_id( + self, + auth_user_id: ObjectID, + network: t.Optional[str] = None, + wallet_type: t.Optional[WalletType] = None, + **kwargs, + ) -> t.List[AuthWallet]: + auth_user = self.logic.AuthUser.get_by_id( + auth_user_id, + join_list=["linked_primary_auth_user"], + ) + + if auth_user.has_linked_primary_auth_user: + logger.info( + "Linked primary_auth_user found for wallet delegation", + extra=dict( + auth_user_id=auth_user.id, + delegated_to=auth_user.linked_primary_auth_user_id, + ), + ) + auth_user = auth_user.linked_primary_auth_user + + return self.logic.AuthWallet.get_by_auth_user_id( + auth_user.id, + network=network, + wallet_type=wallet_type, + **kwargs, + ) + + def sync_auth_wallet( + self, + auth_user_id, + public_address, + encrypted_private_address, + wallet_management_type, + ): + existing_wallet = self.logic.AuthWallet.get_by_auth_user_id( + auth_user_id, + ) + if existing_wallet: + raise RuntimeError("WalletExistsForNetworkAndWalletType") + + return self.logic.AuthWallet.add( + public_address, + encrypted_private_address, + self.wallet_type, + self.wallet_network, + management_type=wallet_management_type, + auth_user_id=auth_user_id, + ) + + +class Handlers: + bind: Bind + + def __init__(self, bind: Bind): + self.bind = bind + + def __getattr__(self, name): + handlers = { + cls.__name__.replace("Handler", ""): cls for cls in HandlerBase.__subclasses__() + } + if name in handlers: + return handlers[name] + raise AttributeError() diff --git a/src/quart_sqlalchemy/sim/legacy.py b/src/quart_sqlalchemy/sim/legacy.py new file mode 100644 index 0000000..5d4662f --- /dev/null +++ b/src/quart_sqlalchemy/sim/legacy.py @@ -0,0 +1,404 @@ +from sqlalchemy import exists +from sqlalchemy.orm import joinedload +from sqlalchemy.sql.expression import func +from sqlalchemy.sql.expression import label + + +def one(input_list): + (item,) = input_list + return item + + +class SQLRepository: + def __init__(self, model): + self._model = model + assert self._model is not None + + @property + def _has_is_active_field(self): + return bool(getattr(self._model, "is_active", None)) + + def get_by_id( + self, + session, + model_id, + allow_inactive=False, + join_list=None, + for_update=False, + ): + """SQL get interface to retrieve by model's id column. + + Args: + session: A database session object. + model_id: The id of the given model to be retrieved. + allow_inactive: Whether to include inactive or not. + join_list: A list of attributes to be joined in the same session for + given model. This is normally the attributes that have + relationship defined and referenced to other models. + for_update: Locks the table for update. + + Returns: + Data retrieved from the database for the model. + """ + query = session.query(self._model) + + if join_list: + for to_join in join_list: + query = query.options(joinedload(to_join)) + + if for_update: + query = query.with_for_update() + + row = query.get(model_id) + + if row is None: + return None + + if self._has_is_active_field and not row.is_active and not allow_inactive: + return None + + return row + + def get_by( + self, + session, + filters=None, + join_list=None, + order_by_clause=None, + for_update=False, + offset=None, + limit=None, + ): + """SQL get_by interface to retrieve model instances based on the given + filters. + + Args: + session: A database session object. + filters: A list of filters on the models. + join_list: A list of attributes to be joined in the same session for + given model. This is normally the attributes that have + relationship defined and referenced to other models. + order_by_clause: An order by clause. + for_update: Locks the table for update. + + Returns: + Modified rows. + + TODO(ajen#ch21549|2020-07-21): Filter out `is_active == False` row. This + will not be a trivial change as many places rely on this method and the + handlers/logics sometimes filter by in_active. Sometimes endpoints might + get affected. Proceed with caution. + """ + # If no filter is given, just return. Prevent table scan. + if filters is None: + return None + + query = session.query(self._model).filter(*filters).order_by(order_by_clause) + + if for_update: + query = query.with_for_update() + + if offset: + query = query.offset(offset) + + if limit: + query = query.limit(limit) + + # Prevent loading all the rows. + if limit == 0: + return [] + + if join_list: + for to_join in join_list: + query = query.options(joinedload(to_join)) + + return query.all() + + def count_by( + self, + session, + filters=None, + group_by=None, + distinct_column=None, + ): + """SQL count_by interface to retrieve model instance count based on the given + filters. + + Args: + session: A database session object. + filters (list): Required; a list of filters on the models. + group_by (list): A list of optional group by expressions. + Returns: + A list of counts of rows. + Raises: + ValueError: Returns a value error, when no filters are provided + """ + # Prevent table scans + if filters is None: + raise ValueError("Full table scans are prohibited. Please provide filters") + + select = [label("count", func.count(self._model.id))] + + if distinct_column: + select = [label("count", func.count(func.distinct(distinct_column)))] + + if group_by: + for group in group_by: + select.append(group.expression) + + query = session.query(*select).filter(*filters) + + if group_by: + query = query.group_by(*group_by) + + return query.all() + + def sum_by( + self, + session, + column, + filters=None, + group_by=None, + ): + """SQL sum_by interface to retrieve aggregate sum of column values for given + filters. + + Args: + session: A database session object. + column (sqlalchemy.Column): Required; the column to sum by. + filters (list): Required; a list of filters to apply to the query + group_by (list): A list of optional group by expressions. + + Returns: + A scalar value representing the sum or None. + + Raises: + ValueError: Returns a value error, when no filters are provided + """ + + # Prevent table scans + if filters is None: + raise ValueError("Full table scans are prohibited. Please provide filters") + + query = session.query(func.sum(column)).filter(*filters) + + if group_by: + query = query.group_by(*group_by) + + return query.scalar() + + def one(self, session, filters=None, join_list=None, for_update=False): + """SQL filtering interface to retrieve the single model instance matching + filter criteria. + + If there are more than one instances, an exception is raised. + + Args: + session: A database session object. + filters: A list of filters on the models. + for_update: Locks the table for update. + + Returns: + A model instance: If one row is found in the db. + None: If no result is found. + """ + row = self.get_by(session, filters=filters, join_list=join_list, for_update=for_update) + + if not row: + return None + + return one(row) + + def update(self, session, model_id, **kwargs): + """SQL update interface to modify data in a given model instance. + + Args: + session: A database session object. + model_id: The id of the given model to be modified. + kwargs: Any fields defined on the models. + + Returns: + Modified rows. + + Note: + We use session.flush() here to move the changes from the application + to SQL database. However, those changes will be in the pending changes + state. Meaning, it is in the queue to be inserted but yet to be done + so until session.commit() is called, which has been taken care of + in our ``with_db_session`` decorator or ``LogicComponent.begin`` + contextmanager. + """ + modified_row = session.query(self._model).get(model_id) + if modified_row is None: + return None + + for key, value in kwargs.items(): + setattr(modified_row, key, value) + + # Flush out our changes to DB transaction buffer but don't commit it yet. + # This is useful in the case when we want to rollback atomically on multiple + # sql operations in the same transaction which may or may not have + # dependencies. + session.flush() + + return modified_row + + def update_by(self, session, filters=None, **kwargs): + """SQL update_by interface to modify data for a given list of filters. + The filters should be provided so it can narrow down to one row. + + Args: + session: A database session object. + filters: A list of filters on the models. + kwargs: Any fields defined on the models. + + Returns: + Modified row. + + Raises: + sqlalchemy.orm.exc.NoResultFound - when no result is found. + sqlalchemy.orm.exc.MultipleResultsFound - when multiple result is found. + """ + # If no filter is given, just return. Prevent table scan. + if filters is None: + return None + + modified_row = session.query(self._model).filter(*filters).one() + for key, value in kwargs.items(): + setattr(modified_row, key, value) + + # Flush out our changes to DB transaction buffer but don't commit it yet. + # This is useful in the case when we want to rollback atomically on multiple + # sql operations in the same transaction which may or may not have + # dependencies. + session.flush() + + return modified_row + + def delete_one_by(self, session, filters=None, optional=False): + """SQL update_by interface to delete data for a given list of filters. + The filters should be provided so it can narrow down to one row. + + Note: Careful consideration should be had prior to using this function. + Always consider setting rows as inactive instead before choosing to use + this function. + + Args: + session: A database session object. + filters: A list of filters on the models. + optional: Whether deletion is optional; i.e. it's OK for the model not to exist + + Returns: + None. + + Raises: + sqlalchemy.orm.exc.NoResultFound - when no result is found and optional is False. + sqlalchemy.orm.exc.MultipleResultsFound - when multiple result is found. + """ + # If no filter is given, just return. Prevent table scan. + if filters is None: + return None + + if optional: + rows = session.query(self._model).filter(*filters).all() + + if not rows: + return None + + row = one(rows) + + else: + row = session.query(self._model).filter(*filters).one() + + session.delete(row) + + # Flush out our changes to DB transaction buffer but don't commit it yet. + # This is useful in the case when we want to rollback atomically on multiple + # sql operations in the same transaction which may or may not have + # dependencies. + session.flush() + + def delete_by_id(self, session, model_id): + return session.query(self._model).get(model_id).delete() + + def add(self, session, **kwargs): + """SQL add interface to insert data to the given model. + + Args: + session: A database session object. + kwargs: Any fields defined on the models. + + Returns: + Newly inserted rows. + + Note: + We use session.flush() here to move the changes from the application + to SQL database. However, those changes will be in the pending changes + state. Meaning, it is in the queue to be inserted but yet to be done + so until session.commit() is called, which has been taken care of + in our ``with_db_session`` decorator or ``LogicComponent.begin`` + contextmanager. + """ + new_row = self._model(**kwargs) + session.add(new_row) + + # Flush out our changes to DB transaction buffer but don't commit it yet. + # This is useful in the case when we want to rollback atomically on multiple + # sql operations in the same transaction which may or may not have + # dependencies. + session.flush() + + return new_row + + def exist(self, session, filters=None): + """SQL exist interface to check if any row exists at all for the given + filters. + + Args: + session: A database session object. + filters: A list of filters on the models. + + Returns: + A boolean. True if any row exists else False. + """ + exist_query = exists() + + for query_filter in filters: + exist_query = exist_query.where(query_filter) + + return session.query(exist_query).scalar() + + def yield_by_chunk(self, session, chunk_size, join_list=None, filters=None): + """This yields a batch of the model objects for the given chunk_size. + + Args: + session: A database session object. + chunk_size (int): The size of the chunk. + filters: A list of filters on the model. + join_list: A list of attributes to be joined in the same session for + given model. This is normally the attributes that have + relationship defined and referenced to other models. + + Returns: + A batch for the given chunk size. + """ + query = session.query(self._model) + + if filters is not None: + query = query.filter(*filters) + + if join_list: + for to_join in join_list: + query = query.options(joinedload(to_join)) + + start = 0 + + while True: + stop = start + chunk_size + model_objs = query.slice(start, stop).all() + if not model_objs: + break + + yield model_objs + + start += chunk_size diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py new file mode 100644 index 0000000..dee3990 --- /dev/null +++ b/src/quart_sqlalchemy/sim/logic.py @@ -0,0 +1,883 @@ +import logging +import typing as t +from datetime import datetime +from functools import wraps + +from pydantic import BaseModel +from pydantic import Field +from sqlalchemy import or_ +from sqlalchemy import ScalarResult +from sqlalchemy.orm import contains_eager +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import selectinload +from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import func +from sqlalchemy.sql.expression import label + +from quart_sqlalchemy.model import Base +from quart_sqlalchemy.types import ColumnExpr +from quart_sqlalchemy.types import EntityIdT +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import ORMOption +from quart_sqlalchemy.types import Selectable + +from . import signals +from .model import AuthUser as auth_user_model +from .model import AuthWallet as auth_wallet_model +from .model import ConnectInteropStatus +from .model import EntityType +from .model import MagicClient as magic_client_model +from .model import Provenance +from .model import WalletType +from .repo import SQLAlchemyRepository +from .repo_adapter import RepositoryLegacyAdapter +from .util import ObjectID +from .util import one + + +logger = logging.getLogger(__name__) + + +class LogicMeta(type): + """This is metaclass provides registry pattern where all the available + logics will be accessible through any instantiated logic object. + + Note: + Don't use this metaclass at another places. This is only intended to be + used by LogicComponent. If you want your own registry, please create + your own. + """ + + def __init__(cls, name, bases, cls_dict): + if not hasattr(cls, "_registry"): + cls._registry = {} + else: + cls._registry[name] = cls() + + super().__init__(name, bases, cls_dict) + + +class LogicComponent(metaclass=LogicMeta): + """This is the base class for any logic class. This overrides the getattr + method for registry lookup. + + Example: + + ``` + class TrollGoat(LogicComponent): + + def add(x): + print(x) + ``` + + Once you have a logic object, you can directly do something like: + + ``` + logic.TrollGoat.add('troll_goat') + ``` + + Note: + You will have to explicitly import your newly created logic in + ``fortmatic.logic.__init__.py``. When the logic is imported, it is created + the first time; hence, it is then registered. If this is unclear to you, + read https://blog.ionelmc.ro/2015/02/09/understanding-python-metaclasses/ + It has all the info you need to understand. For example, everything in + python is an object :P. Enjoy. + """ + + def __dir__(self): + return super().__dir__() + list(self._registry.keys()) + + def __getattr__(self, logic_name): + if logic_name in self._registry: + return self._registry[logic_name] + else: + raise AttributeError( + "{object_name} has no attribute '{logic_name}'".format( + object_name=self.__class__.__name__, + logic_name=logic_name, + ), + ) + + +def with_db_session( + ro: t.Optional[bool] = None, + is_stale_tolerant: t.Optional[bool] = None, +): + """Stub decorator to ease transition with legacy code""" + + def wrapper(func): + @wraps(func) + def inner_wrapper(self, *args, **kwargs): + session = None + return func(self, None, *args, **kwargs) + + return inner_wrapper + + return wrapper + + +class MagicClient(LogicComponent): + def __init__(self, session: Session): + # self._repository = SQLAlchemyRepository[magic_client_model, ObjectID](session) + + self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID, session) + + def _add(self, session, app_name=None): + return self._repository.add( + session, + app_name=app_name, + ) + + add = with_db_session(ro=False)(_add) + + @with_db_session(ro=True) + def get_by_id( + self, + session, + model_id, + allow_inactive=False, + join_list=None, + ) -> t.Optional[magic_client_model]: + return self._repository.get_by_id( + session, + model_id, + allow_inactive=allow_inactive, + join_list=join_list, + ) + + @with_db_session(ro=True) + def get_by_public_api_key( + self, + session, + public_api_key, + ): + return one( + self._repository.get_by( + session, + filters=[magic_client_model.public_api_key == public_api_key], + limit=1, + ) + ) + + # @with_db_session(ro=True) + # def get_magic_api_user_id_by_client_id(self, session, magic_client_id): + # client = self._repository.get_by_id( + # session, + # magic_client_id, + # allow_inactive=False, + # join_list=None, + # ) + + # if client is None: + # return None + + # if client.magic_client_api_user is None: + # return None + + # return client.magic_client_api_user.magic_api_user_id + + @with_db_session(ro=False) + def update_by_id(self, session, model_id, **update_params): + modified_row = self._repository.update(session, model_id, **update_params) + session.refresh(modified_row) + return modified_row + + @with_db_session(ro=True) + def yield_all_clients_by_chunk(self, session, chunk_size): + yield from self._repository.yield_by_chunk(session, chunk_size) + + @with_db_session(ro=True) + def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): + yield from self._repository.yield_by_chunk( + session, + chunk_size, + filters=filters, + join_list=join_list, + ) + + +class DuplicateAuthUser(Exception): + pass + + +class AuthUserDoesNotExist(Exception): + pass + + +class MissingEmail(Exception): + pass + + +class MissingPhoneNumber(Exception): + pass + + +class AuthUser(LogicComponent): + def __init__(self, session: Session): + # self._repository = SQLRepository(auth_user_model) + self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID, session) + + @with_db_session(ro=True) + def get_by_session_token( + self, + session, + session_token, + ): + return one( + self._repository.get_by( + session, + filters=[auth_user_model.current_session_token == session_token], + limit=1, + ) + ) + + def _get_or_add_by_phone_number_and_client_id( + self, + session, + client_id, + phone_number, + user_type=EntityType.FORTMATIC.value, + ): + if phone_number is None: + raise MissingPhoneNumber() + + row = self._get_by_phone_number_and_client_id( + session=session, + phone_number=phone_number, + client_id=client_id, + user_type=user_type, + ) + + if row: + return row + + row = self._repository.add( + session=session, + phone_number=phone_number, + client_id=client_id, + user_type=user_type, + provenance=Provenance.SMS, + ) + logger.info( + "New auth user (id: {}) created by phone number (client_id: {})".format( + row.id, + client_id, + ), + ) + + return row + + get_or_add_by_phone_number_and_client_id = with_db_session(ro=False)( + _get_or_add_by_phone_number_and_client_id, + ) + + def _add_by_email_and_client_id( + self, + session, + client_id, + email=None, + user_type=EntityType.FORTMATIC.value, + **kwargs, + ): + if email is None: + raise MissingEmail() + + if self._exist_by_email_and_client_id( + session, + email, + client_id, + user_type=user_type, + ): + logger.exception( + "User duplication for email: {} (client_id: {})".format( + email, + client_id, + ), + ) + raise DuplicateAuthUser() + + row = self._repository.add( + session, + email=email, + client_id=client_id, + user_type=user_type, + **kwargs, + ) + logger.info( + "New auth user (id: {}) created by email (client_id: {})".format( + row.id, + client_id, + ), + ) + + return row + + add_by_email_and_client_id = with_db_session(ro=False)(_add_by_email_and_client_id) + + def _add_by_client_id( + self, + session, + client_id, + user_type=EntityType.FORTMATIC.value, + provenance=None, + global_auth_user_id=None, + is_verified=False, + ): + row = self._repository.add( + session, + client_id=client_id, + user_type=user_type, + provenance=provenance, + global_auth_user_id=global_auth_user_id, + date_verified=datetime.utcnow() if is_verified else None, + ) + logger.info( + "New auth user (id: {}) created by (client_id: {})".format( + row.id, + client_id, + ), + ) + + return row + + add_by_client_id = with_db_session(ro=False)(_add_by_client_id) + + def _get_by_active_identifier_and_client_id( + self, + session, + identifier_field, + identifier_value, + client_id, + user_type, + ) -> auth_user_model: + """There should only be one active identifier where all the parameters match for a given client ID. In the case of multiple results, the subsequent entries / "dupes" will be marked as inactive.""" + filters = [ + identifier_field == identifier_value, + auth_user_model.client_id == client_id, + auth_user_model.user_type == user_type, + # auth_user_model.is_active == True, # noqa: E712 + ] + + results = self._repository.get_by( + session, + filters=filters, + order_by_clause=auth_user_model.id.asc(), + ) + + if not results: + return None + + original, *duplicates = results + + if duplicates: + signals.auth_user_duplicate.send( + original_auth_user_id=original.id, + duplicate_auth_user_ids=[dupe.id for dupe in duplicates], + ) + + return original + + @with_db_session(ro=True) + def get_by_email_and_client_id( + self, + session, + email, + client_id, + user_type=EntityType.FORTMATIC.value, + ): + return self._get_by_active_identifier_and_client_id( + session=session, + identifier_field=auth_user_model.email, + identifier_value=email, + client_id=client_id, + user_type=user_type, + ) + + def _get_by_phone_number_and_client_id( + self, + session, + phone_number, + client_id, + user_type=EntityType.FORTMATIC.value, + ): + if phone_number is None: + raise MissingPhoneNumber() + + return self._get_by_active_identifier_and_client_id( + session=session, + identifier_field=auth_user_model.phone_number, + identifier_value=phone_number, + client_id=client_id, + user_type=user_type, + ) + + get_by_phone_number_and_client_id = with_db_session(ro=True)( + _get_by_phone_number_and_client_id, + ) + + def _exist_by_email_and_client_id( + self, + session, + email, + client_id, + user_type=EntityType.FORTMATIC.value, + ): + return bool( + self._repository.exist( + session, + filters=[ + auth_user_model.email == email, + auth_user_model.client_id == client_id, + auth_user_model.user_type == user_type, + ], + ), + ) + + exist_by_email_and_client_id = with_db_session(ro=True)(_exist_by_email_and_client_id) + + def _get_by_id(self, session, model_id, join_list=None, for_update=False) -> auth_user_model: + return self._repository.get_by_id( + session, + model_id, + join_list=join_list, + for_update=for_update, + ) + + get_by_id = with_db_session(ro=True)(_get_by_id) + + def _update_by_id(self, session, auth_user_id, **kwargs): + modified_user = self._repository.update(session, auth_user_id, **kwargs) + + if modified_user is None: + raise AuthUserDoesNotExist() + + return modified_user + + update_by_id = with_db_session(ro=False)(_update_by_id) + + @with_db_session(ro=True) + def get_user_count_by_client_id_and_user_type(self, session, client_id, user_type): + query = ( + session.query(auth_user_model) + .filter( + auth_user_model.client_id == client_id, + auth_user_model.user_type == user_type, + # auth_user_model.is_active == True, # noqa: E712 + auth_user_model.date_verified.is_not(None), + ) + .statement.with_only_columns([func.count()]) + .order_by(None) + ) + + return session.execute(query).scalar() + + def _get_by_client_id_and_global_auth_user(self, session, client_id, global_auth_user_id): + return self._repository.get_by( + session=session, + filters=[ + auth_user_model.client_id == client_id, + auth_user_model.user_type == EntityType.CONNECT.value, + # auth_user_model.is_active == True, # noqa: E712 + auth_user_model.global_auth_user_id == global_auth_user_id, + ], + ) + + get_by_client_id_and_global_auth_user = with_db_session(ro=True)( + _get_by_client_id_and_global_auth_user, + ) + + # @with_db_session(ro=True) + # def get_by_client_id_for_connect( + # self, + # session, + # client_id, + # offset=None, + # limit=None, + # ): + # # TODO(thomas|2022-07-12): Determine where/if is the right place to split + # # connect/magic logic based on user type as part of https://app.shortcut.com/magic-labs/story/53323. + # # See https://github.com/fortmatic/fortmatic/pull/6173#discussion_r919529540. + # return ( + # session.query(auth_user_model) + # .join( + # identifier_model, + # auth_user_model.global_auth_user_id == identifier_model.global_auth_user_id, + # ) + # .filter( + # auth_user_model.client_id == client_id, + # auth_user_model.user_type == EntityType.CONNECT.value, + # auth_user_model.is_active == True, # noqa: E712, + # auth_user_model.provenance == Provenance.IDENTIFIER, + # or_( + # identifier_model.identifier_type.in_( + # GlobalAuthUserIdentifierType.get_public_address_enums(), + # ), + # identifier_model.date_verified != None, + # ), + # ) + # .order_by(auth_user_model.id.desc()) + # .limit(limit) + # .offset(offset) + # ).all() + + # @with_db_session(ro=True) + # def get_user_count_by_client_id_for_connect( + # self, + # session, + # client_id, + # ): + # # TODO(thomas|2022-07-12): Determine where/if is the right place to split + # # connect/magic logic based on user type as part of https://app.shortcut.com/magic-labs/story/53323. + # # See https://github.com/fortmatic/fortmatic/pull/6173#discussion_r919529540. + # query = ( + # session.query(auth_user_model) + # .join( + # identifier_model, + # auth_user_model.global_auth_user_id == identifier_model.global_auth_user_id, + # ) + # .filter( + # auth_user_model.client_id == client_id, + # auth_user_model.user_type == EntityType.CONNECT.value, + # auth_user_model.is_active == True, # noqa: E712, + # auth_user_model.provenance == Provenance.IDENTIFIER, + # or_( + # identifier_model.identifier_type.in_( + # GlobalAuthUserIdentifierType.get_public_address_enums(), + # ), + # identifier_model.date_verified != None, + # ), + # ) + # .statement.with_only_columns( + # [func.count(distinct(auth_user_model.global_auth_user_id))], + # ) + # .order_by(None) + # ) + + # return session.execute(query).scalar() + + @with_db_session(ro=True) + def get_by_client_id_and_user_type( + self, + session, + client_id, + user_type, + offset=None, + limit=None, + ): + return self._get_by_client_ids_and_user_type( + session, + [client_id], + user_type, + offset=offset, + limit=limit, + ) + + def _get_by_client_ids_and_user_type( + self, + session, + client_ids, + user_type, + offset=None, + limit=None, + ): + if not client_ids: + return [] + + return self._repository.get_by( + session, + filters=[ + auth_user_model.client_id.in_(client_ids), + auth_user_model.user_type == user_type, + # auth_user_model.is_active == True, # noqa: E712, + auth_user_model.date_verified != None, + ], + offset=offset, + limit=limit, + order_by_clause=auth_user_model.id.desc(), + ) + + get_by_client_ids_and_user_type = with_db_session(ro=True)( + _get_by_client_ids_and_user_type, + ) + + def _get_by_client_id_with_substring_search( + self, + session, + client_id, + substring, + offset=None, + limit=10, + join_list=None, + ): + return self._repository.get_by( + session, + filters=[ + auth_user_model.client_id == client_id, + auth_user_model.user_type == EntityType.MAGIC.value, + or_( + auth_user_model.provenance == Provenance.SMS, + auth_user_model.provenance == Provenance.LINK, + auth_user_model.provenance == None, # noqa: E711 + ), + or_( + auth_user_model.phone_number.contains(substring), + auth_user_model.email.contains(substring), + ), + ], + offset=offset, + limit=limit, + order_by_clause=auth_user_model.id.desc(), + join_list=join_list, + ) + + get_by_client_id_with_substring_search = with_db_session(ro=True)( + _get_by_client_id_with_substring_search, + ) + + @with_db_session(ro=True) + def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): + yield from self._repository.yield_by_chunk( + session, + chunk_size, + filters=filters, + join_list=join_list, + ) + + @with_db_session(ro=True) + def get_by_emails_and_client_id( + self, + session, + email_ids, + client_id, + ): + return self._repository.get_by( + session, + filters=[ + auth_user_model.email.in_(email_ids), + auth_user_model.client_id == client_id, + ], + ) + + def _get_by_email( + self, + session, + email: str, + join_list=None, + filters=None, + for_update: bool = False, + ) -> t.List[auth_user_model]: + filters = filters or [] + combined_filters = filters + [auth_user_model.email == email] + + return self._repository.get_by( + session, + filters=combined_filters, + for_update=for_update, + join_list=join_list, + ) + + get_by_email = with_db_session(ro=True)(_get_by_email) + + def _add(self, session, **kwargs) -> ObjectID: + return self._repository.add(session, **kwargs).id + + add = with_db_session(ro=False)(_add) + + def _get_by_email_for_interop( + self, + session, + email: str, + wallet_type: WalletType, + network: str, + ) -> List[auth_user_model]: + """ + Custom method for searching for users eligible for interop. Unfortunately, this can't be done with the current + abstractions in our sql_repository, so this is a one-off bespoke method. + If we need to add more similar queries involving eager loading and multiple joins, we can add an abstraction + inside the repository. + """ + + query = ( + session.query(auth_user_model) + .join( + auth_user_model.wallets.and_( + auth_wallet_model.wallet_type == str(wallet_type) + ).and_(auth_wallet_model.network == network) + # .and_(auth_wallet_model.is_active == 1), + ) + .options(contains_eager(auth_user_model.wallets)) + .join( + auth_user_model.magic_client.and_( + magic_client_model.connect_interop == ConnectInteropStatus.ENABLED, + ), + ) + .options(contains_eager(auth_user_model.magic_client)) + # TODO(magic-ravi#67899|2022-12-30): Uncomment to allow account-linked users to use interop + # .options( + # joinedload( + # auth_user_model.linked_primary_auth_user, + # ).joinedload("auth_wallets"), + # ) + .filter( + auth_wallet_model.wallet_type == wallet_type, + auth_wallet_model.network == network, + ) + .filter( + auth_user_model.email == email, + auth_user_model.user_type == EntityType.MAGIC.value, + # auth_user_model.is_active == 1, + auth_user_model.linked_primary_auth_user_id == None, # noqa: E711 + ) + .populate_existing() + ) + + return query.all() + + get_by_email_for_interop = with_db_session(ro=True)( + _get_by_email_for_interop, + ) + + def _get_linked_users(self, session, primary_auth_user_id, join_list, no_op=False): + # TODO(magic-ravi#67899|2022-12-30): Re-enable account linked users for interop. Remove no_op flag. + if no_op: + return [] + else: + return self._repository.get_by( + session, + filters=[ + # auth_user_model.is_active == True, # noqa: E712 + auth_user_model.user_type == EntityType.MAGIC.value, + auth_user_model.linked_primary_auth_user_id == primary_auth_user_id, + ], + join_list=join_list, + ) + + get_linked_users = with_db_session(ro=True)(_get_linked_users) + + @with_db_session(ro=True) + def get_by_phone_number(self, session, phone_number): + return self._repository.get_by( + session, + filters=[ + auth_user_model.phone_number == phone_number, + ], + ) + + +class AuthWallet(LogicComponent): + def __init__(self, session: Session): + # self._repository = SQLAlchemyRepository[magic_client_model, ObjectID](session) + self._repository = RepositoryLegacyAdapter(auth_wallet_model, ObjectID, session) + + def _add( + self, + session, + public_address, + encrypted_private_address, + wallet_type, + network, + management_type=None, + auth_user_id=None, + ): + new_row = self._repository.add( + session, + auth_user_id=auth_user_id, + public_address=public_address, + encrypted_private_address=encrypted_private_address, + wallet_type=wallet_type, + management_type=management_type, + network=network, + ) + + return new_row + + add = with_db_session(ro=False)(_add) + + @with_db_session(ro=True) + def get_by_id(self, session, model_id, allow_inactive=False, join_list=None): + return self._repository.get_by_id( + session, + model_id, + allow_inactive=allow_inactive, + join_list=join_list, + ) + + @with_db_session(ro=True) + def get_by_public_address(self, session, public_address, network=None, is_active=True): + """Public address is unique in our system. In any case, we should only + find one row for the given public address. + + Args: + session: A database session object. + public_address (str): A public address. + network (str): A network name. + is_active (boolean): A boolean value to denote if the query should + retrieve active or inactive rows. + + Returns: + A formatted row, either in presenter form or raw db row. + """ + filters = [ + auth_wallet_model.public_address == public_address, + # auth_wallet_model.is_active == is_active, + ] + + if network: + filters.append(auth_wallet_model.network == network) + + row = self._repository.get_by(session, filters=filters, allow_inactive=not is_active) + + if not row: + return None + + return one(row) + + @with_db_session(ro=True) + def get_by_auth_user_id( + self, + session, + auth_user_id, + network=None, + wallet_type=None, + is_active=True, + join_list=None, + ): + """Return all the associated wallets for the given user id. + + Args: + session: A database session object. + auth_user_id (ObjectID): A auth_user id. + network (str|None): A network name. + wallet_type (str|None): a wallet type like ETH or BTC + is_active (boolean): A boolean value to denote if the query should + retrieve active or inactive rows. + join_list (None|List): Table you wish to join. + + Returns: + An empty list or a list of wallets. + """ + filters = [ + auth_wallet_model.auth_user_id == auth_user_id, + # auth_wallet_model.is_active == is_active, + ] + + if network: + filters.append(auth_wallet_model.network == network) + + if wallet_type: + filters.append(auth_wallet_model.wallet_type == wallet_type) + + rows = self._repository.get_by( + session, filters=filters, join_list=join_list, allow_inactive=not is_active + ) + + if not rows: + return [] + + return rows + + def _update_by_id(self, session, model_id, **kwargs): + self._repository.update(session, model_id, **kwargs) + + update_by_id = with_db_session(ro=False)(_update_by_id) diff --git a/src/quart_sqlalchemy/sim/main.py b/src/quart_sqlalchemy/sim/main.py new file mode 100644 index 0000000..7bd73cd --- /dev/null +++ b/src/quart_sqlalchemy/sim/main.py @@ -0,0 +1,9 @@ +from .app import app +from .views import api + + +app.register_blueprint(api) + + +if __name__ == "__main__": + app.run(port=8080) diff --git a/src/quart_sqlalchemy/sim/model.py b/src/quart_sqlalchemy/sim/model.py new file mode 100644 index 0000000..efee04c --- /dev/null +++ b/src/quart_sqlalchemy/sim/model.py @@ -0,0 +1,163 @@ +import secrets +import typing as t +from datetime import datetime +from enum import Enum +from enum import IntEnum + +import sqlalchemy +import sqlalchemy.orm +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import Mapped + +from quart_sqlalchemy.model import SoftDeleteMixin +from quart_sqlalchemy.model import TimestampMixin +from quart_sqlalchemy.sim.app import db +from quart_sqlalchemy.sim.util import ObjectID + + +sa = sqlalchemy + + +class StrEnum(str, Enum): + def __str__(self) -> str: + return str.__str__(self) + + +class ConnectInteropStatus(StrEnum): + ENABLED = "ENABLED" + DISABLED = "DISABLED" + + +class Provenance(Enum): + LINK = 1 + OAUTH = 2 + WEBAUTHN = 3 + SMS = 4 + IDENTIFIER = 5 + FEDERATED = 6 + + +class EntityType(Enum): + FORTMATIC = 1 + MAGIC = 2 + CONNECT = 3 + + +class WalletManagementType(IntEnum): + UNDELEGATED = 1 + DELEGATED = 2 + + +class WalletType(StrEnum): + ETH = "ETH" + HARMONY = "HARMONY" + ICON = "ICON" + FLOW = "FLOW" + TEZOS = "TEZOS" + ZILLIQA = "ZILLIQA" + POLKADOT = "POLKADOT" + SOLANA = "SOLANA" + AVAX = "AVAX" + ALGOD = "ALGOD" + COSMOS = "COSMOS" + CELO = "CELO" + BITCOIN = "BITCOIN" + NEAR = "NEAR" + HELIUM = "HELIUM" + CONFLUX = "CONFLUX" + TERRA = "TERRA" + TAQUITO = "TAQUITO" + ED = "ED" + HEDERA = "HEDERA" + + +class MagicClient(db.Model, SoftDeleteMixin, TimestampMixin): + __tablename__ = "magic_client" + + id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + app_name: Mapped[str] = sa.orm.mapped_column(default="my new app") + rate_limit_tier: Mapped[t.Optional[str]] + connect_interop: Mapped[t.Optional[ConnectInteropStatus]] + is_signing_modal_enabled: Mapped[bool] = sa.orm.mapped_column(default=False) + global_audience_enabled: Mapped[bool] = sa.orm.mapped_column(default=False) + + public_api_key: Mapped[str] = sa.orm.mapped_column(default_factory=secrets.token_hex) + secret_api_key: Mapped[str] = sa.orm.mapped_column(default_factory=secrets.token_hex) + + auth_users: Mapped[t.List["AuthUser"]] = sa.orm.relationship( + back_populates="magic_client", + primaryjoin="and_(foreign(AuthUser.client_id) == MagicClient.id, AuthUser.user_type != 1)", + ) + + +class AuthUser(db.Model, SoftDeleteMixin, TimestampMixin): + __tablename__ = "auth_user" + + id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + email: Mapped[t.Optional[str]] = sa.orm.mapped_column(index=True) + phone_number: Mapped[t.Optional[str]] = sa.orm.mapped_column(index=True) + user_type: Mapped[int] = sa.orm.mapped_column(default=EntityType.FORTMATIC.value) + date_verified: Mapped[t.Optional[datetime]] + provenance: Mapped[t.Optional[Provenance]] + is_admin: Mapped[bool] = sa.orm.mapped_column(default=False) + client_id: Mapped[ObjectID] + linked_primary_auth_user_id: Mapped[t.Optional[ObjectID]] + global_auth_user_id: Mapped[t.Optional[ObjectID]] + + delegated_user_id: Mapped[t.Optional[str]] + delegated_identity_pool_id: Mapped[t.Optional[str]] + + current_session_token: Mapped[t.Optional[str]] + + magic_client: Mapped[MagicClient] = sa.orm.relationship( + back_populates="auth_user", + uselist=False, + ) + linked_primary_auth_user = sa.orm.relationship( + "AuthUser", + remote_side=[id], + lazy="joined", + join_depth=1, + uselist=False, + ) + wallets: Mapped[t.List["AuthWallet"]] = sa.orm.relationship(back_populates="auth_user") + + @hybrid_property + def is_email_verified(self): + return self.email is not None and self.date_verified is not None + + @hybrid_property + def is_waiting_on_email_verification(self): + return self.email is not None and self.date_verified is None + + @hybrid_property + def is_new_signup(self): + return self.date_verified is None + + @hybrid_property + def has_linked_primary_auth_user(self): + return bool(self.linked_primary_auth_user_id) + + @hybrid_property + def is_magic_connect_user(self): + return self.global_auth_user_id is not None and self.user_type == EntityType.CONNECT.value + + +class AuthWallet(db.Model, SoftDeleteMixin, TimestampMixin): + __tablename__ = "auth_user" + + id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + auth_user_id: Mapped[ObjectID] = sa.orm.mapped_column(sa.ForeignKey("auth_user.id")) + wallet_type: Mapped[str] = sa.orm.mapped_column(default=WalletType.ETH.value) + management_type: Mapped[int] = sa.orm.mapped_column( + default=WalletManagementType.UNDELEGATED.value + ) + public_address: Mapped[t.Optional[str]] = sa.orm.mapped_column(index=True) + encrypted_private_address: Mapped[t.Optional[str]] + network: Mapped[str] + is_exported: Mapped[bool] = sa.orm.mapped_column(default=False) + + auth_user: Mapped[AuthUser] = sa.orm.relationship( + back_populates="auth_wallets", + uselist=False, + ) diff --git a/src/quart_sqlalchemy/sim/repo.py b/src/quart_sqlalchemy/sim/repo.py new file mode 100644 index 0000000..d257b5f --- /dev/null +++ b/src/quart_sqlalchemy/sim/repo.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import typing as t +from abc import ABCMeta +from abc import abstractmethod + +import sqlalchemy +import sqlalchemy.event +import sqlalchemy.exc +import sqlalchemy.orm +import sqlalchemy.sql + +from quart_sqlalchemy.types import ColumnExpr +from quart_sqlalchemy.types import EntityIdT +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import ORMOption +from quart_sqlalchemy.types import Selectable +from quart_sqlalchemy.types import SessionT + +from .builder import StatementBuilder + + +sa = sqlalchemy + + +class AbstractRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): + """A repository interface.""" + + # identity: t.Type[EntityIdT] + + # def __init__(self, model: t.Type[EntityT]): + # self.model = model + + @property + def model(self) -> t.Type[EntityT]: + return self.__orig_class__.__args__[0] # type: ignore + + @property + def identity(self) -> t.Type[EntityIdT]: + return self.__orig_class__.__args__[1] # type: ignore + + +class AbstractBulkRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): + """A repository interface for bulk operations. + + Note: this interface circumvents ORM internals, breaking commonly expected behavior in order + to gain performance benefits. Only use this class whenever absolutely necessary. + """ + + @property + def model(self) -> t.Type[EntityT]: + return self.__orig_class__.__args__[0] # type: ignore + + @property + def identity(self) -> t.Type[EntityIdT]: + return self.__orig_class__.__args__[1] # type: ignore + + + +class SQLAlchemyRepository( + AbstractRepository[EntityT, EntityIdT], + t.Generic[EntityT, EntityIdT], +): + """A repository that uses SQLAlchemy to persist data. + + The biggest change with this repository is that for methods returning multiple results, we + return the sa.ScalarResult so that the caller has maximum flexibility in how it's consumed. + + As a result, when calling a method such as get_by, you then need to decide how to fetch the + result. + + Methods of fetching results: + - .all() to return a list of results + - .first() to return the first result + - .one() to return the first result or raise an exception if there are no results + - .one_or_none() to return the first result or None if there are no results + - .partitions(n) to return a results as a list of n-sized sublists + + Additionally, there are methods for transforming the results prior to fetching. + + Methods of transforming results: + - .unique() to apply unique filtering to the result + + """ + + session: sa.orm.Session + builder: StatementBuilder + + def __init__(self, session: sa.orm.Session, **kwargs): + super().__init__(**kwargs) + self.session = session + self.builder = StatementBuilder(None) + + def insert(self, values: t.Dict[str, t.Any]) -> EntityT: + """Insert a new model into the database.""" + new = self.model(**values) + self.session.add(new) + self.session.flush() + self.session.refresh(new) + return new + + def update(self, id_: EntityIdT, values: t.Dict[str, t.Any]) -> EntityT: + """Update existing model with new values.""" + obj = self.session.get(self.model, id_) + if obj is None: + raise ValueError(f"Object with id {id_} not found") + for field, value in values.items(): + if getattr(obj, field) != value: + setattr(obj, field, value) + self.session.flush() + self.session.refresh(obj) + return obj + + def merge( + self, id_: EntityIdT, values: t.Dict[str, t.Any], for_update: bool = False + ) -> EntityT: + """Merge model in session/db having id_ with values.""" + self.session.get(self.model, id_) + values.update(id=id_) + merged = self.session.merge(self.model(**values)) + self.session.flush() + self.session.refresh(merged, with_for_update=for_update) # type: ignore + return merged + + def get( + self, + id_: EntityIdT, + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + for_update: bool = False, + include_inactive: bool = False, + ) -> t.Optional[EntityT]: + """Get object identified by id_ from the database. + + Note: It's a common misconception that session.get(Model, id) is akin to a shortcut for + a select(Model).where(Model.id == id) like statement. However this is not the case. + + Session.get is actually used for looking an object up in the sessions identity map. When + present it will be returned directly, when not, a database lookup will be performed. + + For use cases where this is what you actually want, you can still access the original get + method on self.session. For most uses cases, this behavior can introduce non-determinism + and because of that this method performs lookup using a select statement. Additionally, + to satisfy the expected interface's return type: Optional[EntityT], one_or_none is called + on the result before returning. + """ + execution_options = execution_options or {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + + statement = sa.select(self.model).where(self.model.id == id_).limit(1) # type: ignore + + for option in options: + statement = statement.options(option) + + if for_update: + statement = statement.with_for_update() + + return self.session.scalars(statement, execution_options=execution_options).one_or_none() + + def select( + self, + selectables: t.Sequence[Selectable] = (), + conditions: t.Sequence[ColumnExpr] = (), + group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + yield_by_chunk: t.Optional[int] = None, + ) -> t.Union[sa.ScalarResult[EntityT], t.Iterator[t.Sequence[EntityT]]]: + """Select from the database. + + Note: yield_by_chunk is not compatible with the subquery and joined loader strategies, use selectinload for eager loading. + """ + selectables = selectables or (self.model,) # type: ignore + + execution_options = execution_options or {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + if yield_by_chunk: + execution_options.setdefault("yield_per", yield_by_chunk) + + statement = self.builder.complex_select( + selectables, + conditions=conditions, + group_by=group_by, + order_by=order_by, + options=options, + execution_options=execution_options, + offset=offset, + limit=limit, + distinct=distinct, + for_update=for_update, + ) + + results = self.session.scalars(statement) + if yield_by_chunk: + results = results.partitions() + return results + + def delete(self, id_: EntityIdT, include_inactive: bool = False) -> None: + # if self.has_soft_delete: + # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") + + entity = self.get(id_, include_inactive=include_inactive) + if not entity: + raise RuntimeError(f"Entity with id {id_} not found.") + + self.session.delete(entity) + self.session.flush() + + def deactivate(self, id_: EntityIdT) -> EntityT: + # if not self.has_soft_delete: + # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") + + return self.update(id_, dict(is_active=False)) + + def reactivate(self, id_: EntityIdT) -> EntityT: + # if not self.has_soft_delete: + # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") + + return self.update(id_, dict(is_active=False)) + + def exists( + self, + conditions: t.Sequence[ColumnExpr] = (), + for_update: bool = False, + include_inactive: bool = False, + ) -> bool: + """Return whether an object matching conditions exists. + + Note: This performs better than simply trying to select an object since there is no + overhead in sending the selected object and deserializing it. + """ + selectable = sa.sql.literal(True) + + execution_options = {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + + statement = sa.select(selectable).where(*conditions) # type: ignore + + if for_update: + statement = statement.with_for_update() + + result = self.session.execute(statement, execution_options=execution_options).scalar() + + return bool(result) + + +class SQLAlchemyBulkRepository(AbstractBulkRepository, t.Generic[SessionT, EntityT, EntityIdT]): + def __init__(self, session: SessionT, **kwargs: t.Any): + super().__init__(**kwargs) + self.builder = StatementBuilder(self.model) + self.session = session + + def bulk_insert( + self, + values: t.Sequence[t.Dict[str, t.Any]] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Result[t.Any]: + statement = self.builder.bulk_insert(self.model, values) + return self.session.execute(statement, execution_options=execution_options or {}) + + def bulk_update( + self, + conditions: t.Sequence[ColumnExpr] = (), + values: t.Optional[t.Dict[str, t.Any]] = None, + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Result[t.Any]: + statement = self.builder.bulk_update(self.model, conditions, values) + return self.session.execute(statement, execution_options=execution_options or {}) + + def bulk_delete( + self, + conditions: t.Sequence[ColumnExpr] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Result[t.Any]: + statement = self.builder.bulk_delete(self.model, conditions) + return self.session.execute(statement, execution_options=execution_options or {}) diff --git a/src/quart_sqlalchemy/sim/repo_adapter.py b/src/quart_sqlalchemy/sim/repo_adapter.py new file mode 100644 index 0000000..bbd96b0 --- /dev/null +++ b/src/quart_sqlalchemy/sim/repo_adapter.py @@ -0,0 +1,309 @@ +import typing as t + +from pydantic import BaseModel +from sqlalchemy import ScalarResult +from sqlalchemy.orm import selectinload, Session +from sqlalchemy.sql.expression import func +from sqlalchemy.sql.expression import label +from quart_sqlalchemy.model import Base +from quart_sqlalchemy.types import ColumnExpr +from quart_sqlalchemy.types import EntityIdT +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import ORMOption +from quart_sqlalchemy.types import Selectable + +from .repo import SQLAlchemyRepository + + +class BaseModelSchema(BaseModel): + class Config: + from_orm = True + + +class BaseCreateSchema(BaseModelSchema): + pass + + +class BaseUpdateSchema(BaseModelSchema): + pass + + +ModelSchemaT = t.TypeVar("ModelSchemaT", bound=BaseModelSchema) +CreateSchemaT = t.TypeVar("CreateSchemaT", bound=BaseCreateSchema) +UpdateSchemaT = t.TypeVar("UpdateSchemaT", bound=BaseUpdateSchema) + + +class RepositoryLegacyAdapter(t.Generic[EntityT, EntityIdT]): + def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT], session: Session,): + self.model = model + self._identity = identity + self._session = session + self.repo = SQLAlchemyRepository[model, identity](session) + + def get_by( + self, + session: t.Optional[Session] = None, + filters=None, + allow_inactive=False, + join_list=None, + order_by_clause=None, + for_update=False, + offset=None, + limit=None, + ) -> t.Sequence[EntityT]: + if filters is None: + raise ValueError("Full table scans are prohibited. Please provide filters") + + join_list = join_list or () + + if order_by_clause is not None: + order_by_clause = (order_by_clause,) + else: + order_by_clause = () + + return self.repo.select( + conditions=filters, + options=[selectinload(getattr(self.model, attr)) for attr in join_list], + for_update=for_update, + order_by=order_by_clause, + offset=offset, + limit=limit, + include_inactive=allow_inactive, + ).all() + + def get_by_id( + self, + session = None, + model_id = None, + allow_inactive=False, + join_list=None, + for_update=False, + ) -> t.Optional[EntityT]: + if model_id is None: + raise ValueError("model_id is required") + join_list = join_list or () + return self.repo.get( + id_=model_id, + options=[selectinload(getattr(self.model, attr)) for attr in join_list], + for_update=for_update, + include_inactive=allow_inactive, + ) + + def one(self, session = None, filters=None, join_list=None, for_update=False, include_inactive=False) -> EntityT: + filters = filters or () + join_list = join_list or () + return self.repo.select( + conditions=filters, + options=[selectinload(getattr(self.model, attr)) for attr in join_list], + for_update=for_update, + include_inactive=include_inactive, + ).one() + + def count_by( + self, + session = None, + filters=None, + group_by=None, + distinct_column=None, + ): + if filters is None: + raise ValueError("Full table scans are prohibited. Please provide filters") + + group_by = group_by or () + + if distinct_column: + selectables = [label("count", func.count(func.distinct(distinct_column)))] + else: + selectables = [label("count", func.count(self.model.id))] + + for group in group_by: + selectables.append(group.expression) + + result = self.repo.select(selectables, conditions=filters, group_by=group_by) + + return result.all() + + def add(self, session = None, **kwargs) -> EntityT: + return self.repo.insert(kwargs) + + def update(self, session = None, model_id, **kwargs) -> EntityT: + return self.repo.update(id_=model_id, values=kwargs) + + def update_by(self, session = None, filters=None, **kwargs) -> EntityT: + if not filters: + raise ValueError("Full table scans are prohibited. Please provide filters") + + row = self.repo.select(conditions=filters, limit=2).one() + return self.repo.update(id_=row.id, values=kwargs) + + def delete_by_id(self, session = None, model_id) -> None: + self.repo.delete(id_=model_id, include_inactive=True) + + def delete_one_by(self, session = None, filters=None, optional=False) -> None: + filters = filters or () + result = self.repo.select(conditions=filters, limit=1) + + if optional: + row = result.one_or_none() + if row is None: + return + else: + row = result.one() + + self.repo.delete(id_=row.id) + + def exist(self, session = None, filters=None, allow_inactive=False) -> bool: + filters = filters or () + return self.repo.exists( + conditions=filters, + include_inactive=allow_inactive, + ) + + def yield_by_chunk( + self, session = None, chunk_size = 100, join_list=None, filters=None, allow_inactive=False + ): + filters = filters or () + join_list = join_list or () + results = self.repo.select( + conditions=filters, + options=[selectinload(getattr(self.model, attr)) for attr in join_list], + include_inactive=allow_inactive, + yield_by_chunk=chunk_size, + ) + for result in results: + yield result + + +class PydanticScalarResult(ScalarResult): + pydantic_schema: t.Type[ModelSchemaT] + + def __init__(self, scalar_result, pydantic_schema: t.Type[ModelSchemaT]): + for attribute in scalar_result.__slots__: + setattr(self, attribute, getattr(scalar_result, attribute)) + self.pydantic_schema = pydantic_schema + + def _translate_many(self, rows): + return [self.pydantic_schema.from_orm(row) for row in rows] + + def _translate_one(self, row): + if row is None: + return + return self.pydantic_schema.from_orm(row) + + def all(self): + return self._translate_many(super().all()) + + def fetchall(self): + return self._translate_many(super().fetchall()) + + def fetchmany(self, *args, **kwargs): + return self._translate_many(super().fetchmany(*args, **kwargs)) + + def first(self): + return self._translate_one(super().first()) + + def one(self): + return self._translate_one(super().one()) + + def one_or_none(self): + return self._translate_one(super().one_or_none()) + + def partitions(self, *args, **kwargs): + for partition in super().partitions(*args, **kwargs): + yield self._translate_many(partition) + + +class PydanticRepository(SQLAlchemyRepository, t.Generic[EntityT, EntityIdT, ModelSchemaT]): + @property + def schema(self) -> t.Type[ModelSchemaT]: + return self.__orig_class__.__args__[2] # type: ignore + + def insert( + self, + create_schema: CreateSchemaT, + sqla_model=False, + ): + create_data = create_schema.dict() + result = super().insert(create_data) + + if sqla_model: + return result + return self.schema.from_orm(result) + + def update( + self, + id_: EntityIdT, + update_schema: UpdateSchemaT, + sqla_model=False, + ): + existing = self.session.query(self.model).get(id_) + if existing is None: + raise ValueError("Model not found") + + update_data = update_schema.dict(exclude_unset=True) + for key, value in update_data.items(): + setattr(existing, key, value) + + self.session.add(existing) + self.session.flush() + self.session.refresh(existing) + if sqla_model: + return existing + return self.schema.from_orm(existing) + + def get( + self, + id_: EntityIdT, + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + for_update: bool = False, + include_inactive: bool = False, + sqla_model: bool = False, + ): + row = super().get( + id_, + options, + execution_options, + for_update, + include_inactive, + ) + if row is None: + return + + if sqla_model: + return row + return self.schema.from_orm(row) + + def select( + self, + selectables: t.Sequence[Selectable] = (), + conditions: t.Sequence[ColumnExpr] = (), + group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + yield_by_chunk: t.Optional[int] = None, + sqla_model: bool = False, + ): + result = super().select( + selectables, + conditions, + group_by, + order_by, + options, + execution_options, + offset, + limit, + distinct, + for_update, + include_inactive, + yield_by_chunk, + ) + if sqla_model: + return result + return PydanticScalarResult(result, self.schema) diff --git a/src/quart_sqlalchemy/sim/schema.py b/src/quart_sqlalchemy/sim/schema.py new file mode 100644 index 0000000..f47225a --- /dev/null +++ b/src/quart_sqlalchemy/sim/schema.py @@ -0,0 +1,14 @@ +from datetime import datetime + +from pydantic import BaseModel + +from .util import ObjectID + + +class BaseSchema(BaseModel): + class Config: + arbitrary_types_allowed = True + json_encoders = { + ObjectID: lambda v: v.encode(), + datetime: lambda dt: int(dt.timestamp()), + } diff --git a/src/quart_sqlalchemy/sim/signals.py b/src/quart_sqlalchemy/sim/signals.py new file mode 100644 index 0000000..1981cea --- /dev/null +++ b/src/quart_sqlalchemy/sim/signals.py @@ -0,0 +1,33 @@ +from blinker import Namespace + + +# Synchronous signals +_sync = Namespace() + +auth_user_duplicate = _sync.signal( + "auth_user_duplicate", + doc="""Called on discovery of at least one duplicate auth user. + + Handlers should have the following signature: + def handler( + current_app: Quart, + original_auth_user_id: ObjectID, + duplicate_auth_user_ids: List[ObjectID], + ) -> None: + ... + """, +) + +keys_rolled = _sync.signal( + "keys_rolled", + doc="""Called after api keys are rolled. + + Handlers should have the following signature: + def handler( + app: Quart, + deactivated_keys: Dict[str, Any], + redis_client: Redis, + ) -> None: + ... + """, +) diff --git a/src/quart_sqlalchemy/sim/util.py b/src/quart_sqlalchemy/sim/util.py new file mode 100644 index 0000000..463b350 --- /dev/null +++ b/src/quart_sqlalchemy/sim/util.py @@ -0,0 +1,101 @@ +import logging +import numbers + +from hashids import Hashids + + +logger = logging.getLogger(__name__) + + +class CryptographyError(Exception): + pass + + +class DecryptionError(CryptographyError): + pass + + +def one(input_list): + if len(input_list) != 1: + raise ValueError(f"Expected a list of length 1, got {len(input_list)}") + return input_list[0] + + +class ObjectID: + hashids = Hashids(min_length=12) + + def __init__(self, input_value): + if input_value is None: + raise ValueError("ObjectID cannot be None") + elif isinstance(input_value, ObjectID): + self._source_id = input_value._decoded_id + elif isinstance(input_value, str): + self._source_id = input_value + self._decode() + elif isinstance(input_value, numbers.Number): + try: + input_value = int(input_value) + except (ValueError, TypeError): + pass + + self._source_id = input_value + self._encode() + + @property + def _encoded_id(self): + return self._encode() + + @property + def _decoded_id(self): + return self._decode() + + def __eq__(self, other): + if isinstance(other, ObjectID): + return self._decoded_id == other._decoded_id and self._encoded_id == other._encoded_id + elif isinstance(other, int): + return self._decoded_id == other + elif isinstance(other, str): + return self._encoded_id == other + else: + return False + + def __lt__(self, other): + if isinstance(other, ObjectID): + return self._decoded_id < other._decoded_id + return False + + def __hash__(self): + return hash(tuple([self._encoded_id, self._decoded_id])) + + def __str__(self): + return "{encoded_id}".format(encoded_id=self._encoded_id) + + def __int__(self): + return self._decoded_id + + def __repr__(self): + return f"{type(self).__name__}({self._decoded_id})" + + def __json__(self): + return self.__str__() + + def _encode(self): + if isinstance(self._source_id, int): + return self.hashids.encode(self._source_id) + else: + return self._source_id + + def encode(self): + return self._encoded_id + + def _decode(self): + if isinstance(self._source_id, int): + return self._source_id + else: + return self.hashids.decode(self._source_id) + + def decode(self): + return self._decoded_id + + def decode_str(self): + return str(self._decoded_id) diff --git a/src/quart_sqlalchemy/sim/views/__init__.py b/src/quart_sqlalchemy/sim/views/__init__.py new file mode 100644 index 0000000..d110bc2 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/__init__.py @@ -0,0 +1,18 @@ +from quart import Blueprint +from quart import g + +from .auth_user import api as auth_user_api +from .auth_wallet import api as auth_wallet_api +from .magic_client import api as magic_client_api + + +api = Blueprint("api", __name__, url_prefix="api") + +api.register_blueprint(auth_user_api) +api.register_blueprint(auth_wallet_api) +api.register_blueprint(magic_client_api) + + +@api.before_request +def set_feature_owner(): + g.request_feature_owner = "magic" diff --git a/src/quart_sqlalchemy/sim/views/auth_user.py b/src/quart_sqlalchemy/sim/views/auth_user.py new file mode 100644 index 0000000..ee456e6 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/auth_user.py @@ -0,0 +1,7 @@ +import logging + +from quart import Blueprint + + +logger = logging.getLogger(__name__) +api = Blueprint("auth_user", __name__, url_prefix="auth_user") diff --git a/src/quart_sqlalchemy/sim/views/auth_wallet.py b/src/quart_sqlalchemy/sim/views/auth_wallet.py new file mode 100644 index 0000000..24fc9fc --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/auth_wallet.py @@ -0,0 +1,62 @@ +import logging +import typing as t + +from quart import Blueprint +from quart import g +from quart.utils import run_sync +from quart_schema.validation import validate + +from ..model import WalletManagementType +from ..model import WalletType +from ..schema import BaseSchema +from ..util import ObjectID +from .decorator import authorized_request + + +logger = logging.getLogger(__name__) +api = Blueprint("auth_wallet", __name__, url_prefix="auth_wallet") + + +@api.before_request +def set_feature_owner(): + g.request_feature_owner = "wallet" + + +class WalletSyncRequest(BaseSchema): + public_address: str + encrypted_private_address: str + wallet_type: str + hd_path: t.Optional[str] = None + encrypted_seed_phrase: t.Optional[str] = None + + +class WalletSyncResponse(BaseSchema): + wallet_id: ObjectID + auth_user_id: ObjectID + wallet_type: WalletType + public_address: str + encrypted_private_address: str + + +@authorized_request(authenticate_client=True, authenticate_user=True) +@validate(request=WalletSyncRequest, responses={200: (WalletSyncResponse, None)}) +@api.route("/sync", methods=["POST"]) +async def sync_auth_user_wallet(data: WalletSyncRequest): + try: + with g.bind.Session() as session: + wallet = await run_sync(g.h.AuthWallet(session).sync_auth_wallet)( + g.auth.user.id, + data.public_address, + data.encrypted_private_address, + WalletManagementType.DELEGATED.value, + ) + except RuntimeError: + raise RuntimeError("Unsupported wallet type or network") + + return WalletSyncResponse( + wallet_id=wallet.id, + auth_user_id=wallet.auth_user_id, + wallet_type=wallet.wallet_type, + public_address=wallet.public_address, + encrypted_private_address=wallet.encrypted_private_address, + ) diff --git a/src/quart_sqlalchemy/sim/views/decorator.py b/src/quart_sqlalchemy/sim/views/decorator.py new file mode 100644 index 0000000..357a819 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/decorator.py @@ -0,0 +1,131 @@ +# import inspect +# import typing as t +# from dataclasses import asdict +# from dataclasses import is_dataclass +from functools import wraps + +# from pydantic import BaseModel +# from pydantic import ValidationError +# from pydantic.dataclasses import dataclass as pydantic_dataclass +# from pydantic.schema import model_schema +from quart import current_app +from quart import g + + +# from quart import request +# from quart import ResponseReturnValue as QuartResponseReturnValue +# from quart_schema.typing import Model +# from quart_schema.typing import PydanticModel +# from quart_schema.typing import ResponseReturnValue +# from quart_schema.validation import _convert_headers +# from quart_schema.validation import DataSource +# from quart_schema.validation import QUART_SCHEMA_RESPONSE_ATTRIBUTE +# from quart_schema.validation import ResponseHeadersValidationError +# from quart_schema.validation import ResponseSchemaValidationError +# from quart_schema.validation import validate_headers +# from quart_schema.validation import validate_querystring +# from quart_schema.validation import validate_request + + +def authorized_request(authenticate_client: bool = False, authenticate_user: bool = False): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + if authenticate_client: + if not g.auth.client: + raise RuntimeError("Unable to authenticate client") + kwargs.update(client_id=g.auth.client.id) + + if authenticate_user: + if not g.auth.user: + raise RuntimeError("Unable to authenticate user") + kwargs.update(user_id=g.auth.user.id) + + return await current_app.ensure_async(func)(*args, **kwargs) + + return wrapper + + return decorator + + +# def validate_response() -> t.Callable: +# def decorator( +# func: t.Callable[..., ResponseReturnValue] +# ) -> t.Callable[..., QuartResponseReturnValue]: +# undecorated = func +# while hasattr(undecorated, "__wrapped__"): +# undecorated = undecorated.__wrapped__ + +# signature = inspect.signature(undecorated) +# derived_schema = signature.return_annotation or dict + +# schemas = getattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, {}) +# schemas[200] = (derived_schema, None) +# setattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, schemas) + +# @wraps(func) +# async def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: +# result = await current_app.ensure_async(func)(*args, **kwargs) + +# status_or_headers = None +# headers = None +# if isinstance(result, tuple): +# value, status_or_headers, headers = result + (None,) * (3 - len(result)) +# else: +# value = result + +# status = 200 +# if isinstance(status_or_headers, int): +# status = int(status_or_headers) + +# schemas = getattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, {200: dict}) +# model_class = schemas.get(status, dict) + +# try: +# if isinstance(value, dict): +# model_value = model_class(**value) +# elif type(value) == model_class: +# model_value = value +# elif is_dataclass(value): +# model_value = model_class(**asdict(value)) +# else: +# return result, status, headers + +# except ValidationError as error: +# raise ResponseHeadersValidationError(error) + +# headers_value = headers +# return model_value, status, headers_value + +# return wrapper + +# return decorator + + +# def validate( +# *, +# querystring: t.Optional[Model] = None, +# request: t.Optional[Model] = None, +# request_source: DataSource = DataSource.JSON, +# headers: t.Optional[Model] = None, +# responses: t.Dict[int, t.Tuple[Model, t.Optional[Model]]], +# ) -> t.Callable: +# """Validate the route. + +# This is a shorthand combination of of the validate_querystring, +# validate_request, validate_headers, and validate_response +# decorators. Please see the docstrings for those decorators. +# """ + +# def decorator(func: t.Callable) -> t.Callable: +# if querystring is not None: +# func = validate_querystring(querystring)(func) +# if request is not None: +# func = validate_request(request, source=request_source)(func) +# if headers is not None: +# func = validate_headers(headers)(func) +# for status, models in responses.items(): +# func = validate_response(models[0], status, models[1]) +# return func + +# return decorator diff --git a/src/quart_sqlalchemy/sim/views/magic_client.py b/src/quart_sqlalchemy/sim/views/magic_client.py new file mode 100644 index 0000000..a1d28e1 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/magic_client.py @@ -0,0 +1,7 @@ +import logging + +from quart import Blueprint + + +logger = logging.getLogger(__name__) +api = Blueprint("magic_client", __name__, url_prefix="magic_client") From fc9ae801a8cb5716894ab4ce0f1cd01b786b52ff Mon Sep 17 00:00:00 2001 From: Joe Black Date: Thu, 30 Mar 2023 18:34:08 -0400 Subject: [PATCH 2/9] clean --- src/quart_sqlalchemy/sim/handle.py | 8 +- src/quart_sqlalchemy/sim/legacy.py | 404 ----------------------------- 2 files changed, 1 insertion(+), 411 deletions(-) delete mode 100644 src/quart_sqlalchemy/sim/legacy.py diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py index 7076309..6ed473c 100644 --- a/src/quart_sqlalchemy/sim/handle.py +++ b/src/quart_sqlalchemy/sim/handle.py @@ -37,12 +37,6 @@ class APIKeySet(t.NamedTuple): secret_key: str -def get_session_proxy(): - from .app import db - - return db.bind.Session() - - class HandlerBase: logic: Logic session: Session @@ -51,7 +45,7 @@ class HandlerBase: """ def __init__(self, session: t.Optional[Session], logic: t.Optional[Logic] = None): - self.session = session or get_session_proxy() + self.session = session self.logic = logic or Logic() diff --git a/src/quart_sqlalchemy/sim/legacy.py b/src/quart_sqlalchemy/sim/legacy.py deleted file mode 100644 index 5d4662f..0000000 --- a/src/quart_sqlalchemy/sim/legacy.py +++ /dev/null @@ -1,404 +0,0 @@ -from sqlalchemy import exists -from sqlalchemy.orm import joinedload -from sqlalchemy.sql.expression import func -from sqlalchemy.sql.expression import label - - -def one(input_list): - (item,) = input_list - return item - - -class SQLRepository: - def __init__(self, model): - self._model = model - assert self._model is not None - - @property - def _has_is_active_field(self): - return bool(getattr(self._model, "is_active", None)) - - def get_by_id( - self, - session, - model_id, - allow_inactive=False, - join_list=None, - for_update=False, - ): - """SQL get interface to retrieve by model's id column. - - Args: - session: A database session object. - model_id: The id of the given model to be retrieved. - allow_inactive: Whether to include inactive or not. - join_list: A list of attributes to be joined in the same session for - given model. This is normally the attributes that have - relationship defined and referenced to other models. - for_update: Locks the table for update. - - Returns: - Data retrieved from the database for the model. - """ - query = session.query(self._model) - - if join_list: - for to_join in join_list: - query = query.options(joinedload(to_join)) - - if for_update: - query = query.with_for_update() - - row = query.get(model_id) - - if row is None: - return None - - if self._has_is_active_field and not row.is_active and not allow_inactive: - return None - - return row - - def get_by( - self, - session, - filters=None, - join_list=None, - order_by_clause=None, - for_update=False, - offset=None, - limit=None, - ): - """SQL get_by interface to retrieve model instances based on the given - filters. - - Args: - session: A database session object. - filters: A list of filters on the models. - join_list: A list of attributes to be joined in the same session for - given model. This is normally the attributes that have - relationship defined and referenced to other models. - order_by_clause: An order by clause. - for_update: Locks the table for update. - - Returns: - Modified rows. - - TODO(ajen#ch21549|2020-07-21): Filter out `is_active == False` row. This - will not be a trivial change as many places rely on this method and the - handlers/logics sometimes filter by in_active. Sometimes endpoints might - get affected. Proceed with caution. - """ - # If no filter is given, just return. Prevent table scan. - if filters is None: - return None - - query = session.query(self._model).filter(*filters).order_by(order_by_clause) - - if for_update: - query = query.with_for_update() - - if offset: - query = query.offset(offset) - - if limit: - query = query.limit(limit) - - # Prevent loading all the rows. - if limit == 0: - return [] - - if join_list: - for to_join in join_list: - query = query.options(joinedload(to_join)) - - return query.all() - - def count_by( - self, - session, - filters=None, - group_by=None, - distinct_column=None, - ): - """SQL count_by interface to retrieve model instance count based on the given - filters. - - Args: - session: A database session object. - filters (list): Required; a list of filters on the models. - group_by (list): A list of optional group by expressions. - Returns: - A list of counts of rows. - Raises: - ValueError: Returns a value error, when no filters are provided - """ - # Prevent table scans - if filters is None: - raise ValueError("Full table scans are prohibited. Please provide filters") - - select = [label("count", func.count(self._model.id))] - - if distinct_column: - select = [label("count", func.count(func.distinct(distinct_column)))] - - if group_by: - for group in group_by: - select.append(group.expression) - - query = session.query(*select).filter(*filters) - - if group_by: - query = query.group_by(*group_by) - - return query.all() - - def sum_by( - self, - session, - column, - filters=None, - group_by=None, - ): - """SQL sum_by interface to retrieve aggregate sum of column values for given - filters. - - Args: - session: A database session object. - column (sqlalchemy.Column): Required; the column to sum by. - filters (list): Required; a list of filters to apply to the query - group_by (list): A list of optional group by expressions. - - Returns: - A scalar value representing the sum or None. - - Raises: - ValueError: Returns a value error, when no filters are provided - """ - - # Prevent table scans - if filters is None: - raise ValueError("Full table scans are prohibited. Please provide filters") - - query = session.query(func.sum(column)).filter(*filters) - - if group_by: - query = query.group_by(*group_by) - - return query.scalar() - - def one(self, session, filters=None, join_list=None, for_update=False): - """SQL filtering interface to retrieve the single model instance matching - filter criteria. - - If there are more than one instances, an exception is raised. - - Args: - session: A database session object. - filters: A list of filters on the models. - for_update: Locks the table for update. - - Returns: - A model instance: If one row is found in the db. - None: If no result is found. - """ - row = self.get_by(session, filters=filters, join_list=join_list, for_update=for_update) - - if not row: - return None - - return one(row) - - def update(self, session, model_id, **kwargs): - """SQL update interface to modify data in a given model instance. - - Args: - session: A database session object. - model_id: The id of the given model to be modified. - kwargs: Any fields defined on the models. - - Returns: - Modified rows. - - Note: - We use session.flush() here to move the changes from the application - to SQL database. However, those changes will be in the pending changes - state. Meaning, it is in the queue to be inserted but yet to be done - so until session.commit() is called, which has been taken care of - in our ``with_db_session`` decorator or ``LogicComponent.begin`` - contextmanager. - """ - modified_row = session.query(self._model).get(model_id) - if modified_row is None: - return None - - for key, value in kwargs.items(): - setattr(modified_row, key, value) - - # Flush out our changes to DB transaction buffer but don't commit it yet. - # This is useful in the case when we want to rollback atomically on multiple - # sql operations in the same transaction which may or may not have - # dependencies. - session.flush() - - return modified_row - - def update_by(self, session, filters=None, **kwargs): - """SQL update_by interface to modify data for a given list of filters. - The filters should be provided so it can narrow down to one row. - - Args: - session: A database session object. - filters: A list of filters on the models. - kwargs: Any fields defined on the models. - - Returns: - Modified row. - - Raises: - sqlalchemy.orm.exc.NoResultFound - when no result is found. - sqlalchemy.orm.exc.MultipleResultsFound - when multiple result is found. - """ - # If no filter is given, just return. Prevent table scan. - if filters is None: - return None - - modified_row = session.query(self._model).filter(*filters).one() - for key, value in kwargs.items(): - setattr(modified_row, key, value) - - # Flush out our changes to DB transaction buffer but don't commit it yet. - # This is useful in the case when we want to rollback atomically on multiple - # sql operations in the same transaction which may or may not have - # dependencies. - session.flush() - - return modified_row - - def delete_one_by(self, session, filters=None, optional=False): - """SQL update_by interface to delete data for a given list of filters. - The filters should be provided so it can narrow down to one row. - - Note: Careful consideration should be had prior to using this function. - Always consider setting rows as inactive instead before choosing to use - this function. - - Args: - session: A database session object. - filters: A list of filters on the models. - optional: Whether deletion is optional; i.e. it's OK for the model not to exist - - Returns: - None. - - Raises: - sqlalchemy.orm.exc.NoResultFound - when no result is found and optional is False. - sqlalchemy.orm.exc.MultipleResultsFound - when multiple result is found. - """ - # If no filter is given, just return. Prevent table scan. - if filters is None: - return None - - if optional: - rows = session.query(self._model).filter(*filters).all() - - if not rows: - return None - - row = one(rows) - - else: - row = session.query(self._model).filter(*filters).one() - - session.delete(row) - - # Flush out our changes to DB transaction buffer but don't commit it yet. - # This is useful in the case when we want to rollback atomically on multiple - # sql operations in the same transaction which may or may not have - # dependencies. - session.flush() - - def delete_by_id(self, session, model_id): - return session.query(self._model).get(model_id).delete() - - def add(self, session, **kwargs): - """SQL add interface to insert data to the given model. - - Args: - session: A database session object. - kwargs: Any fields defined on the models. - - Returns: - Newly inserted rows. - - Note: - We use session.flush() here to move the changes from the application - to SQL database. However, those changes will be in the pending changes - state. Meaning, it is in the queue to be inserted but yet to be done - so until session.commit() is called, which has been taken care of - in our ``with_db_session`` decorator or ``LogicComponent.begin`` - contextmanager. - """ - new_row = self._model(**kwargs) - session.add(new_row) - - # Flush out our changes to DB transaction buffer but don't commit it yet. - # This is useful in the case when we want to rollback atomically on multiple - # sql operations in the same transaction which may or may not have - # dependencies. - session.flush() - - return new_row - - def exist(self, session, filters=None): - """SQL exist interface to check if any row exists at all for the given - filters. - - Args: - session: A database session object. - filters: A list of filters on the models. - - Returns: - A boolean. True if any row exists else False. - """ - exist_query = exists() - - for query_filter in filters: - exist_query = exist_query.where(query_filter) - - return session.query(exist_query).scalar() - - def yield_by_chunk(self, session, chunk_size, join_list=None, filters=None): - """This yields a batch of the model objects for the given chunk_size. - - Args: - session: A database session object. - chunk_size (int): The size of the chunk. - filters: A list of filters on the model. - join_list: A list of attributes to be joined in the same session for - given model. This is normally the attributes that have - relationship defined and referenced to other models. - - Returns: - A batch for the given chunk size. - """ - query = session.query(self._model) - - if filters is not None: - query = query.filter(*filters) - - if join_list: - for to_join in join_list: - query = query.options(joinedload(to_join)) - - start = 0 - - while True: - stop = start + chunk_size - model_objs = query.slice(start, stop).all() - if not model_objs: - break - - yield model_objs - - start += chunk_size From d660a7fa136082c1649520a3ff94c425e296acab Mon Sep 17 00:00:00 2001 From: Joe Black Date: Thu, 30 Mar 2023 19:16:56 -0400 Subject: [PATCH 3/9] fix circular dependencies --- src/quart_sqlalchemy/sim/app.py | 8 ++--- src/quart_sqlalchemy/sim/handle.py | 15 +++++----- src/quart_sqlalchemy/sim/logic.py | 38 ++++++++---------------- src/quart_sqlalchemy/sim/main.py | 8 ++--- src/quart_sqlalchemy/sim/model.py | 6 ++-- src/quart_sqlalchemy/sim/repo.py | 4 +-- src/quart_sqlalchemy/sim/repo_adapter.py | 2 +- src/quart_sqlalchemy/sim/schema.py | 2 +- src/quart_sqlalchemy/sqla.py | 8 +++++ 9 files changed, 41 insertions(+), 50 deletions(-) diff --git a/src/quart_sqlalchemy/sim/app.py b/src/quart_sqlalchemy/sim/app.py index 91b0ec8..e462fa0 100644 --- a/src/quart_sqlalchemy/sim/app.py +++ b/src/quart_sqlalchemy/sim/app.py @@ -12,10 +12,10 @@ from quart import Response from quart_schema import QuartSchema -from .. import Base -from .. import SQLAlchemyConfig -from ..framework import QuartSQLAlchemy -from .util import ObjectID +from quart_sqlalchemy import Base +from quart_sqlalchemy import SQLAlchemyConfig +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.sim.util import ObjectID AUTHORIZATION_PATTERN = re.compile(r"Bearer (?P.+)") diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py index 6ed473c..c81b71a 100644 --- a/src/quart_sqlalchemy/sim/handle.py +++ b/src/quart_sqlalchemy/sim/handle.py @@ -5,14 +5,13 @@ from sqlalchemy.orm import Session from quart_sqlalchemy import Bind - -from . import signals -from .logic import LogicComponent as Logic -from .model import AuthUser -from .model import AuthWallet -from .model import EntityType -from .model import WalletType -from .util import ObjectID +from quart_sqlalchemy.sim import signals +from quart_sqlalchemy.sim.logic import LogicComponent as Logic +from quart_sqlalchemy.sim.model import AuthUser +from quart_sqlalchemy.sim.model import AuthWallet +from quart_sqlalchemy.sim.model import EntityType +from quart_sqlalchemy.sim.model import WalletType +from quart_sqlalchemy.sim.util import ObjectID logger = logging.getLogger(__name__) diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py index dee3990..a13397d 100644 --- a/src/quart_sqlalchemy/sim/logic.py +++ b/src/quart_sqlalchemy/sim/logic.py @@ -3,36 +3,22 @@ from datetime import datetime from functools import wraps -from pydantic import BaseModel -from pydantic import Field from sqlalchemy import or_ -from sqlalchemy import ScalarResult from sqlalchemy.orm import contains_eager -from sqlalchemy.orm import DeclarativeBase -from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from sqlalchemy.sql.expression import func -from sqlalchemy.sql.expression import label - -from quart_sqlalchemy.model import Base -from quart_sqlalchemy.types import ColumnExpr -from quart_sqlalchemy.types import EntityIdT -from quart_sqlalchemy.types import EntityT -from quart_sqlalchemy.types import ORMOption -from quart_sqlalchemy.types import Selectable - -from . import signals -from .model import AuthUser as auth_user_model -from .model import AuthWallet as auth_wallet_model -from .model import ConnectInteropStatus -from .model import EntityType -from .model import MagicClient as magic_client_model -from .model import Provenance -from .model import WalletType -from .repo import SQLAlchemyRepository -from .repo_adapter import RepositoryLegacyAdapter -from .util import ObjectID -from .util import one + +from quart_sqlalchemy.sim import signals +from quart_sqlalchemy.sim.model import AuthUser as auth_user_model +from quart_sqlalchemy.sim.model import AuthWallet as auth_wallet_model +from quart_sqlalchemy.sim.model import ConnectInteropStatus +from quart_sqlalchemy.sim.model import EntityType +from quart_sqlalchemy.sim.model import MagicClient as magic_client_model +from quart_sqlalchemy.sim.model import Provenance +from quart_sqlalchemy.sim.model import WalletType +from quart_sqlalchemy.sim.repo_adapter import RepositoryLegacyAdapter +from quart_sqlalchemy.sim.util import ObjectID +from quart_sqlalchemy.sim.util import one logger = logging.getLogger(__name__) diff --git a/src/quart_sqlalchemy/sim/main.py b/src/quart_sqlalchemy/sim/main.py index 7bd73cd..2370b4f 100644 --- a/src/quart_sqlalchemy/sim/main.py +++ b/src/quart_sqlalchemy/sim/main.py @@ -1,9 +1,9 @@ -from .app import app -from .views import api +from quart_sqlalchemy.sim.app import app +from quart_sqlalchemy.sim.views import api -app.register_blueprint(api) +app.register_blueprint(api, url_prefix="/v1") if __name__ == "__main__": - app.run(port=8080) + app.run(port=8081) diff --git a/src/quart_sqlalchemy/sim/model.py b/src/quart_sqlalchemy/sim/model.py index efee04c..7a6566b 100644 --- a/src/quart_sqlalchemy/sim/model.py +++ b/src/quart_sqlalchemy/sim/model.py @@ -81,8 +81,8 @@ class MagicClient(db.Model, SoftDeleteMixin, TimestampMixin): is_signing_modal_enabled: Mapped[bool] = sa.orm.mapped_column(default=False) global_audience_enabled: Mapped[bool] = sa.orm.mapped_column(default=False) - public_api_key: Mapped[str] = sa.orm.mapped_column(default_factory=secrets.token_hex) - secret_api_key: Mapped[str] = sa.orm.mapped_column(default_factory=secrets.token_hex) + public_api_key: Mapped[str] = sa.orm.mapped_column(default=secrets.token_hex) + secret_api_key: Mapped[str] = sa.orm.mapped_column(default=secrets.token_hex) auth_users: Mapped[t.List["AuthUser"]] = sa.orm.relationship( back_populates="magic_client", @@ -144,7 +144,7 @@ def is_magic_connect_user(self): class AuthWallet(db.Model, SoftDeleteMixin, TimestampMixin): - __tablename__ = "auth_user" + __tablename__ = "auth_wallet" id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) auth_user_id: Mapped[ObjectID] = sa.orm.mapped_column(sa.ForeignKey("auth_user.id")) diff --git a/src/quart_sqlalchemy/sim/repo.py b/src/quart_sqlalchemy/sim/repo.py index d257b5f..9e41c70 100644 --- a/src/quart_sqlalchemy/sim/repo.py +++ b/src/quart_sqlalchemy/sim/repo.py @@ -10,6 +10,7 @@ import sqlalchemy.orm import sqlalchemy.sql +from quart_sqlalchemy.sim.builder import StatementBuilder from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT @@ -17,8 +18,6 @@ from quart_sqlalchemy.types import Selectable from quart_sqlalchemy.types import SessionT -from .builder import StatementBuilder - sa = sqlalchemy @@ -56,7 +55,6 @@ def identity(self) -> t.Type[EntityIdT]: return self.__orig_class__.__args__[1] # type: ignore - class SQLAlchemyRepository( AbstractRepository[EntityT, EntityIdT], t.Generic[EntityT, EntityIdT], diff --git a/src/quart_sqlalchemy/sim/repo_adapter.py b/src/quart_sqlalchemy/sim/repo_adapter.py index bbd96b0..89230b0 100644 --- a/src/quart_sqlalchemy/sim/repo_adapter.py +++ b/src/quart_sqlalchemy/sim/repo_adapter.py @@ -12,7 +12,7 @@ from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable -from .repo import SQLAlchemyRepository +from quart_sqlalchemy.sim.repo import SQLAlchemyRepository class BaseModelSchema(BaseModel): diff --git a/src/quart_sqlalchemy/sim/schema.py b/src/quart_sqlalchemy/sim/schema.py index f47225a..f2311ca 100644 --- a/src/quart_sqlalchemy/sim/schema.py +++ b/src/quart_sqlalchemy/sim/schema.py @@ -2,7 +2,7 @@ from pydantic import BaseModel -from .util import ObjectID +from quart_sqlalchemy.sim.util import ObjectID class BaseSchema(BaseModel): diff --git a/src/quart_sqlalchemy/sqla.py b/src/quart_sqlalchemy/sqla.py index 1df492a..56b8b60 100644 --- a/src/quart_sqlalchemy/sqla.py +++ b/src/quart_sqlalchemy/sqla.py @@ -38,6 +38,14 @@ def initialize(self): class Model(self.config.model_class, sa.orm.DeclarativeBase): pass + type_annotation_map = {} + for base_class in Model.__mro__[::-1]: + if base_class is Model: + continue + base_map = getattr(base_class, "type_annotation_map", {}).copy() + type_annotation_map.update(base_map) + + Model.registry.type_annotation_map.update(type_annotation_map) self.Model = Model self.binds = {} From b68fe9d0d3cd2532fceec04eec22eb2d3661b38b Mon Sep 17 00:00:00 2001 From: Joe Black Date: Fri, 31 Mar 2023 19:03:38 -0400 Subject: [PATCH 4/9] update sim --- src/quart_sqlalchemy/retry.py | 1 + src/quart_sqlalchemy/sim/app.py | 167 +++------ src/quart_sqlalchemy/sim/auth.py | 323 ++++++++++++++++++ src/quart_sqlalchemy/sim/db.py | 100 ++++++ src/quart_sqlalchemy/sim/handle.py | 182 +++++----- src/quart_sqlalchemy/sim/logic.py | 136 ++++---- src/quart_sqlalchemy/sim/main.py | 5 +- src/quart_sqlalchemy/sim/model.py | 12 +- src/quart_sqlalchemy/sim/repo.py | 78 +++-- src/quart_sqlalchemy/sim/repo_adapter.py | 82 +++-- src/quart_sqlalchemy/sim/schema.py | 31 +- src/quart_sqlalchemy/sim/testing.py | 13 + src/quart_sqlalchemy/sim/views/__init__.py | 4 +- src/quart_sqlalchemy/sim/views/auth_wallet.py | 53 ++- src/quart_sqlalchemy/sim/views/decorator.py | 131 ------- .../sim/views/util/__init__.py | 12 + .../sim/views/util/blueprint.py | 102 ++++++ .../sim/views/util/decorator.py | 120 +++++++ 18 files changed, 1066 insertions(+), 486 deletions(-) create mode 100644 src/quart_sqlalchemy/sim/auth.py create mode 100644 src/quart_sqlalchemy/sim/db.py create mode 100644 src/quart_sqlalchemy/sim/testing.py delete mode 100644 src/quart_sqlalchemy/sim/views/decorator.py create mode 100644 src/quart_sqlalchemy/sim/views/util/__init__.py create mode 100644 src/quart_sqlalchemy/sim/views/util/blueprint.py create mode 100644 src/quart_sqlalchemy/sim/views/util/decorator.py diff --git a/src/quart_sqlalchemy/retry.py b/src/quart_sqlalchemy/retry.py index 664dc66..8353dc3 100644 --- a/src/quart_sqlalchemy/retry.py +++ b/src/quart_sqlalchemy/retry.py @@ -99,6 +99,7 @@ async def add_user_post(db, user_id, post_values): import sqlalchemy.exc import sqlalchemy.orm import tenacity +from tenacity import RetryError sa = sqlalchemy diff --git a/src/quart_sqlalchemy/sim/app.py b/src/quart_sqlalchemy/sim/app.py index e462fa0..9f12d22 100644 --- a/src/quart_sqlalchemy/sim/app.py +++ b/src/quart_sqlalchemy/sim/app.py @@ -1,145 +1,88 @@ -import json import logging -import re import typing as t +from copy import deepcopy +from functools import wraps -import sqlalchemy as sa -from pydantic import BaseModel from quart import g from quart import Quart from quart import request -from quart import Request from quart import Response +from quart.typing import ResponseReturnValue +from quart_schema import APIKeySecurityScheme +from quart_schema import HttpSecurityScheme from quart_schema import QuartSchema +from werkzeug.utils import import_string -from quart_sqlalchemy import Base -from quart_sqlalchemy import SQLAlchemyConfig -from quart_sqlalchemy.framework import QuartSQLAlchemy -from quart_sqlalchemy.sim.util import ObjectID - -AUTHORIZATION_PATTERN = re.compile(r"Bearer (?P.+)") logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -class MyBase(Base): - type_annotation_map = {ObjectID: sa.Integer} - - -app = Quart(__name__) -db = QuartSQLAlchemy( - SQLAlchemyConfig.parse_obj( - { - "model_class": MyBase, - "binds": { - "default": { - "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, - "session": {"expire_on_commit": False}, - }, - "read-replica": { - "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, - "session": {"expire_on_commit": False}, - "read_only": True, - }, - "async": { - "engine": { - "url": "sqlite+aiosqlite:///file:mem.db?mode=memory&cache=shared&uri=true" - }, - "session": {"expire_on_commit": False}, - }, - }, - } - ) +BLUEPRINTS = ("quart_sqlalchemy.sim.views.api",) +EXTENSIONS = ( + "quart_sqlalchemy.sim.db.db", + "quart_sqlalchemy.sim.app.schema", + "quart_sqlalchemy.sim.auth.auth", ) -openapi = QuartSchema(app) - - -class RequestAuth(BaseModel): - client: t.Optional[t.Any] = None - user: t.Optional[t.Any] = None - - @property - def has_client(self): - return self.client is not None - - @property - def has_user(self): - return self.user is not None - - @property - def is_anonymous(self): - return all([self.has_client is False, self.has_user is False]) - -def get_request_client(request: Request): - api_key = request.headers.get("X-Public-API-Key") - if not api_key: - return +DEFAULT_CONFIG = { + "QUART_AUTH_SECURITY_SCHEMES": { + "public-api-key": APIKeySecurityScheme(in_="header", name="X-Public-API-Key"), + "session-token-bearer": HttpSecurityScheme(scheme="bearer", bearer_format="opaque"), + }, + "REGISTER_BLUEPRINTS": ["quart_sqlalchemy.sim.views.api"], +} - with g.bind.Session() as session: - try: - magic_client = g.h.MagicClient(session).get_by_public_api_key(api_key) - except ValueError: - return - else: - return magic_client +schema = QuartSchema(security_schemes=DEFAULT_CONFIG["QUART_AUTH_SECURITY_SCHEMES"]) -def get_request_user(request: Request): - auth_header = request.headers.get("Authorization") - if not auth_header: - return - m = AUTHORIZATION_PATTERN.match(auth_header) - if m is None: - raise RuntimeError("invalid authorization header") +def wrap_response(func: t.Callable) -> t.Callable: + @wraps(func) + async def decorator(result: ResponseReturnValue) -> Response: + # import pdb - auth_token = m.group("auth_token") + # pdb.set_trace() + return await func(result) - with g.bind.Session() as session: - try: - auth_user = g.h.AuthUser(session).get_by_session_token(auth_token) - except ValueError: - return - else: - return auth_user + return decorator -@app.before_request -def set_ethereum_network(): - g.request_network = request.headers.get("X-Fortmatic-Network", "GOERLI").upper() +def create_app( + override_config: t.Optional[t.Dict[str, t.Any]] = None, + extensions: t.Sequence[str] = EXTENSIONS, + blueprints: t.Sequence[str] = BLUEPRINTS, +): + override_config = override_config or {} + config = deepcopy(DEFAULT_CONFIG) + config.update(override_config) -@app.before_request -def set_bind_handlers_for_request(): - from quart_sqlalchemy.sim.handle import Handlers + app = Quart(__name__) + app.config.from_mapping(config) - g.db = db + for path in extensions: + extension = import_string(path) + extension.init_app(app) - method = request.method - if method in ["GET", "OPTIONS", "TRACE", "HEAD"]: - bind = "read-replica" - else: - bind = "default" + for path in blueprints: + bp = import_string(path) + app.register_blueprint(bp) - g.bind = db.get_bind(bind) - g.h = Handlers(g.bind) + @app.before_request + def set_ethereum_network(): + g.network = request.headers.get("X-Ethereum-Network", "GOERLI").upper() + # app.make_response = wrap_response(app.make_response) -@app.before_request -def set_request_auth(): - g.auth = RequestAuth( - client=get_request_client(request), - user=get_request_user(request), - ) + return app -@app.after_request -async def add_json_response_envelope(response: Response) -> Response: - if response.mimetype != "application/json": - return response - data = await response.get_json() - payload = dict(status="ok", message="", data=data) - response.set_data(json.dumps(payload)) - return response +# @app.after_request +# async def add_json_response_envelope(response: Response) -> Response: +# if response.mimetype != "application/json": +# return response +# data = await response.get_json() +# payload = dict(status="ok", message="", data=data) +# response.set_data(json.dumps(payload)) +# return response diff --git a/src/quart_sqlalchemy/sim/auth.py b/src/quart_sqlalchemy/sim/auth.py new file mode 100644 index 0000000..51c3962 --- /dev/null +++ b/src/quart_sqlalchemy/sim/auth.py @@ -0,0 +1,323 @@ +import logging +import re +import secrets +import typing as t + +import click +import sqlalchemy +import sqlalchemy.orm +import sqlalchemy.orm.exc +from quart import current_app +from quart import g +from quart import Quart +from quart import request +from quart import Request +from quart.cli import AppGroup +from quart.cli import pass_script_info +from quart.cli import ScriptInfo +from quart_schema.extension import QUART_SCHEMA_SECURITY_ATTRIBUTE +from quart_schema.extension import security_scheme +from quart_schema.openapi import APIKeySecurityScheme +from quart_schema.openapi import HttpSecurityScheme +from quart_schema.openapi import SecuritySchemeBase +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from .model import AuthUser +from .model import MagicClient +from .schema import BaseSchema +from .util import ObjectID + + +sa = sqlalchemy + +cli = AppGroup("auth") + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def authorized_request(security_schemes: t.Sequence[t.Dict[str, t.List[t.Any]]]): + def decorator(func): + return security_scheme(security_schemes)(func) + + return decorator + + +class MyRequest(Request): + @property + def ip_addr(self): + return self.remote_addr + + @property + def locale(self): + return self.accept_languages.best_match(["en"]) or "en" + + @property + def redirect_url(self): + return self.args.get("redirect_url") or self.headers.get("x-redirect-url") + + +class ValidatorError(RuntimeError): + pass + + +class SubjectNotFound(ValidatorError): + pass + + +class CredentialNotFound(ValidatorError): + pass + + +class Credential(BaseSchema): + scheme: SecuritySchemeBase + value: t.Optional[str] = None + subject: t.Union[MagicClient, AuthUser] + + +class AuthenticationValidator: + name: str + scheme: SecuritySchemeBase + + def extract(self, request: Request) -> str: + ... + + def lookup(self, value: str, session: Session) -> t.Any: + ... + + def authenticate(self, request: Request) -> Credential: + ... + + +class PublicAPIKeyValidator(AuthenticationValidator): + name = "public-api-key" + scheme = APIKeySecurityScheme(in_="header", name="X-Public-API-Key") + + def extract(self, request: Request) -> str: + if self.scheme.in_ == "header": + return request.headers.get(self.scheme.name, None) + elif self.scheme.in_ == "cookie": + return request.cookies.get(self.scheme.name, None) + elif self.scheme.in_ == "query": + return request.args.get(self.scheme.name, None) + else: + raise ValueError(f"No token found for {self.scheme}") + + def lookup(self, value: str, session: Session) -> t.Any: + statement = sa.select(MagicClient).where(MagicClient.public_api_key == value).limit(1) + + try: + result = session.scalars(statement).one() + except sa.orm.exc.NoResultFound: + raise SubjectNotFound(f"No MagicClient found for public_api_key {value}") + + return result + + def authenticate(self, request: Request, session: Session) -> Credential: + value = self.extract(request) + if value is None: + raise CredentialNotFound() + subject = self.lookup(value, session) + return Credential(scheme=self.scheme, value=value, subject=subject) + + +class SessionTokenValidator(AuthenticationValidator): + name = "session-token-bearer" + scheme = HttpSecurityScheme(scheme="bearer", bearer_format="opaque") + + AUTHORIZATION_PATTERN = re.compile(r"Bearer (?P.+)") + + def extract(self, request: Request) -> str: + if self.scheme.scheme != "bearer": + return + + value = request.headers.get("authorization") + m = self.AUTHORIZATION_PATTERN.match(value) + if m is None: + raise ValueError("Bearer token failed validation") + + return m.group("token") + + def lookup(self, value: str, session: Session) -> t.Any: + statement = sa.select(AuthUser).where(AuthUser.current_session_token == value).limit(1) + + try: + result = session.scalars(statement).one() + except sa.orm.exc.NoResultFound: + raise SubjectNotFound(f"No AuthUser found for session_token {value}") + + return result + + def authenticate(self, request: Request, session: Session) -> Credential: + value = self.extract(request) + if value is None: + raise CredentialNotFound() + subject = self.lookup(value, session) + return Credential(scheme=self.scheme, value=value, subject=subject) + + +class RequestAuthenticator: + validators = [PublicAPIKeyValidator(), SessionTokenValidator()] + validator_scheme_map = {v.name: v for v in validators} + + def enforce(self, security_schemes: t.Sequence[t.Dict[str, t.List[t.Any]]], session: Session): + passed, failed = [], [] + for scheme_credential in self.validate_security(security_schemes, session): + if all(scheme_credential.values()): + passed.append(scheme_credential) + else: + failed.append(scheme_credential) + if passed: + return passed + raise Forbidden() + + def validate_security( + self, security_schemes: t.Sequence[t.Dict[str, t.List[t.Any]]], session: Session + ): + if not security_schemes: + return + + for scheme in security_schemes: + scheme_credentials = {} + for name, _ in scheme.items(): + validator = self.validator_scheme_map[name] + credential = None + try: + credential = validator.authenticate(request, session) + except ValidatorError: + pass + except: + logger.exception(f"Unknown error while validating {name}") + raise + finally: + scheme_credentials[name] = credential + yield scheme_credentials + + +# def convert_model_result(func: t.Callable) -> t.Callable: +# @wraps(func) +# async def decorator(result: ResponseReturnValue) -> Response: +# status_or_headers = None +# headers = None +# if isinstance(result, tuple): +# value, status_or_headers, headers = result + (None,) * (3 - len(result)) +# else: +# value = result + +# was_model = False +# if is_dataclass(value): +# dict_or_value = asdict(value) +# was_model = True +# elif isinstance(value, BaseModel): +# dict_or_value = value.dict(by_alias=True) +# was_model = True +# else: +# dict_or_value = value + +# if was_model: +# dict_or_value = camelize(dict_or_value) + +# return await func((dict_or_value, status_or_headers, headers)) + +# return decorator + + +class QuartAuth: + authenticator = RequestAuthenticator() + + def __init__(self, app: t.Optional[Quart] = None): + if app is not None: + self.init_app(app) + + def init_app(self, app: Quart): + app.before_request(self.auth_endpoint_security) + + app.request_class = MyRequest + + self.security_schemes = app.config.get("QUART_AUTH_SECURITY_SCHEMES", {}) + app.cli.add_command(cli) + + def auth_endpoint_security(self): + db = current_app.extensions.get("sqlalchemy") + view_function = current_app.view_functions[request.endpoint] + security_schemes = getattr(view_function, QUART_SCHEMA_SECURITY_ATTRIBUTE, None) + if security_schemes is None: + g.authorized_credentials = {} + + with db.bind.Session() as session: + results = self.authenticator.enforce(security_schemes, session) + authorized_credentials = {} + for result in results: + authorized_credentials.update(result) + g.authorized_credentials = authorized_credentials + + +from .model import EntityType +from .model import MagicClient +from .model import Provenance + + +@cli.command("add-user") +@click.option( + "--email", + type=str, + default="default@none.com", + help="email", +) +@click.option( + "--user-type", + # type=click.Choice(list(EntityType.__members__)), + type=click.Choice(["FORTMATIC", "MAGIC", "CONNECT"]), + default="MAGIC", + help="user type", +) +@click.option( + "--client-id", + type=int, + required=True, + help="client id", +) +@pass_script_info +def add_user(info: ScriptInfo, email: str, user_type: str, client_id: int) -> None: + app = info.load_app() + db = app.extensions.get("sqlalchemy") + + with db.bind.Session() as s: + with s.begin(): + user = AuthUser( + email=email, + user_type=EntityType[user_type].value, + client_id=ObjectID(client_id), + provenance=Provenance.LINK, + current_session_token=secrets.token_hex(16), + ) + s.add(user) + s.flush() + s.refresh(user) + + click.echo(f"Created user {user.id} with session_token: {user.current_session_token}") + + +@cli.command("add-client") +@click.option( + "--name", + type=str, + default="My App", + help="app name", +) +@pass_script_info +def add_client(info: ScriptInfo, name: str) -> None: + app = info.load_app() + db = app.extensions.get("sqlalchemy") + with db.bind.Session() as s: + with s.begin(): + client = MagicClient(app_name=name, public_api_key=secrets.token_hex(16)) + s.add(client) + s.flush() + s.refresh(client) + + click.echo(f"Created client {client.id} with public_api_key: {client.public_api_key}") + + +auth = QuartAuth() diff --git a/src/quart_sqlalchemy/sim/db.py b/src/quart_sqlalchemy/sim/db.py new file mode 100644 index 0000000..e7677f6 --- /dev/null +++ b/src/quart_sqlalchemy/sim/db.py @@ -0,0 +1,100 @@ +import click +import sqlalchemy as sa +from quart import g +from quart import request +from quart.cli import AppGroup +from quart.cli import pass_script_info +from quart.cli import ScriptInfo +from sqlalchemy.types import Integer +from sqlalchemy.types import TypeDecorator + +from quart_sqlalchemy import Base +from quart_sqlalchemy import SQLAlchemyConfig +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.sim.util import ObjectID + + +cli = AppGroup("db-schema") + + +class ObjectIDType(TypeDecorator): + """A custom database column type that converts integer value to our ObjectID. + This allows us to pass around ObjectID type in the application for easy + frontend encoding and database decoding on the integer value. + + Note: all id db column type should use this type for its column. + """ + + impl = Integer + cache_ok = False + + def process_bind_param(self, value, dialect): + """Data going into to the database will be transformed by this method. + See ``ObjectID`` for the design and rational for this. + """ + if value is None: + return None + + return ObjectID(value).decode() + + def process_result_value(self, value, dialect): + """Data going out from the database will be explicitly casted to the + ``ObjectID``. + """ + if value is None: + return None + + return ObjectID(value) + + +class MyBase(Base): + type_annotation_map = {ObjectID: ObjectIDType} + + +class MyQuartSQLAlchemy(QuartSQLAlchemy): + def init_app(self, app): + super().init_app(app) + + @app.before_request + def set_bind(): + if request.method in ["GET", "OPTIONS", "HEAD", "TRACE"]: + g.bind = self.get_bind("read-replica") + else: + g.bind = self.get_bind("default") + + app.cli.add_command(cli) + + +@cli.command("load") +@pass_script_info +def schema_load(info: ScriptInfo) -> None: + app = info.load_app() + db = app.extensions.get("sqlalchemy") + db.create_all() + + click.echo(f"Initialized database schema for {db}") + + +# sqlite:///file:mem.db?mode=memory&cache=shared&uri=true +db = MyQuartSQLAlchemy( + SQLAlchemyConfig.parse_obj( + { + "model_class": MyBase, + "binds": { + "default": { + "engine": {"url": "sqlite:///file:sim.db?cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + "read-replica": { + "engine": {"url": "sqlite:///file:sim.db?cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + "read_only": True, + }, + "async": { + "engine": {"url": "sqlite+aiosqlite:///file:sim.db?cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + }, + } + ) +) diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py index c81b71a..e079da4 100644 --- a/src/quart_sqlalchemy/sim/handle.py +++ b/src/quart_sqlalchemy/sim/handle.py @@ -4,8 +4,6 @@ from sqlalchemy.orm import Session -from quart_sqlalchemy import Bind -from quart_sqlalchemy.sim import signals from quart_sqlalchemy.sim.logic import LogicComponent as Logic from quart_sqlalchemy.sim.model import AuthUser from quart_sqlalchemy.sim.model import AuthWallet @@ -69,6 +67,7 @@ def add( A ``MagicClient``. """ magic_clients_count = self.logic.MagicClientAPIUser.count_by_magic_api_user_id( + self.session, magic_api_user_id, ) @@ -83,7 +82,7 @@ def add( ) def get_by_public_api_key(self, public_api_key): - return self.logic.MagicClientAPIKey.get_by_public_api_key(public_api_key) + return self.logic.MagicClientAPIKey.get_by_public_api_key(self.session, public_api_key) def add_client( self, @@ -94,40 +93,43 @@ def add_client( ): live_api_key = APIKeySet(public_key="xxx", secret_key="yyy") - with self.logic.begin(ro=False) as session: - magic_client = self.logic.MagicClient._add( - session, - app_name=app_name, - ) - # self.logic.MagicClientAPIKey._add( - # session, - # magic_client.id, - # live_api_key_pair=live_api_key, - # ) - # self.logic.MagicClientAPIUser._add( - # session, - # magic_api_user_id, - # magic_client.id, - # ) - - # self.logic.MagicClientAuthMethods._add( - # session, - # magic_client_id=magic_client.id, - # is_magic_connect_enabled=is_magic_connect_enabled, - # is_metamask_wallet_enabled=(True if is_magic_connect_enabled else False), - # is_wallet_connect_enabled=(True if is_magic_connect_enabled else False), - # is_coinbase_wallet_enabled=(True if is_magic_connect_enabled else False), - # ) - - # self.logic.MagicClientTeam._add(session, magic_client.id, magic_team_id) - - return magic_client, live_api_key + # with self.logic.begin(ro=False) as session: + return self.logic.MagicClient._add( + self.session, + app_name=app_name, + ) + + # self.logic.MagicClientAPIKey._add( + # session, + # magic_client.id, + # live_api_key_pair=live_api_key, + # ) + # self.logic.MagicClientAPIUser._add( + # session, + # magic_api_user_id, + # magic_client.id, + # ) + + # self.logic.MagicClientAuthMethods._add( + # session, + # magic_client_id=magic_client.id, + # is_magic_connect_enabled=is_magic_connect_enabled, + # is_metamask_wallet_enabled=(True if is_magic_connect_enabled else False), + # is_wallet_connect_enabled=(True if is_magic_connect_enabled else False), + # is_coinbase_wallet_enabled=(True if is_magic_connect_enabled else False), + # ) + + # self.logic.MagicClientTeam._add(session, magic_client.id, magic_team_id) + + # return magic_client, live_api_key def get_magic_api_user_id_by_client_id(self, magic_client_id): - return self.logic.MagicClient.get_magic_api_user_id_by_client_id(magic_client_id) + return self.logic.MagicClient.get_magic_api_user_id_by_client_id( + self.session, magic_client_id + ) def get_by_id(self, magic_client_id): - return self.logic.MagicClient.get_by_id(magic_client_id) + return self.logic.MagicClient.get_by_id(self.session, magic_client_id) def update_app_name_by_id(self, magic_client_id, app_name): """ @@ -139,7 +141,9 @@ def update_app_name_by_id(self, magic_client_id, app_name): None if `magic_client_id` doesn't exist in the db app_name if update was successful """ - client = self.logic.MagicClient.update_by_id(magic_client_id, app_name=app_name) + client = self.logic.MagicClient.update_by_id( + self.session, magic_client_id, app_name=app_name + ) if not client: return None @@ -147,7 +151,7 @@ def update_app_name_by_id(self, magic_client_id, app_name): return client.app_name def update_by_id(self, magic_client_id, **kwargs): - client = self.logic.MagicClient.update_by_id(magic_client_id, **kwargs) + client = self.logic.MagicClient.update_by_id(self.session, magic_client_id, **kwargs) return client @@ -159,7 +163,7 @@ def set_inactive_by_id(self, magic_client_id): Returns: None """ - self.logic.MagicClient.update_by_id(magic_client_id, is_active=False) + self.logic.MagicClient.update_by_id(self.session, magic_client_id, is_active=False) def get_users_for_client( self, @@ -171,7 +175,7 @@ def get_users_for_client( """ Returns emails and signup timestamps for all auth users belonging to a given client """ - auth_user_handler = AuthUserHandler() + auth_user_handler = AuthUserHandler(session=self.session) product_type = get_product_type_by_client_id(magic_client_id) auth_users = auth_user_handler.get_by_client_id_and_user_type( magic_client_id, @@ -214,7 +218,7 @@ def get_users_for_client_v2( Returns emails, signup timestamps, provenance and MFA enablement for all auth users belonging to a given client. """ - auth_user_handler = AuthUserHandler() + auth_user_handler = AuthUserHandler(session=self.session) product_type = get_product_type_by_client_id(magic_client_id) auth_users = auth_user_handler.get_by_client_id_and_user_type( magic_client_id, @@ -260,7 +264,7 @@ def __init__(self, *args, auth_user_mfa_handler=None, **kwargs): # self.auth_user_mfa_handler = auth_user_mfa_handler or AuthUserMfaHandler() def get_by_session_token(self, session_token): - return self.logic.AuthUser.get_by_session_token(session_token) + return self.logic.AuthUser.get_by_session_token(self.session, session_token) def get_or_create_by_email_and_client_id( self, @@ -269,6 +273,7 @@ def get_or_create_by_email_and_client_id( user_type=EntityType.MAGIC.value, ): auth_user = self.logic.AuthUser.get_by_email_and_client_id( + self.session, email, client_id, user_type=user_type, @@ -292,6 +297,7 @@ def get_or_create_by_email_and_client_id( # raise EnhancedEmailValidation(error_message=str(e)) from e auth_user = self.logic.AuthUser.add_by_email_and_client_id( + self.session, client_id, email=email, user_type=user_type, @@ -300,7 +306,7 @@ def get_or_create_by_email_and_client_id( def get_by_id_and_validate_exists(self, auth_user_id): """This function helps formalize how a non-existent auth user should be handled.""" - auth_user = self.logic.AuthUser.get_by_id(auth_user_id) + auth_user = self.logic.AuthUser.get_by_id(self.session, auth_user_id) if auth_user is None: raise RuntimeError('resource_name="auth_user"') return auth_user @@ -314,18 +320,18 @@ def create_verified_user( email, user_type=EntityType.FORTMATIC.value, ): - with self.logic.begin(ro=False) as session: - auid = self.logic.AuthUser._add_by_email_and_client_id( - session, - client_id, - email, - user_type=user_type, - ).id - auth_user = self.logic.AuthUser._update_by_id( - session, - auid, - date_verified=datetime.utcnow(), - ) + # with self.logic.begin(ro=False) as session: + auid = self.logic.AuthUser._add_by_email_and_client_id( + self.session, + client_id, + email, + user_type=user_type, + ).id + auth_user = self.logic.AuthUser._update_by_id( + self.session, + auid, + date_verified=datetime.utcnow(), + ) return auth_user @@ -339,7 +345,7 @@ def create_verified_user( def get_by_id(self, auth_user_id, load_mfa_methods=False) -> AuthUser: # join_list = ["mfa_methods"] if load_mfa_methods else None - return self.logic.AuthUser.get_by_id(auth_user_id) + return self.logic.AuthUser.get_by_id(self.session, auth_user_id) def get_by_client_id_and_user_type( self, @@ -350,12 +356,14 @@ def get_by_client_id_and_user_type( ): if user_type == EntityType.CONNECT.value: return self.logic.AuthUser.get_by_client_id_for_connect( + self.session, client_id, offset=offset, limit=limit, ) else: return self.logic.AuthUser.get_by_client_id_and_user_type( + self.session, client_id, user_type, offset=offset, @@ -370,6 +378,7 @@ def get_by_client_ids_and_user_type( limit=None, ): return self.logic.AuthUser.get_by_client_ids_and_user_type( + self.session, client_ids, user_type, offset=offset, @@ -379,29 +388,33 @@ def get_by_client_ids_and_user_type( def get_user_count_by_client_id_and_user_type(self, client_id, user_type): if user_type == EntityType.CONNECT.value: return self.logic.AuthUser.get_user_count_by_client_id_for_connect( + self.session, client_id, ) else: return self.logic.AuthUser.get_user_count_by_client_id_and_user_type( + self.session, client_id, user_type, ) def exist_by_email_client_id_and_user_type(self, email, client_id, user_type): return self.logic.AuthUser.exist_by_email_and_client_id( + self.session, email, client_id, user_type=user_type, ) def update_email_by_id(self, model_id, email): - return self.logic.AuthUser.update_by_id(model_id, email=email) + return self.logic.AuthUser.update_by_id(self.session, model_id, email=email) def update_phone_number_by_id(self, model_id, phone_number): - return self.logic.AuthUser.update_by_id(model_id, phone_number=phone_number) + return self.logic.AuthUser.update_by_id(self.session, model_id, phone_number=phone_number) def get_by_email_client_id_and_user_type(self, email, client_id, user_type): return self.logic.AuthUser.get_by_email_and_client_id( + self.session, email, client_id, user_type, @@ -409,12 +422,14 @@ def get_by_email_client_id_and_user_type(self, email, client_id, user_type): def mark_date_verified_by_id(self, model_id): return self.logic.AuthUser.update_by_id( + self.session, model_id, date_verified=datetime.utcnow(), ) def set_role_by_email_magic_client_id(self, email, magic_client_id, role): auth_user = self.logic.AuthUser.get_by_email_and_client_id( + self.session, email, magic_client_id, EntityType.MAGIC.value, @@ -422,12 +437,13 @@ def set_role_by_email_magic_client_id(self, email, magic_client_id, role): if not auth_user: auth_user = self.logic.AuthUser.add_by_email_and_client_id( + self.session, magic_client_id, email, user_type=EntityType.MAGIC.value, ) - return self.logic.AuthUser.update_by_id(auth_user.id, **{role: True}) + return self.logic.AuthUser.update_by_id(self.session, auth_user.id, **{role: True}) def search_by_client_id_and_substring( self, @@ -443,6 +459,7 @@ def search_by_client_id_and_substring( raise InvalidSubstringError() auth_users = self.logic.AuthUser.get_by_client_id_with_substring_search( + self.session, client_id, substring, offset=offset, @@ -469,7 +486,7 @@ def is_magic_connect_enabled(self, auth_user_id=None, auth_user=None): return auth_user.user_type == EntityType.CONNECT.value def mark_as_inactive(self, auth_user_id): - self.logic.AuthUser.update_by_id(auth_user_id, is_active=False) + self.logic.AuthUser.update_by_id(self.session, auth_user_id, is_active=False) def get_by_email_and_wallet_type_for_interop(self, email, wallet_type, network): """ @@ -488,22 +505,22 @@ def get_magic_connect_auth_user(self, auth_user_id): return auth_user -@signals.auth_user_duplicate.connect -def handle_duplicate_auth_users( - current_app, - original_auth_user_id, - duplicate_auth_user_ids, - auth_user_handler: t.Optional[AuthUserHandler] = None, -) -> None: - logger.info(f"{len(duplicate_auth_user_ids)} dupe(s) found for {original_auth_user_id}") +# @signals.auth_user_duplicate.connect +# def handle_duplicate_auth_users( +# current_app, +# original_auth_user_id, +# duplicate_auth_user_ids, +# auth_user_handler: t.Optional[AuthUserHandler] = None, +# ) -> None: +# logger.info(f"{len(duplicate_auth_user_ids)} dupe(s) found for {original_auth_user_id}") - auth_user_handler = auth_user_handler or AuthUserHandler() +# auth_user_handler = auth_user_handler or AuthUserHandler() - for dupe_id in duplicate_auth_user_ids: - logger.info( - f"marking auth_user_id {dupe_id} as inactive, in favor of original {original_auth_user_id}", - ) - auth_user_handler.mark_as_inactive(dupe_id) +# for dupe_id in duplicate_auth_user_ids: +# logger.info( +# f"marking auth_user_id {dupe_id} as inactive, in favor of original {original_auth_user_id}", +# ) +# auth_user_handler.mark_as_inactive(dupe_id) class AuthWalletHandler(HandlerBase): @@ -515,10 +532,10 @@ def __init__(self, network, *args, wallet_type=WalletType.ETH, **kwargs): self.wallet_type = wallet_type def get_by_id(self, model_id): - return self.logic.AuthWallet.get_by_id(model_id) + return self.logic.AuthWallet.get_by_id(self.session, model_id) def get_by_public_address(self, public_address): - return self.logic.AuthWallet.get_by_public_address(public_address) + return self.logic.AuthWallet.get_by_public_address(self.session, public_address) def get_by_auth_user_id( self, @@ -528,6 +545,7 @@ def get_by_auth_user_id( **kwargs, ) -> t.List[AuthWallet]: auth_user = self.logic.AuthUser.get_by_id( + self.session, auth_user_id, join_list=["linked_primary_auth_user"], ) @@ -543,6 +561,7 @@ def get_by_auth_user_id( auth_user = auth_user.linked_primary_auth_user return self.logic.AuthWallet.get_by_auth_user_id( + self.session, auth_user.id, network=network, wallet_type=wallet_type, @@ -557,12 +576,14 @@ def sync_auth_wallet( wallet_management_type, ): existing_wallet = self.logic.AuthWallet.get_by_auth_user_id( + self.session, auth_user_id, ) if existing_wallet: raise RuntimeError("WalletExistsForNetworkAndWalletType") return self.logic.AuthWallet.add( + self.session, public_address, encrypted_private_address, self.wallet_type, @@ -570,18 +591,3 @@ def sync_auth_wallet( management_type=wallet_management_type, auth_user_id=auth_user_id, ) - - -class Handlers: - bind: Bind - - def __init__(self, bind: Bind): - self.bind = bind - - def __getattr__(self, name): - handlers = { - cls.__name__.replace("Handler", ""): cls for cls in HandlerBase.__subclasses__() - } - if name in handlers: - return handlers[name] - raise AttributeError() diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py index a13397d..ceb6ffd 100644 --- a/src/quart_sqlalchemy/sim/logic.py +++ b/src/quart_sqlalchemy/sim/logic.py @@ -86,28 +86,11 @@ def __getattr__(self, logic_name): ) -def with_db_session( - ro: t.Optional[bool] = None, - is_stale_tolerant: t.Optional[bool] = None, -): - """Stub decorator to ease transition with legacy code""" - - def wrapper(func): - @wraps(func) - def inner_wrapper(self, *args, **kwargs): - session = None - return func(self, None, *args, **kwargs) - - return inner_wrapper - - return wrapper - - class MagicClient(LogicComponent): - def __init__(self, session: Session): + def __init__(self): # self._repository = SQLAlchemyRepository[magic_client_model, ObjectID](session) - self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID, session) + self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID) def _add(self, session, app_name=None): return self._repository.add( @@ -115,9 +98,10 @@ def _add(self, session, app_name=None): app_name=app_name, ) - add = with_db_session(ro=False)(_add) + # add = with_db_session(ro=False)(_add) + add = _add - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_id( self, session, @@ -132,7 +116,7 @@ def get_by_id( join_list=join_list, ) - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_public_api_key( self, session, @@ -163,17 +147,17 @@ def get_by_public_api_key( # return client.magic_client_api_user.magic_api_user_id - @with_db_session(ro=False) + # @with_db_session(ro=False) def update_by_id(self, session, model_id, **update_params): modified_row = self._repository.update(session, model_id, **update_params) session.refresh(modified_row) return modified_row - @with_db_session(ro=True) + # @with_db_session(ro=True) def yield_all_clients_by_chunk(self, session, chunk_size): yield from self._repository.yield_by_chunk(session, chunk_size) - @with_db_session(ro=True) + # @with_db_session(ro=True) def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): yield from self._repository.yield_by_chunk( session, @@ -200,11 +184,11 @@ class MissingPhoneNumber(Exception): class AuthUser(LogicComponent): - def __init__(self, session: Session): + def __init__(self): # self._repository = SQLRepository(auth_user_model) - self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID, session) + self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID) - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_session_token( self, session, @@ -254,9 +238,10 @@ def _get_or_add_by_phone_number_and_client_id( return row - get_or_add_by_phone_number_and_client_id = with_db_session(ro=False)( - _get_or_add_by_phone_number_and_client_id, - ) + # get_or_add_by_phone_number_and_client_id = with_db_session(ro=False)( + # _get_or_add_by_phone_number_and_client_id, + # ) + get_or_add_by_phone_number_and_client_id = _get_or_add_by_phone_number_and_client_id def _add_by_email_and_client_id( self, @@ -299,7 +284,8 @@ def _add_by_email_and_client_id( return row - add_by_email_and_client_id = with_db_session(ro=False)(_add_by_email_and_client_id) + # add_by_email_and_client_id = with_db_session(ro=False)(_add_by_email_and_client_id) + add_by_email_and_client_id = _add_by_email_and_client_id def _add_by_client_id( self, @@ -327,7 +313,8 @@ def _add_by_client_id( return row - add_by_client_id = with_db_session(ro=False)(_add_by_client_id) + # add_by_client_id = with_db_session(ro=False)(_add_by_client_id) + add_by_client_id = _add_by_client_id def _get_by_active_identifier_and_client_id( self, @@ -364,7 +351,7 @@ def _get_by_active_identifier_and_client_id( return original - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_email_and_client_id( self, session, @@ -398,9 +385,10 @@ def _get_by_phone_number_and_client_id( user_type=user_type, ) - get_by_phone_number_and_client_id = with_db_session(ro=True)( - _get_by_phone_number_and_client_id, - ) + # get_by_phone_number_and_client_id = with_db_session(ro=True)( + # _get_by_phone_number_and_client_id, + # ) + get_by_phone_number_and_client_id = _get_by_phone_number_and_client_id def _exist_by_email_and_client_id( self, @@ -420,7 +408,8 @@ def _exist_by_email_and_client_id( ), ) - exist_by_email_and_client_id = with_db_session(ro=True)(_exist_by_email_and_client_id) + # exist_by_email_and_client_id = with_db_session(ro=True)(_exist_by_email_and_client_id) + exist_by_email_and_client_id = _exist_by_email_and_client_id def _get_by_id(self, session, model_id, join_list=None, for_update=False) -> auth_user_model: return self._repository.get_by_id( @@ -430,7 +419,8 @@ def _get_by_id(self, session, model_id, join_list=None, for_update=False) -> aut for_update=for_update, ) - get_by_id = with_db_session(ro=True)(_get_by_id) + get_by_id = _get_by_id + # get_by_id = with_db_session(ro=True)(_get_by_id) def _update_by_id(self, session, auth_user_id, **kwargs): modified_user = self._repository.update(session, auth_user_id, **kwargs) @@ -440,9 +430,10 @@ def _update_by_id(self, session, auth_user_id, **kwargs): return modified_user - update_by_id = with_db_session(ro=False)(_update_by_id) + # update_by_id = with_db_session(ro=False)(_update_by_id) + update_by_id = _update_by_id - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_user_count_by_client_id_and_user_type(self, session, client_id, user_type): query = ( session.query(auth_user_model) @@ -469,9 +460,10 @@ def _get_by_client_id_and_global_auth_user(self, session, client_id, global_auth ], ) - get_by_client_id_and_global_auth_user = with_db_session(ro=True)( - _get_by_client_id_and_global_auth_user, - ) + # get_by_client_id_and_global_auth_user = with_db_session(ro=True)( + # _get_by_client_id_and_global_auth_user, + # ) + get_by_client_id_and_global_auth_user = _get_by_client_id_and_global_auth_user # @with_db_session(ro=True) # def get_by_client_id_for_connect( @@ -542,7 +534,7 @@ def _get_by_client_id_and_global_auth_user(self, session, client_id, global_auth # return session.execute(query).scalar() - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_client_id_and_user_type( self, session, @@ -583,9 +575,10 @@ def _get_by_client_ids_and_user_type( order_by_clause=auth_user_model.id.desc(), ) - get_by_client_ids_and_user_type = with_db_session(ro=True)( - _get_by_client_ids_and_user_type, - ) + # get_by_client_ids_and_user_type = with_db_session(ro=True)( + # _get_by_client_ids_and_user_type, + # ) + get_by_client_ids_and_user_type = _get_by_client_ids_and_user_type def _get_by_client_id_with_substring_search( self, @@ -617,11 +610,12 @@ def _get_by_client_id_with_substring_search( join_list=join_list, ) - get_by_client_id_with_substring_search = with_db_session(ro=True)( - _get_by_client_id_with_substring_search, - ) + # get_by_client_id_with_substring_search = with_db_session(ro=True)( + # _get_by_client_id_with_substring_search, + # ) + get_by_client_id_with_substring_search = _get_by_client_id_with_substring_search - @with_db_session(ro=True) + # @with_db_session(ro=True) def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): yield from self._repository.yield_by_chunk( session, @@ -630,7 +624,7 @@ def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): join_list=join_list, ) - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_emails_and_client_id( self, session, @@ -663,12 +657,14 @@ def _get_by_email( join_list=join_list, ) - get_by_email = with_db_session(ro=True)(_get_by_email) + # get_by_email = with_db_session(ro=True)(_get_by_email) + get_by_email = _get_by_email def _add(self, session, **kwargs) -> ObjectID: return self._repository.add(session, **kwargs).id - add = with_db_session(ro=False)(_add) + # add = with_db_session(ro=False)(_add) + add = _add def _get_by_email_for_interop( self, @@ -676,7 +672,7 @@ def _get_by_email_for_interop( email: str, wallet_type: WalletType, network: str, - ) -> List[auth_user_model]: + ) -> t.List[auth_user_model]: """ Custom method for searching for users eligible for interop. Unfortunately, this can't be done with the current abstractions in our sql_repository, so this is a one-off bespoke method. @@ -720,9 +716,10 @@ def _get_by_email_for_interop( return query.all() - get_by_email_for_interop = with_db_session(ro=True)( - _get_by_email_for_interop, - ) + # get_by_email_for_interop = with_db_session(ro=True)( + # _get_by_email_for_interop, + # ) + get_by_email_for_interop = _get_by_email_for_interop def _get_linked_users(self, session, primary_auth_user_id, join_list, no_op=False): # TODO(magic-ravi#67899|2022-12-30): Re-enable account linked users for interop. Remove no_op flag. @@ -739,9 +736,10 @@ def _get_linked_users(self, session, primary_auth_user_id, join_list, no_op=Fals join_list=join_list, ) - get_linked_users = with_db_session(ro=True)(_get_linked_users) + # get_linked_users = with_db_session(ro=True)(_get_linked_users) + get_linked_users = _get_linked_users - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_phone_number(self, session, phone_number): return self._repository.get_by( session, @@ -752,9 +750,9 @@ def get_by_phone_number(self, session, phone_number): class AuthWallet(LogicComponent): - def __init__(self, session: Session): + def __init__(self): # self._repository = SQLAlchemyRepository[magic_client_model, ObjectID](session) - self._repository = RepositoryLegacyAdapter(auth_wallet_model, ObjectID, session) + self._repository = RepositoryLegacyAdapter(auth_wallet_model, ObjectID) def _add( self, @@ -778,9 +776,10 @@ def _add( return new_row - add = with_db_session(ro=False)(_add) + # add = with_db_session(ro=False)(_add) + add = _add - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_id(self, session, model_id, allow_inactive=False, join_list=None): return self._repository.get_by_id( session, @@ -789,7 +788,7 @@ def get_by_id(self, session, model_id, allow_inactive=False, join_list=None): join_list=join_list, ) - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_public_address(self, session, public_address, network=None, is_active=True): """Public address is unique in our system. In any case, we should only find one row for the given public address. @@ -819,7 +818,7 @@ def get_by_public_address(self, session, public_address, network=None, is_active return one(row) - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_auth_user_id( self, session, @@ -866,4 +865,5 @@ def get_by_auth_user_id( def _update_by_id(self, session, model_id, **kwargs): self._repository.update(session, model_id, **kwargs) - update_by_id = with_db_session(ro=False)(_update_by_id) + # update_by_id = with_db_session(ro=False)(_update_by_id) + update_by_id = _update_by_id diff --git a/src/quart_sqlalchemy/sim/main.py b/src/quart_sqlalchemy/sim/main.py index 2370b4f..23d1e0c 100644 --- a/src/quart_sqlalchemy/sim/main.py +++ b/src/quart_sqlalchemy/sim/main.py @@ -1,8 +1,7 @@ -from quart_sqlalchemy.sim.app import app -from quart_sqlalchemy.sim.views import api +from quart_sqlalchemy.sim.app import create_app -app.register_blueprint(api, url_prefix="/v1") +app = create_app() if __name__ == "__main__": diff --git a/src/quart_sqlalchemy/sim/model.py b/src/quart_sqlalchemy/sim/model.py index 7a6566b..b0199ab 100644 --- a/src/quart_sqlalchemy/sim/model.py +++ b/src/quart_sqlalchemy/sim/model.py @@ -11,7 +11,7 @@ from quart_sqlalchemy.model import SoftDeleteMixin from quart_sqlalchemy.model import TimestampMixin -from quart_sqlalchemy.sim.app import db +from quart_sqlalchemy.sim.db import db from quart_sqlalchemy.sim.util import ObjectID @@ -100,8 +100,10 @@ class AuthUser(db.Model, SoftDeleteMixin, TimestampMixin): date_verified: Mapped[t.Optional[datetime]] provenance: Mapped[t.Optional[Provenance]] is_admin: Mapped[bool] = sa.orm.mapped_column(default=False) - client_id: Mapped[ObjectID] - linked_primary_auth_user_id: Mapped[t.Optional[ObjectID]] + client_id: Mapped[ObjectID] = sa.orm.mapped_column(sa.ForeignKey("magic_client.id")) + linked_primary_auth_user_id: Mapped[t.Optional[ObjectID]] = sa.orm.mapped_column( + sa.ForeignKey("auth_user.id"), default=None + ) global_auth_user_id: Mapped[t.Optional[ObjectID]] delegated_user_id: Mapped[t.Optional[str]] @@ -110,7 +112,7 @@ class AuthUser(db.Model, SoftDeleteMixin, TimestampMixin): current_session_token: Mapped[t.Optional[str]] magic_client: Mapped[MagicClient] = sa.orm.relationship( - back_populates="auth_user", + back_populates="auth_users", uselist=False, ) linked_primary_auth_user = sa.orm.relationship( @@ -158,6 +160,6 @@ class AuthWallet(db.Model, SoftDeleteMixin, TimestampMixin): is_exported: Mapped[bool] = sa.orm.mapped_column(default=False) auth_user: Mapped[AuthUser] = sa.orm.relationship( - back_populates="auth_wallets", + back_populates="wallets", uselist=False, ) diff --git a/src/quart_sqlalchemy/sim/repo.py b/src/quart_sqlalchemy/sim/repo.py index 9e41c70..5844b9d 100644 --- a/src/quart_sqlalchemy/sim/repo.py +++ b/src/quart_sqlalchemy/sim/repo.py @@ -2,7 +2,6 @@ import typing as t from abc import ABCMeta -from abc import abstractmethod import sqlalchemy import sqlalchemy.event @@ -19,6 +18,9 @@ from quart_sqlalchemy.types import SessionT +# from abc import abstractmethod + + sa = sqlalchemy @@ -81,47 +83,54 @@ class SQLAlchemyRepository( """ - session: sa.orm.Session + # session: sa.orm.Session builder: StatementBuilder - def __init__(self, session: sa.orm.Session, **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) - self.session = session + # self.session = session self.builder = StatementBuilder(None) - def insert(self, values: t.Dict[str, t.Any]) -> EntityT: + def insert(self, session: sa.orm.Session, values: t.Dict[str, t.Any]) -> EntityT: """Insert a new model into the database.""" new = self.model(**values) - self.session.add(new) - self.session.flush() - self.session.refresh(new) + session.add(new) + session.flush() + session.refresh(new) return new - def update(self, id_: EntityIdT, values: t.Dict[str, t.Any]) -> EntityT: + def update( + self, session: sa.orm.Session, id_: EntityIdT, values: t.Dict[str, t.Any] + ) -> EntityT: """Update existing model with new values.""" - obj = self.session.get(self.model, id_) + obj = session.get(self.model, id_) if obj is None: raise ValueError(f"Object with id {id_} not found") for field, value in values.items(): if getattr(obj, field) != value: setattr(obj, field, value) - self.session.flush() - self.session.refresh(obj) + session.flush() + session.refresh(obj) return obj def merge( - self, id_: EntityIdT, values: t.Dict[str, t.Any], for_update: bool = False + self, + session: sa.orm.Session, + id_: EntityIdT, + values: t.Dict[str, t.Any], + for_update: bool = False, ) -> EntityT: """Merge model in session/db having id_ with values.""" - self.session.get(self.model, id_) + session.get(self.model, id_) values.update(id=id_) - merged = self.session.merge(self.model(**values)) - self.session.flush() - self.session.refresh(merged, with_for_update=for_update) # type: ignore + merged = session.merge(self.model(**values)) + session.flush() + session.refresh(merged, with_for_update=for_update) # type: ignore return merged def get( self, + session: sa.orm.Session, id_: EntityIdT, options: t.Sequence[ORMOption] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -137,7 +146,7 @@ def get( present it will be returned directly, when not, a database lookup will be performed. For use cases where this is what you actually want, you can still access the original get - method on self.session. For most uses cases, this behavior can introduce non-determinism + method on session. For most uses cases, this behavior can introduce non-determinism and because of that this method performs lookup using a select statement. Additionally, to satisfy the expected interface's return type: Optional[EntityT], one_or_none is called on the result before returning. @@ -154,10 +163,11 @@ def get( if for_update: statement = statement.with_for_update() - return self.session.scalars(statement, execution_options=execution_options).one_or_none() + return session.scalars(statement, execution_options=execution_options).one_or_none() def select( self, + session: sa.orm.Session, selectables: t.Sequence[Selectable] = (), conditions: t.Sequence[ColumnExpr] = (), group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), @@ -196,12 +206,14 @@ def select( for_update=for_update, ) - results = self.session.scalars(statement) + results = session.scalars(statement) if yield_by_chunk: results = results.partitions() return results - def delete(self, id_: EntityIdT, include_inactive: bool = False) -> None: + def delete( + self, session: sa.orm.Session, id_: EntityIdT, include_inactive: bool = False + ) -> None: # if self.has_soft_delete: # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") @@ -209,16 +221,16 @@ def delete(self, id_: EntityIdT, include_inactive: bool = False) -> None: if not entity: raise RuntimeError(f"Entity with id {id_} not found.") - self.session.delete(entity) - self.session.flush() + session.delete(entity) + session.flush() - def deactivate(self, id_: EntityIdT) -> EntityT: + def deactivate(self, session: sa.orm.Session, id_: EntityIdT) -> EntityT: # if not self.has_soft_delete: # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") return self.update(id_, dict(is_active=False)) - def reactivate(self, id_: EntityIdT) -> EntityT: + def reactivate(self, session: sa.orm.Session, id_: EntityIdT) -> EntityT: # if not self.has_soft_delete: # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") @@ -226,6 +238,7 @@ def reactivate(self, id_: EntityIdT) -> EntityT: def exists( self, + session: sa.orm.Session, conditions: t.Sequence[ColumnExpr] = (), for_update: bool = False, include_inactive: bool = False, @@ -246,38 +259,41 @@ def exists( if for_update: statement = statement.with_for_update() - result = self.session.execute(statement, execution_options=execution_options).scalar() + result = session.execute(statement, execution_options=execution_options).scalar() return bool(result) class SQLAlchemyBulkRepository(AbstractBulkRepository, t.Generic[SessionT, EntityT, EntityIdT]): - def __init__(self, session: SessionT, **kwargs: t.Any): + def __init__(self, **kwargs: t.Any): super().__init__(**kwargs) self.builder = StatementBuilder(self.model) - self.session = session + # session = session def bulk_insert( self, + session: sa.orm.Session, values: t.Sequence[t.Dict[str, t.Any]] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: statement = self.builder.bulk_insert(self.model, values) - return self.session.execute(statement, execution_options=execution_options or {}) + return session.execute(statement, execution_options=execution_options or {}) def bulk_update( self, + session: sa.orm.Session, conditions: t.Sequence[ColumnExpr] = (), values: t.Optional[t.Dict[str, t.Any]] = None, execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: statement = self.builder.bulk_update(self.model, conditions, values) - return self.session.execute(statement, execution_options=execution_options or {}) + return session.execute(statement, execution_options=execution_options or {}) def bulk_delete( self, + session: sa.orm.Session, conditions: t.Sequence[ColumnExpr] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: statement = self.builder.bulk_delete(self.model, conditions) - return self.session.execute(statement, execution_options=execution_options or {}) + return session.execute(statement, execution_options=execution_options or {}) diff --git a/src/quart_sqlalchemy/sim/repo_adapter.py b/src/quart_sqlalchemy/sim/repo_adapter.py index 89230b0..3d6b48e 100644 --- a/src/quart_sqlalchemy/sim/repo_adapter.py +++ b/src/quart_sqlalchemy/sim/repo_adapter.py @@ -1,18 +1,24 @@ import typing as t +import sqlalchemy +import sqlalchemy.orm from pydantic import BaseModel from sqlalchemy import ScalarResult -from sqlalchemy.orm import selectinload, Session +from sqlalchemy.orm import selectinload +from sqlalchemy.orm import Session from sqlalchemy.sql.expression import func from sqlalchemy.sql.expression import label + from quart_sqlalchemy.model import Base +from quart_sqlalchemy.sim.repo import SQLAlchemyRepository from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable -from quart_sqlalchemy.sim.repo import SQLAlchemyRepository + +sa = sqlalchemy class BaseModelSchema(BaseModel): @@ -34,11 +40,16 @@ class BaseUpdateSchema(BaseModelSchema): class RepositoryLegacyAdapter(t.Generic[EntityT, EntityIdT]): - def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT], session: Session,): + def __init__( + self, + model: t.Type[EntityT], + identity: t.Type[EntityIdT], + # session: Session, + ): self.model = model self._identity = identity - self._session = session - self.repo = SQLAlchemyRepository[model, identity](session) + # self._session = session + self.repo = SQLAlchemyRepository[model, identity]() def get_by( self, @@ -62,6 +73,7 @@ def get_by( order_by_clause = () return self.repo.select( + session, conditions=filters, options=[selectinload(getattr(self.model, attr)) for attr in join_list], for_update=for_update, @@ -73,8 +85,8 @@ def get_by( def get_by_id( self, - session = None, - model_id = None, + session=None, + model_id=None, allow_inactive=False, join_list=None, for_update=False, @@ -83,16 +95,20 @@ def get_by_id( raise ValueError("model_id is required") join_list = join_list or () return self.repo.get( + session, id_=model_id, options=[selectinload(getattr(self.model, attr)) for attr in join_list], for_update=for_update, include_inactive=allow_inactive, ) - def one(self, session = None, filters=None, join_list=None, for_update=False, include_inactive=False) -> EntityT: + def one( + self, session=None, filters=None, join_list=None, for_update=False, include_inactive=False + ) -> EntityT: filters = filters or () join_list = join_list or () return self.repo.select( + session, conditions=filters, options=[selectinload(getattr(self.model, attr)) for attr in join_list], for_update=for_update, @@ -101,7 +117,7 @@ def one(self, session = None, filters=None, join_list=None, for_update=False, in def count_by( self, - session = None, + session=None, filters=None, group_by=None, distinct_column=None, @@ -119,29 +135,29 @@ def count_by( for group in group_by: selectables.append(group.expression) - result = self.repo.select(selectables, conditions=filters, group_by=group_by) + result = self.repo.select(session, selectables, conditions=filters, group_by=group_by) return result.all() - def add(self, session = None, **kwargs) -> EntityT: - return self.repo.insert(kwargs) + def add(self, session=None, **kwargs) -> EntityT: + return self.repo.insert(session, kwargs) - def update(self, session = None, model_id, **kwargs) -> EntityT: - return self.repo.update(id_=model_id, values=kwargs) + def update(self, session=None, model_id=None, **kwargs) -> EntityT: + return self.repo.update(session, id_=model_id, values=kwargs) - def update_by(self, session = None, filters=None, **kwargs) -> EntityT: + def update_by(self, session=None, filters=None, **kwargs) -> EntityT: if not filters: raise ValueError("Full table scans are prohibited. Please provide filters") - row = self.repo.select(conditions=filters, limit=2).one() - return self.repo.update(id_=row.id, values=kwargs) + row = self.repo.select(session, conditions=filters, limit=2).one() + return self.repo.update(session, id_=row.id, values=kwargs) - def delete_by_id(self, session = None, model_id) -> None: - self.repo.delete(id_=model_id, include_inactive=True) + def delete_by_id(self, session=None, model_id=None) -> None: + self.repo.delete(session, id_=model_id, include_inactive=True) - def delete_one_by(self, session = None, filters=None, optional=False) -> None: + def delete_one_by(self, session=None, filters=None, optional=False) -> None: filters = filters or () - result = self.repo.select(conditions=filters, limit=1) + result = self.repo.select(session, conditions=filters, limit=1) if optional: row = result.one_or_none() @@ -150,21 +166,23 @@ def delete_one_by(self, session = None, filters=None, optional=False) -> None: else: row = result.one() - self.repo.delete(id_=row.id) + self.repo.delete(session, id_=row.id) - def exist(self, session = None, filters=None, allow_inactive=False) -> bool: + def exist(self, session=None, filters=None, allow_inactive=False) -> bool: filters = filters or () return self.repo.exists( + session, conditions=filters, include_inactive=allow_inactive, ) def yield_by_chunk( - self, session = None, chunk_size = 100, join_list=None, filters=None, allow_inactive=False + self, session=None, chunk_size=100, join_list=None, filters=None, allow_inactive=False ): filters = filters or () join_list = join_list or () results = self.repo.select( + session, conditions=filters, options=[selectinload(getattr(self.model, attr)) for attr in join_list], include_inactive=allow_inactive, @@ -220,11 +238,12 @@ def schema(self) -> t.Type[ModelSchemaT]: def insert( self, + session: sa.orm.Session, create_schema: CreateSchemaT, sqla_model=False, ): create_data = create_schema.dict() - result = super().insert(create_data) + result = super().insert(session, create_data) if sqla_model: return result @@ -232,11 +251,12 @@ def insert( def update( self, + session: sa.orm.Session, id_: EntityIdT, update_schema: UpdateSchemaT, sqla_model=False, ): - existing = self.session.query(self.model).get(id_) + existing = session.query(self.model).get(id_) if existing is None: raise ValueError("Model not found") @@ -244,15 +264,16 @@ def update( for key, value in update_data.items(): setattr(existing, key, value) - self.session.add(existing) - self.session.flush() - self.session.refresh(existing) + session.add(existing) + session.flush() + session.refresh(existing) if sqla_model: return existing return self.schema.from_orm(existing) def get( self, + session: sa.orm.Session, id_: EntityIdT, options: t.Sequence[ORMOption] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -261,6 +282,7 @@ def get( sqla_model: bool = False, ): row = super().get( + session, id_, options, execution_options, @@ -276,6 +298,7 @@ def get( def select( self, + session: sa.orm.Session, selectables: t.Sequence[Selectable] = (), conditions: t.Sequence[ColumnExpr] = (), group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), @@ -291,6 +314,7 @@ def select( sqla_model: bool = False, ): result = super().select( + session, selectables, conditions, group_by, diff --git a/src/quart_sqlalchemy/sim/schema.py b/src/quart_sqlalchemy/sim/schema.py index f2311ca..e0304b7 100644 --- a/src/quart_sqlalchemy/sim/schema.py +++ b/src/quart_sqlalchemy/sim/schema.py @@ -1,8 +1,11 @@ +import typing as t from datetime import datetime from pydantic import BaseModel +from pydantic import Field +from pydantic import validator -from quart_sqlalchemy.sim.util import ObjectID +from .util import ObjectID class BaseSchema(BaseModel): @@ -12,3 +15,29 @@ class Config: ObjectID: lambda v: v.encode(), datetime: lambda dt: int(dt.timestamp()), } + + @classmethod + def _get_value(cls, v: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: + if hasattr(v, "__serialize__"): + return v.__serialize__() + for type_, converter in cls.__config__.json_encoders.items(): + if isinstance(v, type_): + return converter(v) + + return super()._get_value(v, *args, **kwargs) + + +class ResponseWrapper(BaseSchema): + """Generic response wrapper""" + + error_code: str = "" + status: str = "" + message: str = "" + data: t.Any = Field(default_factory=dict) + + @validator("status") + def set_status_by_error_code(cls, v, values): + error_code = values.get("error_code") + if error_code: + return "failed" + return "ok" diff --git a/src/quart_sqlalchemy/sim/testing.py b/src/quart_sqlalchemy/sim/testing.py new file mode 100644 index 0000000..347b422 --- /dev/null +++ b/src/quart_sqlalchemy/sim/testing.py @@ -0,0 +1,13 @@ +from contextlib import contextmanager + +from quart import g +from quart import signals + + +@contextmanager +def user_set(app, user): + def handler(sender, **kwargs): + g.user = user + + with signals.appcontext_pushed.connected_to(handler, app): + yield diff --git a/src/quart_sqlalchemy/sim/views/__init__.py b/src/quart_sqlalchemy/sim/views/__init__.py index d110bc2..656f265 100644 --- a/src/quart_sqlalchemy/sim/views/__init__.py +++ b/src/quart_sqlalchemy/sim/views/__init__.py @@ -6,7 +6,7 @@ from .magic_client import api as magic_client_api -api = Blueprint("api", __name__, url_prefix="api") +api = Blueprint("api", __name__, url_prefix="/api") api.register_blueprint(auth_user_api) api.register_blueprint(auth_wallet_api) @@ -15,4 +15,4 @@ @api.before_request def set_feature_owner(): - g.request_feature_owner = "magic" + g.request_feature_owner = "auth-team" diff --git a/src/quart_sqlalchemy/sim/views/auth_wallet.py b/src/quart_sqlalchemy/sim/views/auth_wallet.py index 24fc9fc..9c6f052 100644 --- a/src/quart_sqlalchemy/sim/views/auth_wallet.py +++ b/src/quart_sqlalchemy/sim/views/auth_wallet.py @@ -1,25 +1,28 @@ import logging import typing as t -from quart import Blueprint from quart import g from quart.utils import run_sync -from quart_schema.validation import validate +from quart_sqlalchemy.retry import retry_context +from quart_sqlalchemy.retry import RetryError + +from ..auth import authorized_request +from ..handle import AuthWalletHandler from ..model import WalletManagementType from ..model import WalletType from ..schema import BaseSchema from ..util import ObjectID -from .decorator import authorized_request +from .util import APIBlueprint logger = logging.getLogger(__name__) -api = Blueprint("auth_wallet", __name__, url_prefix="auth_wallet") +api = APIBlueprint("auth_wallet", __name__, url_prefix="auth_wallet") @api.before_request def set_feature_owner(): - g.request_feature_owner = "wallet" + g.request_feature_owner = "wallet-team" class WalletSyncRequest(BaseSchema): @@ -38,18 +41,36 @@ class WalletSyncResponse(BaseSchema): encrypted_private_address: str -@authorized_request(authenticate_client=True, authenticate_user=True) -@validate(request=WalletSyncRequest, responses={200: (WalletSyncResponse, None)}) -@api.route("/sync", methods=["POST"]) -async def sync_auth_user_wallet(data: WalletSyncRequest): +@api.post( + "/sync", + authorizer=authorized_request( + [ + # We use the OpenAPI security scheme metadata to know which kind of authorization to enforce. + # + # Together in the same dict implies logical AND requirement so both public-api-key and + # session-token will be enforced + { + "public-api-key": [], + "session-token-bearer": [], + } + ], + ), +) +async def sync(data: WalletSyncRequest) -> WalletSyncResponse: + user_credential = g.authorized_credentials.get("session-token-bearer") + try: - with g.bind.Session() as session: - wallet = await run_sync(g.h.AuthWallet(session).sync_auth_wallet)( - g.auth.user.id, - data.public_address, - data.encrypted_private_address, - WalletManagementType.DELEGATED.value, - ) + for attempt in retry_context: + with attempt: + with g.bind.Session() as session: + wallet = AuthWalletHandler(g.network, session).sync_auth_wallet( + user_credential.subject.id, + data.public_address, + data.encrypted_private_address, + WalletManagementType.DELEGATED.value, + ) + except RetryError: + pass except RuntimeError: raise RuntimeError("Unsupported wallet type or network") diff --git a/src/quart_sqlalchemy/sim/views/decorator.py b/src/quart_sqlalchemy/sim/views/decorator.py deleted file mode 100644 index 357a819..0000000 --- a/src/quart_sqlalchemy/sim/views/decorator.py +++ /dev/null @@ -1,131 +0,0 @@ -# import inspect -# import typing as t -# from dataclasses import asdict -# from dataclasses import is_dataclass -from functools import wraps - -# from pydantic import BaseModel -# from pydantic import ValidationError -# from pydantic.dataclasses import dataclass as pydantic_dataclass -# from pydantic.schema import model_schema -from quart import current_app -from quart import g - - -# from quart import request -# from quart import ResponseReturnValue as QuartResponseReturnValue -# from quart_schema.typing import Model -# from quart_schema.typing import PydanticModel -# from quart_schema.typing import ResponseReturnValue -# from quart_schema.validation import _convert_headers -# from quart_schema.validation import DataSource -# from quart_schema.validation import QUART_SCHEMA_RESPONSE_ATTRIBUTE -# from quart_schema.validation import ResponseHeadersValidationError -# from quart_schema.validation import ResponseSchemaValidationError -# from quart_schema.validation import validate_headers -# from quart_schema.validation import validate_querystring -# from quart_schema.validation import validate_request - - -def authorized_request(authenticate_client: bool = False, authenticate_user: bool = False): - def decorator(func): - @wraps(func) - async def wrapper(*args, **kwargs): - if authenticate_client: - if not g.auth.client: - raise RuntimeError("Unable to authenticate client") - kwargs.update(client_id=g.auth.client.id) - - if authenticate_user: - if not g.auth.user: - raise RuntimeError("Unable to authenticate user") - kwargs.update(user_id=g.auth.user.id) - - return await current_app.ensure_async(func)(*args, **kwargs) - - return wrapper - - return decorator - - -# def validate_response() -> t.Callable: -# def decorator( -# func: t.Callable[..., ResponseReturnValue] -# ) -> t.Callable[..., QuartResponseReturnValue]: -# undecorated = func -# while hasattr(undecorated, "__wrapped__"): -# undecorated = undecorated.__wrapped__ - -# signature = inspect.signature(undecorated) -# derived_schema = signature.return_annotation or dict - -# schemas = getattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, {}) -# schemas[200] = (derived_schema, None) -# setattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, schemas) - -# @wraps(func) -# async def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: -# result = await current_app.ensure_async(func)(*args, **kwargs) - -# status_or_headers = None -# headers = None -# if isinstance(result, tuple): -# value, status_or_headers, headers = result + (None,) * (3 - len(result)) -# else: -# value = result - -# status = 200 -# if isinstance(status_or_headers, int): -# status = int(status_or_headers) - -# schemas = getattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, {200: dict}) -# model_class = schemas.get(status, dict) - -# try: -# if isinstance(value, dict): -# model_value = model_class(**value) -# elif type(value) == model_class: -# model_value = value -# elif is_dataclass(value): -# model_value = model_class(**asdict(value)) -# else: -# return result, status, headers - -# except ValidationError as error: -# raise ResponseHeadersValidationError(error) - -# headers_value = headers -# return model_value, status, headers_value - -# return wrapper - -# return decorator - - -# def validate( -# *, -# querystring: t.Optional[Model] = None, -# request: t.Optional[Model] = None, -# request_source: DataSource = DataSource.JSON, -# headers: t.Optional[Model] = None, -# responses: t.Dict[int, t.Tuple[Model, t.Optional[Model]]], -# ) -> t.Callable: -# """Validate the route. - -# This is a shorthand combination of of the validate_querystring, -# validate_request, validate_headers, and validate_response -# decorators. Please see the docstrings for those decorators. -# """ - -# def decorator(func: t.Callable) -> t.Callable: -# if querystring is not None: -# func = validate_querystring(querystring)(func) -# if request is not None: -# func = validate_request(request, source=request_source)(func) -# if headers is not None: -# func = validate_headers(headers)(func) -# for status, models in responses.items(): -# func = validate_response(models[0], status, models[1]) -# return func - -# return decorator diff --git a/src/quart_sqlalchemy/sim/views/util/__init__.py b/src/quart_sqlalchemy/sim/views/util/__init__.py new file mode 100644 index 0000000..e677b31 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/util/__init__.py @@ -0,0 +1,12 @@ +from .blueprint import APIBlueprint +from .decorator import inject_request +from .decorator import validate_request +from .decorator import validate_response + + +__all__ = ( + "APIBlueprint", + "inject_request", + "validate_request", + "validate_response", +) diff --git a/src/quart_sqlalchemy/sim/views/util/blueprint.py b/src/quart_sqlalchemy/sim/views/util/blueprint.py new file mode 100644 index 0000000..35927ec --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/util/blueprint.py @@ -0,0 +1,102 @@ +import inspect +import typing as t + +from quart import Blueprint +from quart import Request +from quart_schema.validation import validate_headers +from quart_schema.validation import validate_querystring + +from ...schema import BaseSchema +from .decorator import inject_request +from .decorator import validate_request +from .decorator import validate_response + + +class APIBlueprint(Blueprint): + def _endpoint( + self, + uri: str, + methods: t.Optional[t.Sequence[str]] = ("GET",), + authorizer: t.Optional[t.Callable] = None, + **route_kwargs, + ): + def decorator(func): + sig = inspect.signature(func) + + param_annotation_map = { + name: param.annotation for name, param in sig.parameters.items() + } + has_request_schema = "data" in sig.parameters and issubclass( + param_annotation_map["data"], + BaseSchema, + ) + has_query_schema = "query_args" in sig.parameters and issubclass( + param_annotation_map["query_args"], + BaseSchema, + ) + has_headers_schema = "headers" in sig.parameters and issubclass( + param_annotation_map["headers"], + BaseSchema, + ) + + has_response_schema = isinstance(sig.return_annotation, BaseSchema) + + should_inject_request, request_param_name = False, None + for name in param_annotation_map: + if isinstance(param_annotation_map[name], Request): + should_inject_request, request_param_name = True, name + break + + decorated = func + + if should_inject_request: + decorated = inject_request(request_param_name)(decorated) + + if has_query_schema: + decorated = validate_querystring(param_annotation_map["query_args"])(decorated) + + if has_headers_schema: + decorated = validate_headers(param_annotation_map["headers"])(decorated) + + if has_request_schema: + decorated = validate_request(param_annotation_map["data"])(decorated) + + if has_response_schema: + decorated = validate_response(sig.return_annotation)(decorated) + + if authorizer: + decorated = authorizer(decorated) + + return self.route(uri, t.cast(t.List[str], methods), **route_kwargs)(decorated) + + return decorator + + def get(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["GET"], **kwargs) + + def post(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["POST"], **kwargs) + + def put(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["PUT"], **kwargs) + + def patch(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["PATCH"], **kwargs) + + def delete(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["DELETE"], **kwargs) diff --git a/src/quart_sqlalchemy/sim/views/util/decorator.py b/src/quart_sqlalchemy/sim/views/util/decorator.py new file mode 100644 index 0000000..55323d2 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/util/decorator.py @@ -0,0 +1,120 @@ +import typing as t +from dataclasses import asdict +from dataclasses import is_dataclass +from functools import wraps + +from humps import camelize +from humps import decamelize +from pydantic import BaseModel +from pydantic import ValidationError +from quart import current_app +from quart import request +from quart import Response +from quart_schema.typing import Model +from quart_schema.typing import ResponseReturnValue +from quart_schema.validation import QUART_SCHEMA_REQUEST_ATTRIBUTE +from quart_schema.validation import QUART_SCHEMA_RESPONSE_ATTRIBUTE +from quart_schema.validation import RequestSchemaValidationError +from quart_schema.validation import ResponseSchemaValidationError + + +def convert_model_result(func: t.Callable) -> t.Callable: + @wraps(func) + async def decorator(result: ResponseReturnValue) -> Response: + status_or_headers = None + headers = None + if isinstance(result, tuple): + value, status_or_headers, headers = result + (None,) * (3 - len(result)) + else: + value = result + + was_model = False + if is_dataclass(value): + dict_or_value = asdict(value) + was_model = True + elif isinstance(value, BaseModel): + dict_or_value = value.dict(by_alias=True) + was_model = True + else: + dict_or_value = value + + if was_model: + dict_or_value = camelize(dict_or_value) + + return await func((dict_or_value, status_or_headers, headers)) + + return decorator + + +def validate_request(model_class: Model) -> t.Callable: + def decorator(func: t.Callable) -> t.Callable: + setattr(func, QUART_SCHEMA_REQUEST_ATTRIBUTE, (model_class, None)) + + @wraps(func) + async def wrapper(*args, **kwargs): + data = await request.get_json() + data = decamelize(data) + + try: + model = model_class(**data) + except (TypeError, ValidationError) as error: + raise RequestSchemaValidationError(error) + else: + return await current_app.ensure_async(func)(*args, data=model, **kwargs) + + return wrapper + + return decorator + + +def validate_response(model_class: Model, status_code: int = 200) -> t.Callable: + def decorator(func): + schemas = getattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, {}) + schemas[status_code] = (model_class, None) + setattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, schemas) + + @wraps(func) + async def wrapper(*args, **kwargs): + result = await current_app.ensure_async(func)(*args, **kwargs) + + status_or_headers = None + headers = None + if isinstance(result, tuple): + value, status_or_headers, headers = result + (None,) * (3 - len(result)) + else: + value = result + + status = 200 + if isinstance(status_or_headers, int): + status = int(status_or_headers) + + if status == status_code: + try: + if isinstance(value, dict): + model_value = model_class(**value) + elif type(value) == model_class: + model_value = value + else: + raise ResponseSchemaValidationError() + except ValidationError as error: + raise ResponseSchemaValidationError(error) + + return model_value, status, headers + else: + return result + + return wrapper + + return decorator + + +def inject_request(key: str): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + kwargs[key] = request._get_current_object() + return await current_app.ensure_async(func)(*args, **kwargs) + + return wrapper + + return decorator From bc9f565699f2cc746b4e18e8da91c9973a289591 Mon Sep 17 00:00:00 2001 From: Joe Black Date: Fri, 31 Mar 2023 19:06:50 -0400 Subject: [PATCH 5/9] fix --- src/quart_sqlalchemy/sim/auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/quart_sqlalchemy/sim/auth.py b/src/quart_sqlalchemy/sim/auth.py index 51c3962..80a908a 100644 --- a/src/quart_sqlalchemy/sim/auth.py +++ b/src/quart_sqlalchemy/sim/auth.py @@ -274,12 +274,12 @@ def auth_endpoint_security(self): ) @click.option( "--client-id", - type=int, + type=str, required=True, help="client id", ) @pass_script_info -def add_user(info: ScriptInfo, email: str, user_type: str, client_id: int) -> None: +def add_user(info: ScriptInfo, email: str, user_type: str, client_id: str) -> None: app = info.load_app() db = app.extensions.get("sqlalchemy") From 51ccb5cd412d7bf97dd13fdf6a83de965f1ccc1e Mon Sep 17 00:00:00 2001 From: Joe Black Date: Fri, 31 Mar 2023 19:14:40 -0400 Subject: [PATCH 6/9] fix --- src/quart_sqlalchemy/sim/util.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/quart_sqlalchemy/sim/util.py b/src/quart_sqlalchemy/sim/util.py index 463b350..d4d33d1 100644 --- a/src/quart_sqlalchemy/sim/util.py +++ b/src/quart_sqlalchemy/sim/util.py @@ -30,8 +30,7 @@ def __init__(self, input_value): elif isinstance(input_value, ObjectID): self._source_id = input_value._decoded_id elif isinstance(input_value, str): - self._source_id = input_value - self._decode() + self._source_id = self._decode(input_value) elif isinstance(input_value, numbers.Number): try: input_value = int(input_value) @@ -47,7 +46,7 @@ def _encoded_id(self): @property def _decoded_id(self): - return self._decode() + return self._source_id def __eq__(self, other): if isinstance(other, ObjectID): @@ -88,11 +87,11 @@ def _encode(self): def encode(self): return self._encoded_id - def _decode(self): - if isinstance(self._source_id, int): - return self._source_id + def _decode(self, value): + if isinstance(value, int): + return value else: - return self.hashids.decode(self._source_id) + return self.hashids.decode(value)[0] def decode(self): return self._decoded_id From c7e453dc1571530da00bf93961955b9373419147 Mon Sep 17 00:00:00 2001 From: Joe Black Date: Fri, 31 Mar 2023 19:41:41 -0400 Subject: [PATCH 7/9] added sim extra deps --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 8e2348c..eb2e1be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,9 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] +sim = [ + "quart-schema", "hashids" +] tests = [ "pytest", # "pytest-asyncio~=0.20.3", From 21e2c2447a77daa9805ceaaa0a5692996fdd73df Mon Sep 17 00:00:00 2001 From: Joe Black Date: Wed, 12 Apr 2023 16:31:51 -0400 Subject: [PATCH 8/9] Sim POC --- .env | 3 + docs/Simulation.md | 159 ++++++ docs/usage.md | 300 +++++++++++ examples/decorators/provide_session.py | 57 ++ examples/repository/base.py | 60 ++- examples/repository/sqla.py | 48 +- examples/usrsrv/component/__init__.py | 23 + examples/usrsrv/component/app.py | 10 + examples/usrsrv/component/command.py | 43 ++ examples/usrsrv/component/entity.py | 39 ++ examples/usrsrv/component/event.py | 45 ++ examples/usrsrv/component/exception.py | 2 + ...5_ddd_component_unitofwork_18fd763a02a4.py | 0 examples/usrsrv/component/repository.py | 63 +++ examples/usrsrv/component/service.py | 68 +++ pyproject.toml | 7 +- setup.cfg | 1 + src/quart_sqlalchemy/__init__.py | 2 +- src/quart_sqlalchemy/bind.py | 189 +++++-- src/quart_sqlalchemy/config.py | 193 ++++--- src/quart_sqlalchemy/framework/cli.py | 69 ++- src/quart_sqlalchemy/framework/extension.py | 14 +- src/quart_sqlalchemy/model/__init__.py | 11 +- src/quart_sqlalchemy/model/columns.py | 14 +- src/quart_sqlalchemy/model/mixins.py | 93 +++- src/quart_sqlalchemy/model/model.py | 43 +- src/quart_sqlalchemy/session.py | 53 ++ src/quart_sqlalchemy/signals.py | 31 ++ src/quart_sqlalchemy/sim/app.py | 69 +-- src/quart_sqlalchemy/sim/auth.py | 62 +-- src/quart_sqlalchemy/sim/builder.py | 6 +- src/quart_sqlalchemy/sim/commands.py | 52 ++ src/quart_sqlalchemy/sim/config.py | 55 ++ src/quart_sqlalchemy/sim/container.py | 57 ++ src/quart_sqlalchemy/sim/db.py | 48 +- src/quart_sqlalchemy/sim/handle.py | 454 +++++----------- src/quart_sqlalchemy/sim/logic.py | 501 ++++++------------ src/quart_sqlalchemy/sim/main.py | 2 + src/quart_sqlalchemy/sim/model.py | 8 +- src/quart_sqlalchemy/sim/repo.py | 150 +++--- src/quart_sqlalchemy/sim/repo_adapter.py | 120 ++--- src/quart_sqlalchemy/sim/schema.py | 68 ++- src/quart_sqlalchemy/sim/signals.py | 1 + src/quart_sqlalchemy/sim/testing.py | 7 +- src/quart_sqlalchemy/sim/views/auth_user.py | 85 ++- src/quart_sqlalchemy/sim/views/auth_wallet.py | 62 ++- .../sim/views/magic_client.py | 87 ++- src/quart_sqlalchemy/sim/web3.py | 153 ++++++ src/quart_sqlalchemy/sqla.py | 294 ++++++++-- src/quart_sqlalchemy/testing/fake.py | 0 src/quart_sqlalchemy/testing/signals.py | 40 ++ src/quart_sqlalchemy/testing/transaction.py | 10 +- src/quart_sqlalchemy/types.py | 15 +- tests/base.py | 204 ++++++- tests/conftest.py | 49 -- tests/constants.py | 20 +- tests/integration/concurrency/__init__.py | 0 .../concurrency/with_for_update.py | 103 ++++ tests/integration/framework/smoke_test.py | 14 +- tests/integration/model/mixins_test.py | 133 ++++- tests/integration/model/model_test.py | 59 +++ tests/integration/retry_test.py | 4 +- workspace.code-workspace | 3 +- 63 files changed, 3351 insertions(+), 1284 deletions(-) create mode 100644 .env create mode 100644 docs/Simulation.md create mode 100644 docs/usage.md create mode 100644 examples/decorators/provide_session.py create mode 100644 examples/usrsrv/component/__init__.py create mode 100644 examples/usrsrv/component/app.py create mode 100644 examples/usrsrv/component/command.py create mode 100644 examples/usrsrv/component/entity.py create mode 100644 examples/usrsrv/component/event.py create mode 100644 examples/usrsrv/component/exception.py create mode 100644 examples/usrsrv/component/migrations/2020-04-15_ddd_component_unitofwork_18fd763a02a4.py create mode 100644 examples/usrsrv/component/repository.py create mode 100644 examples/usrsrv/component/service.py create mode 100644 src/quart_sqlalchemy/sim/commands.py create mode 100644 src/quart_sqlalchemy/sim/config.py create mode 100644 src/quart_sqlalchemy/sim/container.py create mode 100644 src/quart_sqlalchemy/sim/web3.py create mode 100644 src/quart_sqlalchemy/testing/fake.py create mode 100644 src/quart_sqlalchemy/testing/signals.py create mode 100644 tests/integration/concurrency/__init__.py create mode 100644 tests/integration/concurrency/with_for_update.py create mode 100644 tests/integration/model/model_test.py diff --git a/.env b/.env new file mode 100644 index 0000000..f9ea633 --- /dev/null +++ b/.env @@ -0,0 +1,3 @@ +WEB3_HTTPS_PROVIDER_URI=https://eth-mainnet.g.alchemy.com/v2/422IpViRAru0Uu1SANhySuOStpaIK3AG +ALCHEMY_API_KEY=xxx +QUART_APP=quart_sqlalchemy.sim.main:app diff --git a/docs/Simulation.md b/docs/Simulation.md new file mode 100644 index 0000000..656d8df --- /dev/null +++ b/docs/Simulation.md @@ -0,0 +1,159 @@ +# Simulation Docs + +# initialize database +```shell +quart db create +``` +``` +Initialized database schema for +``` + +# add first client to the database (Using CLI) +```shell +quart auth add-client +``` +``` +Created client 2VolejRejNmG with public_api_key: 5f794cf72d0cef2dd008be2c0b7a632b +``` + +Use the `public_api_key` returned for the value of the `X-Public-API-Key` header when making API requests. + + +# Create new auth_user via api +```shell +curl -X POST localhost:8081/api/auth_user/ \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Content-Type: application/json' \ + --data '{"email": "joe2@joe.com"}' +``` +```json +{ + "data": { + "auth_user": { + "client_id": "2VolejRejNmG", + "current_session_token": "69ee9af5b9296a09f90be5b71c1dda38", + "date_verified": 1681344793, + "delegated_identity_pool_id": null, + "delegated_user_id": null, + "email": "joe2@joe.com", + "global_auth_user_id": null, + "id": "GWpmbk5ezJn4", + "is_admin": false, + "linked_primary_auth_user_id": null, + "phone_number": null, + "provenance": null, + "user_type": 2 + } + }, + "error_code": "", + "message": "", + "status": "" +} +``` + +Use the `current_session_token` returned for the value of `Authorization: Bearer {token}` header when making API Requests requiring a user. + +# get AuthUser corresponding to provided bearer session token +```shell +curl -X GET localhost:8081/api/auth_user/ \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Authorization: Bearer 69ee9af5b9296a09f90be5b71c1dda38' \ + -H 'Content-Type: application/json' +``` +```json +{ + "data": { + "client_id": "2VolejRejNmG", + "current_session_token": "69ee9af5b9296a09f90be5b71c1dda38", + "date_verified": 1681344793, + "delegated_identity_pool_id": null, + "delegated_user_id": null, + "email": "joe2@joe.com", + "global_auth_user_id": null, + "id": "GWpmbk5ezJn4", + "is_admin": false, + "linked_primary_auth_user_id": null, + "phone_number": null, + "provenance": null, + "user_type": 2 + }, + "error_code": "", + "message": "", + "status": "" +} +``` + + +# AuthWallet Sync +```shell +curl -X POST localhost:8081/api/auth_wallet/sync \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Authorization: Bearer 69ee9af5b9296a09f90be5b71c1dda38' \ + -H 'Content-Type: application/json' \ + --data '{"public_address": "xxx", "encrypted_private_address": "xxx", "wallet_type": "ETH"}' +``` +```json +{ + "data": { + "auth_user_id": "GWpmbk5ezJn4", + "encrypted_private_address": "xxx", + "public_address": "xxx", + "wallet_id": "GWpmbk5ezJn4", + "wallet_type": "ETH" + }, + "error_code": "", + "message": "", + "status": "" +} +``` + +# get magic client corresponding to provided public api key +```shell +curl -X GET localhost:8081/api/magic_client/ \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Content-Type: application/json' +``` +```json +{ + "data": { + "app_name": "My App", + "connect_interop": null, + "global_audience_enabled": false, + "id": "2VolejRejNmG", + "is_signing_modal_enabled": false, + "public_api_key": "5f794cf72d0cef2dd008be2c0b7a632b", + "rate_limit_tier": null, + "secret_api_key": "c6ecbced505b35505751c862ed0fb10ffb623d24095019433e0d4d94e240e508" + }, + "error_code": "", + "message": "", + "status": "" +} +``` + +# Create new magic client +```shell +curl -X POST localhost:8081/api/magic_client/ \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Content-Type: application/json' \ + --data '{"app_name": "New App"}' +``` +```json +{ + "data": { + "magic_client": { + "app_name": "New App", + "connect_interop": null, + "global_audience_enabled": false, + "id": "GWpmbk5ezJn4", + "is_signing_modal_enabled": false, + "public_api_key": "fb7e0466e2e09387b93af7da49bb1386", + "rate_limit_tier": null, + "secret_api_key": "2ac56a6068d0d4b2ce911ba08401c7bf4acdb03db957550c260bd317c6c49a76" + } + }, + "error_code": "", + "message": "", + "status": "" +} +``` \ No newline at end of file diff --git a/docs/usage.md b/docs/usage.md new file mode 100644 index 0000000..f0bfdbe --- /dev/null +++ b/docs/usage.md @@ -0,0 +1,300 @@ +# API + +## SQLAlchemy +### `quart_sqlalchemy.sqla.SQLAlchemy` + +### Conventions +This manager class keeps things very simple by using a few configuration conventions: + +* Configuration has been simplified down to base_class and binds. +* Everything related to ORM mapping, DeclarativeBase, registry, MetaData, etc should be configured by passing the a custom DeclarativeBase class as the base_class configuration parameter. +* Everything related to engine/session configuration should be configured by passing a dictionary mapping string names to BindConfigs as the `binds` configuration parameter. +* the bind named `default` is the canonical bind, and to be used unless something more specific has been requested + +### Configuration +BindConfig can be as simple as a dictionary containing a url key like so: +```python +bind_config = { + "default": {"url": "sqlite://"} +} +``` + +But most use cases will require more than just a connection url, and divide core/engine configuration from orm/session configuration which looks more like this: +```python +bind_config = { + "default": { + "engine": { + "url": "sqlite://" + }, + "session": { + "expire_on_commit": False + } + } +} +``` + +It helps to think of the bind configuration as being the options dictionary used to build the main core and orm factory objects. +* For SQLAlchemy core, the configuration under the key `engine` will be used by `sa.engine_from_config` to build the `sa.Engine` object which acts as a factory for `sa.Connection` objects. + ```python + engine = sa.engine_from_config(config.engine, prefix="") + ``` +* For SQLAlchemy orm, the configuration under the key `session` will be used to build the `sa.orm.sessionmaker` session factory which acts as a factory for `sa.orm.Session` objects. + ```python + session_factory = sa.orm.sessionmaker(bind=engine, **config.session) + ``` + +#### Usage Examples +SQLAlchemyConfig is to be passed to SQLAlchemy or QuartSQLAlchemy as the first parameter when initializing. + +```python +db = SQLAlchemy( + SQLAlchemyConfig( + binds=dict( + default=dict( + url="sqlite://" + ) + ) + ) +) +``` + +When nothing is provided to SQLAlchemyConfig directly, it is instantiated with the following defaults + +```python +db = SQLAlchemy(SQLAlchemyConfig()) +``` + +For `QuartSQLAlchemy` configuration can also be provided via Quart configuration. +```python +from quart_sqlalchemy.framework import QuartSQLAlchemy + +app = Quart(__name__) +app.config.from_mapping( + { + "SQLALCHEMY_BINDS": { + "default": { + "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + } + }, + "SQLALCHEMY_BASE_CLASS": Base, + } +) +db = QuartSQLAlchemy(app=app) +``` + + + + +A typical configuration containing engine and session config both: +```python +config = SQLAlchemyConfig( + binds=dict( + default=dict( + engine=dict( + url="sqlite://" + ), + session=dict( + expire_on_commit=False + ) + ) + ) +) +``` + +Async first configuration +```python +config = SQLAlchemyConfig( + binds=dict( + default=dict( + engine=dict( + url="sqlite+aiosqlite:///file:mem.db?mode=memory&cache=shared&uri=true" + ), + session=dict( + expire_on_commit=False + ) + ) + ) +) +``` + +More complex configuration having two additional binds based on default, one for a read-replica and the second having an async driver + +```python +config = { + "SQLALCHEMY_BINDS": { + "default": { + "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + "read-replica": { + "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + "read_only": True, + }, + "async": { + "engine": {"url": "sqlite+aiosqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + }, + "SQLALCHEMY_BASE_CLASS": Base, +} +``` + + + Once instantiated, operations targetting all of the binds, aka metadata, like + `metadata.create_all` should be called from this class. Operations specific to a bind + should be called from that bind. This class has a few ways to get a specific bind. + + * To get a Bind, you can call `.get_bind(name)` on this class. The default bind can be + referenced at `.bind`. + + * To define an ORM model using the Base class attached to this class, simply inherit + from `.Base` + + db = SQLAlchemy(SQLAlchemyConfig()) + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + * You can also decouple Base from SQLAlchemy with some dependency inversion: + from quart_sqlalchemy.model.mixins import DynamicArgsMixin, ReprMixin, TableNameMixin + + class Base(DynamicArgsMixin, ReprMixin, TableNameMixin): + __abstract__ = True + + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db = SQLAlchemy(SQLAlchemyConfig(bind_class=Base)) + + db.create_all() + + + Declarative Mapping using registry based decorator: + + db = SQLAlchemy(SQLAlchemyConfig()) + + @db.registry.mapped + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Declarative with Imperative Table (Hybrid Declarative): + + class User(db.Base): + __table__ = sa.Table( + "user", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.String, default="Joe"), + ) + + + Declarative using reflection to automatically build the table object: + + class User(db.Base): + __table__ = sa.Table( + "user", + db.metadata, + autoload_with=db.bind.engine, + ) + + + Declarative Dataclass Mapping: + + from quart_sqlalchemy.model import Base as Base_ + + class Base(sa.orm.MappedAsDataclass, Base_): + pass + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + class User(db.Base): + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Declarative Dataclass Mapping (using decorator): + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + @db.registry.mapped_as_dataclass + class User: + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Alternate Dataclass Provider Pattern: + + from pydantic.dataclasses import dataclass + from quart_sqlalchemy.model import Base as Base_ + + class Base(sa.orm.MappedAsDataclass, Base_, dataclass_callable=dataclass): + pass + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + class User(db.Base): + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + Imperative style Mapping + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + user_table = sa.Table( + "user", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.String, default="Joe"), + ) + + post_table = sa.Table( + "post", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("title", sa.String, default="My post"), + sa.Column("user_id", sa.ForeignKey("user.id"), nullable=False), + ) + + class User: + pass + + class Post: + pass + + db.registry.map_imperatively( + User, + user_table, + properties={ + "posts": sa.orm.relationship(Post, back_populates="user") + } + ) + db.registry.map_imperatively( + Post, + post_table, + properties={ + "user": sa.orm.relationship(User, back_populates="posts", uselist=False) + } + ) \ No newline at end of file diff --git a/examples/decorators/provide_session.py b/examples/decorators/provide_session.py new file mode 100644 index 0000000..f67f2b0 --- /dev/null +++ b/examples/decorators/provide_session.py @@ -0,0 +1,57 @@ +import inspect +import typing as t +from contextlib import contextmanager +from functools import wraps + + +RT = t.TypeVar("RT") + + +@contextmanager +def create_session(bind): + """Contextmanager that will create and teardown a session.""" + session = bind.Session() + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + +def provide_session(bind_name: str = "default"): + """ + Function decorator that provides a session if it isn't provided. + If you want to reuse a session or run the function as part of a + database transaction, you pass it to the function, if not this wrapper + will create one and close it for you. + """ + + def decorator(func: t.Callable[..., RT]) -> t.Callable[..., RT]: + from quart_sqlalchemy import Bind + + func_params = inspect.signature(func).parameters + try: + # func_params is an ordered dict -- this is the "recommended" way of getting the position + session_args_idx = tuple(func_params).index("session") + except ValueError: + raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None + + # We don't need this anymore -- ensure we don't keep a reference to it by mistake + del func_params + + @wraps(func) + def wrapper(*args, **kwargs) -> RT: + if "session" in kwargs or session_args_idx < len(args): + return func(*args, **kwargs) + else: + bind = Bind.get_instance(bind_name) + + with create_session(bind) as session: + return func(*args, session=session, **kwargs) + + return wrapper + + return decorator diff --git a/examples/repository/base.py b/examples/repository/base.py index ec70e39..3ae0d2e 100644 --- a/examples/repository/base.py +++ b/examples/repository/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator import typing as t from abc import ABCMeta from abc import abstractmethod @@ -14,42 +15,49 @@ from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import Operator from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable +from quart_sqlalchemy.types import SessionT sa = sqlalchemy -class AbstractRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): +class AbstractRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): """A repository interface.""" - identity: t.Type[EntityIdT] + # entity: t.Type[EntityT] - # def __init__(self, model: t.Type[EntityT]): - # self.model = model + # def __init__(self, entity: t.Type[EntityT]): + # self.entity = entity @property - def model(self) -> EntityT: + def entity(self) -> EntityT: return self.__orig_class__.__args__[0] @abstractmethod - def insert(self, values: t.Dict[str, t.Any]) -> EntityT: + def insert(self, session: SessionT, values: t.Dict[str, t.Any]) -> EntityT: """Add `values` to the collection.""" @abstractmethod - def update(self, id_: EntityIdT, values: t.Dict[str, t.Any]) -> EntityT: + def update(self, session: SessionT, id_: EntityIdT, values: t.Dict[str, t.Any]) -> EntityT: """Update model with model_id using values.""" @abstractmethod def merge( - self, id_: EntityIdT, values: t.Dict[str, t.Any], for_update: bool = False + self, + session: SessionT, + id_: EntityIdT, + values: t.Dict[str, t.Any], + for_update: bool = False, ) -> EntityT: """Merge model with model_id using values.""" @abstractmethod def get( self, + session: SessionT, id_: EntityIdT, options: t.Sequence[ORMOption] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -58,9 +66,28 @@ def get( ) -> t.Optional[EntityT]: """Get model with model_id.""" + @abstractmethod + def get_by_field( + self, + session: SessionT, + field: t.Union[ColumnExpr, str], + value: t.Any, + op: Operator = operator.eq, + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + ) -> sa.ScalarResult[EntityT]: + """Select models where field is equal to value.""" + @abstractmethod def select( self, + session: SessionT, selectables: t.Sequence[Selectable] = (), conditions: t.Sequence[ColumnExpr] = (), group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), @@ -77,12 +104,13 @@ def select( """Select models matching conditions.""" @abstractmethod - def delete(self, id_: EntityIdT) -> None: + def delete(self, session: SessionT, id_: EntityIdT) -> None: """Delete model with id_.""" @abstractmethod def exists( self, + session: SessionT, conditions: t.Sequence[ColumnExpr] = (), for_update: bool = False, include_inactive: bool = False, @@ -90,27 +118,31 @@ def exists( """Return the existence of an object matching conditions.""" @abstractmethod - def deactivate(self, id_: EntityIdT) -> EntityT: + def deactivate(self, session: SessionT, id_: EntityIdT) -> EntityT: """Soft-Delete model with id_.""" @abstractmethod - def reactivate(self, id_: EntityIdT) -> EntityT: + def reactivate(self, session: SessionT, id_: EntityIdT) -> EntityT: """Soft-Delete model with id_.""" -class AbstractBulkRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): +class AbstractBulkRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): """A repository interface for bulk operations. Note: this interface circumvents ORM internals, breaking commonly expected behavior in order to gain performance benefits. Only use this class whenever absolutely necessary. """ - model: t.Type[EntityT] builder: StatementBuilder + @property + def entity(self) -> EntityT: + return self.__orig_class__.__args__[0] + @abstractmethod def bulk_insert( self, + session: SessionT, values: t.Sequence[t.Dict[str, t.Any]] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: @@ -119,6 +151,7 @@ def bulk_insert( @abstractmethod def bulk_update( self, + session: SessionT, conditions: t.Sequence[ColumnExpr] = (), values: t.Optional[t.Dict[str, t.Any]] = None, execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -128,6 +161,7 @@ def bulk_update( @abstractmethod def bulk_delete( self, + session: SessionT, conditions: t.Sequence[ColumnExpr] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: diff --git a/examples/repository/sqla.py b/examples/repository/sqla.py index 417ddaf..5f77dae 100644 --- a/examples/repository/sqla.py +++ b/examples/repository/sqla.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator import typing as t import sqlalchemy @@ -15,6 +16,7 @@ from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import Operator from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable from quart_sqlalchemy.types import SessionT @@ -25,8 +27,8 @@ class SQLAlchemyRepository( TableMetadataMixin, - AbstractRepository[EntityT, EntityIdT], - t.Generic[EntityT, EntityIdT], + AbstractRepository[EntityT, EntityIdT, SessionT], + t.Generic[EntityT, EntityIdT, SessionT], ): """A repository that uses SQLAlchemy to persist data. @@ -53,7 +55,7 @@ class SQLAlchemyRepository( session: sa.orm.Session builder: StatementBuilder - def __init__(self, session: sa.orm.Session, **kwargs): + def __init__(self, model: sa.orm.Session, **kwargs): super().__init__(**kwargs) self.session = session self.builder = StatementBuilder(None) @@ -125,6 +127,46 @@ def get( return self.session.scalars(statement, execution_options=execution_options).one_or_none() + def get_by_field( + self, + field: t.Union[ColumnExpr, str], + value: t.Any, + op: Operator = operator.eq, + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + ) -> sa.ScalarResult[EntityT]: + """Select models where field is equal to value.""" + selectables = (self.model,) # type: ignore + + execution_options = execution_options or {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + + if isinstance(field, str): + field = getattr(self.model, field) + + conditions = [t.cast(ColumnExpr, op(field, value))] + + statement = self.builder.complex_select( + selectables, + conditions=conditions, + order_by=order_by, + options=options, + execution_options=execution_options, + offset=offset, + limit=limit, + distinct=distinct, + for_update=for_update, + ) + + return self.session.scalars(statement) + def select( self, selectables: t.Sequence[Selectable] = (), diff --git a/examples/usrsrv/component/__init__.py b/examples/usrsrv/component/__init__.py new file mode 100644 index 0000000..21b27e9 --- /dev/null +++ b/examples/usrsrv/component/__init__.py @@ -0,0 +1,23 @@ +from . import commands +from . import events +from . import exceptions +from .app import handler +from .entity import EntityID +from .service import CommandHandler +from .service import Listener + + +handle = handler.handle +register = handler.register +unregister = handler.unregister + +__all__ = [ + "commands", + "events", + "exceptions", + "EntityID", + "CommandHandler", + "handle", + "register", + "unregister", +] diff --git a/examples/usrsrv/component/app.py b/examples/usrsrv/component/app.py new file mode 100644 index 0000000..b18baee --- /dev/null +++ b/examples/usrsrv/component/app.py @@ -0,0 +1,10 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from .repository import ORMRepository +from .service import CommandHandler + + +some_engine = create_engine("sqlite:///") +Session = sessionmaker(bind=some_engine) +handler = CommandHandler(ORMRepository(Session())) diff --git a/examples/usrsrv/component/command.py b/examples/usrsrv/component/command.py new file mode 100644 index 0000000..47f59b0 --- /dev/null +++ b/examples/usrsrv/component/command.py @@ -0,0 +1,43 @@ +""" +Commands +======== +A command is always DTO and as specific, as it can be from a domain perspective. I aim to create +separate classes for commands so I can just dispatch handlers by command class. + +```python +@dataclass +class Create(Command): + command_id: CommandID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) +``` +""" + +from abc import ABC +from dataclasses import dataclass, field +from datetime import datetime +from typing import Text +from uuid import UUID, uuid1 + +from .entity import EntityID + +CommandID = UUID + + +class Command(ABC): + entity_id: EntityID + command_id: CommandID + timestamp: datetime + + +@dataclass +class Create(Command): + command_id: CommandID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) + + +@dataclass +class UpdateValue(Command): + entity_id: EntityID + value: Text + command_id: CommandID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) diff --git a/examples/usrsrv/component/entity.py b/examples/usrsrv/component/entity.py new file mode 100644 index 0000000..a910ad2 --- /dev/null +++ b/examples/usrsrv/component/entity.py @@ -0,0 +1,39 @@ +from typing import NewType +from typing import Optional +from typing import Text +from uuid import UUID +from uuid import uuid1 + + +EntityID = NewType("EntityID", UUID) + + +class EntityDTO: + id: EntityID + value: Optional[Text] + + +class Entity: + id: EntityID + dto: EntityDTO + + class Event: + pass + + class Updated(Event): + pass + + def __init__(self, dto: EntityDTO) -> None: + self.id = dto.id + self.dto = dto + + @classmethod + def create(cls) -> "Entity": + dto = EntityDTO() + dto.id = EntityID(uuid1()) + dto.value = None + return Entity(dto) + + def update(self, value: Text) -> Updated: + self.dto.value = value + return self.Updated() diff --git a/examples/usrsrv/component/event.py b/examples/usrsrv/component/event.py new file mode 100644 index 0000000..4dd2895 --- /dev/null +++ b/examples/usrsrv/component/event.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from dataclasses import field +from datetime import datetime +from functools import singledispatch +from uuid import UUID +from uuid import uuid1 + +from .command import Command +from .command import CommandID +from .entity import Entity +from .entity import EntityID + + +EventID = UUID + + +class Event: + command_id: CommandID + event_id: EventID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) + + +@dataclass +class Created(Event): + command_id: CommandID + uow_id: EntityID + event_id: EventID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) + + +@dataclass +class Updated(Event): + command_id: CommandID + event_id: EventID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) + + +@singledispatch +def app_event(event: Entity.Event, command: Command) -> Event: + raise NotImplementedError + + +@app_event.register(Entity.Updated) +def _(event: Entity.Updated, command: Command) -> Updated: + return Updated(command.command_id) diff --git a/examples/usrsrv/component/exception.py b/examples/usrsrv/component/exception.py new file mode 100644 index 0000000..f88096f --- /dev/null +++ b/examples/usrsrv/component/exception.py @@ -0,0 +1,2 @@ +class NotFound(Exception): + pass diff --git a/examples/usrsrv/component/migrations/2020-04-15_ddd_component_unitofwork_18fd763a02a4.py b/examples/usrsrv/component/migrations/2020-04-15_ddd_component_unitofwork_18fd763a02a4.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/usrsrv/component/repository.py b/examples/usrsrv/component/repository.py new file mode 100644 index 0000000..44dffcd --- /dev/null +++ b/examples/usrsrv/component/repository.py @@ -0,0 +1,63 @@ +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.orm import registry +from sqlalchemy.orm import Session + +from . import EntityID +from .entity import Entity +from .entity import EntityDTO +from .exception import NotFound +from .service import Repository + + +metadata = MetaData() +mapper_registry = registry(metadata=metadata) + + +entities_table = Table( + "entities", + metadata, + Column("id", Integer, primary_key=True, autoincrement=True), + Column("uuid", String, unique=True, index=True), + Column("value", String, nullable=True), +) + +# EntityMapper = mapper( +# EntityDTO, +# entities_table, +# properties={ +# "id": entities_table.c.uuid, +# "value": entities_table.c.value, +# }, +# column_prefix="_db_column_", +# ) + +EntityMapper = mapper_registry.map_imperatively( + EntityDTO, + entities_table, + properties={ + "id": entities_table.c.uuid, + "value": entities_table.c.value, + }, + column_prefix="_db_column_", +) + + +class ORMRepository(Repository): + def __init__(self, session: Session): + self._session = session + self._query = select(EntityMapper) + + def get(self, entity_id: EntityID) -> Entity: + dto = self._session.scalars(self._query.filter_by(uuid=entity_id)).one_or_none() + if not dto: + raise NotFound(entity_id) + return Entity(dto) + + def save(self, entity: Entity) -> None: + self._session.add(entity.dto) + self._session.flush() diff --git a/examples/usrsrv/component/service.py b/examples/usrsrv/component/service.py new file mode 100644 index 0000000..9846736 --- /dev/null +++ b/examples/usrsrv/component/service.py @@ -0,0 +1,68 @@ +from abc import ABC +from abc import abstractmethod +from functools import singledispatch +from typing import Callable +from typing import List +from typing import Optional + +from .command import Command +from .command import Create +from .command import UpdateValue +from .entity import Entity +from .entity import EntityID +from .event import app_event +from .event import Created +from .event import Event + + +Listener = Callable[[Event], None] + + +class Repository(ABC): + @abstractmethod + def get(self, entity_id: EntityID) -> Entity: + raise NotImplementedError + + @abstractmethod + def save(self, entity: Entity) -> None: + raise NotImplementedError + + +class CommandHandler: + def __init__(self, repository: Repository) -> None: + self._repository = repository + self._listeners: List[Listener] = [] + super().__init__() + + def register(self, listener: Listener) -> None: + if listener not in self._listeners: + self._listeners.append(listener) + + def unregister(self, listener: Listener) -> None: + if listener in self._listeners: + self._listeners.remove(listener) + + @singledispatch + def handle(self, command: Command) -> Optional[Event]: + entity: Entity = self._repository.get(command.entity_id) + + event: Event = app_event(self._handle(command, entity), command) + for listener in self._listeners: + listener(event) + + self._repository.save(entity) + return event + + @handle.register(Create) + def create(self, command: Create) -> Event: + entity = Entity.create() + self._repository.save(entity) + return Created(command.command_id, entity.id) + + @singledispatch + def _handle(self, c: Command, u: Entity) -> Entity.Event: + raise NotImplementedError + + @_handle.register(UpdateValue) + def _(self, command: UpdateValue, entity: Entity) -> Entity.Event: + return entity.update(command.value) diff --git a/pyproject.toml b/pyproject.toml index eb2e1be..17e8c4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,8 @@ dependencies = [ "pydantic", "tenacity", "sqlapagination", - "exceptiongroup" + "exceptiongroup", + "python-ulid" ] requires-python = ">=3.7" readme = "README.rst" @@ -31,7 +32,7 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] sim = [ - "quart-schema", "hashids" + "quart-schema", "hashids", "web3", "dependency-injector", ] tests = [ "pytest", @@ -121,7 +122,7 @@ ignore_missing_imports = true [tool.pylint.messages_control] max-line-length = 100 -disable = ["missing-docstring", "protected-access"] +disable = ["invalid-name", "missing-docstring", "protected-access"] [tool.flakeheaven] baseline = ".flakeheaven_baseline" diff --git a/setup.cfg b/setup.cfg index ee9c148..0c8f6eb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,6 +27,7 @@ ignore = WPS463 allowed-domain-names = + db value val vals diff --git a/src/quart_sqlalchemy/__init__.py b/src/quart_sqlalchemy/__init__.py index b85cd50..28059e4 100644 --- a/src/quart_sqlalchemy/__init__.py +++ b/src/quart_sqlalchemy/__init__.py @@ -1,5 +1,5 @@ __version__ = "3.0.2" - +from . import util from .bind import AsyncBind from .bind import Bind from .bind import BindContext diff --git a/src/quart_sqlalchemy/bind.py b/src/quart_sqlalchemy/bind.py index 8a69eaa..dce08c2 100644 --- a/src/quart_sqlalchemy/bind.py +++ b/src/quart_sqlalchemy/bind.py @@ -1,8 +1,12 @@ from __future__ import annotations import os +import threading import typing as t +from contextlib import asynccontextmanager from contextlib import contextmanager +from contextlib import ExitStack +from weakref import WeakValueDictionary import sqlalchemy import sqlalchemy.event @@ -11,6 +15,7 @@ import sqlalchemy.ext.asyncio import sqlalchemy.orm import sqlalchemy.util +import typing_extensions as tx from . import signals from .config import BindConfig @@ -21,8 +26,16 @@ sa = sqlalchemy +SqlAMode = tx.Literal["orm", "core"] + + +class BindNotInitialized(RuntimeError): + """ "Bind not initialized yet.""" + class BindBase: + name: t.Optional[str] + url: sa.URL config: BindConfig metadata: sa.MetaData engine: sa.Engine @@ -30,66 +43,157 @@ class BindBase: def __init__( self, - config: BindConfig, - metadata: sa.MetaData, + name: t.Optional[str] = None, + url: t.Union[sa.URL, str] = "sqlite://", + config: t.Optional[BindConfig] = None, + metadata: t.Optional[sa.MetaData] = None, ): - self.config = config - self.metadata = metadata - - @property - def url(self) -> str: - if not hasattr(self, "engine"): - raise RuntimeError("Database not initialized yet. Call initialize() first.") - return str(self.engine.url) + self.name = name + self.url = sa.make_url(url) + self.config = config or BindConfig.default() + self.metadata = metadata or sa.MetaData() @property def is_async(self) -> bool: - if not hasattr(self, "engine"): - raise RuntimeError("Database not initialized yet. Call initialize() first.") - return self.engine.url.get_dialect().is_async + return self.url.get_dialect().is_async @property - def is_read_only(self): + def is_read_only(self) -> bool: return self.config.read_only + def __repr__(self) -> str: + parts = [type(self).__name__] + if self.name: + parts.append(self.name) + if self.url: + parts.append(str(self.url)) + if self.is_read_only: + parts.append("[read-only]") + + return f"<{' '.join(parts)}>" + class BindContext(BindBase): pass class Bind(BindBase): + lock: threading.Lock + _instances: WeakValueDictionary = WeakValueDictionary() + def __init__( self, - config: BindConfig, - metadata: sa.MetaData, + name: t.Optional[str] = None, + url: t.Union[sa.URL, str] = "sqlite://", + config: t.Optional[BindConfig] = None, + metadata: t.Optional[sa.MetaData] = None, initialize: bool = True, + track_instance: bool = False, ): - self.config = config - self.metadata = metadata + super().__init__(name, url, config, metadata) + self._initialization_lock = threading.Lock() + + if track_instance: + self._track_instance(name) if initialize: self.initialize() - def initialize(self): - if hasattr(self, "engine"): - self.engine.dispose() + self._session_stack = [] - self.engine = self.create_engine( - self.config.engine.dict(exclude_unset=True, exclude_none=True), - prefix="", - ) - self.Session = self.create_session_factory( - self.config.session.dict(exclude_unset=True, exclude_none=True), - ) + def initialize(self) -> tx.Self: + with self._initialization_lock: + if hasattr(self, "engine"): + self.engine.dispose() + + engine_config = self.config.engine.dict(exclude_unset=True, exclude_none=True) + engine_config.setdefault("url", self.url) + self.engine = self.create_engine(engine_config, prefix="") + + session_options = self.config.session.dict(exclude_unset=True, exclude_none=True) + self.Session = self.create_session_factory(session_options) return self + def _track_instance(self, name): + if name is None: + return + + if name in Bind._instances: + raise ValueError("Bind instance `{name}` already exists, use another name.") + else: + Bind._instances[name] = self + + @classmethod + def get_instance(cls, name: str = "default") -> Bind: + """Get the singleton instance having `name`. + + This enables some really cool patterns similar to how logging allows getting an already + initialized logger from anywhere without importing it directly. Features like this are + most useful when working in web frameworks like flask and quart that are more prone to + circular dependency issues. + + Example: + app/db.py: + from quart_sqlalchemy import Bind + + default = Bind("default", url="sqlite://") + + with default.Session() as session: + with session.begin(): + session.add(User()) + + + app/views/v1/user/login.py + from quart_sqlalchemy import Bind + + # get the same `default` bind already instantiated in app/db.py + default = Bind.get_instance("default") + + with default.Session() as session: + with session.begin(): + session.add(User()) + ... + """ + try: + return Bind._instances[name]() + except KeyError as err: + raise ValueError(f"Bind instance `{name}` does not exist.") from err + + @t.overload + @contextmanager + def transaction(self, mode: SqlAMode = "orm") -> t.Generator[sa.orm.Session, None, None]: + ... + + @t.overload + @contextmanager + def transaction(self, mode: SqlAMode = "core") -> t.Generator[sa.Connection, None, None]: + ... + + @contextmanager + def transaction( + self, mode: SqlAMode = "orm" + ) -> t.Generator[t.Union[sa.orm.Session, sa.Connection], None, None]: + if mode == "orm": + with self.Session() as session: + with session.begin(): + yield session + elif mode == "core": + with self.engine.connect() as connection: + with connection.begin(): + yield connection + else: + raise ValueError(f"Invalid transaction mode `{mode}`") + + def test_transaction(self, savepoint: bool = False) -> TestTransaction: + return TestTransaction(self, savepoint=savepoint) + @contextmanager def context( self, engine_execution_options: t.Optional[t.Dict[str, t.Any]] = None, session_execution__options: t.Optional[t.Dict[str, t.Any]] = None, ) -> t.Generator[BindContext, None, None]: - context = BindContext(self.config, self.metadata) + context = BindContext(f"{self.name}-context", self.url, self.config, self.metadata) context.engine = self.engine.execution_options(**engine_execution_options or {}) context.Session = self.create_session_factory(session_execution__options or {}) context.Session.configure(bind=context.engine) @@ -110,7 +214,7 @@ def context( ) def create_session_factory( - self, options: dict[str, t.Any] + self, options: t.Dict[str, t.Any] ) -> sa.orm.sessionmaker[sa.orm.Session]: signals.before_bind_session_factory_created.send(self, options=options) session_factory = sa.orm.sessionmaker(bind=self.engine, **options) @@ -125,9 +229,6 @@ def create_engine(self, config: t.Dict[str, t.Any], prefix: str = "") -> sa.Engi signals.after_bind_engine_created.send(self, config=config, prefix=prefix, engine=engine) return engine - def test_transaction(self, savepoint: bool = False): - return TestTransaction(self, savepoint=savepoint) - def _call_metadata(self, method: str): with self.engine.connect() as conn: with conn.begin(): @@ -142,14 +243,27 @@ def drop_all(self): def reflect(self): return self._call_metadata("reflect") - def __repr__(self) -> str: - return f"<{type(self).__name__} {self.engine.url}>" - class AsyncBind(Bind): engine: sa.ext.asyncio.AsyncEngine Session: sa.ext.asyncio.async_sessionmaker + @asynccontextmanager + async def transaction(self, mode: SqlAMode = "orm"): + if mode == "orm": + async with self.Session() as session: + async with session.begin(): + yield session + elif mode == "core": + async with self.engine.connect() as connection: + async with connection.begin(): + yield connection + else: + raise ValueError(f"Invalid transaction mode `{mode}`") + + def test_transaction(self, savepoint: bool = False): + return AsyncTestTransaction(self, savepoint=savepoint) + def create_session_factory( self, options: dict[str, t.Any] ) -> sa.ext.asyncio.async_sessionmaker[sa.ext.asyncio.AsyncSession]: @@ -182,9 +296,6 @@ def create_engine( signals.after_bind_engine_created.send(self, config=config, prefix=prefix, engine=engine) return engine - def test_transaction(self, savepoint: bool = False): - return AsyncTestTransaction(self, savepoint=savepoint) - async def _call_metadata(self, method: str): async with self.engine.connect() as conn: async with conn.begin(): diff --git a/src/quart_sqlalchemy/config.py b/src/quart_sqlalchemy/config.py index 5caa0ec..0190160 100644 --- a/src/quart_sqlalchemy/config.py +++ b/src/quart_sqlalchemy/config.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import os import types import typing as t @@ -11,6 +10,7 @@ import sqlalchemy.ext import sqlalchemy.ext.asyncio import sqlalchemy.orm +import sqlalchemy.sql.sqltypes import sqlalchemy.util import typing_extensions as tx from pydantic import BaseModel @@ -22,6 +22,8 @@ from .model import Base from .types import BoundParamStyle from .types import DMLStrategy +from .types import Empty +from .types import EmptyType from .types import SessionBind from .types import SessionBindKey from .types import SynchronizeSession @@ -63,9 +65,9 @@ class ConfigBase(BaseModel): class Config: arbitrary_types_allowed = True - @classmethod - def default(cls): - return cls() + @root_validator + def scrub_empty(cls, values): + return {key: val for key, val in values.items() if val not in [Empty, {}]} class CoreExecutionOptions(ConfigBase): @@ -73,15 +75,15 @@ class CoreExecutionOptions(ConfigBase): https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.Connection.execution_options """ - isolation_level: t.Optional[TransactionIsolationLevel] = None - compiled_cache: t.Optional[t.Dict[t.Any, Compiled]] = Field(default_factory=dict) - logging_token: t.Optional[str] = None - no_parameters: bool = False - stream_results: bool = False - max_row_buffer: int = 1000 - yield_per: t.Optional[int] = None - insertmanyvalues_page_size: int = 1000 - schema_translate_map: t.Optional[t.Dict[str, str]] = None + isolation_level: t.Union[TransactionIsolationLevel, EmptyType] = Empty + compiled_cache: t.Union[t.Dict[t.Any, Compiled], None, EmptyType] = Empty + logging_token: t.Union[str, None, EmptyType] = Empty + no_parameters: t.Union[bool, EmptyType] = Empty + stream_results: t.Union[bool, EmptyType] = Empty + max_row_buffer: t.Union[int, EmptyType] = Empty + yield_per: t.Union[int, None, EmptyType] = Empty + insertmanyvalues_page_size: t.Union[int, EmptyType] = Empty + schema_translate_map: t.Union[t.Dict[str, str], None, EmptyType] = Empty class ORMExecutionOptions(ConfigBase): @@ -89,14 +91,21 @@ class ORMExecutionOptions(ConfigBase): https://docs.sqlalchemy.org/en/20/orm/queryguide/api.html#orm-queryguide-execution-options """ - isolation_level: t.Optional[TransactionIsolationLevel] = None - stream_results: bool = False - yield_per: t.Optional[int] = None - populate_existing: bool = False - autoflush: bool = True - identity_token: t.Optional[str] = None - synchronize_session: SynchronizeSession = "auto" - dml_strategy: DMLStrategy = "auto" + isolation_level: t.Union[TransactionIsolationLevel, EmptyType] = Empty + stream_results: t.Union[bool, EmptyType] = Empty + yield_per: t.Union[int, None, EmptyType] = Empty + populate_existing: t.Union[bool, EmptyType] = Empty + autoflush: t.Union[bool, EmptyType] = Empty + identity_token: t.Union[str, None, EmptyType] = Empty + synchronize_session: t.Union[SynchronizeSession, None, EmptyType] = Empty + dml_strategy: t.Union[DMLStrategy, None, EmptyType] = Empty + + +# connect_args: +# mysql: +# connect_timeout: +# postgres: +# connect_timeout: class EngineConfig(ConfigBase): @@ -104,42 +113,55 @@ class EngineConfig(ConfigBase): https://docs.sqlalchemy.org/en/20/core/engines.html#sqlalchemy.create_engine """ - url: t.Union[sa.URL, str] = "sqlite://" - echo: bool = False - echo_pool: bool = False - connect_args: t.Dict[str, t.Any] = Field(default_factory=dict) + url: t.Union[sa.URL, str, EmptyType] = Empty + echo: t.Union[bool, EmptyType] = Empty + echo_pool: t.Union[bool, EmptyType] = Empty + connect_args: t.Union[t.Dict[str, t.Any], EmptyType] = Empty execution_options: CoreExecutionOptions = Field(default_factory=CoreExecutionOptions) - enable_from_linting: bool = True - hide_parameters: bool = False - insertmanyvalues_page_size: int = 1000 - isolation_level: t.Optional[TransactionIsolationLevel] = None - json_deserializer: t.Callable[[str], t.Any] = json.loads - json_serializer: t.Callable[[t.Any], str] = json.dumps - label_length: t.Optional[int] = None - logging_name: t.Optional[str] = None - max_identifier_length: t.Optional[int] = None - max_overflow: int = 10 - module: t.Optional[types.ModuleType] = None - paramstyle: t.Optional[BoundParamStyle] = None - pool: t.Optional[sa.Pool] = None - poolclass: t.Optional[t.Type[sa.Pool]] = None - pool_logging_name: t.Optional[str] = None - pool_pre_ping: bool = False - pool_size: int = 5 - pool_recycle: int = -1 - pool_reset_on_return: t.Optional[tx.Literal["values", "rollback"]] = None - pool_timeout: int = 40 - pool_use_lifo: bool = False - plugins: t.Sequence[str] = Field(default_factory=list) - query_cache_size: int = 500 - user_insertmanyvalues: bool = True + enable_from_linting: t.Union[bool, EmptyType] = Empty + hide_parameters: t.Union[bool, EmptyType] = Empty + insertmanyvalues_page_size: t.Union[int, EmptyType] = Empty + isolation_level: t.Union[TransactionIsolationLevel, EmptyType] = Empty + json_deserializer: t.Union[t.Callable[[str], t.Any], EmptyType] = Empty + json_serializer: t.Union[t.Callable[[t.Any], str], EmptyType] = Empty + label_length: t.Union[int, None, EmptyType] = Empty + logging_name: t.Union[str, None, EmptyType] = Empty + max_identifier_length: t.Union[int, None, EmptyType] = Empty + max_overflow: t.Union[int, EmptyType] = Empty + module: t.Union[types.ModuleType, None, EmptyType] = Empty + paramstyle: t.Union[BoundParamStyle, None, EmptyType] = Empty + pool: t.Union[sa.Pool, None, EmptyType] = Empty + poolclass: t.Union[t.Type[sa.Pool], None, EmptyType] = Empty + pool_logging_name: t.Union[str, None, EmptyType] = Empty + pool_pre_ping: t.Union[bool, EmptyType] = Empty + pool_size: t.Union[int, EmptyType] = Empty + pool_recycle: t.Union[int, EmptyType] = Empty + pool_reset_on_return: t.Union[tx.Literal["values", "rollback"], None, EmptyType] = Empty + pool_timeout: t.Union[int, EmptyType] = Empty + pool_use_lifo: t.Union[bool, EmptyType] = Empty + plugins: t.Union[t.Sequence[str], EmptyType] = Empty + query_cache_size: t.Union[int, EmptyType] = Empty + user_insertmanyvalues: t.Union[bool, EmptyType] = Empty - @classmethod - def default(cls): - return cls(url="sqlite://") + @root_validator + def scrub_execution_options(cls, values): + if "execution_options" in values: + execute_options = values["execution_options"].dict(exclude_defaults=True) + if execute_options: + values["execution_options"] = execute_options + return values + + @root_validator + def set_defaults(cls, values): + values.setdefault("url", "sqlite://") + return values @root_validator def apply_driver_defaults(cls, values): + # values["execution_options"] = values["execution_options"].dict(exclude_defaults=True) + # values = {key: val for key, val in values.items() if val not in [Empty, {}]} + # values.setdefault("url", "sqlite://") + url = sa.make_url(values["url"]) driver = url.drivername @@ -177,19 +199,31 @@ def apply_driver_defaults(cls, values): return values +class AsyncEngineConfig(EngineConfig): + @root_validator + def set_defaults(cls, values): + values.setdefault("url", "sqlite+aiosqlite://") + return values + + class SessionOptions(ConfigBase): """ https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.Session """ - autoflush: bool = True - autobegin: bool = True - expire_on_commit: bool = False - bind: t.Optional[SessionBind] = None - binds: t.Optional[t.Dict[SessionBindKey, SessionBind]] = None - twophase: bool = False - info: t.Optional[t.Dict[t.Any, t.Any]] = None - join_transaction_mode: JoinTransactionMode = "conditional_savepoint" + autoflush: t.Union[bool, EmptyType] = Empty + autobegin: t.Union[bool, EmptyType] = Empty + expire_on_commit: t.Union[bool, EmptyType] = Empty + bind: t.Union[SessionBind, None, EmptyType] = Empty + binds: t.Union[t.Dict[SessionBindKey, SessionBind], None, EmptyType] = Empty + twophase: t.Union[bool, EmptyType] = Empty + info: t.Union[t.Dict[t.Any, t.Any], None, EmptyType] = Empty + join_transaction_mode: t.Union[JoinTransactionMode, EmptyType] = Empty + + @root_validator + def set_defaults(cls, values): + values.setdefault("expire_on_commit", False) + return values class SessionmakerOptions(SessionOptions): @@ -197,7 +231,7 @@ class SessionmakerOptions(SessionOptions): https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.sessionmaker """ - class_: t.Type[sa.orm.Session] = sa.orm.Session + class_: t.Union[t.Type[sa.orm.Session], EmptyType] = Empty class AsyncSessionOptions(SessionOptions): @@ -205,7 +239,7 @@ class AsyncSessionOptions(SessionOptions): https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.AsyncSession """ - sync_session_class: t.Type[sa.orm.Session] = sa.orm.Session + sync_session_class: t.Union[t.Type[sa.orm.Session], EmptyType] = Empty class AsyncSessionmakerOptions(AsyncSessionOptions): @@ -213,13 +247,14 @@ class AsyncSessionmakerOptions(AsyncSessionOptions): https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.async_sessionmaker """ - class_: t.Type[sa.ext.asyncio.AsyncSession] = sa.ext.asyncio.AsyncSession + class_: t.Union[t.Type[sa.ext.asyncio.AsyncSession], EmptyType] = Empty class BindConfig(ConfigBase): read_only: bool = False - session: SessionmakerOptions = Field(default_factory=SessionmakerOptions.default) - engine: EngineConfig = Field(default_factory=EngineConfig.default) + session: SessionmakerOptions = Field(default_factory=SessionmakerOptions) + engine: EngineConfig = Field(default_factory=EngineConfig) + track_instance: bool = False @root_validator def validate_dialect(cls, values): @@ -227,30 +262,24 @@ def validate_dialect(cls, values): class AsyncBindConfig(BindConfig): - session: AsyncSessionmakerOptions = Field(default_factory=AsyncSessionmakerOptions.default) + session: AsyncSessionmakerOptions = Field(default_factory=AsyncSessionmakerOptions) + engine: AsyncEngineConfig = Field(default_factory=AsyncEngineConfig) @root_validator def validate_dialect(cls, values): return validate_dialect(cls, values, "async") -def default(): - dict(default=dict()) - - class SQLAlchemyConfig(ConfigBase): - class Meta: - web_config_field_map = { - "SQLALCHEMY_MODEL_CLASS": "model_class", - "SQLALCHEMY_BINDS": "binds", - } + base_class: t.Type[t.Any] = Base + binds: t.Dict[str, t.Union[BindConfig, AsyncBindConfig]] = Field(default_factory=dict) - model_class: t.Type[t.Any] = Base - binds: t.Dict[str, t.Union[AsyncBindConfig, BindConfig]] = Field( - default_factory=lambda: dict(default=BindConfig()) - ) + @root_validator + def set_default_bind(cls, values): + values.setdefault("binds", dict(default=BindConfig())) + return values @classmethod - def from_framework(cls, values: t.Dict[str, t.Any]): - key_map = cls.Meta.web_config_field_map - return cls(**{key_map.get(key, key): val for key, val in values.items()}) + def from_framework(cls, framework_config): + config = framework_config.get_namespace("SQLALCHEMY_") + return cls.parse_obj(config or {}) diff --git a/src/quart_sqlalchemy/framework/cli.py b/src/quart_sqlalchemy/framework/cli.py index 13ac0af..6f48455 100644 --- a/src/quart_sqlalchemy/framework/cli.py +++ b/src/quart_sqlalchemy/framework/cli.py @@ -1,28 +1,83 @@ import json import sys +import typing as t import urllib.parse import click -from quart import current_app from quart.cli import AppGroup +from quart.cli import pass_script_info +from quart.cli import ScriptInfo + +from quart_sqlalchemy import signals + + +if t.TYPE_CHECKING: + from quart_sqlalchemy.framework import QuartSQLAlchemy db_cli = AppGroup("db") +fixtures_cli = AppGroup("fixtures") -@db_cli.command("info", with_appcontext=True) +@db_cli.command("info") +@pass_script_info @click.option("--uri-only", is_flag=True, default=False, help="Only output the connection uri") -def db_info(uri_only=False): - db = current_app.extensions["sqlalchemy"].db - uri = urllib.parse.unquote(str(db.engine.url)) - db_info = dict(db.engine.url._asdict()) +def db_info(info: ScriptInfo, uri_only=False): + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + uri = urllib.parse.unquote(str(db.bind.url)) + info = dict(db.bind.url._asdict()) if uri_only: click.echo(uri) sys.exit(0) click.echo("Database Connection Info") - click.echo(json.dumps(db_info, indent=2)) + click.echo(json.dumps(info, indent=2)) click.echo("\n") click.echo("Connection URI") click.echo(uri) + + +@db_cli.command("create") +@pass_script_info +def create(info: ScriptInfo) -> None: + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + db.create_all() + + click.echo(f"Initialized database schema for {db}") + + +@db_cli.command("drop") +@pass_script_info +def drop(info: ScriptInfo) -> None: + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + db.drop_all() + + click.echo(f"Dropped database schema for {db}") + + +@db_cli.command("recreate") +@pass_script_info +def recreate(info: ScriptInfo) -> None: + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + db.drop_all() + db.create_all() + + click.echo(f"Recreated database schema for {db}") + + +@fixtures_cli.command("load") +@pass_script_info +def load(info: ScriptInfo) -> None: + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + signals.framework_extension_load_fixtures.send(sender=db, app=app) + + click.echo(f"Loaded database fixtures for {db}") + + +db_cli.add_command(fixtures_cli) diff --git a/src/quart_sqlalchemy/framework/extension.py b/src/quart_sqlalchemy/framework/extension.py index 4d29d74..70f0737 100644 --- a/src/quart_sqlalchemy/framework/extension.py +++ b/src/quart_sqlalchemy/framework/extension.py @@ -11,10 +11,12 @@ class QuartSQLAlchemy(SQLAlchemy): def __init__( self, - config: SQLAlchemyConfig, + config: t.Optional[SQLAlchemyConfig] = None, app: t.Optional[Quart] = None, ): - super().__init__(config) + initialize = False if config is None else True + super().__init__(config, initialize=initialize) + if app is not None: self.init_app(app) @@ -24,15 +26,21 @@ def init_app(self, app: Quart) -> None: f"A {type(self).__name__} instance has already been registered on this app" ) + if self.config is None: + self.config = SQLAlchemyConfig.from_framework(app.config) + self.initialize() + signals.before_framework_extension_initialization.send(self, app=app) app.extensions["sqlalchemy"] = self @app.shell_context_processor def export_sqlalchemy_objects(): + nonlocal self + return dict( db=self, - **{m.class_.__name__: m.class_ for m in self.Model._sa_registry.mappers}, + **{m.class_.__name__: m.class_ for m in self.Base.registry.mappers}, ) app.cli.add_command(db_cli) diff --git a/src/quart_sqlalchemy/model/__init__.py b/src/quart_sqlalchemy/model/__init__.py index cf7a4ea..fce0fbf 100644 --- a/src/quart_sqlalchemy/model/__init__.py +++ b/src/quart_sqlalchemy/model/__init__.py @@ -1,7 +1,9 @@ -from .columns import CreatedTimestamp +from .columns import Created +from .columns import IntPK from .columns import Json -from .columns import PrimaryKey -from .columns import UpdatedTimestamp +from .columns import ULID +from .columns import Updated +from .columns import UUID from .custom_types import PydanticType from .custom_types import TZDateTime from .mixins import DynamicArgsMixin @@ -15,3 +17,6 @@ from .mixins import TimestampMixin from .mixins import VersionMixin from .model import Base +from .model import BaseMixins +from .model import default_metadata_naming_convention +from .model import default_type_annotation_map diff --git a/src/quart_sqlalchemy/model/columns.py b/src/quart_sqlalchemy/model/columns.py index c00d9eb..d1bfd94 100644 --- a/src/quart_sqlalchemy/model/columns.py +++ b/src/quart_sqlalchemy/model/columns.py @@ -2,6 +2,8 @@ import typing as t from datetime import datetime +from uuid import UUID +from uuid import uuid4 import sqlalchemy import sqlalchemy.event @@ -12,21 +14,24 @@ import sqlalchemy.util import sqlalchemy_utils import typing_extensions as tx +from ulid import ULID sa = sqlalchemy sau = sqlalchemy_utils +IntPK = tx.Annotated[int, sa.orm.mapped_column(primary_key=True, autoincrement=True)] +UUID = tx.Annotated[UUID, sa.orm.mapped_column(default=uuid4)] +ULID = tx.Annotated[ULID, sa.orm.mapped_column(default=ULID)] -PrimaryKey = tx.Annotated[int, sa.orm.mapped_column(sa.Identity(), primary_key=True)] -CreatedTimestamp = tx.Annotated[ +Created = tx.Annotated[ datetime, sa.orm.mapped_column( default=sa.func.now(), server_default=sa.FetchedValue(), ), ] -UpdatedTimestamp = tx.Annotated[ +Updated = tx.Annotated[ datetime, sa.orm.mapped_column( default=sa.func.now(), @@ -35,7 +40,8 @@ server_onupdate=sa.FetchedValue(), ), ] + Json = tx.Annotated[ t.Dict[t.Any, t.Any], - sa.orm.mapped_column(sau.JSONType, default_factory=dict), + sa.orm.mapped_column(sau.JSONType, default=dict), ] diff --git a/src/quart_sqlalchemy/model/mixins.py b/src/quart_sqlalchemy/model/mixins.py index 3dc3834..d08f446 100644 --- a/src/quart_sqlalchemy/model/mixins.py +++ b/src/quart_sqlalchemy/model/mixins.py @@ -10,6 +10,7 @@ import sqlalchemy.ext.asyncio import sqlalchemy.orm import sqlalchemy.util +import typing_extensions as tx from sqlalchemy.orm import Mapped from ..util import camel_to_snake_case @@ -18,14 +19,37 @@ sa = sqlalchemy +class ORMModel(tx.Protocol): + __table__: sa.Table + + +class SerializingModel(ORMModel): + __table__: sa.Table + + def to_dict( + self: ORMModel, + obj: t.Optional[t.Any] = None, + max_depth: int = 3, + _children_seen: t.Optional[set] = None, + _relations_seen: t.Optional[set] = None, + ) -> t.Dict[str, t.Any]: + ... + + class TableNameMixin: + __abstract__ = True + __table__: sa.Table + @sa.orm.declared_attr.directive - def __tablename__(cls) -> str: + def __tablename__(cls: t.Type[ORMModel]) -> str: return camel_to_snake_case(cls.__name__) class ReprMixin: - def __repr__(self) -> str: + __abstract__ = True + __table__: sa.Table + + def __repr__(self: ORMModel) -> str: state = sa.inspect(self) if state is None: return super().__repr__() @@ -41,7 +65,10 @@ def __repr__(self) -> str: class ComparableMixin: - def __eq__(self, other): + __abstract__ = True + __table__: sa.Table + + def __eq__(self: ORMModel, other: ORMModel) -> bool: if type(self).__name__ != type(other).__name__: return False @@ -55,37 +82,38 @@ def __eq__(self, other): class TotalOrderMixin: - def __lt__(self, other): - if type(self).__name__ != type(other).__name__: - return False + __abstract__ = True + __table__: sa.Table - for key, column in sa.inspect(type(self)).columns.items(): - if column.primary_key: - continue + def __lt__(self: ORMModel, other: ORMModel) -> bool: + if type(self).__name__ != type(other).__name__: + raise NotImplemented - if not (getattr(self, key) == getattr(other, key)): - return False - return True + primary_keys = sa.inspect(type(self)).primary_key + self_keys = [getattr(self, col.name) for col in primary_keys] + other_keys = [getattr(other, col.name) for col in primary_keys] + return self_keys < other_keys class SimpleDictMixin: __abstract__ = True __table__: sa.Table - def to_dict(self): + def to_dict(self) -> t.Dict[str, t.Any]: return {c.name: getattr(self, c.name) for c in self.__table__.columns} class RecursiveDictMixin: __abstract__ = True + __table__: sa.Table def to_dict( - self, + self: tx.Self, obj: t.Optional[t.Any] = None, - max_depth: int = 3, + max_depth: int = 1, _children_seen: t.Optional[set] = None, _relations_seen: t.Optional[set] = None, - ): + ) -> t.Dict[str, t.Any]: """Convert model to python dict, with recursion. Args: @@ -106,11 +134,7 @@ def to_dict( mapper = sa.inspect(obj).mapper columns = [column.key for column in mapper.columns] - get_key_value = ( - lambda c: (c, getattr(obj, c).isoformat()) - if isinstance(getattr(obj, c), datetime) - else (c, getattr(obj, c)) - ) + get_key_value = lambda c: (c, getattr(obj, c)) data = dict(map(get_key_value, columns)) if max_depth > 0: @@ -125,10 +149,12 @@ def to_dict( if relationship_children is not None: if relation.uselist: children = [] - for child in (c for c in relationship_children if c not in _children_seen): - _children_seen.add(child) + for child in ( + c for c in relationship_children if repr(c) not in _children_seen + ): + _children_seen.add(repr(child)) children.append( - self.model_to_dict( + self.to_dict( child, max_depth=max_depth - 1, _children_seen=_children_seen, @@ -137,7 +163,7 @@ def to_dict( ) data[name] = children else: - data[name] = self.model_to_dict( + data[name] = self.to_dict( relationship_children, max_depth=max_depth - 1, _children_seen=_children_seen, @@ -148,6 +174,9 @@ def to_dict( class IdentityMixin: + __abstract__ = True + __table__: sa.Table + id: Mapped[int] = sa.orm.mapped_column(sa.Identity(), primary_key=True, autoincrement=True) @@ -191,21 +220,29 @@ class User(db.Model, SoftDeleteMixin): """ __abstract__ = True + __table__: sa.Table is_active: Mapped[bool] = sa.orm.mapped_column(default=True) class TimestampMixin: __abstract__ = True + __table__: sa.Table - created_at: Mapped[datetime] = sa.orm.mapped_column(default=sa.func.now()) + created_at: Mapped[datetime] = sa.orm.mapped_column( + default=sa.func.now(), server_default=sa.FetchedValue() + ) updated_at: Mapped[datetime] = sa.orm.mapped_column( - default=sa.func.now(), onupdate=sa.func.now() + default=sa.func.now(), + onupdate=sa.func.now(), + server_default=sa.FetchedValue(), + server_onupdate=sa.FetchedValue(), ) class VersionMixin: __abstract__ = True + __table__: sa.Table version_id: Mapped[int] = sa.orm.mapped_column(nullable=False) @@ -222,6 +259,7 @@ class EagerDefaultsMixin: """ __abstract__ = True + __table__: sa.Table @sa.orm.declared_attr.directive def __mapper_args__(cls) -> dict[str, t.Any]: @@ -289,6 +327,7 @@ def accumulate_tuples_with_mapping(class_, attribute) -> t.Sequence[t.Any]: class DynamicArgsMixin: __abstract__ = True + __table__: sa.Table @sa.orm.declared_attr.directive def __mapper_args__(cls) -> t.Dict[str, t.Any]: diff --git a/src/quart_sqlalchemy/model/model.py b/src/quart_sqlalchemy/model/model.py index bbcdc21..313b6b3 100644 --- a/src/quart_sqlalchemy/model/model.py +++ b/src/quart_sqlalchemy/model/model.py @@ -1,6 +1,7 @@ from __future__ import annotations import enum +import uuid import sqlalchemy import sqlalchemy.event @@ -10,21 +11,49 @@ import sqlalchemy.orm import sqlalchemy.util import typing_extensions as tx +from sqlalchemy_utils import JSONType from .mixins import ComparableMixin from .mixins import DynamicArgsMixin +from .mixins import EagerDefaultsMixin +from .mixins import RecursiveDictMixin from .mixins import ReprMixin -from .mixins import SimpleDictMixin from .mixins import TableNameMixin +from .mixins import TotalOrderMixin sa = sqlalchemy - -class Base(DynamicArgsMixin, ReprMixin, SimpleDictMixin, ComparableMixin, TableNameMixin): +default_metadata_naming_convention = { + "ix": "ix_%(column_0_label)s", # INDEX + "uq": "uq_%(table_name)s_%(column_0_N_name)s", # UNIQUE + "ck": "ck_%(table_name)s_%(constraint_name)s", # CHECK + "fk": "fk_%(table_name)s_%(column_0_N_name)s_%(referred_table_name)s", # FOREIGN KEY + "pk": "pk_%(table_name)s", # PRIMARY KEY +} + +default_type_annotation_map = { + enum.Enum: sa.Enum(enum.Enum, native_enum=False, validate_strings=True), + tx.Literal: sa.Enum(enum.Enum, native_enum=False, validate_strings=True), + uuid.UUID: sa.Uuid, + dict: JSONType, +} + + +class BaseMixins( + DynamicArgsMixin, + EagerDefaultsMixin, + ReprMixin, + RecursiveDictMixin, + TotalOrderMixin, + ComparableMixin, + TableNameMixin, +): __abstract__ = True + __table__: sa.Table + - type_annotation_map = { - enum.Enum: sa.Enum(enum.Enum, native_enum=False, validate_strings=True), - tx.Literal: sa.Enum(enum.Enum, native_enum=False, validate_strings=True), - } +class Base(BaseMixins, sa.orm.DeclarativeBase): + __abstract__ = True + metadata = sa.MetaData(naming_convention=default_metadata_naming_convention) + type_annotation_map = default_type_annotation_map diff --git a/src/quart_sqlalchemy/session.py b/src/quart_sqlalchemy/session.py index 4833a5c..36db6c3 100644 --- a/src/quart_sqlalchemy/session.py +++ b/src/quart_sqlalchemy/session.py @@ -1,6 +1,9 @@ from __future__ import annotations import typing as t +from contextlib import contextmanager +from contextvars import ContextVar +from functools import wraps import sqlalchemy import sqlalchemy.exc @@ -18,6 +21,56 @@ sa = sqlalchemy +""" +Requirements: + * a global context var session + * a context manager that sets the session value and manages its lifetime + * a factory that will always return the current session value + * a decorator that will inject the current session value +""" + +_global_contextual_session = ContextVar("_global_contextual_session") + + +@contextmanager +def set_global_contextual_session(session, bind=None): + token = _global_contextual_session.set(session) + try: + yield + finally: + _global_contextual_session.reset(token) + + +def provide_global_contextual_session(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + session_in_args = any( + [isinstance(arg, (sa.orm.Session, sa.ext.asyncio.AsyncSession)) for arg in args] + ) + session_in_kwargs = "session" in kwargs + session_provided = session_in_args or session_in_kwargs + + if session_provided: + return func(self, *args, **kwargs) + else: + session = session_proxy() + + return func(self, session, *args, **kwargs) + + return wrapper + + +class SessionProxy: + def __call__(self) -> t.Union[sa.orm.Session, sa.ext.asyncio.AsyncSession]: + return _global_contextual_session.get() + + def __getattr__(self, name): + return getattr(self(), name) + + +session_proxy = SessionProxy() + + class Session(sa.orm.Session, t.Generic[EntityT, EntityIdT]): """A SQLAlchemy :class:`~sqlalchemy.orm.Session` class. diff --git a/src/quart_sqlalchemy/signals.py b/src/quart_sqlalchemy/signals.py index 3015b16..ecef762 100644 --- a/src/quart_sqlalchemy/signals.py +++ b/src/quart_sqlalchemy/signals.py @@ -113,3 +113,34 @@ def handle(sender: QuartSQLAlchemy, app: Quart): ... """, ) + + +framework_extension_load_fixtures = sync_signals.signal( + "quart-sqlalchemy.framework.extension.fixtures.load", + doc="""Fired to load fixtures into a fresh database. + + No default signal handlers exist for this signal as the logic is very application dependent. + This signal handler is typically triggered using the CLI: + + $ quart db fixtures load + + Example: + + @signals.framework_extension_load_fixtures.connect + def handle(sender: QuartSQLAlchemy, app: Quart): + db = sender.get_bind("default") + with db.Session() as session: + with session.begin(): + session.add_all( + [ + models.User(username="user1"), + models.User(username="user2"), + ] + ) + session.commit() + + Handler signature: + def handle(sender: QuartSQLAlchemy, app: Quart): + ... + """, +) diff --git a/src/quart_sqlalchemy/sim/app.py b/src/quart_sqlalchemy/sim/app.py index 9f12d22..c138f42 100644 --- a/src/quart_sqlalchemy/sim/app.py +++ b/src/quart_sqlalchemy/sim/app.py @@ -1,88 +1,41 @@ import logging import typing as t from copy import deepcopy -from functools import wraps -from quart import g from quart import Quart -from quart import request -from quart import Response -from quart.typing import ResponseReturnValue -from quart_schema import APIKeySecurityScheme -from quart_schema import HttpSecurityScheme from quart_schema import QuartSchema from werkzeug.utils import import_string +from .config import settings +from .container import Container + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -BLUEPRINTS = ("quart_sqlalchemy.sim.views.api",) -EXTENSIONS = ( - "quart_sqlalchemy.sim.db.db", - "quart_sqlalchemy.sim.app.schema", - "quart_sqlalchemy.sim.auth.auth", -) - -DEFAULT_CONFIG = { - "QUART_AUTH_SECURITY_SCHEMES": { - "public-api-key": APIKeySecurityScheme(in_="header", name="X-Public-API-Key"), - "session-token-bearer": HttpSecurityScheme(scheme="bearer", bearer_format="opaque"), - }, - "REGISTER_BLUEPRINTS": ["quart_sqlalchemy.sim.views.api"], -} - - -schema = QuartSchema(security_schemes=DEFAULT_CONFIG["QUART_AUTH_SECURITY_SCHEMES"]) - - -def wrap_response(func: t.Callable) -> t.Callable: - @wraps(func) - async def decorator(result: ResponseReturnValue) -> Response: - # import pdb +schema = QuartSchema(security_schemes=settings.SECURITY_SCHEMES) - # pdb.set_trace() - return await func(result) - return decorator - - -def create_app( - override_config: t.Optional[t.Dict[str, t.Any]] = None, - extensions: t.Sequence[str] = EXTENSIONS, - blueprints: t.Sequence[str] = BLUEPRINTS, -): +def create_app(override_config: t.Optional[t.Dict[str, t.Any]] = None): override_config = override_config or {} - config = deepcopy(DEFAULT_CONFIG) + config = deepcopy(settings.dict()) config.update(override_config) app = Quart(__name__) app.config.from_mapping(config) + app.config.from_prefixed_env() - for path in extensions: + for path in app.config["LOAD_EXTENSIONS"]: extension = import_string(path) extension.init_app(app) - for path in blueprints: + for path in app.config["LOAD_BLUEPRINTS"]: bp = import_string(path) app.register_blueprint(bp) - @app.before_request - def set_ethereum_network(): - g.network = request.headers.get("X-Ethereum-Network", "GOERLI").upper() - - # app.make_response = wrap_response(app.make_response) + container = Container(app=app) + app.container = container return app - - -# @app.after_request -# async def add_json_response_envelope(response: Response) -> Response: -# if response.mimetype != "application/json": -# return response -# data = await response.get_json() -# payload = dict(status="ok", message="", data=data) -# response.set_data(json.dumps(payload)) -# return response diff --git a/src/quart_sqlalchemy/sim/auth.py b/src/quart_sqlalchemy/sim/auth.py index 80a908a..dece2ba 100644 --- a/src/quart_sqlalchemy/sim/auth.py +++ b/src/quart_sqlalchemy/sim/auth.py @@ -24,7 +24,9 @@ from werkzeug.exceptions import Forbidden from .model import AuthUser +from .model import EntityType from .model import MagicClient +from .model import Provenance from .schema import BaseSchema from .util import ObjectID @@ -195,38 +197,12 @@ def validate_security( yield scheme_credentials -# def convert_model_result(func: t.Callable) -> t.Callable: -# @wraps(func) -# async def decorator(result: ResponseReturnValue) -> Response: -# status_or_headers = None -# headers = None -# if isinstance(result, tuple): -# value, status_or_headers, headers = result + (None,) * (3 - len(result)) -# else: -# value = result - -# was_model = False -# if is_dataclass(value): -# dict_or_value = asdict(value) -# was_model = True -# elif isinstance(value, BaseModel): -# dict_or_value = value.dict(by_alias=True) -# was_model = True -# else: -# dict_or_value = value - -# if was_model: -# dict_or_value = camelize(dict_or_value) - -# return await func((dict_or_value, status_or_headers, headers)) - -# return decorator - - class QuartAuth: authenticator = RequestAuthenticator() - def __init__(self, app: t.Optional[Quart] = None): + def __init__(self, app: t.Optional[Quart] = None, bind_name: str = "default"): + self.bind_name = bind_name + if app is not None: self.init_app(app) @@ -238,6 +214,8 @@ def init_app(self, app: Quart): self.security_schemes = app.config.get("QUART_AUTH_SECURITY_SCHEMES", {}) app.cli.add_command(cli) + app.extensions["auth"] = self + def auth_endpoint_security(self): db = current_app.extensions.get("sqlalchemy") view_function = current_app.view_functions[request.endpoint] @@ -245,7 +223,8 @@ def auth_endpoint_security(self): if security_schemes is None: g.authorized_credentials = {} - with db.bind.Session() as session: + bind = db.get_bind(self.bind_name) + with bind.Session() as session: results = self.authenticator.enforce(security_schemes, session) authorized_credentials = {} for result in results: @@ -253,9 +232,17 @@ def auth_endpoint_security(self): g.authorized_credentials = authorized_credentials -from .model import EntityType -from .model import MagicClient -from .model import Provenance +class RequestCredentials: + def __init__(self, request): + self.request = request + + @property + def current_user(self): + return g.authorized_credentials.get("session-token-bearer") + + @property + def current_client(self): + return g.authorized_credentials.get("public-api-key") @cli.command("add-user") @@ -282,8 +269,9 @@ def auth_endpoint_security(self): def add_user(info: ScriptInfo, email: str, user_type: str, client_id: str) -> None: app = info.load_app() db = app.extensions.get("sqlalchemy") - - with db.bind.Session() as s: + auth = app.extensions.get("auth") + bind = db.get_bind(auth.bind_name) + with bind.Session() as s: with s.begin(): user = AuthUser( email=email, @@ -310,7 +298,9 @@ def add_user(info: ScriptInfo, email: str, user_type: str, client_id: str) -> No def add_client(info: ScriptInfo, name: str) -> None: app = info.load_app() db = app.extensions.get("sqlalchemy") - with db.bind.Session() as s: + auth = app.extensions.get("auth") + bind = db.get_bind(auth.bind_name) + with bind.Session() as s: with s.begin(): client = MagicClient(app_name=name, public_api_key=secrets.token_hex(16)) s.add(client) diff --git a/src/quart_sqlalchemy/sim/builder.py b/src/quart_sqlalchemy/sim/builder.py index ccaffeb..5aa96d6 100644 --- a/src/quart_sqlalchemy/sim/builder.py +++ b/src/quart_sqlalchemy/sim/builder.py @@ -19,12 +19,12 @@ class StatementBuilder(t.Generic[EntityT]): - model: t.Type[EntityT] + model: t.Optional[t.Type[EntityT]] - def __init__(self, model: t.Type[EntityT]): + def __init__(self, model: t.Optional[t.Type[EntityT]] = None): self.model = model - def complex_select( + def select( self, selectables: t.Sequence[Selectable] = (), conditions: t.Sequence[ColumnExpr] = (), diff --git a/src/quart_sqlalchemy/sim/commands.py b/src/quart_sqlalchemy/sim/commands.py new file mode 100644 index 0000000..c2a6567 --- /dev/null +++ b/src/quart_sqlalchemy/sim/commands.py @@ -0,0 +1,52 @@ +import asyncio +import sys + +import click +import IPython +from IPython.terminal.ipapp import load_default_config +from quart import current_app + + +def attach(app): + app.shell_context_processor(app_env) + app.cli.command( + with_appcontext=True, + context_settings=dict( + ignore_unknown_options=True, + ), + )(ishell) + + +def app_env(): + app = current_app + return dict(container=app.container) + + +@click.argument("ipython_args", nargs=-1, type=click.UNPROCESSED) +def ishell(ipython_args): + import nest_asyncio + + nest_asyncio.apply() + + config = load_default_config() + + asyncio.run(current_app.startup()) + + context = current_app.make_shell_context() + + config.TerminalInteractiveShell.banner1 = """Python %s on %s +IPython: %s +App: %s [%s] +""" % ( + sys.version, + sys.platform, + IPython.__version__, + current_app.import_name, + current_app.env, + ) + + IPython.start_ipython( + argv=ipython_args, + user_ns=context, + config=config, + ) diff --git a/src/quart_sqlalchemy/sim/config.py b/src/quart_sqlalchemy/sim/config.py new file mode 100644 index 0000000..d743ee7 --- /dev/null +++ b/src/quart_sqlalchemy/sim/config.py @@ -0,0 +1,55 @@ +import typing as t + +import sqlalchemy +from pydantic import BaseSettings +from pydantic import Field +from pydantic import PyObject +from quart_schema import APIKeySecurityScheme +from quart_schema import HttpSecurityScheme +from quart_schema.openapi import SecuritySchemeBase + +from quart_sqlalchemy import AsyncBindConfig +from quart_sqlalchemy import BindConfig +from quart_sqlalchemy.sim.db import MyBase + + +sa = sqlalchemy + + +class AppSettings(BaseSettings): + class Config: + env_file = ".env", ".secrets.env" + + LOAD_BLUEPRINTS: t.List[str] = Field( + default_factory=lambda: list(("quart_sqlalchemy.sim.views.api",)) + ) + LOAD_EXTENSIONS: t.List[str] = Field( + default_factory=lambda: list( + ( + "quart_sqlalchemy.sim.db.db", + "quart_sqlalchemy.sim.app.schema", + "quart_sqlalchemy.sim.auth.auth", + ) + ) + ) + SECURITY_SCHEMES: t.Dict[str, SecuritySchemeBase] = Field( + default_factory=lambda: { + "public-api-key": APIKeySecurityScheme(in_="header", name="X-Public-API-Key"), + "session-token-bearer": HttpSecurityScheme(scheme="bearer", bearer_format="opaque"), + } + ) + + SQLALCHEMY_BINDS: t.Dict[str, t.Union[AsyncBindConfig, BindConfig]] = Field( + default_factory=lambda: dict(default=BindConfig(engine=dict(url="sqlite:///app.db"))) + ) + SQLALCHEMY_BASE_CLASS: t.Type[t.Any] = Field(default=MyBase) + + WEB3_DEFAULT_CHAIN: str = Field(default="ethereum") + WEB3_DEFAULT_NETWORK: str = Field(default="goerli") + + WEB3_PROVIDER_CLASS: PyObject = Field("web3.providers.HTTPProvider", env="WEB3_PROVIDER_CLASS") + ALCHEMY_API_KEY: str = Field(env="ALCHEMY_API_KEY") + WEB3_HTTPS_PROVIDER_URI: str = Field(env="WEB3_HTTPS_PROVIDER_URI") + + +settings = AppSettings() diff --git a/src/quart_sqlalchemy/sim/container.py b/src/quart_sqlalchemy/sim/container.py new file mode 100644 index 0000000..91bb358 --- /dev/null +++ b/src/quart_sqlalchemy/sim/container.py @@ -0,0 +1,57 @@ +import typing as t + +import sqlalchemy.orm +from dependency_injector import containers +from dependency_injector import providers +from quart import request + +from quart_sqlalchemy.session import SessionProxy +from quart_sqlalchemy.sim.auth import RequestCredentials +from quart_sqlalchemy.sim.handle import AuthUserHandler +from quart_sqlalchemy.sim.handle import AuthWalletHandler +from quart_sqlalchemy.sim.handle import MagicClientHandler +from quart_sqlalchemy.sim.logic import LogicComponent + +from .config import AppSettings +from .web3 import Web3 +from .web3 import web3_node_factory + + +sa = sqlalchemy + + +def get_db_from_app(app): + return app.extensions["sqlalchemy"] + + +class Container(containers.DeclarativeContainer): + wiring_config = containers.WiringConfiguration( + modules=[ + "quart_sqlalchemy.sim.views", + "quart_sqlalchemy.sim.logic", + "quart_sqlalchemy.sim.handle", + "quart_sqlalchemy.sim.views.auth_wallet", + "quart_sqlalchemy.sim.views.auth_user", + "quart_sqlalchemy.sim.views.magic_client", + ] + ) + config = providers.Configuration(pydantic_settings=[AppSettings()]) + app = providers.Object() + db = providers.Singleton(get_db_from_app, app=app) + + session_factory = providers.Singleton(SessionProxy) + logic = providers.Singleton(LogicComponent) + + AuthUserHandler = providers.Singleton(AuthUserHandler) + MagicClientHandler = providers.Singleton(MagicClientHandler) + AuthWalletHandler = providers.Singleton(AuthWalletHandler) + + web3_node = providers.Singleton(web3_node_factory, config=config) + web3 = providers.Singleton( + Web3, + node=web3_node, + default_network=config.WEB3_DEFAULT_NETWORK, + default_chain=config.WEB3_DEFAULT_CHAIN, + ) + current_request = providers.Factory(lambda: request) + request_credentials = providers.Singleton(RequestCredentials, request=current_request) diff --git a/src/quart_sqlalchemy/sim/db.py b/src/quart_sqlalchemy/sim/db.py index e7677f6..9634815 100644 --- a/src/quart_sqlalchemy/sim/db.py +++ b/src/quart_sqlalchemy/sim/db.py @@ -1,22 +1,41 @@ import click -import sqlalchemy as sa -from quart import g -from quart import request +import sqlalchemy +import sqlalchemy.orm from quart.cli import AppGroup from quart.cli import pass_script_info from quart.cli import ScriptInfo from sqlalchemy.types import Integer from sqlalchemy.types import TypeDecorator -from quart_sqlalchemy import Base from quart_sqlalchemy import SQLAlchemyConfig from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.model import BaseMixins from quart_sqlalchemy.sim.util import ObjectID +sa = sqlalchemy cli = AppGroup("db-schema") +def init_fixtures(session): + """Initialize the database with some fixtures.""" + from quart_sqlalchemy.sim.model import AuthUser + from quart_sqlalchemy.sim.model import MagicClient + + client = MagicClient( + app_name="My App", + public_api_key="4700aed5ee9f76f7be6398cd4b00b586", + auth_users=[ + AuthUser( + email="joe@magic.link", + current_session_token="97ee741d53e11a490460927c8a2ce4a3", + ), + ], + ) + session.add(client) + session.flush() + + class ObjectIDType(TypeDecorator): """A custom database column type that converts integer value to our ObjectID. This allows us to pass around ObjectID type in the application for easy @@ -47,24 +66,11 @@ def process_result_value(self, value, dialect): return ObjectID(value) -class MyBase(Base): +class MyBase(BaseMixins, sa.orm.DeclarativeBase): + __abstract__ = True type_annotation_map = {ObjectID: ObjectIDType} -class MyQuartSQLAlchemy(QuartSQLAlchemy): - def init_app(self, app): - super().init_app(app) - - @app.before_request - def set_bind(): - if request.method in ["GET", "OPTIONS", "HEAD", "TRACE"]: - g.bind = self.get_bind("read-replica") - else: - g.bind = self.get_bind("default") - - app.cli.add_command(cli) - - @cli.command("load") @pass_script_info def schema_load(info: ScriptInfo) -> None: @@ -76,10 +82,10 @@ def schema_load(info: ScriptInfo) -> None: # sqlite:///file:mem.db?mode=memory&cache=shared&uri=true -db = MyQuartSQLAlchemy( +db = QuartSQLAlchemy( SQLAlchemyConfig.parse_obj( { - "model_class": MyBase, + "base_class": MyBase, "binds": { "default": { "engine": {"url": "sqlite:///file:sim.db?cache=shared&uri=true"}, diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py index e079da4..47d0cf3 100644 --- a/src/quart_sqlalchemy/sim/handle.py +++ b/src/quart_sqlalchemy/sim/handle.py @@ -1,10 +1,17 @@ +from __future__ import annotations + import logging +import secrets import typing as t from datetime import datetime -from sqlalchemy.orm import Session +import sqlalchemy +from dependency_injector.wiring import Provide +from quart import Quart -from quart_sqlalchemy.sim.logic import LogicComponent as Logic +from quart_sqlalchemy.session import SessionProxy +from quart_sqlalchemy.sim import signals +from quart_sqlalchemy.sim.logic import LogicComponent from quart_sqlalchemy.sim.model import AuthUser from quart_sqlalchemy.sim.model import AuthWallet from quart_sqlalchemy.sim.model import EntityType @@ -12,11 +19,17 @@ from quart_sqlalchemy.sim.util import ObjectID +sa = sqlalchemy + logger = logging.getLogger(__name__) CLIENTS_PER_API_USER_LIMIT = 50 +def get_product_type_by_client_id(_): + return EntityType.MAGIC.value + + class MaxClientsExceeded(Exception): pass @@ -29,34 +42,21 @@ class InvalidSubstringError(AuthUserBaseError): pass -class APIKeySet(t.NamedTuple): - public_key: str - secret_key: str - - class HandlerBase: - logic: Logic - session: Session - """The base class for all handler classes. It provides handler with access - to our logic object. - """ - - def __init__(self, session: t.Optional[Session], logic: t.Optional[Logic] = None): - self.session = session - self.logic = logic or Logic() - - -def get_product_type_by_client_id(client_id): - return EntityType.MAGIC.value + logic: LogicComponent = Provide["logic"] + session_factory = SessionProxy() class MagicClientHandler(HandlerBase): + auth_user_handler: AuthUserHandler = Provide["AuthUserHandler"] + def add( self, - magic_api_user_id, - magic_team_id, app_name=None, - is_magic_connect_enabled=False, + rate_limit_tier=None, + connect_interop=None, + is_signing_modal_enabled=False, + global_audience_enabled=False, ): """Registers a new client. @@ -66,70 +66,21 @@ def add( Returns: A ``MagicClient``. """ - magic_clients_count = self.logic.MagicClientAPIUser.count_by_magic_api_user_id( - self.session, - magic_api_user_id, - ) - - if magic_clients_count >= CLIENTS_PER_API_USER_LIMIT: - raise MaxClientsExceeded() - - return self.add_client( - magic_api_user_id, - magic_team_id, - app_name, - is_magic_connect_enabled, - ) - - def get_by_public_api_key(self, public_api_key): - return self.logic.MagicClientAPIKey.get_by_public_api_key(self.session, public_api_key) - - def add_client( - self, - magic_api_user_id, - magic_team_id, - app_name=None, - is_magic_connect_enabled=False, - ): - live_api_key = APIKeySet(public_key="xxx", secret_key="yyy") - # with self.logic.begin(ro=False) as session: - return self.logic.MagicClient._add( - self.session, + return self.logic.MagicClient.add( + self.session_factory(), app_name=app_name, + rate_limit_tier=rate_limit_tier, + connect_interop=connect_interop, + is_signing_modal_enabled=is_signing_modal_enabled, + global_audience_enabled=global_audience_enabled, ) - # self.logic.MagicClientAPIKey._add( - # session, - # magic_client.id, - # live_api_key_pair=live_api_key, - # ) - # self.logic.MagicClientAPIUser._add( - # session, - # magic_api_user_id, - # magic_client.id, - # ) - - # self.logic.MagicClientAuthMethods._add( - # session, - # magic_client_id=magic_client.id, - # is_magic_connect_enabled=is_magic_connect_enabled, - # is_metamask_wallet_enabled=(True if is_magic_connect_enabled else False), - # is_wallet_connect_enabled=(True if is_magic_connect_enabled else False), - # is_coinbase_wallet_enabled=(True if is_magic_connect_enabled else False), - # ) - - # self.logic.MagicClientTeam._add(session, magic_client.id, magic_team_id) - - # return magic_client, live_api_key - - def get_magic_api_user_id_by_client_id(self, magic_client_id): - return self.logic.MagicClient.get_magic_api_user_id_by_client_id( - self.session, magic_client_id - ) + def get_by_public_api_key(self, public_api_key): + return self.logic.MagicClient.get_by_public_api_key(self.session_factory(), public_api_key) def get_by_id(self, magic_client_id): - return self.logic.MagicClient.get_by_id(self.session, magic_client_id) + return self.logic.MagicClient.get_by_id(self.session_factory(), magic_client_id) def update_app_name_by_id(self, magic_client_id, app_name): """ @@ -142,7 +93,7 @@ def update_app_name_by_id(self, magic_client_id, app_name): app_name if update was successful """ client = self.logic.MagicClient.update_by_id( - self.session, magic_client_id, app_name=app_name + self.session_factory(), magic_client_id, app_name=app_name ) if not client: @@ -151,7 +102,9 @@ def update_app_name_by_id(self, magic_client_id, app_name): return client.app_name def update_by_id(self, magic_client_id, **kwargs): - client = self.logic.MagicClient.update_by_id(self.session, magic_client_id, **kwargs) + client = self.logic.MagicClient.update_by_id( + self.session_factory(), magic_client_id, **kwargs + ) return client @@ -163,108 +116,38 @@ def set_inactive_by_id(self, magic_client_id): Returns: None """ - self.logic.MagicClient.update_by_id(self.session, magic_client_id, is_active=False) + self.logic.MagicClient.update_by_id( + self.session_factory(), magic_client_id, is_active=False + ) def get_users_for_client( self, magic_client_id, offset=None, limit=None, - include_count=False, ): """ Returns emails and signup timestamps for all auth users belonging to a given client """ - auth_user_handler = AuthUserHandler(session=self.session) - product_type = get_product_type_by_client_id(magic_client_id) - auth_users = auth_user_handler.get_by_client_id_and_user_type( - magic_client_id, - product_type, - offset=offset, - limit=limit, - ) - - # Here we blindly load from oauth users table because we only provide - # two login methods right now. If not email link then it is oauth. - # TODO(ajen#ch22926|2020-08-14): rely on the `login_method` column to - # deterministically load from correct source (oauth, webauthn, etc.). - # emails_from_oauth = OAuthUserHandler().get_emails_by_auth_user_ids( - # [auth_user.id for auth_user in auth_users if auth_user.email is None], - # ) - - data = { - "users": [ - dict(email=u.email or "none", signup_ts=int(datetime.timestamp(u.time_created))) - for u in auth_users - ] - } - - if include_count: - data["count"] = auth_user_handler.get_user_count_by_client_id_and_user_type( - magic_client_id, - product_type, - ) - - return data - - def get_users_for_client_v2( - self, - magic_client_id, - offset=None, - limit=None, - include_count=False, - ): - """ - Returns emails, signup timestamps, provenance and MFA enablement for all auth users - belonging to a given client. - """ - auth_user_handler = AuthUserHandler(session=self.session) product_type = get_product_type_by_client_id(magic_client_id) - auth_users = auth_user_handler.get_by_client_id_and_user_type( + auth_users = self.auth_user_handler.get_by_client_id_and_user_type( magic_client_id, product_type, offset=offset, limit=limit, ) - data = { + return { "users": [ dict(email=u.email or "none", signup_ts=int(datetime.timestamp(u.time_created))) for u in auth_users ] } - if include_count: - data["count"] = auth_user_handler.get_user_count_by_client_id_and_user_type( - magic_client_id, - product_type, - ) - - return data - - # def get_user_logins_for_client(self, magic_client_id, limit=None): - # logins = AuthUserLoginHandler().get_logins_by_magic_client_id( - # magic_client_id, - # limit=limit or 20, - # ) - # user_logins = get_user_logins_response(logins) - - # return sorted( - # user_logins, - # key=lambda x: x["login_ts"], - # reverse=True, - # )[:limit] - class AuthUserHandler(HandlerBase): - # auth_user_mfa_handler: AuthUserMfaHandler - - def __init__(self, *args, auth_user_mfa_handler=None, **kwargs): - super().__init__(*args, **kwargs) - # self.auth_user_mfa_handler = auth_user_mfa_handler or AuthUserMfaHandler() - def get_by_session_token(self, session_token): - return self.logic.AuthUser.get_by_session_token(self.session, session_token) + return self.logic.AuthUser.get_by_session_token(self.session_factory(), session_token) def get_or_create_by_email_and_client_id( self, @@ -272,41 +155,27 @@ def get_or_create_by_email_and_client_id( client_id, user_type=EntityType.MAGIC.value, ): - auth_user = self.logic.AuthUser.get_by_email_and_client_id( - self.session, - email, - client_id, - user_type=user_type, - ) - if not auth_user: - # try: - # email = enhanced_email_validation( - # email, - # source=MAGIC, - # # So we don't affect sign-up. - # silence_network_error=True, - # ) - # except ( - # EnhanceEmailValidationError, - # EnhanceEmailSuggestionError, - # ) as e: - # logger.warning( - # "Email Start Attempt.", - # exc_info=True, - # ) - # raise EnhancedEmailValidation(error_message=str(e)) from e - - auth_user = self.logic.AuthUser.add_by_email_and_client_id( - self.session, + session = self.session_factory() + with session.begin_nested(): + auth_user = self.logic.AuthUser.get_by_email_and_client_id( + session, + email, client_id, - email=email, user_type=user_type, + for_update=True, ) + if not auth_user: + auth_user = self.logic.AuthUser.add_by_email_and_client_id( + session, + client_id, + email=email, + user_type=user_type, + ) return auth_user def get_by_id_and_validate_exists(self, auth_user_id): """This function helps formalize how a non-existent auth user should be handled.""" - auth_user = self.logic.AuthUser.get_by_id(self.session, auth_user_id) + auth_user = self.logic.AuthUser.get_by_id(self.session_factory(), auth_user_id) if auth_user is None: raise RuntimeError('resource_name="auth_user"') return auth_user @@ -319,33 +188,33 @@ def create_verified_user( client_id, email, user_type=EntityType.FORTMATIC.value, + **kwargs, ): + # with self.session_factory() as session: # with self.logic.begin(ro=False) as session: - auid = self.logic.AuthUser._add_by_email_and_client_id( - self.session, - client_id, - email, - user_type=user_type, - ).id - auth_user = self.logic.AuthUser._update_by_id( - self.session, - auid, - date_verified=datetime.utcnow(), - ) - - return auth_user + session = self.session_factory() + with session.begin_nested(): + auid = self.logic.AuthUser.add_by_email_and_client_id( + session, + client_id, + email, + user_type=user_type, + **kwargs, + ).id - # def get_auth_user_from_public_address(self, public_address): - # wallet = self.logic.AuthWallet.get_by_public_address(public_address) + session.flush() - # if not wallet: - # return None + auth_user = self.logic.AuthUser.update_by_id( + session, + auid, + date_verified=datetime.utcnow(), + current_session_token=secrets.token_hex(16), + ) - # return self.logic.AuthUser.get_by_id(wallet.auth_user_id) + return auth_user - def get_by_id(self, auth_user_id, load_mfa_methods=False) -> AuthUser: - # join_list = ["mfa_methods"] if load_mfa_methods else None - return self.logic.AuthUser.get_by_id(self.session, auth_user_id) + def get_by_id(self, auth_user_id) -> AuthUser: + return self.logic.AuthUser.get_by_id(self.session_factory(), auth_user_id) def get_by_client_id_and_user_type( self, @@ -354,21 +223,13 @@ def get_by_client_id_and_user_type( offset=None, limit=None, ): - if user_type == EntityType.CONNECT.value: - return self.logic.AuthUser.get_by_client_id_for_connect( - self.session, - client_id, - offset=offset, - limit=limit, - ) - else: - return self.logic.AuthUser.get_by_client_id_and_user_type( - self.session, - client_id, - user_type, - offset=offset, - limit=limit, - ) + return self.logic.AuthUser.get_by_client_id_and_user_type( + self.session_factory(), + client_id, + user_type, + offset=offset, + limit=limit, + ) def get_by_client_ids_and_user_type( self, @@ -378,43 +239,32 @@ def get_by_client_ids_and_user_type( limit=None, ): return self.logic.AuthUser.get_by_client_ids_and_user_type( - self.session, + self.session_factory(), client_ids, user_type, offset=offset, limit=limit, ) - def get_user_count_by_client_id_and_user_type(self, client_id, user_type): - if user_type == EntityType.CONNECT.value: - return self.logic.AuthUser.get_user_count_by_client_id_for_connect( - self.session, - client_id, - ) - else: - return self.logic.AuthUser.get_user_count_by_client_id_and_user_type( - self.session, - client_id, - user_type, - ) - def exist_by_email_client_id_and_user_type(self, email, client_id, user_type): return self.logic.AuthUser.exist_by_email_and_client_id( - self.session, + self.session_factory(), email, client_id, user_type=user_type, ) def update_email_by_id(self, model_id, email): - return self.logic.AuthUser.update_by_id(self.session, model_id, email=email) + return self.logic.AuthUser.update_by_id(self.session_factory(), model_id, email=email) def update_phone_number_by_id(self, model_id, phone_number): - return self.logic.AuthUser.update_by_id(self.session, model_id, phone_number=phone_number) + return self.logic.AuthUser.update_by_id( + self.session_factory(), model_id, phone_number=phone_number + ) def get_by_email_client_id_and_user_type(self, email, client_id, user_type): return self.logic.AuthUser.get_by_email_and_client_id( - self.session, + self.session_factory(), email, client_id, user_type, @@ -422,28 +272,32 @@ def get_by_email_client_id_and_user_type(self, email, client_id, user_type): def mark_date_verified_by_id(self, model_id): return self.logic.AuthUser.update_by_id( - self.session, + self.session_factory(), model_id, date_verified=datetime.utcnow(), ) def set_role_by_email_magic_client_id(self, email, magic_client_id, role): + session = self.session_factory() auth_user = self.logic.AuthUser.get_by_email_and_client_id( - self.session, + session, email, magic_client_id, EntityType.MAGIC.value, + for_update=True, ) if not auth_user: auth_user = self.logic.AuthUser.add_by_email_and_client_id( - self.session, + session, magic_client_id, email, user_type=EntityType.MAGIC.value, ) - return self.logic.AuthUser.update_by_id(self.session, auth_user.id, **{role: True}) + session.flush() + + return self.logic.AuthUser.update_by_id(session, auth_user.id, **{role: True}) def search_by_client_id_and_substring( self, @@ -451,29 +305,18 @@ def search_by_client_id_and_substring( substring, offset=None, limit=10, - load_mfa_methods=False, ): - # join_list = ["mfa_methods"] if load_mfa_methods is True else None - if not isinstance(substring, str) or len(substring) < 3: raise InvalidSubstringError() auth_users = self.logic.AuthUser.get_by_client_id_with_substring_search( - self.session, + self.session_factory(), client_id, substring, offset=offset, limit=limit, - # join_list=join_list, ) - # mfa_enablements = self.auth_user_mfa_handler.is_active_batch( - # [auth_user.id for auth_user in auth_users], - # ) - # for auth_user in auth_users: - # if mfa_enablements[auth_user.id] is False: - # auth_user.mfa_methods = [] - return auth_users def is_magic_connect_enabled(self, auth_user_id=None, auth_user=None): @@ -486,13 +329,14 @@ def is_magic_connect_enabled(self, auth_user_id=None, auth_user=None): return auth_user.user_type == EntityType.CONNECT.value def mark_as_inactive(self, auth_user_id): - self.logic.AuthUser.update_by_id(self.session, auth_user_id, is_active=False) + self.logic.AuthUser.update_by_id(self.session_factory(), auth_user_id, is_active=False) def get_by_email_and_wallet_type_for_interop(self, email, wallet_type, network): """ Opinionated method for fetching AuthWallets by email address, wallet_type and network. """ return self.logic.AuthUser.get_by_email_for_interop( + self.session_factory(), email=email, wallet_type=wallet_type, network=network, @@ -505,37 +349,27 @@ def get_magic_connect_auth_user(self, auth_user_id): return auth_user -# @signals.auth_user_duplicate.connect -# def handle_duplicate_auth_users( -# current_app, -# original_auth_user_id, -# duplicate_auth_user_ids, -# auth_user_handler: t.Optional[AuthUserHandler] = None, -# ) -> None: -# logger.info(f"{len(duplicate_auth_user_ids)} dupe(s) found for {original_auth_user_id}") +@signals.auth_user_duplicate.connect +def handle_duplicate_auth_users( + app: Quart, + original_auth_user_id: ObjectID, + duplicate_auth_user_ids: t.Sequence[ObjectID], +) -> None: + logger.info(f"{len(duplicate_auth_user_ids)} dupe(s) found for {original_auth_user_id}") -# auth_user_handler = auth_user_handler or AuthUserHandler() - -# for dupe_id in duplicate_auth_user_ids: -# logger.info( -# f"marking auth_user_id {dupe_id} as inactive, in favor of original {original_auth_user_id}", -# ) -# auth_user_handler.mark_as_inactive(dupe_id) + for dupe_id in duplicate_auth_user_ids: + logger.info( + f"marking auth_user_id {dupe_id} as inactive, in favor of original {original_auth_user_id}", + ) + app.container.logic().AuthUser.update_by_id(dupe_id, is_active=False) class AuthWalletHandler(HandlerBase): - # account_linking_feature = LDFeatureFlag("is-account-linking-enabled", anonymous_user=True) - - def __init__(self, network, *args, wallet_type=WalletType.ETH, **kwargs): - super().__init__(*args, **kwargs) - self.wallet_network = network - self.wallet_type = wallet_type - def get_by_id(self, model_id): - return self.logic.AuthWallet.get_by_id(self.session, model_id) + return self.logic.AuthWallet.get_by_id(self.session_factory(), model_id) def get_by_public_address(self, public_address): - return self.logic.AuthWallet.get_by_public_address(self.session, public_address) + return self.logic.AuthWallet().get_by_public_address(self.session_factory(), public_address) def get_by_auth_user_id( self, @@ -544,25 +378,9 @@ def get_by_auth_user_id( wallet_type: t.Optional[WalletType] = None, **kwargs, ) -> t.List[AuthWallet]: - auth_user = self.logic.AuthUser.get_by_id( - self.session, - auth_user_id, - join_list=["linked_primary_auth_user"], - ) - - if auth_user.has_linked_primary_auth_user: - logger.info( - "Linked primary_auth_user found for wallet delegation", - extra=dict( - auth_user_id=auth_user.id, - delegated_to=auth_user.linked_primary_auth_user_id, - ), - ) - auth_user = auth_user.linked_primary_auth_user - return self.logic.AuthWallet.get_by_auth_user_id( - self.session, - auth_user.id, + self.session_factory(), + auth_user_id, network=network, wallet_type=wallet_type, **kwargs, @@ -574,20 +392,24 @@ def sync_auth_wallet( public_address, encrypted_private_address, wallet_management_type, + network: t.Optional[str] = None, + wallet_type: t.Optional[WalletType] = None, ): - existing_wallet = self.logic.AuthWallet.get_by_auth_user_id( - self.session, - auth_user_id, - ) - if existing_wallet: - raise RuntimeError("WalletExistsForNetworkAndWalletType") - - return self.logic.AuthWallet.add( - self.session, - public_address, - encrypted_private_address, - self.wallet_type, - self.wallet_network, - management_type=wallet_management_type, - auth_user_id=auth_user_id, - ) + session = self.session_factory() + with session.begin_nested(): + existing_wallet = self.logic.AuthWallet.get_by_auth_user_id( + session, + auth_user_id, + ) + if existing_wallet: + raise RuntimeError("WalletExistsForNetworkAndWalletType") + + return self.logic.AuthWallet.add( + session, + public_address, + encrypted_private_address, + wallet_type, + network, + management_type=wallet_management_type, + auth_user_id=auth_user_id, + ) diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py index ceb6ffd..7b969bc 100644 --- a/src/quart_sqlalchemy/sim/logic.py +++ b/src/quart_sqlalchemy/sim/logic.py @@ -1,13 +1,14 @@ +import inspect import logging +import secrets import typing as t from datetime import datetime -from functools import wraps -from sqlalchemy import or_ -from sqlalchemy.orm import contains_eager -from sqlalchemy.orm import Session -from sqlalchemy.sql.expression import func +import sqlalchemy +import sqlalchemy.orm +from quart import current_app +from quart_sqlalchemy.session import provide_global_contextual_session from quart_sqlalchemy.sim import signals from quart_sqlalchemy.sim.model import AuthUser as auth_user_model from quart_sqlalchemy.sim.model import AuthWallet as auth_wallet_model @@ -19,89 +20,59 @@ from quart_sqlalchemy.sim.repo_adapter import RepositoryLegacyAdapter from quart_sqlalchemy.sim.util import ObjectID from quart_sqlalchemy.sim.util import one +from quart_sqlalchemy.types import EntityIdT +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import SessionT logger = logging.getLogger(__name__) +sa = sqlalchemy class LogicMeta(type): - """This is metaclass provides registry pattern where all the available - logics will be accessible through any instantiated logic object. - - Note: - Don't use this metaclass at another places. This is only intended to be - used by LogicComponent. If you want your own registry, please create - your own. - """ + _ignore = {"LegacyLogicComponent"} def __init__(cls, name, bases, cls_dict): if not hasattr(cls, "_registry"): cls._registry = {} else: - cls._registry[name] = cls() - - super().__init__(name, bases, cls_dict) - - -class LogicComponent(metaclass=LogicMeta): - """This is the base class for any logic class. This overrides the getattr - method for registry lookup. + if cls.__name__ not in cls._ignore: + model = getattr(cls, "model", None) + if model is not None: + name = model.__name__ - Example: + cls._registry[name] = cls() - ``` - class TrollGoat(LogicComponent): - - def add(x): - print(x) - ``` - - Once you have a logic object, you can directly do something like: - - ``` - logic.TrollGoat.add('troll_goat') - ``` + super().__init__(name, bases, cls_dict) - Note: - You will have to explicitly import your newly created logic in - ``fortmatic.logic.__init__.py``. When the logic is imported, it is created - the first time; hence, it is then registered. If this is unclear to you, - read https://blog.ionelmc.ro/2015/02/09/understanding-python-metaclasses/ - It has all the info you need to understand. For example, everything in - python is an object :P. Enjoy. - """ +class LogicComponent(t.Generic[EntityT, EntityIdT, SessionT], metaclass=LogicMeta): def __dir__(self): return super().__dir__() + list(self._registry.keys()) - def __getattr__(self, logic_name): - if logic_name in self._registry: - return self._registry[logic_name] + def __getattr__(self, name): + if name in self._registry: + return self._registry[name] else: - raise AttributeError( - "{object_name} has no attribute '{logic_name}'".format( - object_name=self.__class__.__name__, - logic_name=logic_name, - ), - ) + raise AttributeError(f"{type(self).__name__} has no attribute '{name}'") -class MagicClient(LogicComponent): - def __init__(self): - # self._repository = SQLAlchemyRepository[magic_client_model, ObjectID](session) +class MagicClient(LogicComponent[magic_client_model, ObjectID, sa.orm.Session]): + model = magic_client_model + identity = ObjectID + _repository = RepositoryLegacyAdapter(model, identity) - self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID) - - def _add(self, session, app_name=None): + @provide_global_contextual_session + def add(self, session, app_name=None, **kwargs): + public_api_key = secrets.token_hex(16) return self._repository.add( session, app_name=app_name, + **kwargs, + public_api_key=public_api_key, ) - # add = with_db_session(ro=False)(_add) - add = _add - - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_id( self, session, @@ -116,7 +87,7 @@ def get_by_id( join_list=join_list, ) - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_public_api_key( self, session, @@ -130,34 +101,17 @@ def get_by_public_api_key( ) ) - # @with_db_session(ro=True) - # def get_magic_api_user_id_by_client_id(self, session, magic_client_id): - # client = self._repository.get_by_id( - # session, - # magic_client_id, - # allow_inactive=False, - # join_list=None, - # ) - - # if client is None: - # return None - - # if client.magic_client_api_user is None: - # return None - - # return client.magic_client_api_user.magic_api_user_id - - # @with_db_session(ro=False) + @provide_global_contextual_session def update_by_id(self, session, model_id, **update_params): modified_row = self._repository.update(session, model_id, **update_params) session.refresh(modified_row) return modified_row - # @with_db_session(ro=True) + @provide_global_contextual_session def yield_all_clients_by_chunk(self, session, chunk_size): yield from self._repository.yield_by_chunk(session, chunk_size) - # @with_db_session(ro=True) + @provide_global_contextual_session def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): yield from self._repository.yield_by_chunk( session, @@ -183,67 +137,17 @@ class MissingPhoneNumber(Exception): pass -class AuthUser(LogicComponent): - def __init__(self): - # self._repository = SQLRepository(auth_user_model) - self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID) - - # @with_db_session(ro=True) - def get_by_session_token( - self, - session, - session_token, - ): - return one( - self._repository.get_by( - session, - filters=[auth_user_model.current_session_token == session_token], - limit=1, - ) - ) - - def _get_or_add_by_phone_number_and_client_id( - self, - session, - client_id, - phone_number, - user_type=EntityType.FORTMATIC.value, - ): - if phone_number is None: - raise MissingPhoneNumber() - - row = self._get_by_phone_number_and_client_id( - session=session, - phone_number=phone_number, - client_id=client_id, - user_type=user_type, - ) - - if row: - return row - - row = self._repository.add( - session=session, - phone_number=phone_number, - client_id=client_id, - user_type=user_type, - provenance=Provenance.SMS, - ) - logger.info( - "New auth user (id: {}) created by phone number (client_id: {})".format( - row.id, - client_id, - ), - ) - - return row +class AuthUser(LogicComponent[auth_user_model, ObjectID, sa.orm.Session]): + model = auth_user_model + identity = ObjectID + _repository = RepositoryLegacyAdapter(model, identity) - # get_or_add_by_phone_number_and_client_id = with_db_session(ro=False)( - # _get_or_add_by_phone_number_and_client_id, - # ) - get_or_add_by_phone_number_and_client_id = _get_or_add_by_phone_number_and_client_id + @provide_global_contextual_session + def add(self, session, **kwargs) -> auth_user_model: + return self._repository.add(session, **kwargs) - def _add_by_email_and_client_id( + @provide_global_contextual_session + def add_by_email_and_client_id( self, session, client_id, @@ -254,7 +158,7 @@ def _add_by_email_and_client_id( if email is None: raise MissingEmail() - if self._exist_by_email_and_client_id( + if self.exist_by_email_and_client_id( session, email, client_id, @@ -284,10 +188,8 @@ def _add_by_email_and_client_id( return row - # add_by_email_and_client_id = with_db_session(ro=False)(_add_by_email_and_client_id) - add_by_email_and_client_id = _add_by_email_and_client_id - - def _add_by_client_id( + @provide_global_contextual_session + def add_by_client_id( self, session, client_id, @@ -305,7 +207,55 @@ def _add_by_client_id( date_verified=datetime.utcnow() if is_verified else None, ) logger.info( - "New auth user (id: {}) created by (client_id: {})".format( + "New auth user (id: {}) created by (client_id: {})".format(row.id, client_id), + ) + + return row + + @provide_global_contextual_session + def get_by_session_token( + self, + session, + session_token, + ): + return one( + self._repository.get_by( + session, + filters=[auth_user_model.current_session_token == session_token], + limit=1, + ) + ) + + @provide_global_contextual_session + def get_or_add_by_phone_number_and_client_id( + self, + session, + client_id, + phone_number, + user_type=EntityType.FORTMATIC.value, + ): + if phone_number is None: + raise MissingPhoneNumber() + + row = self.get_by_phone_number_and_client_id( + session=session, + phone_number=phone_number, + client_id=client_id, + user_type=user_type, + ) + + if row: + return row + + row = self._repository.add( + session=session, + phone_number=phone_number, + client_id=client_id, + user_type=user_type, + provenance=Provenance.SMS, + ) + logger.info( + "New auth user (id: {}) created by phone number (client_id: {})".format( row.id, client_id, ), @@ -313,29 +263,28 @@ def _add_by_client_id( return row - # add_by_client_id = with_db_session(ro=False)(_add_by_client_id) - add_by_client_id = _add_by_client_id - - def _get_by_active_identifier_and_client_id( + @provide_global_contextual_session + def get_by_active_identifier_and_client_id( self, session, identifier_field, identifier_value, client_id, user_type, - ) -> auth_user_model: + for_update=False, + ) -> t.Optional[auth_user_model]: """There should only be one active identifier where all the parameters match for a given client ID. In the case of multiple results, the subsequent entries / "dupes" will be marked as inactive.""" filters = [ identifier_field == identifier_value, auth_user_model.client_id == client_id, auth_user_model.user_type == user_type, - # auth_user_model.is_active == True, # noqa: E712 ] results = self._repository.get_by( session, filters=filters, order_by_clause=auth_user_model.id.asc(), + for_update=for_update, ) if not results: @@ -345,29 +294,33 @@ def _get_by_active_identifier_and_client_id( if duplicates: signals.auth_user_duplicate.send( + current_app, original_auth_user_id=original.id, duplicate_auth_user_ids=[dupe.id for dupe in duplicates], ) return original - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_email_and_client_id( self, session, email, client_id, user_type=EntityType.FORTMATIC.value, + for_update=False, ): - return self._get_by_active_identifier_and_client_id( + return self.get_by_active_identifier_and_client_id( session=session, identifier_field=auth_user_model.email, identifier_value=email, client_id=client_id, user_type=user_type, + for_update=for_update, ) - def _get_by_phone_number_and_client_id( + @provide_global_contextual_session + def get_by_phone_number_and_client_id( self, session, phone_number, @@ -377,7 +330,7 @@ def _get_by_phone_number_and_client_id( if phone_number is None: raise MissingPhoneNumber() - return self._get_by_active_identifier_and_client_id( + return self.get_by_active_identifier_and_client_id( session=session, identifier_field=auth_user_model.phone_number, identifier_value=phone_number, @@ -385,12 +338,8 @@ def _get_by_phone_number_and_client_id( user_type=user_type, ) - # get_by_phone_number_and_client_id = with_db_session(ro=True)( - # _get_by_phone_number_and_client_id, - # ) - get_by_phone_number_and_client_id = _get_by_phone_number_and_client_id - - def _exist_by_email_and_client_id( + @provide_global_contextual_session + def exist_by_email_and_client_id( self, session, email, @@ -408,10 +357,10 @@ def _exist_by_email_and_client_id( ), ) - # exist_by_email_and_client_id = with_db_session(ro=True)(_exist_by_email_and_client_id) - exist_by_email_and_client_id = _exist_by_email_and_client_id - - def _get_by_id(self, session, model_id, join_list=None, for_update=False) -> auth_user_model: + @provide_global_contextual_session + def get_by_id( + self, session, model_id, join_list=None, for_update=False + ) -> t.Optional[auth_user_model]: return self._repository.get_by_id( session, model_id, @@ -419,10 +368,8 @@ def _get_by_id(self, session, model_id, join_list=None, for_update=False) -> aut for_update=for_update, ) - get_by_id = _get_by_id - # get_by_id = with_db_session(ro=True)(_get_by_id) - - def _update_by_id(self, session, auth_user_id, **kwargs): + @provide_global_contextual_session + def update_by_id(self, session, auth_user_id, **kwargs): modified_user = self._repository.update(session, auth_user_id, **kwargs) if modified_user is None: @@ -430,111 +377,33 @@ def _update_by_id(self, session, auth_user_id, **kwargs): return modified_user - # update_by_id = with_db_session(ro=False)(_update_by_id) - update_by_id = _update_by_id - - # @with_db_session(ro=True) + @provide_global_contextual_session def get_user_count_by_client_id_and_user_type(self, session, client_id, user_type): query = ( session.query(auth_user_model) .filter( auth_user_model.client_id == client_id, auth_user_model.user_type == user_type, - # auth_user_model.is_active == True, # noqa: E712 auth_user_model.date_verified.is_not(None), ) - .statement.with_only_columns([func.count()]) + .statement.with_only_columns(sa.func.count()) .order_by(None) ) return session.execute(query).scalar() - def _get_by_client_id_and_global_auth_user(self, session, client_id, global_auth_user_id): + @provide_global_contextual_session + def get_by_client_id_and_global_auth_user(self, session, client_id, global_auth_user_id): return self._repository.get_by( session=session, filters=[ auth_user_model.client_id == client_id, auth_user_model.user_type == EntityType.CONNECT.value, - # auth_user_model.is_active == True, # noqa: E712 auth_user_model.global_auth_user_id == global_auth_user_id, ], ) - # get_by_client_id_and_global_auth_user = with_db_session(ro=True)( - # _get_by_client_id_and_global_auth_user, - # ) - get_by_client_id_and_global_auth_user = _get_by_client_id_and_global_auth_user - - # @with_db_session(ro=True) - # def get_by_client_id_for_connect( - # self, - # session, - # client_id, - # offset=None, - # limit=None, - # ): - # # TODO(thomas|2022-07-12): Determine where/if is the right place to split - # # connect/magic logic based on user type as part of https://app.shortcut.com/magic-labs/story/53323. - # # See https://github.com/fortmatic/fortmatic/pull/6173#discussion_r919529540. - # return ( - # session.query(auth_user_model) - # .join( - # identifier_model, - # auth_user_model.global_auth_user_id == identifier_model.global_auth_user_id, - # ) - # .filter( - # auth_user_model.client_id == client_id, - # auth_user_model.user_type == EntityType.CONNECT.value, - # auth_user_model.is_active == True, # noqa: E712, - # auth_user_model.provenance == Provenance.IDENTIFIER, - # or_( - # identifier_model.identifier_type.in_( - # GlobalAuthUserIdentifierType.get_public_address_enums(), - # ), - # identifier_model.date_verified != None, - # ), - # ) - # .order_by(auth_user_model.id.desc()) - # .limit(limit) - # .offset(offset) - # ).all() - - # @with_db_session(ro=True) - # def get_user_count_by_client_id_for_connect( - # self, - # session, - # client_id, - # ): - # # TODO(thomas|2022-07-12): Determine where/if is the right place to split - # # connect/magic logic based on user type as part of https://app.shortcut.com/magic-labs/story/53323. - # # See https://github.com/fortmatic/fortmatic/pull/6173#discussion_r919529540. - # query = ( - # session.query(auth_user_model) - # .join( - # identifier_model, - # auth_user_model.global_auth_user_id == identifier_model.global_auth_user_id, - # ) - # .filter( - # auth_user_model.client_id == client_id, - # auth_user_model.user_type == EntityType.CONNECT.value, - # auth_user_model.is_active == True, # noqa: E712, - # auth_user_model.provenance == Provenance.IDENTIFIER, - # or_( - # identifier_model.identifier_type.in_( - # GlobalAuthUserIdentifierType.get_public_address_enums(), - # ), - # identifier_model.date_verified != None, - # ), - # ) - # .statement.with_only_columns( - # [func.count(distinct(auth_user_model.global_auth_user_id))], - # ) - # .order_by(None) - # ) - - # return session.execute(query).scalar() - - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_client_id_and_user_type( self, session, @@ -543,7 +412,7 @@ def get_by_client_id_and_user_type( offset=None, limit=None, ): - return self._get_by_client_ids_and_user_type( + return self.get_by_client_ids_and_user_type( session, [client_id], user_type, @@ -551,7 +420,8 @@ def get_by_client_id_and_user_type( limit=limit, ) - def _get_by_client_ids_and_user_type( + @provide_global_contextual_session + def get_by_client_ids_and_user_type( self, session, client_ids, @@ -567,7 +437,6 @@ def _get_by_client_ids_and_user_type( filters=[ auth_user_model.client_id.in_(client_ids), auth_user_model.user_type == user_type, - # auth_user_model.is_active == True, # noqa: E712, auth_user_model.date_verified != None, ], offset=offset, @@ -575,12 +444,8 @@ def _get_by_client_ids_and_user_type( order_by_clause=auth_user_model.id.desc(), ) - # get_by_client_ids_and_user_type = with_db_session(ro=True)( - # _get_by_client_ids_and_user_type, - # ) - get_by_client_ids_and_user_type = _get_by_client_ids_and_user_type - - def _get_by_client_id_with_substring_search( + @provide_global_contextual_session + def get_by_client_id_with_substring_search( self, session, client_id, @@ -594,12 +459,12 @@ def _get_by_client_id_with_substring_search( filters=[ auth_user_model.client_id == client_id, auth_user_model.user_type == EntityType.MAGIC.value, - or_( + sa.or_( auth_user_model.provenance == Provenance.SMS, auth_user_model.provenance == Provenance.LINK, auth_user_model.provenance == None, # noqa: E711 ), - or_( + sa.or_( auth_user_model.phone_number.contains(substring), auth_user_model.email.contains(substring), ), @@ -610,12 +475,7 @@ def _get_by_client_id_with_substring_search( join_list=join_list, ) - # get_by_client_id_with_substring_search = with_db_session(ro=True)( - # _get_by_client_id_with_substring_search, - # ) - get_by_client_id_with_substring_search = _get_by_client_id_with_substring_search - - # @with_db_session(ro=True) + @provide_global_contextual_session def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): yield from self._repository.yield_by_chunk( session, @@ -624,7 +484,7 @@ def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): join_list=join_list, ) - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_emails_and_client_id( self, session, @@ -639,7 +499,8 @@ def get_by_emails_and_client_id( ], ) - def _get_by_email( + @provide_global_contextual_session + def get_by_email( self, session, email: str, @@ -657,16 +518,8 @@ def _get_by_email( join_list=join_list, ) - # get_by_email = with_db_session(ro=True)(_get_by_email) - get_by_email = _get_by_email - - def _add(self, session, **kwargs) -> ObjectID: - return self._repository.add(session, **kwargs).id - - # add = with_db_session(ro=False)(_add) - add = _add - - def _get_by_email_for_interop( + @provide_global_contextual_session + def get_by_email_for_interop( self, session, email: str, @@ -686,21 +539,14 @@ def _get_by_email_for_interop( auth_user_model.wallets.and_( auth_wallet_model.wallet_type == str(wallet_type) ).and_(auth_wallet_model.network == network) - # .and_(auth_wallet_model.is_active == 1), ) - .options(contains_eager(auth_user_model.wallets)) + .options(sa.orm.contains_eager(auth_user_model.wallets)) .join( auth_user_model.magic_client.and_( magic_client_model.connect_interop == ConnectInteropStatus.ENABLED, ), ) - .options(contains_eager(auth_user_model.magic_client)) - # TODO(magic-ravi#67899|2022-12-30): Uncomment to allow account-linked users to use interop - # .options( - # joinedload( - # auth_user_model.linked_primary_auth_user, - # ).joinedload("auth_wallets"), - # ) + .options(sa.orm.contains_eager(auth_user_model.magic_client)) .filter( auth_wallet_model.wallet_type == wallet_type, auth_wallet_model.network == network, @@ -708,20 +554,14 @@ def _get_by_email_for_interop( .filter( auth_user_model.email == email, auth_user_model.user_type == EntityType.MAGIC.value, - # auth_user_model.is_active == 1, - auth_user_model.linked_primary_auth_user_id == None, # noqa: E711 ) .populate_existing() ) return query.all() - # get_by_email_for_interop = with_db_session(ro=True)( - # _get_by_email_for_interop, - # ) - get_by_email_for_interop = _get_by_email_for_interop - - def _get_linked_users(self, session, primary_auth_user_id, join_list, no_op=False): + @provide_global_contextual_session + def get_linked_users(self, session, primary_auth_user_id, join_list, no_op=False): # TODO(magic-ravi#67899|2022-12-30): Re-enable account linked users for interop. Remove no_op flag. if no_op: return [] @@ -729,17 +569,13 @@ def _get_linked_users(self, session, primary_auth_user_id, join_list, no_op=Fals return self._repository.get_by( session, filters=[ - # auth_user_model.is_active == True, # noqa: E712 auth_user_model.user_type == EntityType.MAGIC.value, auth_user_model.linked_primary_auth_user_id == primary_auth_user_id, ], join_list=join_list, ) - # get_linked_users = with_db_session(ro=True)(_get_linked_users) - get_linked_users = _get_linked_users - - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_phone_number(self, session, phone_number): return self._repository.get_by( session, @@ -749,12 +585,13 @@ def get_by_phone_number(self, session, phone_number): ) -class AuthWallet(LogicComponent): - def __init__(self): - # self._repository = SQLAlchemyRepository[magic_client_model, ObjectID](session) - self._repository = RepositoryLegacyAdapter(auth_wallet_model, ObjectID) +class AuthWallet(LogicComponent[auth_wallet_model, ObjectID, sa.orm.Session]): + model = auth_wallet_model + identity = ObjectID + _repository = RepositoryLegacyAdapter(model, identity) - def _add( + @provide_global_contextual_session + def add( self, session, public_address, @@ -776,10 +613,7 @@ def _add( return new_row - # add = with_db_session(ro=False)(_add) - add = _add - - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_id(self, session, model_id, allow_inactive=False, join_list=None): return self._repository.get_by_id( session, @@ -788,24 +622,10 @@ def get_by_id(self, session, model_id, allow_inactive=False, join_list=None): join_list=join_list, ) - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_public_address(self, session, public_address, network=None, is_active=True): - """Public address is unique in our system. In any case, we should only - find one row for the given public address. - - Args: - session: A database session object. - public_address (str): A public address. - network (str): A network name. - is_active (boolean): A boolean value to denote if the query should - retrieve active or inactive rows. - - Returns: - A formatted row, either in presenter form or raw db row. - """ filters = [ auth_wallet_model.public_address == public_address, - # auth_wallet_model.is_active == is_active, ] if network: @@ -818,7 +638,7 @@ def get_by_public_address(self, session, public_address, network=None, is_active return one(row) - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_auth_user_id( self, session, @@ -828,23 +648,8 @@ def get_by_auth_user_id( is_active=True, join_list=None, ): - """Return all the associated wallets for the given user id. - - Args: - session: A database session object. - auth_user_id (ObjectID): A auth_user id. - network (str|None): A network name. - wallet_type (str|None): a wallet type like ETH or BTC - is_active (boolean): A boolean value to denote if the query should - retrieve active or inactive rows. - join_list (None|List): Table you wish to join. - - Returns: - An empty list or a list of wallets. - """ filters = [ auth_wallet_model.auth_user_id == auth_user_id, - # auth_wallet_model.is_active == is_active, ] if network: @@ -862,8 +667,6 @@ def get_by_auth_user_id( return rows - def _update_by_id(self, session, model_id, **kwargs): + @provide_global_contextual_session + def update_by_id(self, session, model_id, **kwargs): self._repository.update(session, model_id, **kwargs) - - # update_by_id = with_db_session(ro=False)(_update_by_id) - update_by_id = _update_by_id diff --git a/src/quart_sqlalchemy/sim/main.py b/src/quart_sqlalchemy/sim/main.py index 23d1e0c..8350c68 100644 --- a/src/quart_sqlalchemy/sim/main.py +++ b/src/quart_sqlalchemy/sim/main.py @@ -1,8 +1,10 @@ +from quart_sqlalchemy.sim import commands from quart_sqlalchemy.sim.app import create_app app = create_app() +commands.attach(app) if __name__ == "__main__": app.run(port=8081) diff --git a/src/quart_sqlalchemy/sim/model.py b/src/quart_sqlalchemy/sim/model.py index b0199ab..9df6fe9 100644 --- a/src/quart_sqlalchemy/sim/model.py +++ b/src/quart_sqlalchemy/sim/model.py @@ -11,7 +11,7 @@ from quart_sqlalchemy.model import SoftDeleteMixin from quart_sqlalchemy.model import TimestampMixin -from quart_sqlalchemy.sim.db import db +from quart_sqlalchemy.sim.db import MyBase from quart_sqlalchemy.sim.util import ObjectID @@ -71,7 +71,7 @@ class WalletType(StrEnum): HEDERA = "HEDERA" -class MagicClient(db.Model, SoftDeleteMixin, TimestampMixin): +class MagicClient(MyBase, SoftDeleteMixin, TimestampMixin): __tablename__ = "magic_client" id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) @@ -90,7 +90,7 @@ class MagicClient(db.Model, SoftDeleteMixin, TimestampMixin): ) -class AuthUser(db.Model, SoftDeleteMixin, TimestampMixin): +class AuthUser(MyBase, SoftDeleteMixin, TimestampMixin): __tablename__ = "auth_user" id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) @@ -145,7 +145,7 @@ def is_magic_connect_user(self): return self.global_auth_user_id is not None and self.user_type == EntityType.CONNECT.value -class AuthWallet(db.Model, SoftDeleteMixin, TimestampMixin): +class AuthWallet(MyBase, SoftDeleteMixin, TimestampMixin): __tablename__ = "auth_wallet" id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) diff --git a/src/quart_sqlalchemy/sim/repo.py b/src/quart_sqlalchemy/sim/repo.py index 5844b9d..959efb2 100644 --- a/src/quart_sqlalchemy/sim/repo.py +++ b/src/quart_sqlalchemy/sim/repo.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator import typing as t from abc import ABCMeta @@ -13,53 +14,35 @@ from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import Operator from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable from quart_sqlalchemy.types import SessionT -# from abc import abstractmethod - - sa = sqlalchemy -class AbstractRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): +class AbstractRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): """A repository interface.""" - # identity: t.Type[EntityIdT] - - # def __init__(self, model: t.Type[EntityT]): - # self.model = model - - @property - def model(self) -> t.Type[EntityT]: - return self.__orig_class__.__args__[0] # type: ignore + model: t.Type[EntityT] + identity: t.Type[EntityIdT] - @property - def identity(self) -> t.Type[EntityIdT]: - return self.__orig_class__.__args__[1] # type: ignore - -class AbstractBulkRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): +class AbstractBulkRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): """A repository interface for bulk operations. Note: this interface circumvents ORM internals, breaking commonly expected behavior in order to gain performance benefits. Only use this class whenever absolutely necessary. """ - @property - def model(self) -> t.Type[EntityT]: - return self.__orig_class__.__args__[0] # type: ignore - - @property - def identity(self) -> t.Type[EntityIdT]: - return self.__orig_class__.__args__[1] # type: ignore + model: t.Type[EntityT] + identity: t.Type[EntityIdT] class SQLAlchemyRepository( - AbstractRepository[EntityT, EntityIdT], - t.Generic[EntityT, EntityIdT], + AbstractRepository[EntityT, EntityIdT, SessionT], t.Generic[EntityT, EntityIdT, SessionT] ): """A repository that uses SQLAlchemy to persist data. @@ -83,20 +66,21 @@ class SQLAlchemyRepository( """ - # session: sa.orm.Session builder: StatementBuilder - def __init__(self, **kwargs): - super().__init__(**kwargs) - # self.session = session - self.builder = StatementBuilder(None) + def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT]): + self.model = model + self.identity = identity + self.builder = StatementBuilder(self.model) def insert(self, session: sa.orm.Session, values: t.Dict[str, t.Any]) -> EntityT: """Insert a new model into the database.""" new = self.model(**values) + session.add(new) session.flush() session.refresh(new) + return new def update( @@ -109,8 +93,10 @@ def update( for field, value in values.items(): if getattr(obj, field) != value: setattr(obj, field, value) + session.flush() session.refresh(obj) + return obj def merge( @@ -123,9 +109,11 @@ def merge( """Merge model in session/db having id_ with values.""" session.get(self.model, id_) values.update(id=id_) + merged = session.merge(self.model(**values)) session.flush() session.refresh(merged, with_for_update=for_update) # type: ignore + return merged def get( @@ -151,19 +139,61 @@ def get( to satisfy the expected interface's return type: Optional[EntityT], one_or_none is called on the result before returning. """ + selectables = (self.model,) + execution_options = execution_options or {} if include_inactive: execution_options.setdefault("include_inactive", include_inactive) - statement = sa.select(self.model).where(self.model.id == id_).limit(1) # type: ignore + statement = self.builder.select( + selectables, # type: ignore + conditions=[self.model.id == id_], + options=options, + limit=1, + for_update=for_update, + ) - for option in options: - statement = statement.options(option) + return session.scalars(statement, execution_options=execution_options).one_or_none() - if for_update: - statement = statement.with_for_update() + def get_by_field( + self, + session: sa.orm.Session, + field: t.Union[ColumnExpr, str], + value: t.Any, + op: Operator = operator.eq, + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + ) -> sa.ScalarResult[EntityT]: + """Select models where field is equal to value.""" + selectables = (self.model,) - return session.scalars(statement, execution_options=execution_options).one_or_none() + execution_options = execution_options or {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + + if isinstance(field, str): + field = getattr(self.model, field) + + conditions = [t.cast(ColumnExpr, op(field, value))] + + statement = self.builder.select( + selectables, # type: ignore + conditions=conditions, + order_by=order_by, + options=options, + offset=offset, + limit=limit, + distinct=distinct, + for_update=for_update, + ) + + return session.scalars(statement, execution_options=execution_options) def select( self, @@ -193,7 +223,7 @@ def select( if yield_by_chunk: execution_options.setdefault("yield_per", yield_by_chunk) - statement = self.builder.complex_select( + statement = self.builder.select( selectables, conditions=conditions, group_by=group_by, @@ -214,10 +244,7 @@ def select( def delete( self, session: sa.orm.Session, id_: EntityIdT, include_inactive: bool = False ) -> None: - # if self.has_soft_delete: - # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") - - entity = self.get(id_, include_inactive=include_inactive) + entity = self.get(session, id_, include_inactive=include_inactive) if not entity: raise RuntimeError(f"Entity with id {id_} not found.") @@ -225,16 +252,10 @@ def delete( session.flush() def deactivate(self, session: sa.orm.Session, id_: EntityIdT) -> EntityT: - # if not self.has_soft_delete: - # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") - - return self.update(id_, dict(is_active=False)) + return self.update(session, id_, dict(is_active=False)) def reactivate(self, session: sa.orm.Session, id_: EntityIdT) -> EntityT: - # if not self.has_soft_delete: - # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") - - return self.update(id_, dict(is_active=False)) + return self.update(session, id_, dict(is_active=False)) def exists( self, @@ -248,31 +269,36 @@ def exists( Note: This performs better than simply trying to select an object since there is no overhead in sending the selected object and deserializing it. """ - selectable = sa.sql.literal(True) + selectables = (sa.sql.literal(True),) execution_options = {} if include_inactive: execution_options.setdefault("include_inactive", include_inactive) - statement = sa.select(selectable).where(*conditions) # type: ignore - - if for_update: - statement = statement.with_for_update() + statement = self.builder.select( + selectables, + conditions=conditions, + limit=1, + for_update=for_update, + ) result = session.execute(statement, execution_options=execution_options).scalar() return bool(result) -class SQLAlchemyBulkRepository(AbstractBulkRepository, t.Generic[SessionT, EntityT, EntityIdT]): - def __init__(self, **kwargs: t.Any): +class SQLAlchemyBulkRepository( + AbstractBulkRepository[EntityT, EntityIdT, SessionT], t.Generic[EntityT, EntityIdT, SessionT] +): + builder: StatementBuilder + + def __init__(self, **kwargs): super().__init__(**kwargs) - self.builder = StatementBuilder(self.model) - # session = session + self.builder = StatementBuilder(None) def bulk_insert( self, - session: sa.orm.Session, + session: SessionT, values: t.Sequence[t.Dict[str, t.Any]] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: @@ -281,7 +307,7 @@ def bulk_insert( def bulk_update( self, - session: sa.orm.Session, + session: SessionT, conditions: t.Sequence[ColumnExpr] = (), values: t.Optional[t.Dict[str, t.Any]] = None, execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -291,7 +317,7 @@ def bulk_update( def bulk_delete( self, - session: sa.orm.Session, + session: SessionT, conditions: t.Sequence[ColumnExpr] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: diff --git a/src/quart_sqlalchemy/sim/repo_adapter.py b/src/quart_sqlalchemy/sim/repo_adapter.py index 3d6b48e..8e0a43d 100644 --- a/src/quart_sqlalchemy/sim/repo_adapter.py +++ b/src/quart_sqlalchemy/sim/repo_adapter.py @@ -3,19 +3,14 @@ import sqlalchemy import sqlalchemy.orm from pydantic import BaseModel -from sqlalchemy import ScalarResult -from sqlalchemy.orm import selectinload -from sqlalchemy.orm import Session -from sqlalchemy.sql.expression import func -from sqlalchemy.sql.expression import label -from quart_sqlalchemy.model import Base from quart_sqlalchemy.sim.repo import SQLAlchemyRepository from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable +from quart_sqlalchemy.types import SessionT sa = sqlalchemy @@ -39,21 +34,18 @@ class BaseUpdateSchema(BaseModelSchema): UpdateSchemaT = t.TypeVar("UpdateSchemaT", bound=BaseUpdateSchema) -class RepositoryLegacyAdapter(t.Generic[EntityT, EntityIdT]): - def __init__( - self, - model: t.Type[EntityT], - identity: t.Type[EntityIdT], - # session: Session, - ): +class RepositoryLegacyAdapter(t.Generic[EntityT, EntityIdT, SessionT]): + model: t.Type[EntityT] + identity: t.Type[EntityIdT] + + def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT]): self.model = model - self._identity = identity - # self._session = session - self.repo = SQLAlchemyRepository[model, identity]() + self.identity = identity + self._adapted = SQLAlchemyRepository(model, identity) def get_by( self, - session: t.Optional[Session] = None, + session: SessionT, filters=None, allow_inactive=False, join_list=None, @@ -72,10 +64,10 @@ def get_by( else: order_by_clause = () - return self.repo.select( + return self._adapted.select( session, conditions=filters, - options=[selectinload(getattr(self.model, attr)) for attr in join_list], + options=[sa.orm.selectinload(getattr(self.model, attr)) for attr in join_list], for_update=for_update, order_by=order_by_clause, offset=offset, @@ -85,7 +77,7 @@ def get_by( def get_by_id( self, - session=None, + session: SessionT, model_id=None, allow_inactive=False, join_list=None, @@ -94,30 +86,35 @@ def get_by_id( if model_id is None: raise ValueError("model_id is required") join_list = join_list or () - return self.repo.get( + return self._adapted.get( session, id_=model_id, - options=[selectinload(getattr(self.model, attr)) for attr in join_list], + options=[sa.orm.selectinload(getattr(self.model, attr)) for attr in join_list], for_update=for_update, include_inactive=allow_inactive, ) def one( - self, session=None, filters=None, join_list=None, for_update=False, include_inactive=False + self, + session: SessionT, + filters=None, + join_list=None, + for_update=False, + include_inactive=False, ) -> EntityT: filters = filters or () join_list = join_list or () - return self.repo.select( + return self._adapted.select( session, conditions=filters, - options=[selectinload(getattr(self.model, attr)) for attr in join_list], + options=[sa.orm.selectinload(getattr(self.model, attr)) for attr in join_list], for_update=for_update, include_inactive=include_inactive, ).one() def count_by( self, - session=None, + session: SessionT, filters=None, group_by=None, distinct_column=None, @@ -128,36 +125,36 @@ def count_by( group_by = group_by or () if distinct_column: - selectables = [label("count", func.count(func.distinct(distinct_column)))] + selectables = [sa.label("count", sa.func.count(sa.func.distinct(distinct_column)))] else: - selectables = [label("count", func.count(self.model.id))] + selectables = [sa.label("count", sa.func.count(self.model.id))] for group in group_by: selectables.append(group.expression) - result = self.repo.select(session, selectables, conditions=filters, group_by=group_by) + result = self._adapted.select(session, selectables, conditions=filters, group_by=group_by) return result.all() - def add(self, session=None, **kwargs) -> EntityT: - return self.repo.insert(session, kwargs) + def add(self, session: SessionT, **kwargs) -> EntityT: + return self._adapted.insert(session, kwargs) - def update(self, session=None, model_id=None, **kwargs) -> EntityT: - return self.repo.update(session, id_=model_id, values=kwargs) + def update(self, session: SessionT, model_id=None, **kwargs) -> EntityT: + return self._adapted.update(session, id_=model_id, values=kwargs) - def update_by(self, session=None, filters=None, **kwargs) -> EntityT: + def update_by(self, session: SessionT, filters=None, **kwargs) -> EntityT: if not filters: raise ValueError("Full table scans are prohibited. Please provide filters") - row = self.repo.select(session, conditions=filters, limit=2).one() - return self.repo.update(session, id_=row.id, values=kwargs) + row = self._adapted.select(session, conditions=filters, limit=2).one() + return self._adapted.update(session, id_=row.id, values=kwargs) - def delete_by_id(self, session=None, model_id=None) -> None: - self.repo.delete(session, id_=model_id, include_inactive=True) + def delete_by_id(self, session: SessionT, model_id=None) -> None: + self._adapted.delete(session, id_=model_id, include_inactive=True) - def delete_one_by(self, session=None, filters=None, optional=False) -> None: + def delete_one_by(self, session: SessionT, filters=None, optional=False) -> None: filters = filters or () - result = self.repo.select(session, conditions=filters, limit=1) + result = self._adapted.select(session, conditions=filters, limit=1) if optional: row = result.one_or_none() @@ -166,25 +163,25 @@ def delete_one_by(self, session=None, filters=None, optional=False) -> None: else: row = result.one() - self.repo.delete(session, id_=row.id) + self._adapted.delete(session, id_=row.id) - def exist(self, session=None, filters=None, allow_inactive=False) -> bool: + def exist(self, session: SessionT, filters=None, allow_inactive=False) -> bool: filters = filters or () - return self.repo.exists( + return self._adapted.exists( session, conditions=filters, include_inactive=allow_inactive, ) def yield_by_chunk( - self, session=None, chunk_size=100, join_list=None, filters=None, allow_inactive=False + self, session: SessionT, chunk_size=100, join_list=None, filters=None, allow_inactive=False ): filters = filters or () join_list = join_list or () - results = self.repo.select( + results = self._adapted.select( session, conditions=filters, - options=[selectinload(getattr(self.model, attr)) for attr in join_list], + options=[sa.orm.selectinload(getattr(self.model, attr)) for attr in join_list], include_inactive=allow_inactive, yield_by_chunk=chunk_size, ) @@ -192,10 +189,10 @@ def yield_by_chunk( yield result -class PydanticScalarResult(ScalarResult): +class PydanticScalarResult(sa.ScalarResult, t.Generic[ModelSchemaT]): pydantic_schema: t.Type[ModelSchemaT] - def __init__(self, scalar_result, pydantic_schema: t.Type[ModelSchemaT]): + def __init__(self, scalar_result: t.Any, pydantic_schema: t.Type[ModelSchemaT]): for attribute in scalar_result.__slots__: setattr(self, attribute, getattr(scalar_result, attribute)) self.pydantic_schema = pydantic_schema @@ -231,14 +228,15 @@ def partitions(self, *args, **kwargs): yield self._translate_many(partition) -class PydanticRepository(SQLAlchemyRepository, t.Generic[EntityT, EntityIdT, ModelSchemaT]): - @property - def schema(self) -> t.Type[ModelSchemaT]: - return self.__orig_class__.__args__[2] # type: ignore +class PydanticRepository( + SQLAlchemyRepository[EntityT, EntityIdT, SessionT], + t.Generic[EntityT, EntityIdT, SessionT, ModelSchemaT, CreateSchemaT, UpdateSchemaT], +): + model_schema: t.Type[ModelSchemaT] def insert( self, - session: sa.orm.Session, + session: SessionT, create_schema: CreateSchemaT, sqla_model=False, ): @@ -247,11 +245,11 @@ def insert( if sqla_model: return result - return self.schema.from_orm(result) + return self.model_schema.from_orm(result) def update( self, - session: sa.orm.Session, + session: SessionT, id_: EntityIdT, update_schema: UpdateSchemaT, sqla_model=False, @@ -267,13 +265,14 @@ def update( session.add(existing) session.flush() session.refresh(existing) + if sqla_model: return existing - return self.schema.from_orm(existing) + return self.model_schema.from_orm(existing) def get( self, - session: sa.orm.Session, + session: SessionT, id_: EntityIdT, options: t.Sequence[ORMOption] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -294,11 +293,11 @@ def get( if sqla_model: return row - return self.schema.from_orm(row) + return self.model_schema.from_orm(row) def select( self, - session: sa.orm.Session, + session: SessionT, selectables: t.Sequence[Selectable] = (), conditions: t.Sequence[ColumnExpr] = (), group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), @@ -328,6 +327,7 @@ def select( include_inactive, yield_by_chunk, ) + if sqla_model: return result - return PydanticScalarResult(result, self.schema) + return PydanticScalarResult[self.model_schema](result, self.model_schema) diff --git a/src/quart_sqlalchemy/sim/schema.py b/src/quart_sqlalchemy/sim/schema.py index e0304b7..e679d61 100644 --- a/src/quart_sqlalchemy/sim/schema.py +++ b/src/quart_sqlalchemy/sim/schema.py @@ -1,20 +1,34 @@ import typing as t from datetime import datetime +from enum import Enum from pydantic import BaseModel from pydantic import Field from pydantic import validator +from pydantic.generics import GenericModel +from .model import ConnectInteropStatus +from .model import EntityType +from .model import Provenance +from .model import WalletManagementType +from .model import WalletType from .util import ObjectID +DataT = t.TypeVar("DataT") + +json_encoders = { + ObjectID: lambda v: v.encode(), + datetime: lambda dt: int(dt.timestamp()), + Enum: lambda e: e.value, +} + + class BaseSchema(BaseModel): class Config: arbitrary_types_allowed = True - json_encoders = { - ObjectID: lambda v: v.encode(), - datetime: lambda dt: int(dt.timestamp()), - } + json_encoders = dict(json_encoders) + orm_mode = True @classmethod def _get_value(cls, v: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: @@ -27,13 +41,17 @@ def _get_value(cls, v: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: return super()._get_value(v, *args, **kwargs) -class ResponseWrapper(BaseSchema): +class ResponseWrapper(GenericModel, t.Generic[DataT]): """Generic response wrapper""" + class Config: + arbitrary_types_allowed = True + json_encoders = dict(json_encoders) + error_code: str = "" status: str = "" message: str = "" - data: t.Any = Field(default_factory=dict) + data: DataT = Field(default_factory=dict) @validator("status") def set_status_by_error_code(cls, v, values): @@ -41,3 +59,41 @@ def set_status_by_error_code(cls, v, values): if error_code: return "failed" return "ok" + + +class MagicClientSchema(BaseSchema): + id: ObjectID + app_name: str + rate_limit_tier: t.Optional[str] = None + connect_interop: t.Optional[ConnectInteropStatus] = None + is_signing_modal_enabled: bool + global_audience_enabled: bool + public_api_key: str + secret_api_key: str + + +class AuthUserSchema(BaseSchema): + id: ObjectID + client_id: ObjectID + email: str + phone_number: t.Optional[str] = None + user_type: EntityType = EntityType.MAGIC + provenance: t.Optional[Provenance] = None + date_verified: t.Optional[datetime] = None + is_admin: bool = False + linked_primary_auth_user_id: t.Optional[ObjectID] = None + global_auth_user_id: t.Optional[ObjectID] = None + delegated_user_id: t.Optional[str] = None + delegated_identity_pool_id: t.Optional[str] = None + current_session_token: t.Optional[str] = None + + +class AuthWalletSchema(BaseSchema): + id: ObjectID + auth_user_id: ObjectID + wallet_type: WalletType + management_type: WalletManagementType + public_address: str + encrypted_private_address: str + network: str + is_exported: bool diff --git a/src/quart_sqlalchemy/sim/signals.py b/src/quart_sqlalchemy/sim/signals.py index 1981cea..510be83 100644 --- a/src/quart_sqlalchemy/sim/signals.py +++ b/src/quart_sqlalchemy/sim/signals.py @@ -13,6 +13,7 @@ def handler( current_app: Quart, original_auth_user_id: ObjectID, duplicate_auth_user_ids: List[ObjectID], + session: sa.orm.Session, ) -> None: ... """, diff --git a/src/quart_sqlalchemy/sim/testing.py b/src/quart_sqlalchemy/sim/testing.py index 347b422..fd84547 100644 --- a/src/quart_sqlalchemy/sim/testing.py +++ b/src/quart_sqlalchemy/sim/testing.py @@ -1,13 +1,16 @@ from contextlib import contextmanager from quart import g +from quart import Quart from quart import signals +from quart_sqlalchemy import Bind + @contextmanager -def user_set(app, user): +def global_bind(app: Quart, bind: Bind): def handler(sender, **kwargs): - g.user = user + g.bind = bind with signals.appcontext_pushed.connected_to(handler, app): yield diff --git a/src/quart_sqlalchemy/sim/views/auth_user.py b/src/quart_sqlalchemy/sim/views/auth_user.py index ee456e6..6c8664b 100644 --- a/src/quart_sqlalchemy/sim/views/auth_user.py +++ b/src/quart_sqlalchemy/sim/views/auth_user.py @@ -1,7 +1,88 @@ import logging +import typing as t -from quart import Blueprint +import sqlalchemy.orm +from dependency_injector.wiring import inject +from dependency_injector.wiring import Provide +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.session import set_global_contextual_session + +from ..auth import authorized_request +from ..auth import RequestCredentials +from ..container import Container +from ..handle import AuthUserHandler +from ..model import EntityType +from ..schema import AuthUserSchema +from ..schema import BaseSchema +from ..schema import ResponseWrapper +from .util import APIBlueprint + + +sa = sqlalchemy logger = logging.getLogger(__name__) -api = Blueprint("auth_user", __name__, url_prefix="auth_user") +api = APIBlueprint("auth_user", __name__, url_prefix="/auth_user") + + +class CreateAuthUserRequest(BaseSchema): + email: str + + +class CreateAuthUserResponse(BaseSchema): + auth_user: AuthUserSchema + + +@api.get( + "/", + authorizer=authorized_request( + [ + { + "public-api-key": [], + "session-token-bearer": [], + } + ], + ), +) +@inject +def get_auth_user( + auth_user_handler: AuthUserHandler = Provide["AuthUserHandler"], + db: QuartSQLAlchemy = Provide["db"], + credentials: RequestCredentials = Provide["request_credentials"], +) -> ResponseWrapper[AuthUserSchema]: + with db.bind.Session() as session: + with set_global_contextual_session(session): + auth_user = auth_user_handler.get_by_session_token(credentials.current_user.value) + + return ResponseWrapper[AuthUserSchema](data=AuthUserSchema.from_orm(auth_user)) + + +@api.post( + "/", + authorizer=authorized_request( + [ + { + "public-api-key": [], + } + ], + ), +) +@inject +def create_auth_user( + data: CreateAuthUserRequest, + auth_user_handler: AuthUserHandler = Provide["AuthUserHandler"], + db: QuartSQLAlchemy = Provide["db"], + credentials: RequestCredentials = Provide[Container.request_credentials], +) -> ResponseWrapper[CreateAuthUserResponse]: + with db.bind.Session() as session: + with session.begin(): + with set_global_contextual_session(session): + client = auth_user_handler.create_verified_user( + email=data.email, + client_id=credentials.current_client.subject.id, + user_type=EntityType.MAGIC.value, + ) + + return ResponseWrapper[CreateAuthUserResponse]( + data=dict(auth_user=AuthUserSchema.from_orm(client)) # type: ignore + ) diff --git a/src/quart_sqlalchemy/sim/views/auth_wallet.py b/src/quart_sqlalchemy/sim/views/auth_wallet.py index 9c6f052..79c3be0 100644 --- a/src/quart_sqlalchemy/sim/views/auth_wallet.py +++ b/src/quart_sqlalchemy/sim/views/auth_wallet.py @@ -1,23 +1,27 @@ import logging import typing as t +from dependency_injector.wiring import inject +from dependency_injector.wiring import Provide from quart import g -from quart.utils import run_sync -from quart_sqlalchemy.retry import retry_context -from quart_sqlalchemy.retry import RetryError +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.session import set_global_contextual_session from ..auth import authorized_request +from ..auth import RequestCredentials from ..handle import AuthWalletHandler from ..model import WalletManagementType from ..model import WalletType from ..schema import BaseSchema +from ..schema import ResponseWrapper from ..util import ObjectID +from ..web3 import Web3 from .util import APIBlueprint logger = logging.getLogger(__name__) -api = APIBlueprint("auth_wallet", __name__, url_prefix="auth_wallet") +api = APIBlueprint("auth_wallet", __name__, url_prefix="/auth_wallet") @api.before_request @@ -56,28 +60,32 @@ class WalletSyncResponse(BaseSchema): ], ), ) -async def sync(data: WalletSyncRequest) -> WalletSyncResponse: - user_credential = g.authorized_credentials.get("session-token-bearer") +@inject +def sync( + data: WalletSyncRequest, + auth_wallet_handler: AuthWalletHandler = Provide["AuthWalletHandler"], + web3: Web3 = Provide["web3"], + db: QuartSQLAlchemy = Provide["db"], + credentials: RequestCredentials = Provide["request_credentials"], +) -> ResponseWrapper[WalletSyncResponse]: + with db.bind.Session() as session: + with session.begin(): + with set_global_contextual_session(session): + wallet = auth_wallet_handler.sync_auth_wallet( + credentials.current_user.subject.id, + data.public_address, + data.encrypted_private_address, + WalletManagementType.DELEGATED.value, + network=web3.network, + wallet_type=data.wallet_type, + ) - try: - for attempt in retry_context: - with attempt: - with g.bind.Session() as session: - wallet = AuthWalletHandler(g.network, session).sync_auth_wallet( - user_credential.subject.id, - data.public_address, - data.encrypted_private_address, - WalletManagementType.DELEGATED.value, - ) - except RetryError: - pass - except RuntimeError: - raise RuntimeError("Unsupported wallet type or network") - - return WalletSyncResponse( - wallet_id=wallet.id, - auth_user_id=wallet.auth_user_id, - wallet_type=wallet.wallet_type, - public_address=wallet.public_address, - encrypted_private_address=wallet.encrypted_private_address, + return ResponseWrapper[WalletSyncResponse]( + data=dict( + wallet_id=wallet.id, + auth_user_id=wallet.auth_user_id, + wallet_type=wallet.wallet_type, + public_address=wallet.public_address, + encrypted_private_address=wallet.encrypted_private_address, + ) # type: ignore ) diff --git a/src/quart_sqlalchemy/sim/views/magic_client.py b/src/quart_sqlalchemy/sim/views/magic_client.py index a1d28e1..09d889b 100644 --- a/src/quart_sqlalchemy/sim/views/magic_client.py +++ b/src/quart_sqlalchemy/sim/views/magic_client.py @@ -1,7 +1,90 @@ import logging +import typing as t -from quart import Blueprint +from dependency_injector.wiring import inject +from dependency_injector.wiring import Provide + +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.session import set_global_contextual_session + +from ..auth import authorized_request +from ..auth import RequestCredentials +from ..handle import MagicClientHandler +from ..model import ConnectInteropStatus +from ..schema import BaseSchema +from ..schema import MagicClientSchema +from ..schema import ResponseWrapper +from .util import APIBlueprint logger = logging.getLogger(__name__) -api = Blueprint("magic_client", __name__, url_prefix="magic_client") +api = APIBlueprint("magic_client", __name__, url_prefix="/magic_client") + + +class CreateMagicClientRequest(BaseSchema): + app_name: str + rate_limit_tier: t.Optional[str] = None + connect_interop: t.Optional[ConnectInteropStatus] = None + is_signing_modal_enabled: bool = False + global_audience_enabled: bool = False + + +class CreateMagicClientResponse(BaseSchema): + magic_client: MagicClientSchema + + +@api.post( + "/", + authorizer=authorized_request( + [ + { + "public-api-key": [], + } + ], + ), +) +@inject +def create_magic_client( + data: CreateMagicClientRequest, + magic_client_handler: MagicClientHandler = Provide["MagicClientHandler"], + db: QuartSQLAlchemy = Provide["db"], +) -> ResponseWrapper[CreateMagicClientResponse]: + with db.bind.Session() as session: + with session.begin(): + with set_global_contextual_session(session): + client = magic_client_handler.add( + app_name=data.app_name, + rate_limit_tier=data.rate_limit_tier, + connect_interop=data.connect_interop, + is_signing_modal_enabled=data.is_signing_modal_enabled, + global_audience_enabled=data.global_audience_enabled, + ) + + return ResponseWrapper[CreateMagicClientResponse]( + data=dict(magic_client=MagicClientSchema.from_orm(client)) # type: ignore + ) + + +@api.get( + "/", + authorizer=authorized_request( + [ + { + "public-api-key": [], + } + ], + ), +) +@inject +def get_magic_client( + magic_client_handler: MagicClientHandler = Provide["MagicClientHandler"], + credentials: RequestCredentials = Provide["request_credentials"], + db: QuartSQLAlchemy = Provide["db"], +) -> ResponseWrapper[MagicClientSchema]: + with db.bind.Session() as session: + with set_global_contextual_session(session): + client = magic_client_handler.get_by_public_api_key(credentials.current_client.value) + + return ResponseWrapper[MagicClientSchema]( + data=MagicClientSchema.from_orm(client) # type: ignore + ) diff --git a/src/quart_sqlalchemy/sim/web3.py b/src/quart_sqlalchemy/sim/web3.py new file mode 100644 index 0000000..e1a088b --- /dev/null +++ b/src/quart_sqlalchemy/sim/web3.py @@ -0,0 +1,153 @@ +import typing as t +from decimal import Decimal + +import typing_extensions as tx +import web3.providers +from ens import ENS +from eth_typing import AnyAddress +from eth_typing import ChecksumAddress +from eth_typing import HexStr +from eth_typing import Primitives +from eth_typing.abi import TypeStr +from quart import request +from quart.ctx import has_request_context +from web3.eth import Eth +from web3.geth import Geth +from web3.main import BaseWeb3 +from web3.module import Module +from web3.net import Net +from web3.providers import BaseProvider +from web3.types import Wei + + +""" +generate new key address pairing + +```zsh +python -c "from web3 import Web3; w3 = Web3(); acc = w3.eth.account.create(); print(f'private key={w3.to_hex(acc.key)}, account={acc.address}')" +``` +""" + + +class Web3Node(tx.Protocol): + eth: Eth + net: Net + geth: Geth + provider: BaseProvider + ens: ENS + + def is_connected(self) -> bool: + ... + + @staticmethod + def to_bytes( + primitive: t.Optional[Primitives] = None, + hexstr: t.Optional[HexStr] = None, + text: t.Optional[str] = None, + ) -> bytes: + ... + + @staticmethod + def to_int( + primitive: t.Optional[Primitives] = None, + hexstr: t.Optional[HexStr] = None, + text: t.Optional[str] = None, + ) -> int: + ... + + @staticmethod + def to_hex( + primitive: t.Optional[Primitives] = None, + hexstr: t.Optional[HexStr] = None, + text: t.Optional[str] = None, + ) -> HexStr: + ... + + @staticmethod + def to_text( + primitive: t.Optional[Primitives] = None, + hexstr: t.Optional[HexStr] = None, + text: t.Optional[str] = None, + ) -> str: + ... + + @staticmethod + def to_json(obj: t.Dict[t.Any, t.Any]) -> str: + ... + + @staticmethod + def to_wei(number: t.Union[int, float, str, Decimal], unit: str) -> Wei: + ... + + @staticmethod + def from_wei(number: int, unit: str) -> t.Union[int, Decimal]: + ... + + @staticmethod + def is_address(value: t.Any) -> bool: + ... + + @staticmethod + def is_checksum_address(value: t.Any) -> bool: + ... + + @staticmethod + def to_checksum_address(value: t.Union[AnyAddress, str, bytes]) -> ChecksumAddress: + ... + + @property + def api(self) -> str: + ... + + @staticmethod + def keccak( + primitive: t.Optional[Primitives] = None, + text: t.Optional[str] = None, + hexstr: t.Optional[HexStr] = None, + ) -> bytes: + ... + + @classmethod + def normalize_values( + cls, _w3: BaseWeb3, abi_types: t.List[TypeStr], values: t.List[t.Any] + ) -> t.List[t.Any]: + ... + + @classmethod + def solidity_keccak(cls, abi_types: t.List[TypeStr], values: t.List[t.Any]) -> bytes: + ... + + def attach_modules( + self, modules: t.Optional[t.Dict[str, t.Union[t.Type[Module], t.Sequence[t.Any]]]] + ) -> None: + ... + + def is_encodable(self, _type: TypeStr, value: t.Any) -> bool: + ... + + +def web3_node_factory(config): + if config["WEB3_PROVIDER_CLASS"] is web3.providers.HTTPProvider: + provider = config["WEB3_PROVIDER_CLASS"](config["WEB3_HTTPS_PROVIDER_URI"]) + return web3.Web3(provider) + + +class Web3: + node: Web3Node + + def __init__(self, node: Web3Node, default_network: str, default_chain: str): + self.node = node + self.default_network = default_network + self.default_chain = default_chain + + @property + def chain(self) -> str: + if has_request_context(): + return request.headers.get("x-web3-chain", self.default_chain).upper() + return self.default_chain + + @property + def network(self) -> str: + if has_request_context(): + return request.headers.get("x-web3-network", self.default_network).upper() + return self.default_network diff --git a/src/quart_sqlalchemy/sqla.py b/src/quart_sqlalchemy/sqla.py index 56b8b60..b437625 100644 --- a/src/quart_sqlalchemy/sqla.py +++ b/src/quart_sqlalchemy/sqla.py @@ -10,55 +10,280 @@ import sqlalchemy.orm import sqlalchemy.util -from .bind import AsyncBind -from .bind import Bind -from .config import AsyncBindConfig -from .config import SQLAlchemyConfig +from quart_sqlalchemy.bind import AsyncBind +from quart_sqlalchemy.bind import Bind +from quart_sqlalchemy.config import AsyncBindConfig +from quart_sqlalchemy.config import SQLAlchemyConfig sa = sqlalchemy class SQLAlchemy: + """ + This manager class keeps things very simple by using a few configuration conventions. + + Configuration has been simplified down to base_class and binds. + + * Everything related to ORM mapping, DeclarativeBase, registry, MetaData, etc should be + configured by passing the a custom DeclarativeBase class as the base_class configuration + parameter. + + * Everything related to engine/session configuration should be configured by passing a + dictionary mapping string names to BindConfigs as the `binds` configuration parameter. + + BindConfig can be as simple as a dictionary containing a url key like so: + + bind_config = { + "default": {"url": "sqlite://"} + } + + But most use cases will require more than just a connection url, and divide core/engine + configuration from orm/session configuration which looks more like this: + + bind_config = { + "default": { + "engine": { + "url": "sqlite://" + }, + "session": { + "expire_on_commit": False + } + } + } + + Everything under `engine` will then be passed to `sqlalchemy.create_engine_from_config` and + everything under `session` will be passed to `sqlalchemy.orm.sessionmaker`. + + engine = sa.create_engine_from_config(bind_config.engine) + Session = sa.orm.sessionmaker(bind=engine, **bind_config.session) + + Config Examples: + + Simple URL: + db = SQLAlchemy( + SQLAlchemyConfig( + binds=dict( + default=dict( + url="sqlite://" + ) + ) + ) + ) + + Shortcut for the above: + db = SQLAlchemy(SQLAlchemyConfig()) + + More complex configuration for engine and session both: + db = SQLAlchemy( + SQLAlchemyConfig( + binds=dict( + default=dict( + engine=dict( + url="sqlite://" + ), + session=dict( + expire_on_commit=False + ) + ) + ) + ) + ) + + Once instantiated, operations targetting all of the binds, aka metadata, like + `metadata.create_all` should be called from this class. Operations specific to a bind + should be called from that bind. This class has a few ways to get a specific bind. + + * To get a Bind, you can call `.get_bind(name)` on this class. The default bind can be + referenced at `.bind`. + + * To define an ORM model using the Base class attached to this class, simply inherit + from `.Base` + + db = SQLAlchemy(SQLAlchemyConfig()) + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + * You can also decouple Base from SQLAlchemy with some dependency inversion: + from quart_sqlalchemy.model.mixins import DynamicArgsMixin, ReprMixin, TableNameMixin + + class Base(DynamicArgsMixin, ReprMixin, TableNameMixin): + __abstract__ = True + + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db = SQLAlchemy(SQLAlchemyConfig(bind_class=Base)) + + db.create_all() + + + Declarative Mapping using registry based decorator: + + db = SQLAlchemy(SQLAlchemyConfig()) + + @db.registry.mapped + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Declarative with Imperative Table (Hybrid Declarative): + + class User(db.Base): + __table__ = sa.Table( + "user", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.String, default="Joe"), + ) + + + Declarative using reflection to automatically build the table object: + + class User(db.Base): + __table__ = sa.Table( + "user", + db.metadata, + autoload_with=db.bind.engine, + ) + + + Declarative Dataclass Mapping: + + from quart_sqlalchemy.model import Base as Base_ + + class Base(sa.orm.MappedAsDataclass, Base_): + pass + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + class User(db.Base): + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Declarative Dataclass Mapping (using decorator): + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + @db.registry.mapped_as_dataclass + class User: + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Alternate Dataclass Provider Pattern: + + from pydantic.dataclasses import dataclass + from quart_sqlalchemy.model import Base as Base_ + + class Base(sa.orm.MappedAsDataclass, Base_, dataclass_callable=dataclass): + pass + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + class User(db.Base): + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + Imperative style Mapping + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + user_table = sa.Table( + "user", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.String, default="Joe"), + ) + + post_table = sa.Table( + "post", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("title", sa.String, default="My post"), + sa.Column("user_id", sa.ForeignKey("user.id"), nullable=False), + ) + + class User: + pass + + class Post: + pass + + db.registry.map_imperatively( + User, + user_table, + properties={ + "posts": sa.orm.relationship(Post, back_populates="user") + } + ) + db.registry.map_imperatively( + Post, + post_table, + properties={ + "user": sa.orm.relationship(User, back_populates="posts", uselist=False) + } + ) + """ + config: SQLAlchemyConfig binds: t.Dict[str, t.Union[Bind, AsyncBind]] - Model: t.Type[sa.orm.DeclarativeBase] + Base: t.Type[sa.orm.DeclarativeBase] - def __init__(self, config: SQLAlchemyConfig, initialize: bool = True): + def __init__( + self, + config: t.Optional[SQLAlchemyConfig] = None, + initialize: bool = True, + ): self.config = config if initialize: self.initialize() - def initialize(self): - if issubclass(self.config.model_class, sa.orm.DeclarativeBase): - Model = self.config.model_class # type: ignore - else: - - class Model(self.config.model_class, sa.orm.DeclarativeBase): - pass + def initialize(self, config: t.Optional[SQLAlchemyConfig] = None): + if config is not None: + self.config = config + if self.config is None: + self.config = SQLAlchemyConfig.default() - type_annotation_map = {} - for base_class in Model.__mro__[::-1]: - if base_class is Model: - continue - base_map = getattr(base_class, "type_annotation_map", {}).copy() - type_annotation_map.update(base_map) + if issubclass(self.config.base_class, sa.orm.DeclarativeBase): + Base = self.config.base_class # type: ignore + else: + Base = type("Base", (self.config.base_class, sa.orm.DeclarativeBase), {}) - Model.registry.type_annotation_map.update(type_annotation_map) - self.Model = Model + self.Base = Base - self.binds = {} - for name, bind_config in self.config.binds.items(): - is_async = isinstance(bind_config, AsyncBindConfig) - if is_async: - self.binds[name] = AsyncBind(bind_config, self.metadata) - else: - self.binds[name] = Bind(bind_config, self.metadata) + if not hasattr(self, "binds"): + self.binds = {} + for name, bind_config in self.config.binds.items(): + is_async = isinstance(bind_config, AsyncBindConfig) + factory = AsyncBind if is_async else Bind + self.binds[name] = factory(name, bind_config.engine.url, bind_config, self.metadata) - @classmethod - def default(cls): - return cls(SQLAlchemyConfig()) + def get_bind(self, bind: str = "default"): + return self.binds[bind] @property def bind(self) -> Bind: @@ -66,10 +291,11 @@ def bind(self) -> Bind: @property def metadata(self) -> sa.MetaData: - return self.Model.metadata + return self.Base.metadata - def get_bind(self, bind: str = "default"): - return self.binds[bind] + @property + def registry(self) -> sa.orm.registry: + return self.Base.registry def create_all(self, bind: str = "default"): return self.binds[bind].create_all() diff --git a/src/quart_sqlalchemy/testing/fake.py b/src/quart_sqlalchemy/testing/fake.py new file mode 100644 index 0000000..e69de29 diff --git a/src/quart_sqlalchemy/testing/signals.py b/src/quart_sqlalchemy/testing/signals.py new file mode 100644 index 0000000..d9c1cb9 --- /dev/null +++ b/src/quart_sqlalchemy/testing/signals.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import sqlalchemy +import sqlalchemy.orm +from blinker import Namespace +from quart.signals import AsyncNamespace + + +sa = sqlalchemy + +sync_signals = Namespace() +async_signals = AsyncNamespace() + + +load_test_fixtures = sync_signals.signal( + "quart-sqlalchemy.testing.fixtures.load.sync", + doc="""Fired to load test fixtures into a freshly instantiated test database. + + No default signal handlers exist for this signal as the logic is very application dependent. + + Example: + + @signals.framework_extension_load_fixtures.connect + def handle(sender: QuartSQLAlchemy, app: Quart): + bind = sender.get_bind("default") + with bind.Session() as session: + with session.begin(): + session.add_all( + [ + models.User(username="user1"), + models.User(username="user2"), + ] + ) + session.commit() + + Handler signature: + def handle(sender: QuartSQLAlchemy, app: Quart): + ... + """, +) diff --git a/src/quart_sqlalchemy/testing/transaction.py b/src/quart_sqlalchemy/testing/transaction.py index 507a963..d7ff873 100644 --- a/src/quart_sqlalchemy/testing/transaction.py +++ b/src/quart_sqlalchemy/testing/transaction.py @@ -23,13 +23,13 @@ def __init__(self, bind: "Bind", savepoint: bool = False): self.savepoint = savepoint self.bind = bind - def Session(self, **options): + def Session(self, **options: t.Any) -> sa.orm.Session: options.update(bind=self.connection) if self.savepoint: options.update(join_transaction_mode="create_savepoint") return self.bind.Session(**options) - def begin(self): + def open(self) -> None: self.connection = self.bind.engine.connect() self.trans = self.connection.begin() @@ -59,7 +59,7 @@ def close(self, exc: t.Optional[Exception] = None) -> None: ) def __enter__(self): - self.begin() + self.open() return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -82,7 +82,7 @@ class AsyncTestTransaction(TestTransaction): def __init__(self, bind: "AsyncBind", savepoint: bool = False): super().__init__(bind, savepoint=savepoint) - async def begin(self): + async def open(self): self.connection = await self.bind.engine.connect() self.trans = await self.connection.begin() @@ -112,7 +112,7 @@ async def close(self, exc: t.Optional[Exception] = None) -> None: ) async def __aenter__(self): - await self.begin() + await self.open() return self async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/src/quart_sqlalchemy/types.py b/src/quart_sqlalchemy/types.py index 2c5b96f..73e6c6f 100644 --- a/src/quart_sqlalchemy/types.py +++ b/src/quart_sqlalchemy/types.py @@ -7,6 +7,7 @@ import sqlalchemy.orm import sqlalchemy.sql import typing_extensions as tx +from sqlalchemy import SQLColumnExpression from sqlalchemy.orm.interfaces import ORMOption as _ORMOption from sqlalchemy.sql._typing import _ColumnExpressionArgument from sqlalchemy.sql._typing import _ColumnsClauseArgument @@ -15,11 +16,18 @@ sa = sqlalchemy + +class Empty: + pass + + +EmptyType = t.Type[Empty] + SessionT = t.TypeVar("SessionT", bound=sa.orm.Session) EntityT = t.TypeVar("EntityT", bound=sa.orm.DeclarativeBase) EntityIdT = t.TypeVar("EntityIdT", bound=t.Any) -ColumnExpr = _ColumnExpressionArgument +ColumnExpr = SQLColumnExpression Selectable = _ColumnsClauseArgument DMLTable = _DMLTableArgument ORMOption = _ORMOption @@ -40,3 +48,8 @@ SABind = t.Union[ sa.Engine, sa.Connection, sa.ext.asyncio.AsyncEngine, sa.ext.asyncio.AsyncConnection ] + + +class Operator(tx.Protocol): + def __call__(self, __a: object, __b: object) -> object: + ... diff --git a/tests/base.py b/tests/base.py index 1ff8756..8afa3fb 100644 --- a/tests/base.py +++ b/tests/base.py @@ -2,6 +2,7 @@ import random import typing as t +from copy import deepcopy from datetime import datetime import pytest @@ -10,8 +11,12 @@ from quart import Quart from sqlalchemy.orm import Mapped -from quart_sqlalchemy import SQLAlchemyConfig +from quart_sqlalchemy import Base from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.model.mixins import ComparableMixin +from quart_sqlalchemy.model.mixins import DynamicArgsMixin +from quart_sqlalchemy.model.mixins import EagerDefaultsMixin +from quart_sqlalchemy.model.mixins import TableNameMixin from . import constants @@ -21,27 +26,156 @@ class SimpleTestBase: @pytest.fixture(scope="class") - def app(self, request): + def Base(self) -> t.Type[t.Any]: + return Base + + @pytest.fixture(scope="class") + def app_config(self, Base): + config = deepcopy(constants.simple_config) + config.update(SQLALCHEMY_BASE_CLASS=Base) + return config + + @pytest.fixture(scope="class") + def app(self, app_config, request): app = Quart(request.module.__name__) - app.config.from_mapping({"TESTING": True}) + app.config.from_mapping(app_config) + app.config["TESTING"] = True return app @pytest.fixture(scope="class") - def sqlalchemy_config(self): - return SQLAlchemyConfig.parse_obj(constants.simple_mapping_config) + def db(self, app: Quart) -> t.Generator[QuartSQLAlchemy, None, None]: + db = QuartSQLAlchemy(app=app) - @pytest.fixture(scope="class") - def db(self, sqlalchemy_config, app: Quart) -> QuartSQLAlchemy: - return QuartSQLAlchemy(sqlalchemy_config, app) - # yield db - # db.drop_all() + yield db @pytest.fixture(scope="class") - def models(self, app: Quart, db: QuartSQLAlchemy) -> t.Mapping[str, t.Type[t.Any]]: - class Todo(db.Model): + def models( + self, app: Quart, db: QuartSQLAlchemy + ) -> t.Generator[t.Mapping[str, t.Type[t.Any]], None, None]: + class Todo(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") + user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) + + user: Mapped[t.Optional["User"]] = sa.orm.relationship( + back_populates="todos", lazy="noload", uselist=False + ) + + class User(db.Base): id: Mapped[int] = sa.orm.mapped_column( - sa.Identity(), primary_key=True, autoincrement=True + primary_key=True, + autoincrement=True, ) + name: Mapped[str] = sa.orm.mapped_column(default="default") + + created_at: Mapped[datetime] = sa.orm.mapped_column( + default=sa.func.now(), + server_default=sa.FetchedValue(), + ) + + time_updated: Mapped[datetime] = sa.orm.mapped_column( + default=sa.func.now(), + onupdate=sa.func.now(), + server_default=sa.FetchedValue(), + server_onupdate=sa.FetchedValue(), + ) + + todos: Mapped[t.List[Todo]] = sa.orm.relationship(lazy="noload", back_populates="user") + + yield dict(todo=Todo, user=User) + # We need to cleanup these objects that like to retain state beyond the fixture scope lifecycle + Base.registry.dispose() + Base.metadata.clear() + + @pytest.fixture(scope="class", autouse=True) + def create_drop_all(self, db: QuartSQLAlchemy, models): + db.create_all() + yield + db.drop_all() + + @pytest.fixture(scope="class") + def Todo(self, models: t.Mapping[str, t.Type[t.Any]]) -> t.Type[sa.orm.DeclarativeBase]: + return models["todo"] + + @pytest.fixture(scope="class") + def User(self, models: t.Mapping[str, t.Type[t.Any]]) -> t.Type[sa.orm.DeclarativeBase]: + return models["user"] + + @pytest.fixture(scope="class") + def _user_fixtures(self, User: t.Type[t.Any], Todo: t.Type[t.Any]): + users = [] + for i in range(5): + user = User(name=f"user: {i}") + for j in range(random.randint(0, 6)): + todo = Todo(title=f"todo: {j}") + user.todos.append(todo) + users.append(user) + return users + + @pytest.fixture(scope="class") + def _add_fixtures( + self, db: QuartSQLAlchemy, User: t.Type[t.Any], Todo: t.Type[t.Any], _user_fixtures + ) -> None: + with db.bind.Session() as s: + with s.begin(): + s.add_all(_user_fixtures) + + @pytest.fixture(scope="class", autouse=True) + def db_fixtures( + self, db: QuartSQLAlchemy, User: t.Type[t.Any], Todo: t.Type[t.Any], _add_fixtures + ) -> t.Dict[t.Type[t.Any], t.Sequence[t.Any]]: + with db.bind.Session() as s: + users = s.scalars(sa.select(User).options(sa.orm.selectinload(User.todos))).all() + todos = s.scalars(sa.select(Todo).options(sa.orm.selectinload(Todo.user))).all() + + return {User: users, Todo: todos} + + +class MixinTestBase: + default_mixins = ( + DynamicArgsMixin, + EagerDefaultsMixin, + TableNameMixin, + ComparableMixin, + ) + extra_mixins = () + + @pytest.fixture(scope="class") + def Base(self) -> t.Type[t.Any]: + return type( + "Base", + tuple(self.extra_mixins + self.default_mixins), + {"__abstract__": True}, + ) + + @pytest.fixture(scope="class") + def app_config(self, Base): + config = deepcopy(constants.simple_config) + config.update(SQLALCHEMY_BASE_CLASS=Base) + return config + + @pytest.fixture(scope="class") + def app(self, app_config, request): + app = Quart(request.module.__name__) + app.config.from_mapping(app_config) + app.config["TESTING"] = True + return app + + @pytest.fixture(scope="class") + def db(self, app: Quart) -> t.Generator[QuartSQLAlchemy, None, None]: + db = QuartSQLAlchemy(app=app) + + yield db + + # It's very important to clear the class _instances dict before recreating binds with the same name. + # Bind._instances.clear() + + @pytest.fixture(scope="class") + def models( + self, app: Quart, db: QuartSQLAlchemy + ) -> t.Generator[t.Mapping[str, t.Type[t.Any]], None, None]: + class Todo(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) title: Mapped[str] = sa.orm.mapped_column(default="default") user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) @@ -49,9 +183,8 @@ class Todo(db.Model): back_populates="todos", lazy="noload", uselist=False ) - class User(db.Model): + class User(db.Base): id: Mapped[int] = sa.orm.mapped_column( - sa.Identity(), primary_key=True, autoincrement=True, ) @@ -71,7 +204,10 @@ class User(db.Model): todos: Mapped[t.List[Todo]] = sa.orm.relationship(lazy="noload", back_populates="user") - return dict(todo=Todo, user=User) + yield dict(todo=Todo, user=User) + # We need to cleanup these objects that like to retain state beyond the fixture scope lifecycle + Base.registry.dispose() + Base.metadata.clear() @pytest.fixture(scope="class", autouse=True) def create_drop_all(self, db: QuartSQLAlchemy, models): @@ -119,8 +255,10 @@ def db_fixtures( class AsyncTestBase(SimpleTestBase): @pytest.fixture(scope="class") - def sqlalchemy_config(self): - return SQLAlchemyConfig.parse_obj(constants.async_mapping_config) + def app_config(self, Base): + config = deepcopy(constants.async_config) + config.update(SQLALCHEMY_BASE_CLASS=Base) + return config @pytest.fixture(scope="class", autouse=True) async def create_drop_all(self, db: QuartSQLAlchemy, models) -> t.AsyncGenerator[None, None]: @@ -151,5 +289,31 @@ async def db_fixtures( class ComplexTestBase(SimpleTestBase): @pytest.fixture(scope="class") - def sqlalchemy_config(self): - return SQLAlchemyConfig.parse_obj(constants.complex_mapping_config) + def app_config(self, Base): + config = deepcopy(constants.complex_config) + config.update(SQLALCHEMY_BASE_CLASS=Base) + return config + + +# class CustomMixinTestBase(SimpleTestBase): +# default_mixins = ( +# DynamicArgsMixin, +# EagerDefaultsMixin, +# TableNameMixin, +# ) +# additional_mixins = () + +# @pytest.fixture(scope="class") +# def Base(self) -> t.Type[t.Any]: +# return type( +# "Base", +# tuple(self.additional_mixins + self.default_mixins), +# {"__abstract__": True}, +# ) + +# @pytest.fixture(scope="class") +# def app_config(self, Base): +# config = deepcopy(constants.simple_config) +# config["SQLALCHEMY_BASE_CLASS"] = Base +# config["TESTING"] = True +# return config diff --git a/tests/conftest.py b/tests/conftest.py index eefd7d7..e69de29 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,49 +0,0 @@ -from __future__ import annotations - -import typing as t - -import pytest -import sqlalchemy -import sqlalchemy.orm -from quart import Quart -from sqlalchemy.orm import Mapped - -from quart_sqlalchemy import SQLAlchemyConfig -from quart_sqlalchemy.framework import QuartSQLAlchemy - -from . import constants - - -sa = sqlalchemy - - -@pytest.fixture(scope="session") -def app(request: pytest.FixtureRequest) -> Quart: - app = Quart(request.module.__name__) - app.config.from_mapping({"TESTING": True}) - return app - - -@pytest.fixture(scope="session") -def sqlalchemy_config(): - return SQLAlchemyConfig.parse_obj(constants.simple_mapping_config) - - -@pytest.fixture(scope="session") -def db(sqlalchemy_config, app: Quart) -> QuartSQLAlchemy: - return QuartSQLAlchemy(sqlalchemy_config, app) - - -@pytest.fixture(name="Todo", scope="session") -def _todo_fixture( - app: Quart, db: QuartSQLAlchemy -) -> t.Generator[t.Type[sa.orm.DeclarativeBase], None, None]: - class Todo(db.Model): - id: Mapped[int] = sa.orm.mapped_column(sa.Identity(), primary_key=True, autoincrement=True) - title: Mapped[str] = sa.orm.mapped_column(default="default") - - db.create_all() - - yield Todo - - db.drop_all() diff --git a/tests/constants.py b/tests/constants.py index 4fb8240..6a2277c 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,19 +1,18 @@ from quart_sqlalchemy import Base -simple_mapping_config = { - "model_class": Base, - "binds": { +simple_config = { + "SQLALCHEMY_BINDS": { "default": { "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, "session": {"expire_on_commit": False}, } }, + "SQLALCHEMY_BASE_CLASS": Base, } -complex_mapping_config = { - "model_class": Base, - "binds": { +complex_config = { + "SQLALCHEMY_BINDS": { "default": { "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, "session": {"expire_on_commit": False}, @@ -28,14 +27,15 @@ "session": {"expire_on_commit": False}, }, }, + "SQLALCHEMY_BASE_CLASS": Base, } -async_mapping_config = { - "model_class": Base, - "binds": { +async_config = { + "SQLALCHEMY_BINDS": { "default": { "engine": {"url": "sqlite+aiosqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, "session": {"expire_on_commit": False}, - } + }, }, + "SQLALCHEMY_BASE_CLASS": Base, } diff --git a/tests/integration/concurrency/__init__.py b/tests/integration/concurrency/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/concurrency/with_for_update.py b/tests/integration/concurrency/with_for_update.py new file mode 100644 index 0000000..a9fa552 --- /dev/null +++ b/tests/integration/concurrency/with_for_update.py @@ -0,0 +1,103 @@ +import logging +import threading +import time + +import pytest +import sqlalchemy +import sqlalchemy.orm + + +sa = sqlalchemy + +logging.basicConfig(level=logging.DEBUG) +logging.getLogger("sqlalchemy").setLevel(logging.INFO) + +log = logging.getLogger(__name__) + + +class Base(sa.orm.DeclarativeBase): + pass + + +class Thing(Base): + __tablename__ = "things" + + id = sa.Column(sa.Integer, primary_key=True) + status = sa.Column(sa.String) + + +@pytest.fixture(scope="module") +def engine(): + engine = sa.create_engine("sqlite:///") + # engine = sa.create_engine("postgresql+psycopg2://spikes:sesame@localhost/spikes") + Base.metadata.create_all(engine) + + yield engine + + Base.metadata.drop_all(engine) + + +@pytest.fixture(scope="module") +def connection(engine): + with engine.connect() as conn: + yield conn + + +@pytest.fixture +def db(connection): + transaction = connection.begin() + session = sa.orm.Session(bind=connection) + + # now we can even `.commit()` such session + yield session + + session.close() + transaction.rollback() + + +def test_select_for_update(engine): + # scoped_db = scoped_session(sessionmaker(bind=connection)) + scoped_db = sa.orm.scoped_session(sa.orm.sessionmaker(bind=engine)) + db = scoped_db() + db.add(Thing(status="old")) + db.commit() + + def first(event, sess_factory, status): + sess = sess_factory() + # thing = sess.query(Thing).get(1) + thing = sess.query(Thing).with_for_update().get(1) + event.set() # poke second thread + log.debug("Make him wait for a while") + time.sleep(0.263) + thing.status = status + sess.commit() + log.debug("Done!") + # it is always better to explicitly `.remove()` scoped sessions, but + # in this case it is not necessary because it will be garbage-collected + # sess_factory.remove() + + def second(event, sess_factory, status): + event.wait() # ensure we are called in the right moment + sess = sess_factory() + # thing = sess.query(Thing).get(1) + thing = sess.query(Thing).with_for_update().get(1) + thing.status = status + sess.commit() + + event = threading.Event() + th1 = threading.Thread(target=first, args=(event, scoped_db, "new")) + th2 = threading.Thread(target=second, args=(event, scoped_db, "brand_new")) + + th1.start() + th2.start() + + th1.join() + th2.join() + + # assert db.query(Thing).filter_by(id=1).one().status == 'new' + t = db.query(Thing).get(1) + # it is only mandatory to remove session here, seems like it is not + # garbage-collected becasue it is in `assert` statement (not sure about that) + scoped_db.remove() + + assert t.status == "brand_new" diff --git a/tests/integration/framework/smoke_test.py b/tests/integration/framework/smoke_test.py index 682d0e4..7154fd7 100644 --- a/tests/integration/framework/smoke_test.py +++ b/tests/integration/framework/smoke_test.py @@ -40,7 +40,7 @@ def test_simple_transactional_orm_flow(self, db: QuartSQLAlchemy, Todo: t.Any): def test_simple_transactional_core_flow(self, db: QuartSQLAlchemy, Todo: t.Any): with db.bind.engine.connect() as conn: with conn.begin(): - result = conn.execute(sa.insert(Todo)) + result = conn.execute(sa.insert(Todo).values(title="default")) insert_row = result.inserted_primary_key select_row = conn.execute(sa.select(Todo).where(Todo.id == insert_row.id)).one() @@ -56,15 +56,3 @@ def test_simple_transactional_core_flow(self, db: QuartSQLAlchemy, Todo: t.Any): with db.bind.engine.connect() as conn: with pytest.raises(sa.exc.NoResultFound): conn.execute(sa.select(Todo).where(Todo.id == insert_row.id)).one() - - def test_orm_models_comparable(self, db: QuartSQLAlchemy, Todo: t.Any): - with db.bind.Session() as s: - with s.begin(): - todo = Todo() - s.add(todo) - s.flush() - s.refresh(todo) - - with db.bind.Session() as s: - select_todo = s.scalars(sa.select(Todo).where(Todo.id == todo.id)).one() - assert todo == select_todo diff --git a/tests/integration/model/mixins_test.py b/tests/integration/model/mixins_test.py index a88b4d9..43928e6 100644 --- a/tests/integration/model/mixins_test.py +++ b/tests/integration/model/mixins_test.py @@ -8,21 +8,26 @@ from sqlalchemy.orm import Mapped from quart_sqlalchemy import SQLAlchemy -from quart_sqlalchemy.model import Base -from quart_sqlalchemy.model import SoftDeleteMixin +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.model.mixins import ComparableMixin +from quart_sqlalchemy.model.mixins import RecursiveDictMixin +from quart_sqlalchemy.model.mixins import ReprMixin +from quart_sqlalchemy.model.mixins import SimpleDictMixin +from quart_sqlalchemy.model.mixins import SoftDeleteMixin +from quart_sqlalchemy.model.mixins import TotalOrderMixin -from ...base import SimpleTestBase +from ... import base sa = sqlalchemy -class TestSoftDeleteFeature(SimpleTestBase): - @pytest.fixture - def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[Base], None, None]: - class Post(SoftDeleteMixin, db.Model): - id: Mapped[int] = sa.orm.mapped_column(primary_key=True) - title: Mapped[str] = sa.orm.mapped_column() +class TestSoftDeleteFeature(base.MixinTestBase): + @pytest.fixture(scope="class") + def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[t.Any], None, None]: + class Post(SoftDeleteMixin, db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) user: Mapped[t.Optional[User]] = sa.orm.relationship(backref="posts") @@ -53,3 +58,113 @@ def test_inactive_filtered(self, db: SQLAlchemy, Post: t.Type[t.Any]): assert select_post.id == post.id assert select_post.is_active is False + + +class TestComparableMixin(base.MixinTestBase): + extra_mixins = (TotalOrderMixin,) + + def test_orm_models_comparable(self, db: QuartSQLAlchemy, Todo: t.Any): + assert ComparableMixin in self.default_mixins + + with db.bind.Session() as s: + with s.begin(): + todos = [Todo() for _ in range(5)] + s.add_all(todos) + + with db.bind.Session() as s: + todos = s.scalars(sa.select(Todo).order_by(Todo.id)).all() + + todo1, todo2, *_ = todos + assert todo1 < todo2 + + +class TestReprMixin(base.MixinTestBase): + @pytest.fixture(scope="class") + def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[t.Any], None, None]: + class Post(ReprMixin, db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") + user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) + + user: Mapped[t.Optional[User]] = sa.orm.relationship(backref="posts") + + db.create_all() + yield Post + + def test_mixin_generates_repr(self, db: QuartSQLAlchemy, Post: t.Any): + with db.bind.Session() as s: + with s.begin(): + post = Post() + s.add(post) + s.flush() + s.refresh(post) + + assert repr(post) == f"<{type(post).__name__} {post.id}>" + + +class TestSimpleDictMixin(base.MixinTestBase): + extra_mixins = (SimpleDictMixin,) + + @pytest.fixture(scope="class") + def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[t.Any], None, None]: + class Post(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") + user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) + + user: Mapped[t.Optional[User]] = sa.orm.relationship(backref="posts") + + db.create_all() + yield Post + + def test_mixin_converts_model_to_dict(self, db: QuartSQLAlchemy, Post: t.Any, User: t.Any): + with db.bind.Session() as s: + with s.begin(): + user = s.scalars(sa.select(User)).first() + post = Post(user=user) + s.add(post) + s.flush() + s.refresh(post.user) + + with db.bind.Session() as s: + with s.begin(): + user = s.scalars(sa.select(User).options(sa.orm.selectinload(User.posts))).first() + + data = user.to_dict() + + for field in data: + assert data[field] == getattr(user, field) + + +class TestRecursiveMixin(base.MixinTestBase): + extra_mixins = (RecursiveDictMixin,) + + @pytest.fixture(scope="class") + def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[t.Any], None, None]: + class Post(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") + user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) + + user: Mapped[t.Optional[User]] = sa.orm.relationship(backref="posts") + + db.create_all() + yield Post + + def test_mixin_converts_model_to_dict(self, db: QuartSQLAlchemy, Post: t.Any, User: t.Any): + with db.bind.Session() as s: + with s.begin(): + user = s.scalars(sa.select(User)).first() + post = Post(user=user) + s.add(post) + s.flush() + s.refresh(post.user) + + with db.bind.Session() as s: + with s.begin(): + user = s.scalars(sa.select(User).options(sa.orm.selectinload(User.posts))).first() + + data = user.to_dict() + + for col in sa.inspect(user).mapper.columns: + assert data[col.name] == getattr(user, col.name) diff --git a/tests/integration/model/model_test.py b/tests/integration/model/model_test.py new file mode 100644 index 0000000..651d908 --- /dev/null +++ b/tests/integration/model/model_test.py @@ -0,0 +1,59 @@ +from datetime import datetime + +import pytest +import sqlalchemy +import sqlalchemy.event +import sqlalchemy.exc +import sqlalchemy.ext +import sqlalchemy.ext.asyncio +import sqlalchemy.orm +import sqlalchemy.util +from sqlalchemy.orm import Mapped + +from quart_sqlalchemy import Base +from quart_sqlalchemy import Bind +from quart_sqlalchemy import SQLAlchemy +from quart_sqlalchemy import SQLAlchemyConfig +from quart_sqlalchemy.model.model import BaseMixins + + +sa = sqlalchemy + + +class TestSQLAlchemyWithCustomModelClass: + def test_base_class_with_declarative_preserves_class_and_table_metadata(self): + """This is nice to have as it decouples quart and quart_sqlalchemy from the data + models themselves. + """ + + class User(Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + db.create_all() + + with db.bind.Session() as s: + with s.begin(): + user = User() + s.add(user) + s.flush() + s.refresh(user) + + Base.registry.dispose() + Bind._instances.clear() + + def test_sqla_class_adds_declarative_base_when_missing_from_base_class(self): + db = SQLAlchemy(SQLAlchemyConfig(base_class=BaseMixins)) + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + + db.create_all() + + with db.bind.Session() as s: + with s.begin(): + user = User() + s.add(user) + s.flush() + s.refresh(user) diff --git a/tests/integration/retry_test.py b/tests/integration/retry_test.py index 4e4d60b..a4bb6a6 100644 --- a/tests/integration/retry_test.py +++ b/tests/integration/retry_test.py @@ -46,7 +46,7 @@ def test_retrying_session(self, db: SQLAlchemy, Todo: t.Type[t.Any], mocker): # s.commit() def test_retrying_session_class(self, db: SQLAlchemy, Todo: t.Type[t.Any], mocker): - class Unique(db.Model): + class Unique(db.Base): id: Mapped[int] = sa.orm.mapped_column( sa.Identity(), primary_key=True, autoincrement=True ) @@ -57,7 +57,7 @@ class Unique(db.Model): db.create_all() with retrying_session(db.bind) as s: - todo = Todo(title="hello") + todo = Unique(name="hello") s.add(todo) diff --git a/workspace.code-workspace b/workspace.code-workspace index 10a11a5..aa9e12e 100644 --- a/workspace.code-workspace +++ b/workspace.code-workspace @@ -13,6 +13,7 @@ "source.organizeImports": true } }, - "esbonio.sphinx.confDir": "" + "esbonio.sphinx.confDir": "", + "python.linting.pylintEnabled": false } } From bcc5a9b9f87d6234ea1a4cfcae95758c682fbc1a Mon Sep 17 00:00:00 2001 From: Joe Black Date: Wed, 12 Apr 2023 17:59:47 -0400 Subject: [PATCH 9/9] clean up --- src/quart_sqlalchemy/sim/handle.py | 83 ------------- src/quart_sqlalchemy/sim/logic.py | 188 ----------------------------- src/quart_sqlalchemy/sim/repo.py | 66 +++++++--- 3 files changed, 49 insertions(+), 288 deletions(-) diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py index 47d0cf3..51ea99d 100644 --- a/src/quart_sqlalchemy/sim/handle.py +++ b/src/quart_sqlalchemy/sim/handle.py @@ -173,16 +173,6 @@ def get_or_create_by_email_and_client_id( ) return auth_user - def get_by_id_and_validate_exists(self, auth_user_id): - """This function helps formalize how a non-existent auth user should be handled.""" - auth_user = self.logic.AuthUser.get_by_id(self.session_factory(), auth_user_id) - if auth_user is None: - raise RuntimeError('resource_name="auth_user"') - return auth_user - - # This function is reserved for consolidating into a canonical user. Do not - # call this function under other circumstances as it will automatically set - # the user as verified. See ch-25343 for additional details. def create_verified_user( self, client_id, @@ -190,8 +180,6 @@ def create_verified_user( user_type=EntityType.FORTMATIC.value, **kwargs, ): - # with self.session_factory() as session: - # with self.logic.begin(ro=False) as session: session = self.session_factory() with session.begin_nested(): auid = self.logic.AuthUser.add_by_email_and_client_id( @@ -231,21 +219,6 @@ def get_by_client_id_and_user_type( limit=limit, ) - def get_by_client_ids_and_user_type( - self, - client_ids, - user_type, - offset=None, - limit=None, - ): - return self.logic.AuthUser.get_by_client_ids_and_user_type( - self.session_factory(), - client_ids, - user_type, - offset=offset, - limit=limit, - ) - def exist_by_email_client_id_and_user_type(self, email, client_id, user_type): return self.logic.AuthUser.exist_by_email_and_client_id( self.session_factory(), @@ -257,11 +230,6 @@ def exist_by_email_client_id_and_user_type(self, email, client_id, user_type): def update_email_by_id(self, model_id, email): return self.logic.AuthUser.update_by_id(self.session_factory(), model_id, email=email) - def update_phone_number_by_id(self, model_id, phone_number): - return self.logic.AuthUser.update_by_id( - self.session_factory(), model_id, phone_number=phone_number - ) - def get_by_email_client_id_and_user_type(self, email, client_id, user_type): return self.logic.AuthUser.get_by_email_and_client_id( self.session_factory(), @@ -299,55 +267,9 @@ def set_role_by_email_magic_client_id(self, email, magic_client_id, role): return self.logic.AuthUser.update_by_id(session, auth_user.id, **{role: True}) - def search_by_client_id_and_substring( - self, - client_id, - substring, - offset=None, - limit=10, - ): - if not isinstance(substring, str) or len(substring) < 3: - raise InvalidSubstringError() - - auth_users = self.logic.AuthUser.get_by_client_id_with_substring_search( - self.session_factory(), - client_id, - substring, - offset=offset, - limit=limit, - ) - - return auth_users - - def is_magic_connect_enabled(self, auth_user_id=None, auth_user=None): - if auth_user is None and auth_user_id is None: - raise Exception("At least one argument needed: auth_user_id or auth_user.") - - if auth_user is None: - auth_user = self.get_by_id(auth_user_id) - - return auth_user.user_type == EntityType.CONNECT.value - def mark_as_inactive(self, auth_user_id): self.logic.AuthUser.update_by_id(self.session_factory(), auth_user_id, is_active=False) - def get_by_email_and_wallet_type_for_interop(self, email, wallet_type, network): - """ - Opinionated method for fetching AuthWallets by email address, wallet_type and network. - """ - return self.logic.AuthUser.get_by_email_for_interop( - self.session_factory(), - email=email, - wallet_type=wallet_type, - network=network, - ) - - def get_magic_connect_auth_user(self, auth_user_id): - auth_user = self.get_by_id_and_validate_exists(auth_user_id) - if not auth_user.is_magic_connect_user: - raise RuntimeError("RequestForbidden") - return auth_user - @signals.auth_user_duplicate.connect def handle_duplicate_auth_users( @@ -355,12 +277,7 @@ def handle_duplicate_auth_users( original_auth_user_id: ObjectID, duplicate_auth_user_ids: t.Sequence[ObjectID], ) -> None: - logger.info(f"{len(duplicate_auth_user_ids)} dupe(s) found for {original_auth_user_id}") - for dupe_id in duplicate_auth_user_ids: - logger.info( - f"marking auth_user_id {dupe_id} as inactive, in favor of original {original_auth_user_id}", - ) app.container.logic().AuthUser.update_by_id(dupe_id, is_active=False) diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py index 7b969bc..38c58a4 100644 --- a/src/quart_sqlalchemy/sim/logic.py +++ b/src/quart_sqlalchemy/sim/logic.py @@ -226,43 +226,6 @@ def get_by_session_token( ) ) - @provide_global_contextual_session - def get_or_add_by_phone_number_and_client_id( - self, - session, - client_id, - phone_number, - user_type=EntityType.FORTMATIC.value, - ): - if phone_number is None: - raise MissingPhoneNumber() - - row = self.get_by_phone_number_and_client_id( - session=session, - phone_number=phone_number, - client_id=client_id, - user_type=user_type, - ) - - if row: - return row - - row = self._repository.add( - session=session, - phone_number=phone_number, - client_id=client_id, - user_type=user_type, - provenance=Provenance.SMS, - ) - logger.info( - "New auth user (id: {}) created by phone number (client_id: {})".format( - row.id, - client_id, - ), - ) - - return row - @provide_global_contextual_session def get_by_active_identifier_and_client_id( self, @@ -319,25 +282,6 @@ def get_by_email_and_client_id( for_update=for_update, ) - @provide_global_contextual_session - def get_by_phone_number_and_client_id( - self, - session, - phone_number, - client_id, - user_type=EntityType.FORTMATIC.value, - ): - if phone_number is None: - raise MissingPhoneNumber() - - return self.get_by_active_identifier_and_client_id( - session=session, - identifier_field=auth_user_model.phone_number, - identifier_value=phone_number, - client_id=client_id, - user_type=user_type, - ) - @provide_global_contextual_session def exist_by_email_and_client_id( self, @@ -392,17 +336,6 @@ def get_user_count_by_client_id_and_user_type(self, session, client_id, user_typ return session.execute(query).scalar() - @provide_global_contextual_session - def get_by_client_id_and_global_auth_user(self, session, client_id, global_auth_user_id): - return self._repository.get_by( - session=session, - filters=[ - auth_user_model.client_id == client_id, - auth_user_model.user_type == EntityType.CONNECT.value, - auth_user_model.global_auth_user_id == global_auth_user_id, - ], - ) - @provide_global_contextual_session def get_by_client_id_and_user_type( self, @@ -420,61 +353,6 @@ def get_by_client_id_and_user_type( limit=limit, ) - @provide_global_contextual_session - def get_by_client_ids_and_user_type( - self, - session, - client_ids, - user_type, - offset=None, - limit=None, - ): - if not client_ids: - return [] - - return self._repository.get_by( - session, - filters=[ - auth_user_model.client_id.in_(client_ids), - auth_user_model.user_type == user_type, - auth_user_model.date_verified != None, - ], - offset=offset, - limit=limit, - order_by_clause=auth_user_model.id.desc(), - ) - - @provide_global_contextual_session - def get_by_client_id_with_substring_search( - self, - session, - client_id, - substring, - offset=None, - limit=10, - join_list=None, - ): - return self._repository.get_by( - session, - filters=[ - auth_user_model.client_id == client_id, - auth_user_model.user_type == EntityType.MAGIC.value, - sa.or_( - auth_user_model.provenance == Provenance.SMS, - auth_user_model.provenance == Provenance.LINK, - auth_user_model.provenance == None, # noqa: E711 - ), - sa.or_( - auth_user_model.phone_number.contains(substring), - auth_user_model.email.contains(substring), - ), - ], - offset=offset, - limit=limit, - order_by_clause=auth_user_model.id.desc(), - join_list=join_list, - ) - @provide_global_contextual_session def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): yield from self._repository.yield_by_chunk( @@ -518,72 +396,6 @@ def get_by_email( join_list=join_list, ) - @provide_global_contextual_session - def get_by_email_for_interop( - self, - session, - email: str, - wallet_type: WalletType, - network: str, - ) -> t.List[auth_user_model]: - """ - Custom method for searching for users eligible for interop. Unfortunately, this can't be done with the current - abstractions in our sql_repository, so this is a one-off bespoke method. - If we need to add more similar queries involving eager loading and multiple joins, we can add an abstraction - inside the repository. - """ - - query = ( - session.query(auth_user_model) - .join( - auth_user_model.wallets.and_( - auth_wallet_model.wallet_type == str(wallet_type) - ).and_(auth_wallet_model.network == network) - ) - .options(sa.orm.contains_eager(auth_user_model.wallets)) - .join( - auth_user_model.magic_client.and_( - magic_client_model.connect_interop == ConnectInteropStatus.ENABLED, - ), - ) - .options(sa.orm.contains_eager(auth_user_model.magic_client)) - .filter( - auth_wallet_model.wallet_type == wallet_type, - auth_wallet_model.network == network, - ) - .filter( - auth_user_model.email == email, - auth_user_model.user_type == EntityType.MAGIC.value, - ) - .populate_existing() - ) - - return query.all() - - @provide_global_contextual_session - def get_linked_users(self, session, primary_auth_user_id, join_list, no_op=False): - # TODO(magic-ravi#67899|2022-12-30): Re-enable account linked users for interop. Remove no_op flag. - if no_op: - return [] - else: - return self._repository.get_by( - session, - filters=[ - auth_user_model.user_type == EntityType.MAGIC.value, - auth_user_model.linked_primary_auth_user_id == primary_auth_user_id, - ], - join_list=join_list, - ) - - @provide_global_contextual_session - def get_by_phone_number(self, session, phone_number): - return self._repository.get_by( - session, - filters=[ - auth_user_model.phone_number == phone_number, - ], - ) - class AuthWallet(LogicComponent[auth_wallet_model, ObjectID, sa.orm.Session]): model = auth_wallet_model diff --git a/src/quart_sqlalchemy/sim/repo.py b/src/quart_sqlalchemy/sim/repo.py index 959efb2..618af04 100644 --- a/src/quart_sqlalchemy/sim/repo.py +++ b/src/quart_sqlalchemy/sim/repo.py @@ -29,6 +29,10 @@ class AbstractRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCM model: t.Type[EntityT] identity: t.Type[EntityIdT] + def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT]): + self.model = model + self.identity = identity + class AbstractBulkRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): """A repository interface for bulk operations. @@ -40,6 +44,10 @@ class AbstractBulkRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass= model: t.Type[EntityT] identity: t.Type[EntityIdT] + def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT]): + self.model = model + self.identity = identity + class SQLAlchemyRepository( AbstractRepository[EntityT, EntityIdT, SessionT], t.Generic[EntityT, EntityIdT, SessionT] @@ -68,11 +76,23 @@ class SQLAlchemyRepository( builder: StatementBuilder - def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT]): - self.model = model - self.identity = identity + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.builder = StatementBuilder(self.model) + def _build_execution_options( + self, + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + include_inactive: bool = False, + yield_by_chunk: bool = False, + ): + execution_options = execution_options or {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + if yield_by_chunk: + execution_options.setdefault("yield_per", yield_by_chunk) + return execution_options + def insert(self, session: sa.orm.Session, values: t.Dict[str, t.Any]) -> EntityT: """Insert a new model into the database.""" new = self.model(**values) @@ -141,9 +161,12 @@ def get( """ selectables = (self.model,) - execution_options = execution_options or {} - if include_inactive: - execution_options.setdefault("include_inactive", include_inactive) + execution_options = self._build_execution_options( + execution_options, include_inactive=include_inactive + ) + # execution_options = execution_options or {} + # if include_inactive: + # execution_options.setdefault("include_inactive", include_inactive) statement = self.builder.select( selectables, # type: ignore @@ -173,9 +196,12 @@ def get_by_field( """Select models where field is equal to value.""" selectables = (self.model,) - execution_options = execution_options or {} - if include_inactive: - execution_options.setdefault("include_inactive", include_inactive) + execution_options = self._build_execution_options( + execution_options, include_inactive=include_inactive + ) + # execution_options = execution_options or {} + # if include_inactive: + # execution_options.setdefault("include_inactive", include_inactive) if isinstance(field, str): field = getattr(self.model, field) @@ -217,11 +243,16 @@ def select( """ selectables = selectables or (self.model,) # type: ignore - execution_options = execution_options or {} - if include_inactive: - execution_options.setdefault("include_inactive", include_inactive) - if yield_by_chunk: - execution_options.setdefault("yield_per", yield_by_chunk) + execution_options = self._build_execution_options( + execution_options, + include_inactive=include_inactive, + yield_by_chunk=yield_by_chunk, + ) + # execution_options = execution_options or {} + # if include_inactive: + # execution_options.setdefault("include_inactive", include_inactive) + # if yield_by_chunk: + # execution_options.setdefault("yield_per", yield_by_chunk) statement = self.builder.select( selectables, @@ -271,9 +302,10 @@ def exists( """ selectables = (sa.sql.literal(True),) - execution_options = {} - if include_inactive: - execution_options.setdefault("include_inactive", include_inactive) + execution_options = self._build_execution_options(None, include_inactive=include_inactive) + # execution_options = {} + # if include_inactive: + # execution_options.setdefault("include_inactive", include_inactive) statement = self.builder.select( selectables,