From e8667894df1b98aaecd3103ecd8969a383a8cf47 Mon Sep 17 00:00:00 2001 From: Joe Black Date: Thu, 30 Mar 2023 18:24:25 -0400 Subject: [PATCH 01/11] simulation --- src/quart_sqlalchemy/model/mixins.py | 2 +- src/quart_sqlalchemy/model/model.py | 3 +- src/quart_sqlalchemy/sim/__init__.py | 2 + src/quart_sqlalchemy/sim/app.py | 145 +++ src/quart_sqlalchemy/sim/builder.py | 96 ++ src/quart_sqlalchemy/sim/handle.py | 594 ++++++++++++ src/quart_sqlalchemy/sim/legacy.py | 404 ++++++++ src/quart_sqlalchemy/sim/logic.py | 883 ++++++++++++++++++ src/quart_sqlalchemy/sim/main.py | 9 + src/quart_sqlalchemy/sim/model.py | 163 ++++ src/quart_sqlalchemy/sim/repo.py | 285 ++++++ src/quart_sqlalchemy/sim/repo_adapter.py | 309 ++++++ src/quart_sqlalchemy/sim/schema.py | 14 + src/quart_sqlalchemy/sim/signals.py | 33 + src/quart_sqlalchemy/sim/util.py | 101 ++ src/quart_sqlalchemy/sim/views/__init__.py | 18 + src/quart_sqlalchemy/sim/views/auth_user.py | 7 + src/quart_sqlalchemy/sim/views/auth_wallet.py | 62 ++ src/quart_sqlalchemy/sim/views/decorator.py | 131 +++ .../sim/views/magic_client.py | 7 + 20 files changed, 3266 insertions(+), 2 deletions(-) create mode 100644 src/quart_sqlalchemy/sim/__init__.py create mode 100644 src/quart_sqlalchemy/sim/app.py create mode 100644 src/quart_sqlalchemy/sim/builder.py create mode 100644 src/quart_sqlalchemy/sim/handle.py create mode 100644 src/quart_sqlalchemy/sim/legacy.py create mode 100644 src/quart_sqlalchemy/sim/logic.py create mode 100644 src/quart_sqlalchemy/sim/main.py create mode 100644 src/quart_sqlalchemy/sim/model.py create mode 100644 src/quart_sqlalchemy/sim/repo.py create mode 100644 src/quart_sqlalchemy/sim/repo_adapter.py create mode 100644 src/quart_sqlalchemy/sim/schema.py create mode 100644 src/quart_sqlalchemy/sim/signals.py create mode 100644 src/quart_sqlalchemy/sim/util.py create mode 100644 src/quart_sqlalchemy/sim/views/__init__.py create mode 100644 src/quart_sqlalchemy/sim/views/auth_user.py create mode 100644 src/quart_sqlalchemy/sim/views/auth_wallet.py create mode 100644 src/quart_sqlalchemy/sim/views/decorator.py create mode 100644 src/quart_sqlalchemy/sim/views/magic_client.py diff --git a/src/quart_sqlalchemy/model/mixins.py b/src/quart_sqlalchemy/model/mixins.py index 55c0c3f..3dc3834 100644 --- a/src/quart_sqlalchemy/model/mixins.py +++ b/src/quart_sqlalchemy/model/mixins.py @@ -79,7 +79,7 @@ def to_dict(self): class RecursiveDictMixin: __abstract__ = True - def model_to_dict( + def to_dict( self, obj: t.Optional[t.Any] = None, max_depth: int = 3, diff --git a/src/quart_sqlalchemy/model/model.py b/src/quart_sqlalchemy/model/model.py index cd82457..bbcdc21 100644 --- a/src/quart_sqlalchemy/model/model.py +++ b/src/quart_sqlalchemy/model/model.py @@ -14,13 +14,14 @@ from .mixins import ComparableMixin from .mixins import DynamicArgsMixin from .mixins import ReprMixin +from .mixins import SimpleDictMixin from .mixins import TableNameMixin sa = sqlalchemy -class Base(DynamicArgsMixin, ReprMixin, ComparableMixin, TableNameMixin): +class Base(DynamicArgsMixin, ReprMixin, SimpleDictMixin, ComparableMixin, TableNameMixin): __abstract__ = True type_annotation_map = { diff --git a/src/quart_sqlalchemy/sim/__init__.py b/src/quart_sqlalchemy/sim/__init__.py new file mode 100644 index 0000000..6a3f395 --- /dev/null +++ b/src/quart_sqlalchemy/sim/__init__.py @@ -0,0 +1,2 @@ +from . import app +from . import model diff --git a/src/quart_sqlalchemy/sim/app.py b/src/quart_sqlalchemy/sim/app.py new file mode 100644 index 0000000..91b0ec8 --- /dev/null +++ b/src/quart_sqlalchemy/sim/app.py @@ -0,0 +1,145 @@ +import json +import logging +import re +import typing as t + +import sqlalchemy as sa +from pydantic import BaseModel +from quart import g +from quart import Quart +from quart import request +from quart import Request +from quart import Response +from quart_schema import QuartSchema + +from .. import Base +from .. import SQLAlchemyConfig +from ..framework import QuartSQLAlchemy +from .util import ObjectID + + +AUTHORIZATION_PATTERN = re.compile(r"Bearer (?P.+)") +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class MyBase(Base): + type_annotation_map = {ObjectID: sa.Integer} + + +app = Quart(__name__) +db = QuartSQLAlchemy( + SQLAlchemyConfig.parse_obj( + { + "model_class": MyBase, + "binds": { + "default": { + "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + "read-replica": { + "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + "read_only": True, + }, + "async": { + "engine": { + "url": "sqlite+aiosqlite:///file:mem.db?mode=memory&cache=shared&uri=true" + }, + "session": {"expire_on_commit": False}, + }, + }, + } + ) +) +openapi = QuartSchema(app) + + +class RequestAuth(BaseModel): + client: t.Optional[t.Any] = None + user: t.Optional[t.Any] = None + + @property + def has_client(self): + return self.client is not None + + @property + def has_user(self): + return self.user is not None + + @property + def is_anonymous(self): + return all([self.has_client is False, self.has_user is False]) + + +def get_request_client(request: Request): + api_key = request.headers.get("X-Public-API-Key") + if not api_key: + return + + with g.bind.Session() as session: + try: + magic_client = g.h.MagicClient(session).get_by_public_api_key(api_key) + except ValueError: + return + else: + return magic_client + + +def get_request_user(request: Request): + auth_header = request.headers.get("Authorization") + + if not auth_header: + return + m = AUTHORIZATION_PATTERN.match(auth_header) + if m is None: + raise RuntimeError("invalid authorization header") + + auth_token = m.group("auth_token") + + with g.bind.Session() as session: + try: + auth_user = g.h.AuthUser(session).get_by_session_token(auth_token) + except ValueError: + return + else: + return auth_user + + +@app.before_request +def set_ethereum_network(): + g.request_network = request.headers.get("X-Fortmatic-Network", "GOERLI").upper() + + +@app.before_request +def set_bind_handlers_for_request(): + from quart_sqlalchemy.sim.handle import Handlers + + g.db = db + + method = request.method + if method in ["GET", "OPTIONS", "TRACE", "HEAD"]: + bind = "read-replica" + else: + bind = "default" + + g.bind = db.get_bind(bind) + g.h = Handlers(g.bind) + + +@app.before_request +def set_request_auth(): + g.auth = RequestAuth( + client=get_request_client(request), + user=get_request_user(request), + ) + + +@app.after_request +async def add_json_response_envelope(response: Response) -> Response: + if response.mimetype != "application/json": + return response + data = await response.get_json() + payload = dict(status="ok", message="", data=data) + response.set_data(json.dumps(payload)) + return response diff --git a/src/quart_sqlalchemy/sim/builder.py b/src/quart_sqlalchemy/sim/builder.py new file mode 100644 index 0000000..ccaffeb --- /dev/null +++ b/src/quart_sqlalchemy/sim/builder.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import typing as t + +import sqlalchemy +import sqlalchemy.event +import sqlalchemy.exc +import sqlalchemy.orm +import sqlalchemy.sql +from sqlalchemy.orm.interfaces import ORMOption + +from quart_sqlalchemy.types import ColumnExpr +from quart_sqlalchemy.types import DMLTable +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import Selectable + + +sa = sqlalchemy + + +class StatementBuilder(t.Generic[EntityT]): + model: t.Type[EntityT] + + def __init__(self, model: t.Type[EntityT]): + self.model = model + + def complex_select( + self, + selectables: t.Sequence[Selectable] = (), + conditions: t.Sequence[ColumnExpr] = (), + group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + ) -> sa.Select: + statement = sa.select(*selectables or self.model).where(*conditions) + + if for_update: + statement = statement.with_for_update() + if offset: + statement = statement.offset(offset) + if limit: + statement = statement.limit(limit) + if group_by: + statement = statement.group_by(*group_by) + if order_by: + statement = statement.order_by(*order_by) + + for option in options: + for context in option.context: + for strategy in context.strategy: + if "joined" in strategy: + distinct = True + + statement = statement.options(option) + + if distinct: + statement = statement.distinct() + + if execution_options: + statement = statement.execution_options(**execution_options) + + return statement + + def insert( + self, + target: t.Optional[DMLTable] = None, + values: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Insert: + return sa.insert(target or self.model).values(**values or {}) + + def bulk_insert( + self, + target: t.Optional[DMLTable] = None, + values: t.Sequence[t.Dict[str, t.Any]] = (), + ) -> sa.Insert: + return sa.insert(target or self.model).values(*values) + + def bulk_update( + self, + target: t.Optional[DMLTable] = None, + conditions: t.Sequence[ColumnExpr] = (), + values: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Update: + return sa.update(target or self.model).where(*conditions).values(**values or {}) + + def bulk_delete( + self, + target: t.Optional[DMLTable] = None, + conditions: t.Sequence[ColumnExpr] = (), + ) -> sa.Delete: + return sa.delete(target or self.model).where(*conditions) diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py new file mode 100644 index 0000000..7076309 --- /dev/null +++ b/src/quart_sqlalchemy/sim/handle.py @@ -0,0 +1,594 @@ +import logging +import typing as t +from datetime import datetime + +from sqlalchemy.orm import Session + +from quart_sqlalchemy import Bind + +from . import signals +from .logic import LogicComponent as Logic +from .model import AuthUser +from .model import AuthWallet +from .model import EntityType +from .model import WalletType +from .util import ObjectID + + +logger = logging.getLogger(__name__) + +CLIENTS_PER_API_USER_LIMIT = 50 + + +class MaxClientsExceeded(Exception): + pass + + +class AuthUserBaseError(Exception): + pass + + +class InvalidSubstringError(AuthUserBaseError): + pass + + +class APIKeySet(t.NamedTuple): + public_key: str + secret_key: str + + +def get_session_proxy(): + from .app import db + + return db.bind.Session() + + +class HandlerBase: + logic: Logic + session: Session + """The base class for all handler classes. It provides handler with access + to our logic object. + """ + + def __init__(self, session: t.Optional[Session], logic: t.Optional[Logic] = None): + self.session = session or get_session_proxy() + self.logic = logic or Logic() + + +def get_product_type_by_client_id(client_id): + return EntityType.MAGIC.value + + +class MagicClientHandler(HandlerBase): + def add( + self, + magic_api_user_id, + magic_team_id, + app_name=None, + is_magic_connect_enabled=False, + ): + """Registers a new client. + + Args: + is_magic_connect_enabled (boolean): if True, it will create a Magic Connect app. + + Returns: + A ``MagicClient``. + """ + magic_clients_count = self.logic.MagicClientAPIUser.count_by_magic_api_user_id( + magic_api_user_id, + ) + + if magic_clients_count >= CLIENTS_PER_API_USER_LIMIT: + raise MaxClientsExceeded() + + return self.add_client( + magic_api_user_id, + magic_team_id, + app_name, + is_magic_connect_enabled, + ) + + def get_by_public_api_key(self, public_api_key): + return self.logic.MagicClientAPIKey.get_by_public_api_key(public_api_key) + + def add_client( + self, + magic_api_user_id, + magic_team_id, + app_name=None, + is_magic_connect_enabled=False, + ): + live_api_key = APIKeySet(public_key="xxx", secret_key="yyy") + + with self.logic.begin(ro=False) as session: + magic_client = self.logic.MagicClient._add( + session, + app_name=app_name, + ) + # self.logic.MagicClientAPIKey._add( + # session, + # magic_client.id, + # live_api_key_pair=live_api_key, + # ) + # self.logic.MagicClientAPIUser._add( + # session, + # magic_api_user_id, + # magic_client.id, + # ) + + # self.logic.MagicClientAuthMethods._add( + # session, + # magic_client_id=magic_client.id, + # is_magic_connect_enabled=is_magic_connect_enabled, + # is_metamask_wallet_enabled=(True if is_magic_connect_enabled else False), + # is_wallet_connect_enabled=(True if is_magic_connect_enabled else False), + # is_coinbase_wallet_enabled=(True if is_magic_connect_enabled else False), + # ) + + # self.logic.MagicClientTeam._add(session, magic_client.id, magic_team_id) + + return magic_client, live_api_key + + def get_magic_api_user_id_by_client_id(self, magic_client_id): + return self.logic.MagicClient.get_magic_api_user_id_by_client_id(magic_client_id) + + def get_by_id(self, magic_client_id): + return self.logic.MagicClient.get_by_id(magic_client_id) + + def update_app_name_by_id(self, magic_client_id, app_name): + """ + Args: + magic_client_id (ObjectID|int|str): self explanatory. + app_name (str): Desired application name. + + Returns: + None if `magic_client_id` doesn't exist in the db + app_name if update was successful + """ + client = self.logic.MagicClient.update_by_id(magic_client_id, app_name=app_name) + + if not client: + return None + + return client.app_name + + def update_by_id(self, magic_client_id, **kwargs): + client = self.logic.MagicClient.update_by_id(magic_client_id, **kwargs) + + return client + + def set_inactive_by_id(self, magic_client_id): + """ + Args: + magic_client_id (ObjectID|int|str): self explanatory. + + Returns: + None + """ + self.logic.MagicClient.update_by_id(magic_client_id, is_active=False) + + def get_users_for_client( + self, + magic_client_id, + offset=None, + limit=None, + include_count=False, + ): + """ + Returns emails and signup timestamps for all auth users belonging to a given client + """ + auth_user_handler = AuthUserHandler() + product_type = get_product_type_by_client_id(magic_client_id) + auth_users = auth_user_handler.get_by_client_id_and_user_type( + magic_client_id, + product_type, + offset=offset, + limit=limit, + ) + + # Here we blindly load from oauth users table because we only provide + # two login methods right now. If not email link then it is oauth. + # TODO(ajen#ch22926|2020-08-14): rely on the `login_method` column to + # deterministically load from correct source (oauth, webauthn, etc.). + # emails_from_oauth = OAuthUserHandler().get_emails_by_auth_user_ids( + # [auth_user.id for auth_user in auth_users if auth_user.email is None], + # ) + + data = { + "users": [ + dict(email=u.email or "none", signup_ts=int(datetime.timestamp(u.time_created))) + for u in auth_users + ] + } + + if include_count: + data["count"] = auth_user_handler.get_user_count_by_client_id_and_user_type( + magic_client_id, + product_type, + ) + + return data + + def get_users_for_client_v2( + self, + magic_client_id, + offset=None, + limit=None, + include_count=False, + ): + """ + Returns emails, signup timestamps, provenance and MFA enablement for all auth users + belonging to a given client. + """ + auth_user_handler = AuthUserHandler() + product_type = get_product_type_by_client_id(magic_client_id) + auth_users = auth_user_handler.get_by_client_id_and_user_type( + magic_client_id, + product_type, + offset=offset, + limit=limit, + ) + + data = { + "users": [ + dict(email=u.email or "none", signup_ts=int(datetime.timestamp(u.time_created))) + for u in auth_users + ] + } + + if include_count: + data["count"] = auth_user_handler.get_user_count_by_client_id_and_user_type( + magic_client_id, + product_type, + ) + + return data + + # def get_user_logins_for_client(self, magic_client_id, limit=None): + # logins = AuthUserLoginHandler().get_logins_by_magic_client_id( + # magic_client_id, + # limit=limit or 20, + # ) + # user_logins = get_user_logins_response(logins) + + # return sorted( + # user_logins, + # key=lambda x: x["login_ts"], + # reverse=True, + # )[:limit] + + +class AuthUserHandler(HandlerBase): + # auth_user_mfa_handler: AuthUserMfaHandler + + def __init__(self, *args, auth_user_mfa_handler=None, **kwargs): + super().__init__(*args, **kwargs) + # self.auth_user_mfa_handler = auth_user_mfa_handler or AuthUserMfaHandler() + + def get_by_session_token(self, session_token): + return self.logic.AuthUser.get_by_session_token(session_token) + + def get_or_create_by_email_and_client_id( + self, + email, + client_id, + user_type=EntityType.MAGIC.value, + ): + auth_user = self.logic.AuthUser.get_by_email_and_client_id( + email, + client_id, + user_type=user_type, + ) + if not auth_user: + # try: + # email = enhanced_email_validation( + # email, + # source=MAGIC, + # # So we don't affect sign-up. + # silence_network_error=True, + # ) + # except ( + # EnhanceEmailValidationError, + # EnhanceEmailSuggestionError, + # ) as e: + # logger.warning( + # "Email Start Attempt.", + # exc_info=True, + # ) + # raise EnhancedEmailValidation(error_message=str(e)) from e + + auth_user = self.logic.AuthUser.add_by_email_and_client_id( + client_id, + email=email, + user_type=user_type, + ) + return auth_user + + def get_by_id_and_validate_exists(self, auth_user_id): + """This function helps formalize how a non-existent auth user should be handled.""" + auth_user = self.logic.AuthUser.get_by_id(auth_user_id) + if auth_user is None: + raise RuntimeError('resource_name="auth_user"') + return auth_user + + # This function is reserved for consolidating into a canonical user. Do not + # call this function under other circumstances as it will automatically set + # the user as verified. See ch-25343 for additional details. + def create_verified_user( + self, + client_id, + email, + user_type=EntityType.FORTMATIC.value, + ): + with self.logic.begin(ro=False) as session: + auid = self.logic.AuthUser._add_by_email_and_client_id( + session, + client_id, + email, + user_type=user_type, + ).id + auth_user = self.logic.AuthUser._update_by_id( + session, + auid, + date_verified=datetime.utcnow(), + ) + + return auth_user + + # def get_auth_user_from_public_address(self, public_address): + # wallet = self.logic.AuthWallet.get_by_public_address(public_address) + + # if not wallet: + # return None + + # return self.logic.AuthUser.get_by_id(wallet.auth_user_id) + + def get_by_id(self, auth_user_id, load_mfa_methods=False) -> AuthUser: + # join_list = ["mfa_methods"] if load_mfa_methods else None + return self.logic.AuthUser.get_by_id(auth_user_id) + + def get_by_client_id_and_user_type( + self, + client_id, + user_type, + offset=None, + limit=None, + ): + if user_type == EntityType.CONNECT.value: + return self.logic.AuthUser.get_by_client_id_for_connect( + client_id, + offset=offset, + limit=limit, + ) + else: + return self.logic.AuthUser.get_by_client_id_and_user_type( + client_id, + user_type, + offset=offset, + limit=limit, + ) + + def get_by_client_ids_and_user_type( + self, + client_ids, + user_type, + offset=None, + limit=None, + ): + return self.logic.AuthUser.get_by_client_ids_and_user_type( + client_ids, + user_type, + offset=offset, + limit=limit, + ) + + def get_user_count_by_client_id_and_user_type(self, client_id, user_type): + if user_type == EntityType.CONNECT.value: + return self.logic.AuthUser.get_user_count_by_client_id_for_connect( + client_id, + ) + else: + return self.logic.AuthUser.get_user_count_by_client_id_and_user_type( + client_id, + user_type, + ) + + def exist_by_email_client_id_and_user_type(self, email, client_id, user_type): + return self.logic.AuthUser.exist_by_email_and_client_id( + email, + client_id, + user_type=user_type, + ) + + def update_email_by_id(self, model_id, email): + return self.logic.AuthUser.update_by_id(model_id, email=email) + + def update_phone_number_by_id(self, model_id, phone_number): + return self.logic.AuthUser.update_by_id(model_id, phone_number=phone_number) + + def get_by_email_client_id_and_user_type(self, email, client_id, user_type): + return self.logic.AuthUser.get_by_email_and_client_id( + email, + client_id, + user_type, + ) + + def mark_date_verified_by_id(self, model_id): + return self.logic.AuthUser.update_by_id( + model_id, + date_verified=datetime.utcnow(), + ) + + def set_role_by_email_magic_client_id(self, email, magic_client_id, role): + auth_user = self.logic.AuthUser.get_by_email_and_client_id( + email, + magic_client_id, + EntityType.MAGIC.value, + ) + + if not auth_user: + auth_user = self.logic.AuthUser.add_by_email_and_client_id( + magic_client_id, + email, + user_type=EntityType.MAGIC.value, + ) + + return self.logic.AuthUser.update_by_id(auth_user.id, **{role: True}) + + def search_by_client_id_and_substring( + self, + client_id, + substring, + offset=None, + limit=10, + load_mfa_methods=False, + ): + # join_list = ["mfa_methods"] if load_mfa_methods is True else None + + if not isinstance(substring, str) or len(substring) < 3: + raise InvalidSubstringError() + + auth_users = self.logic.AuthUser.get_by_client_id_with_substring_search( + client_id, + substring, + offset=offset, + limit=limit, + # join_list=join_list, + ) + + # mfa_enablements = self.auth_user_mfa_handler.is_active_batch( + # [auth_user.id for auth_user in auth_users], + # ) + # for auth_user in auth_users: + # if mfa_enablements[auth_user.id] is False: + # auth_user.mfa_methods = [] + + return auth_users + + def is_magic_connect_enabled(self, auth_user_id=None, auth_user=None): + if auth_user is None and auth_user_id is None: + raise Exception("At least one argument needed: auth_user_id or auth_user.") + + if auth_user is None: + auth_user = self.get_by_id(auth_user_id) + + return auth_user.user_type == EntityType.CONNECT.value + + def mark_as_inactive(self, auth_user_id): + self.logic.AuthUser.update_by_id(auth_user_id, is_active=False) + + def get_by_email_and_wallet_type_for_interop(self, email, wallet_type, network): + """ + Opinionated method for fetching AuthWallets by email address, wallet_type and network. + """ + return self.logic.AuthUser.get_by_email_for_interop( + email=email, + wallet_type=wallet_type, + network=network, + ) + + def get_magic_connect_auth_user(self, auth_user_id): + auth_user = self.get_by_id_and_validate_exists(auth_user_id) + if not auth_user.is_magic_connect_user: + raise RuntimeError("RequestForbidden") + return auth_user + + +@signals.auth_user_duplicate.connect +def handle_duplicate_auth_users( + current_app, + original_auth_user_id, + duplicate_auth_user_ids, + auth_user_handler: t.Optional[AuthUserHandler] = None, +) -> None: + logger.info(f"{len(duplicate_auth_user_ids)} dupe(s) found for {original_auth_user_id}") + + auth_user_handler = auth_user_handler or AuthUserHandler() + + for dupe_id in duplicate_auth_user_ids: + logger.info( + f"marking auth_user_id {dupe_id} as inactive, in favor of original {original_auth_user_id}", + ) + auth_user_handler.mark_as_inactive(dupe_id) + + +class AuthWalletHandler(HandlerBase): + # account_linking_feature = LDFeatureFlag("is-account-linking-enabled", anonymous_user=True) + + def __init__(self, network, *args, wallet_type=WalletType.ETH, **kwargs): + super().__init__(*args, **kwargs) + self.wallet_network = network + self.wallet_type = wallet_type + + def get_by_id(self, model_id): + return self.logic.AuthWallet.get_by_id(model_id) + + def get_by_public_address(self, public_address): + return self.logic.AuthWallet.get_by_public_address(public_address) + + def get_by_auth_user_id( + self, + auth_user_id: ObjectID, + network: t.Optional[str] = None, + wallet_type: t.Optional[WalletType] = None, + **kwargs, + ) -> t.List[AuthWallet]: + auth_user = self.logic.AuthUser.get_by_id( + auth_user_id, + join_list=["linked_primary_auth_user"], + ) + + if auth_user.has_linked_primary_auth_user: + logger.info( + "Linked primary_auth_user found for wallet delegation", + extra=dict( + auth_user_id=auth_user.id, + delegated_to=auth_user.linked_primary_auth_user_id, + ), + ) + auth_user = auth_user.linked_primary_auth_user + + return self.logic.AuthWallet.get_by_auth_user_id( + auth_user.id, + network=network, + wallet_type=wallet_type, + **kwargs, + ) + + def sync_auth_wallet( + self, + auth_user_id, + public_address, + encrypted_private_address, + wallet_management_type, + ): + existing_wallet = self.logic.AuthWallet.get_by_auth_user_id( + auth_user_id, + ) + if existing_wallet: + raise RuntimeError("WalletExistsForNetworkAndWalletType") + + return self.logic.AuthWallet.add( + public_address, + encrypted_private_address, + self.wallet_type, + self.wallet_network, + management_type=wallet_management_type, + auth_user_id=auth_user_id, + ) + + +class Handlers: + bind: Bind + + def __init__(self, bind: Bind): + self.bind = bind + + def __getattr__(self, name): + handlers = { + cls.__name__.replace("Handler", ""): cls for cls in HandlerBase.__subclasses__() + } + if name in handlers: + return handlers[name] + raise AttributeError() diff --git a/src/quart_sqlalchemy/sim/legacy.py b/src/quart_sqlalchemy/sim/legacy.py new file mode 100644 index 0000000..5d4662f --- /dev/null +++ b/src/quart_sqlalchemy/sim/legacy.py @@ -0,0 +1,404 @@ +from sqlalchemy import exists +from sqlalchemy.orm import joinedload +from sqlalchemy.sql.expression import func +from sqlalchemy.sql.expression import label + + +def one(input_list): + (item,) = input_list + return item + + +class SQLRepository: + def __init__(self, model): + self._model = model + assert self._model is not None + + @property + def _has_is_active_field(self): + return bool(getattr(self._model, "is_active", None)) + + def get_by_id( + self, + session, + model_id, + allow_inactive=False, + join_list=None, + for_update=False, + ): + """SQL get interface to retrieve by model's id column. + + Args: + session: A database session object. + model_id: The id of the given model to be retrieved. + allow_inactive: Whether to include inactive or not. + join_list: A list of attributes to be joined in the same session for + given model. This is normally the attributes that have + relationship defined and referenced to other models. + for_update: Locks the table for update. + + Returns: + Data retrieved from the database for the model. + """ + query = session.query(self._model) + + if join_list: + for to_join in join_list: + query = query.options(joinedload(to_join)) + + if for_update: + query = query.with_for_update() + + row = query.get(model_id) + + if row is None: + return None + + if self._has_is_active_field and not row.is_active and not allow_inactive: + return None + + return row + + def get_by( + self, + session, + filters=None, + join_list=None, + order_by_clause=None, + for_update=False, + offset=None, + limit=None, + ): + """SQL get_by interface to retrieve model instances based on the given + filters. + + Args: + session: A database session object. + filters: A list of filters on the models. + join_list: A list of attributes to be joined in the same session for + given model. This is normally the attributes that have + relationship defined and referenced to other models. + order_by_clause: An order by clause. + for_update: Locks the table for update. + + Returns: + Modified rows. + + TODO(ajen#ch21549|2020-07-21): Filter out `is_active == False` row. This + will not be a trivial change as many places rely on this method and the + handlers/logics sometimes filter by in_active. Sometimes endpoints might + get affected. Proceed with caution. + """ + # If no filter is given, just return. Prevent table scan. + if filters is None: + return None + + query = session.query(self._model).filter(*filters).order_by(order_by_clause) + + if for_update: + query = query.with_for_update() + + if offset: + query = query.offset(offset) + + if limit: + query = query.limit(limit) + + # Prevent loading all the rows. + if limit == 0: + return [] + + if join_list: + for to_join in join_list: + query = query.options(joinedload(to_join)) + + return query.all() + + def count_by( + self, + session, + filters=None, + group_by=None, + distinct_column=None, + ): + """SQL count_by interface to retrieve model instance count based on the given + filters. + + Args: + session: A database session object. + filters (list): Required; a list of filters on the models. + group_by (list): A list of optional group by expressions. + Returns: + A list of counts of rows. + Raises: + ValueError: Returns a value error, when no filters are provided + """ + # Prevent table scans + if filters is None: + raise ValueError("Full table scans are prohibited. Please provide filters") + + select = [label("count", func.count(self._model.id))] + + if distinct_column: + select = [label("count", func.count(func.distinct(distinct_column)))] + + if group_by: + for group in group_by: + select.append(group.expression) + + query = session.query(*select).filter(*filters) + + if group_by: + query = query.group_by(*group_by) + + return query.all() + + def sum_by( + self, + session, + column, + filters=None, + group_by=None, + ): + """SQL sum_by interface to retrieve aggregate sum of column values for given + filters. + + Args: + session: A database session object. + column (sqlalchemy.Column): Required; the column to sum by. + filters (list): Required; a list of filters to apply to the query + group_by (list): A list of optional group by expressions. + + Returns: + A scalar value representing the sum or None. + + Raises: + ValueError: Returns a value error, when no filters are provided + """ + + # Prevent table scans + if filters is None: + raise ValueError("Full table scans are prohibited. Please provide filters") + + query = session.query(func.sum(column)).filter(*filters) + + if group_by: + query = query.group_by(*group_by) + + return query.scalar() + + def one(self, session, filters=None, join_list=None, for_update=False): + """SQL filtering interface to retrieve the single model instance matching + filter criteria. + + If there are more than one instances, an exception is raised. + + Args: + session: A database session object. + filters: A list of filters on the models. + for_update: Locks the table for update. + + Returns: + A model instance: If one row is found in the db. + None: If no result is found. + """ + row = self.get_by(session, filters=filters, join_list=join_list, for_update=for_update) + + if not row: + return None + + return one(row) + + def update(self, session, model_id, **kwargs): + """SQL update interface to modify data in a given model instance. + + Args: + session: A database session object. + model_id: The id of the given model to be modified. + kwargs: Any fields defined on the models. + + Returns: + Modified rows. + + Note: + We use session.flush() here to move the changes from the application + to SQL database. However, those changes will be in the pending changes + state. Meaning, it is in the queue to be inserted but yet to be done + so until session.commit() is called, which has been taken care of + in our ``with_db_session`` decorator or ``LogicComponent.begin`` + contextmanager. + """ + modified_row = session.query(self._model).get(model_id) + if modified_row is None: + return None + + for key, value in kwargs.items(): + setattr(modified_row, key, value) + + # Flush out our changes to DB transaction buffer but don't commit it yet. + # This is useful in the case when we want to rollback atomically on multiple + # sql operations in the same transaction which may or may not have + # dependencies. + session.flush() + + return modified_row + + def update_by(self, session, filters=None, **kwargs): + """SQL update_by interface to modify data for a given list of filters. + The filters should be provided so it can narrow down to one row. + + Args: + session: A database session object. + filters: A list of filters on the models. + kwargs: Any fields defined on the models. + + Returns: + Modified row. + + Raises: + sqlalchemy.orm.exc.NoResultFound - when no result is found. + sqlalchemy.orm.exc.MultipleResultsFound - when multiple result is found. + """ + # If no filter is given, just return. Prevent table scan. + if filters is None: + return None + + modified_row = session.query(self._model).filter(*filters).one() + for key, value in kwargs.items(): + setattr(modified_row, key, value) + + # Flush out our changes to DB transaction buffer but don't commit it yet. + # This is useful in the case when we want to rollback atomically on multiple + # sql operations in the same transaction which may or may not have + # dependencies. + session.flush() + + return modified_row + + def delete_one_by(self, session, filters=None, optional=False): + """SQL update_by interface to delete data for a given list of filters. + The filters should be provided so it can narrow down to one row. + + Note: Careful consideration should be had prior to using this function. + Always consider setting rows as inactive instead before choosing to use + this function. + + Args: + session: A database session object. + filters: A list of filters on the models. + optional: Whether deletion is optional; i.e. it's OK for the model not to exist + + Returns: + None. + + Raises: + sqlalchemy.orm.exc.NoResultFound - when no result is found and optional is False. + sqlalchemy.orm.exc.MultipleResultsFound - when multiple result is found. + """ + # If no filter is given, just return. Prevent table scan. + if filters is None: + return None + + if optional: + rows = session.query(self._model).filter(*filters).all() + + if not rows: + return None + + row = one(rows) + + else: + row = session.query(self._model).filter(*filters).one() + + session.delete(row) + + # Flush out our changes to DB transaction buffer but don't commit it yet. + # This is useful in the case when we want to rollback atomically on multiple + # sql operations in the same transaction which may or may not have + # dependencies. + session.flush() + + def delete_by_id(self, session, model_id): + return session.query(self._model).get(model_id).delete() + + def add(self, session, **kwargs): + """SQL add interface to insert data to the given model. + + Args: + session: A database session object. + kwargs: Any fields defined on the models. + + Returns: + Newly inserted rows. + + Note: + We use session.flush() here to move the changes from the application + to SQL database. However, those changes will be in the pending changes + state. Meaning, it is in the queue to be inserted but yet to be done + so until session.commit() is called, which has been taken care of + in our ``with_db_session`` decorator or ``LogicComponent.begin`` + contextmanager. + """ + new_row = self._model(**kwargs) + session.add(new_row) + + # Flush out our changes to DB transaction buffer but don't commit it yet. + # This is useful in the case when we want to rollback atomically on multiple + # sql operations in the same transaction which may or may not have + # dependencies. + session.flush() + + return new_row + + def exist(self, session, filters=None): + """SQL exist interface to check if any row exists at all for the given + filters. + + Args: + session: A database session object. + filters: A list of filters on the models. + + Returns: + A boolean. True if any row exists else False. + """ + exist_query = exists() + + for query_filter in filters: + exist_query = exist_query.where(query_filter) + + return session.query(exist_query).scalar() + + def yield_by_chunk(self, session, chunk_size, join_list=None, filters=None): + """This yields a batch of the model objects for the given chunk_size. + + Args: + session: A database session object. + chunk_size (int): The size of the chunk. + filters: A list of filters on the model. + join_list: A list of attributes to be joined in the same session for + given model. This is normally the attributes that have + relationship defined and referenced to other models. + + Returns: + A batch for the given chunk size. + """ + query = session.query(self._model) + + if filters is not None: + query = query.filter(*filters) + + if join_list: + for to_join in join_list: + query = query.options(joinedload(to_join)) + + start = 0 + + while True: + stop = start + chunk_size + model_objs = query.slice(start, stop).all() + if not model_objs: + break + + yield model_objs + + start += chunk_size diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py new file mode 100644 index 0000000..dee3990 --- /dev/null +++ b/src/quart_sqlalchemy/sim/logic.py @@ -0,0 +1,883 @@ +import logging +import typing as t +from datetime import datetime +from functools import wraps + +from pydantic import BaseModel +from pydantic import Field +from sqlalchemy import or_ +from sqlalchemy import ScalarResult +from sqlalchemy.orm import contains_eager +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import selectinload +from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import func +from sqlalchemy.sql.expression import label + +from quart_sqlalchemy.model import Base +from quart_sqlalchemy.types import ColumnExpr +from quart_sqlalchemy.types import EntityIdT +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import ORMOption +from quart_sqlalchemy.types import Selectable + +from . import signals +from .model import AuthUser as auth_user_model +from .model import AuthWallet as auth_wallet_model +from .model import ConnectInteropStatus +from .model import EntityType +from .model import MagicClient as magic_client_model +from .model import Provenance +from .model import WalletType +from .repo import SQLAlchemyRepository +from .repo_adapter import RepositoryLegacyAdapter +from .util import ObjectID +from .util import one + + +logger = logging.getLogger(__name__) + + +class LogicMeta(type): + """This is metaclass provides registry pattern where all the available + logics will be accessible through any instantiated logic object. + + Note: + Don't use this metaclass at another places. This is only intended to be + used by LogicComponent. If you want your own registry, please create + your own. + """ + + def __init__(cls, name, bases, cls_dict): + if not hasattr(cls, "_registry"): + cls._registry = {} + else: + cls._registry[name] = cls() + + super().__init__(name, bases, cls_dict) + + +class LogicComponent(metaclass=LogicMeta): + """This is the base class for any logic class. This overrides the getattr + method for registry lookup. + + Example: + + ``` + class TrollGoat(LogicComponent): + + def add(x): + print(x) + ``` + + Once you have a logic object, you can directly do something like: + + ``` + logic.TrollGoat.add('troll_goat') + ``` + + Note: + You will have to explicitly import your newly created logic in + ``fortmatic.logic.__init__.py``. When the logic is imported, it is created + the first time; hence, it is then registered. If this is unclear to you, + read https://blog.ionelmc.ro/2015/02/09/understanding-python-metaclasses/ + It has all the info you need to understand. For example, everything in + python is an object :P. Enjoy. + """ + + def __dir__(self): + return super().__dir__() + list(self._registry.keys()) + + def __getattr__(self, logic_name): + if logic_name in self._registry: + return self._registry[logic_name] + else: + raise AttributeError( + "{object_name} has no attribute '{logic_name}'".format( + object_name=self.__class__.__name__, + logic_name=logic_name, + ), + ) + + +def with_db_session( + ro: t.Optional[bool] = None, + is_stale_tolerant: t.Optional[bool] = None, +): + """Stub decorator to ease transition with legacy code""" + + def wrapper(func): + @wraps(func) + def inner_wrapper(self, *args, **kwargs): + session = None + return func(self, None, *args, **kwargs) + + return inner_wrapper + + return wrapper + + +class MagicClient(LogicComponent): + def __init__(self, session: Session): + # self._repository = SQLAlchemyRepository[magic_client_model, ObjectID](session) + + self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID, session) + + def _add(self, session, app_name=None): + return self._repository.add( + session, + app_name=app_name, + ) + + add = with_db_session(ro=False)(_add) + + @with_db_session(ro=True) + def get_by_id( + self, + session, + model_id, + allow_inactive=False, + join_list=None, + ) -> t.Optional[magic_client_model]: + return self._repository.get_by_id( + session, + model_id, + allow_inactive=allow_inactive, + join_list=join_list, + ) + + @with_db_session(ro=True) + def get_by_public_api_key( + self, + session, + public_api_key, + ): + return one( + self._repository.get_by( + session, + filters=[magic_client_model.public_api_key == public_api_key], + limit=1, + ) + ) + + # @with_db_session(ro=True) + # def get_magic_api_user_id_by_client_id(self, session, magic_client_id): + # client = self._repository.get_by_id( + # session, + # magic_client_id, + # allow_inactive=False, + # join_list=None, + # ) + + # if client is None: + # return None + + # if client.magic_client_api_user is None: + # return None + + # return client.magic_client_api_user.magic_api_user_id + + @with_db_session(ro=False) + def update_by_id(self, session, model_id, **update_params): + modified_row = self._repository.update(session, model_id, **update_params) + session.refresh(modified_row) + return modified_row + + @with_db_session(ro=True) + def yield_all_clients_by_chunk(self, session, chunk_size): + yield from self._repository.yield_by_chunk(session, chunk_size) + + @with_db_session(ro=True) + def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): + yield from self._repository.yield_by_chunk( + session, + chunk_size, + filters=filters, + join_list=join_list, + ) + + +class DuplicateAuthUser(Exception): + pass + + +class AuthUserDoesNotExist(Exception): + pass + + +class MissingEmail(Exception): + pass + + +class MissingPhoneNumber(Exception): + pass + + +class AuthUser(LogicComponent): + def __init__(self, session: Session): + # self._repository = SQLRepository(auth_user_model) + self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID, session) + + @with_db_session(ro=True) + def get_by_session_token( + self, + session, + session_token, + ): + return one( + self._repository.get_by( + session, + filters=[auth_user_model.current_session_token == session_token], + limit=1, + ) + ) + + def _get_or_add_by_phone_number_and_client_id( + self, + session, + client_id, + phone_number, + user_type=EntityType.FORTMATIC.value, + ): + if phone_number is None: + raise MissingPhoneNumber() + + row = self._get_by_phone_number_and_client_id( + session=session, + phone_number=phone_number, + client_id=client_id, + user_type=user_type, + ) + + if row: + return row + + row = self._repository.add( + session=session, + phone_number=phone_number, + client_id=client_id, + user_type=user_type, + provenance=Provenance.SMS, + ) + logger.info( + "New auth user (id: {}) created by phone number (client_id: {})".format( + row.id, + client_id, + ), + ) + + return row + + get_or_add_by_phone_number_and_client_id = with_db_session(ro=False)( + _get_or_add_by_phone_number_and_client_id, + ) + + def _add_by_email_and_client_id( + self, + session, + client_id, + email=None, + user_type=EntityType.FORTMATIC.value, + **kwargs, + ): + if email is None: + raise MissingEmail() + + if self._exist_by_email_and_client_id( + session, + email, + client_id, + user_type=user_type, + ): + logger.exception( + "User duplication for email: {} (client_id: {})".format( + email, + client_id, + ), + ) + raise DuplicateAuthUser() + + row = self._repository.add( + session, + email=email, + client_id=client_id, + user_type=user_type, + **kwargs, + ) + logger.info( + "New auth user (id: {}) created by email (client_id: {})".format( + row.id, + client_id, + ), + ) + + return row + + add_by_email_and_client_id = with_db_session(ro=False)(_add_by_email_and_client_id) + + def _add_by_client_id( + self, + session, + client_id, + user_type=EntityType.FORTMATIC.value, + provenance=None, + global_auth_user_id=None, + is_verified=False, + ): + row = self._repository.add( + session, + client_id=client_id, + user_type=user_type, + provenance=provenance, + global_auth_user_id=global_auth_user_id, + date_verified=datetime.utcnow() if is_verified else None, + ) + logger.info( + "New auth user (id: {}) created by (client_id: {})".format( + row.id, + client_id, + ), + ) + + return row + + add_by_client_id = with_db_session(ro=False)(_add_by_client_id) + + def _get_by_active_identifier_and_client_id( + self, + session, + identifier_field, + identifier_value, + client_id, + user_type, + ) -> auth_user_model: + """There should only be one active identifier where all the parameters match for a given client ID. In the case of multiple results, the subsequent entries / "dupes" will be marked as inactive.""" + filters = [ + identifier_field == identifier_value, + auth_user_model.client_id == client_id, + auth_user_model.user_type == user_type, + # auth_user_model.is_active == True, # noqa: E712 + ] + + results = self._repository.get_by( + session, + filters=filters, + order_by_clause=auth_user_model.id.asc(), + ) + + if not results: + return None + + original, *duplicates = results + + if duplicates: + signals.auth_user_duplicate.send( + original_auth_user_id=original.id, + duplicate_auth_user_ids=[dupe.id for dupe in duplicates], + ) + + return original + + @with_db_session(ro=True) + def get_by_email_and_client_id( + self, + session, + email, + client_id, + user_type=EntityType.FORTMATIC.value, + ): + return self._get_by_active_identifier_and_client_id( + session=session, + identifier_field=auth_user_model.email, + identifier_value=email, + client_id=client_id, + user_type=user_type, + ) + + def _get_by_phone_number_and_client_id( + self, + session, + phone_number, + client_id, + user_type=EntityType.FORTMATIC.value, + ): + if phone_number is None: + raise MissingPhoneNumber() + + return self._get_by_active_identifier_and_client_id( + session=session, + identifier_field=auth_user_model.phone_number, + identifier_value=phone_number, + client_id=client_id, + user_type=user_type, + ) + + get_by_phone_number_and_client_id = with_db_session(ro=True)( + _get_by_phone_number_and_client_id, + ) + + def _exist_by_email_and_client_id( + self, + session, + email, + client_id, + user_type=EntityType.FORTMATIC.value, + ): + return bool( + self._repository.exist( + session, + filters=[ + auth_user_model.email == email, + auth_user_model.client_id == client_id, + auth_user_model.user_type == user_type, + ], + ), + ) + + exist_by_email_and_client_id = with_db_session(ro=True)(_exist_by_email_and_client_id) + + def _get_by_id(self, session, model_id, join_list=None, for_update=False) -> auth_user_model: + return self._repository.get_by_id( + session, + model_id, + join_list=join_list, + for_update=for_update, + ) + + get_by_id = with_db_session(ro=True)(_get_by_id) + + def _update_by_id(self, session, auth_user_id, **kwargs): + modified_user = self._repository.update(session, auth_user_id, **kwargs) + + if modified_user is None: + raise AuthUserDoesNotExist() + + return modified_user + + update_by_id = with_db_session(ro=False)(_update_by_id) + + @with_db_session(ro=True) + def get_user_count_by_client_id_and_user_type(self, session, client_id, user_type): + query = ( + session.query(auth_user_model) + .filter( + auth_user_model.client_id == client_id, + auth_user_model.user_type == user_type, + # auth_user_model.is_active == True, # noqa: E712 + auth_user_model.date_verified.is_not(None), + ) + .statement.with_only_columns([func.count()]) + .order_by(None) + ) + + return session.execute(query).scalar() + + def _get_by_client_id_and_global_auth_user(self, session, client_id, global_auth_user_id): + return self._repository.get_by( + session=session, + filters=[ + auth_user_model.client_id == client_id, + auth_user_model.user_type == EntityType.CONNECT.value, + # auth_user_model.is_active == True, # noqa: E712 + auth_user_model.global_auth_user_id == global_auth_user_id, + ], + ) + + get_by_client_id_and_global_auth_user = with_db_session(ro=True)( + _get_by_client_id_and_global_auth_user, + ) + + # @with_db_session(ro=True) + # def get_by_client_id_for_connect( + # self, + # session, + # client_id, + # offset=None, + # limit=None, + # ): + # # TODO(thomas|2022-07-12): Determine where/if is the right place to split + # # connect/magic logic based on user type as part of https://app.shortcut.com/magic-labs/story/53323. + # # See https://github.com/fortmatic/fortmatic/pull/6173#discussion_r919529540. + # return ( + # session.query(auth_user_model) + # .join( + # identifier_model, + # auth_user_model.global_auth_user_id == identifier_model.global_auth_user_id, + # ) + # .filter( + # auth_user_model.client_id == client_id, + # auth_user_model.user_type == EntityType.CONNECT.value, + # auth_user_model.is_active == True, # noqa: E712, + # auth_user_model.provenance == Provenance.IDENTIFIER, + # or_( + # identifier_model.identifier_type.in_( + # GlobalAuthUserIdentifierType.get_public_address_enums(), + # ), + # identifier_model.date_verified != None, + # ), + # ) + # .order_by(auth_user_model.id.desc()) + # .limit(limit) + # .offset(offset) + # ).all() + + # @with_db_session(ro=True) + # def get_user_count_by_client_id_for_connect( + # self, + # session, + # client_id, + # ): + # # TODO(thomas|2022-07-12): Determine where/if is the right place to split + # # connect/magic logic based on user type as part of https://app.shortcut.com/magic-labs/story/53323. + # # See https://github.com/fortmatic/fortmatic/pull/6173#discussion_r919529540. + # query = ( + # session.query(auth_user_model) + # .join( + # identifier_model, + # auth_user_model.global_auth_user_id == identifier_model.global_auth_user_id, + # ) + # .filter( + # auth_user_model.client_id == client_id, + # auth_user_model.user_type == EntityType.CONNECT.value, + # auth_user_model.is_active == True, # noqa: E712, + # auth_user_model.provenance == Provenance.IDENTIFIER, + # or_( + # identifier_model.identifier_type.in_( + # GlobalAuthUserIdentifierType.get_public_address_enums(), + # ), + # identifier_model.date_verified != None, + # ), + # ) + # .statement.with_only_columns( + # [func.count(distinct(auth_user_model.global_auth_user_id))], + # ) + # .order_by(None) + # ) + + # return session.execute(query).scalar() + + @with_db_session(ro=True) + def get_by_client_id_and_user_type( + self, + session, + client_id, + user_type, + offset=None, + limit=None, + ): + return self._get_by_client_ids_and_user_type( + session, + [client_id], + user_type, + offset=offset, + limit=limit, + ) + + def _get_by_client_ids_and_user_type( + self, + session, + client_ids, + user_type, + offset=None, + limit=None, + ): + if not client_ids: + return [] + + return self._repository.get_by( + session, + filters=[ + auth_user_model.client_id.in_(client_ids), + auth_user_model.user_type == user_type, + # auth_user_model.is_active == True, # noqa: E712, + auth_user_model.date_verified != None, + ], + offset=offset, + limit=limit, + order_by_clause=auth_user_model.id.desc(), + ) + + get_by_client_ids_and_user_type = with_db_session(ro=True)( + _get_by_client_ids_and_user_type, + ) + + def _get_by_client_id_with_substring_search( + self, + session, + client_id, + substring, + offset=None, + limit=10, + join_list=None, + ): + return self._repository.get_by( + session, + filters=[ + auth_user_model.client_id == client_id, + auth_user_model.user_type == EntityType.MAGIC.value, + or_( + auth_user_model.provenance == Provenance.SMS, + auth_user_model.provenance == Provenance.LINK, + auth_user_model.provenance == None, # noqa: E711 + ), + or_( + auth_user_model.phone_number.contains(substring), + auth_user_model.email.contains(substring), + ), + ], + offset=offset, + limit=limit, + order_by_clause=auth_user_model.id.desc(), + join_list=join_list, + ) + + get_by_client_id_with_substring_search = with_db_session(ro=True)( + _get_by_client_id_with_substring_search, + ) + + @with_db_session(ro=True) + def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): + yield from self._repository.yield_by_chunk( + session, + chunk_size, + filters=filters, + join_list=join_list, + ) + + @with_db_session(ro=True) + def get_by_emails_and_client_id( + self, + session, + email_ids, + client_id, + ): + return self._repository.get_by( + session, + filters=[ + auth_user_model.email.in_(email_ids), + auth_user_model.client_id == client_id, + ], + ) + + def _get_by_email( + self, + session, + email: str, + join_list=None, + filters=None, + for_update: bool = False, + ) -> t.List[auth_user_model]: + filters = filters or [] + combined_filters = filters + [auth_user_model.email == email] + + return self._repository.get_by( + session, + filters=combined_filters, + for_update=for_update, + join_list=join_list, + ) + + get_by_email = with_db_session(ro=True)(_get_by_email) + + def _add(self, session, **kwargs) -> ObjectID: + return self._repository.add(session, **kwargs).id + + add = with_db_session(ro=False)(_add) + + def _get_by_email_for_interop( + self, + session, + email: str, + wallet_type: WalletType, + network: str, + ) -> List[auth_user_model]: + """ + Custom method for searching for users eligible for interop. Unfortunately, this can't be done with the current + abstractions in our sql_repository, so this is a one-off bespoke method. + If we need to add more similar queries involving eager loading and multiple joins, we can add an abstraction + inside the repository. + """ + + query = ( + session.query(auth_user_model) + .join( + auth_user_model.wallets.and_( + auth_wallet_model.wallet_type == str(wallet_type) + ).and_(auth_wallet_model.network == network) + # .and_(auth_wallet_model.is_active == 1), + ) + .options(contains_eager(auth_user_model.wallets)) + .join( + auth_user_model.magic_client.and_( + magic_client_model.connect_interop == ConnectInteropStatus.ENABLED, + ), + ) + .options(contains_eager(auth_user_model.magic_client)) + # TODO(magic-ravi#67899|2022-12-30): Uncomment to allow account-linked users to use interop + # .options( + # joinedload( + # auth_user_model.linked_primary_auth_user, + # ).joinedload("auth_wallets"), + # ) + .filter( + auth_wallet_model.wallet_type == wallet_type, + auth_wallet_model.network == network, + ) + .filter( + auth_user_model.email == email, + auth_user_model.user_type == EntityType.MAGIC.value, + # auth_user_model.is_active == 1, + auth_user_model.linked_primary_auth_user_id == None, # noqa: E711 + ) + .populate_existing() + ) + + return query.all() + + get_by_email_for_interop = with_db_session(ro=True)( + _get_by_email_for_interop, + ) + + def _get_linked_users(self, session, primary_auth_user_id, join_list, no_op=False): + # TODO(magic-ravi#67899|2022-12-30): Re-enable account linked users for interop. Remove no_op flag. + if no_op: + return [] + else: + return self._repository.get_by( + session, + filters=[ + # auth_user_model.is_active == True, # noqa: E712 + auth_user_model.user_type == EntityType.MAGIC.value, + auth_user_model.linked_primary_auth_user_id == primary_auth_user_id, + ], + join_list=join_list, + ) + + get_linked_users = with_db_session(ro=True)(_get_linked_users) + + @with_db_session(ro=True) + def get_by_phone_number(self, session, phone_number): + return self._repository.get_by( + session, + filters=[ + auth_user_model.phone_number == phone_number, + ], + ) + + +class AuthWallet(LogicComponent): + def __init__(self, session: Session): + # self._repository = SQLAlchemyRepository[magic_client_model, ObjectID](session) + self._repository = RepositoryLegacyAdapter(auth_wallet_model, ObjectID, session) + + def _add( + self, + session, + public_address, + encrypted_private_address, + wallet_type, + network, + management_type=None, + auth_user_id=None, + ): + new_row = self._repository.add( + session, + auth_user_id=auth_user_id, + public_address=public_address, + encrypted_private_address=encrypted_private_address, + wallet_type=wallet_type, + management_type=management_type, + network=network, + ) + + return new_row + + add = with_db_session(ro=False)(_add) + + @with_db_session(ro=True) + def get_by_id(self, session, model_id, allow_inactive=False, join_list=None): + return self._repository.get_by_id( + session, + model_id, + allow_inactive=allow_inactive, + join_list=join_list, + ) + + @with_db_session(ro=True) + def get_by_public_address(self, session, public_address, network=None, is_active=True): + """Public address is unique in our system. In any case, we should only + find one row for the given public address. + + Args: + session: A database session object. + public_address (str): A public address. + network (str): A network name. + is_active (boolean): A boolean value to denote if the query should + retrieve active or inactive rows. + + Returns: + A formatted row, either in presenter form or raw db row. + """ + filters = [ + auth_wallet_model.public_address == public_address, + # auth_wallet_model.is_active == is_active, + ] + + if network: + filters.append(auth_wallet_model.network == network) + + row = self._repository.get_by(session, filters=filters, allow_inactive=not is_active) + + if not row: + return None + + return one(row) + + @with_db_session(ro=True) + def get_by_auth_user_id( + self, + session, + auth_user_id, + network=None, + wallet_type=None, + is_active=True, + join_list=None, + ): + """Return all the associated wallets for the given user id. + + Args: + session: A database session object. + auth_user_id (ObjectID): A auth_user id. + network (str|None): A network name. + wallet_type (str|None): a wallet type like ETH or BTC + is_active (boolean): A boolean value to denote if the query should + retrieve active or inactive rows. + join_list (None|List): Table you wish to join. + + Returns: + An empty list or a list of wallets. + """ + filters = [ + auth_wallet_model.auth_user_id == auth_user_id, + # auth_wallet_model.is_active == is_active, + ] + + if network: + filters.append(auth_wallet_model.network == network) + + if wallet_type: + filters.append(auth_wallet_model.wallet_type == wallet_type) + + rows = self._repository.get_by( + session, filters=filters, join_list=join_list, allow_inactive=not is_active + ) + + if not rows: + return [] + + return rows + + def _update_by_id(self, session, model_id, **kwargs): + self._repository.update(session, model_id, **kwargs) + + update_by_id = with_db_session(ro=False)(_update_by_id) diff --git a/src/quart_sqlalchemy/sim/main.py b/src/quart_sqlalchemy/sim/main.py new file mode 100644 index 0000000..7bd73cd --- /dev/null +++ b/src/quart_sqlalchemy/sim/main.py @@ -0,0 +1,9 @@ +from .app import app +from .views import api + + +app.register_blueprint(api) + + +if __name__ == "__main__": + app.run(port=8080) diff --git a/src/quart_sqlalchemy/sim/model.py b/src/quart_sqlalchemy/sim/model.py new file mode 100644 index 0000000..efee04c --- /dev/null +++ b/src/quart_sqlalchemy/sim/model.py @@ -0,0 +1,163 @@ +import secrets +import typing as t +from datetime import datetime +from enum import Enum +from enum import IntEnum + +import sqlalchemy +import sqlalchemy.orm +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import Mapped + +from quart_sqlalchemy.model import SoftDeleteMixin +from quart_sqlalchemy.model import TimestampMixin +from quart_sqlalchemy.sim.app import db +from quart_sqlalchemy.sim.util import ObjectID + + +sa = sqlalchemy + + +class StrEnum(str, Enum): + def __str__(self) -> str: + return str.__str__(self) + + +class ConnectInteropStatus(StrEnum): + ENABLED = "ENABLED" + DISABLED = "DISABLED" + + +class Provenance(Enum): + LINK = 1 + OAUTH = 2 + WEBAUTHN = 3 + SMS = 4 + IDENTIFIER = 5 + FEDERATED = 6 + + +class EntityType(Enum): + FORTMATIC = 1 + MAGIC = 2 + CONNECT = 3 + + +class WalletManagementType(IntEnum): + UNDELEGATED = 1 + DELEGATED = 2 + + +class WalletType(StrEnum): + ETH = "ETH" + HARMONY = "HARMONY" + ICON = "ICON" + FLOW = "FLOW" + TEZOS = "TEZOS" + ZILLIQA = "ZILLIQA" + POLKADOT = "POLKADOT" + SOLANA = "SOLANA" + AVAX = "AVAX" + ALGOD = "ALGOD" + COSMOS = "COSMOS" + CELO = "CELO" + BITCOIN = "BITCOIN" + NEAR = "NEAR" + HELIUM = "HELIUM" + CONFLUX = "CONFLUX" + TERRA = "TERRA" + TAQUITO = "TAQUITO" + ED = "ED" + HEDERA = "HEDERA" + + +class MagicClient(db.Model, SoftDeleteMixin, TimestampMixin): + __tablename__ = "magic_client" + + id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + app_name: Mapped[str] = sa.orm.mapped_column(default="my new app") + rate_limit_tier: Mapped[t.Optional[str]] + connect_interop: Mapped[t.Optional[ConnectInteropStatus]] + is_signing_modal_enabled: Mapped[bool] = sa.orm.mapped_column(default=False) + global_audience_enabled: Mapped[bool] = sa.orm.mapped_column(default=False) + + public_api_key: Mapped[str] = sa.orm.mapped_column(default_factory=secrets.token_hex) + secret_api_key: Mapped[str] = sa.orm.mapped_column(default_factory=secrets.token_hex) + + auth_users: Mapped[t.List["AuthUser"]] = sa.orm.relationship( + back_populates="magic_client", + primaryjoin="and_(foreign(AuthUser.client_id) == MagicClient.id, AuthUser.user_type != 1)", + ) + + +class AuthUser(db.Model, SoftDeleteMixin, TimestampMixin): + __tablename__ = "auth_user" + + id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + email: Mapped[t.Optional[str]] = sa.orm.mapped_column(index=True) + phone_number: Mapped[t.Optional[str]] = sa.orm.mapped_column(index=True) + user_type: Mapped[int] = sa.orm.mapped_column(default=EntityType.FORTMATIC.value) + date_verified: Mapped[t.Optional[datetime]] + provenance: Mapped[t.Optional[Provenance]] + is_admin: Mapped[bool] = sa.orm.mapped_column(default=False) + client_id: Mapped[ObjectID] + linked_primary_auth_user_id: Mapped[t.Optional[ObjectID]] + global_auth_user_id: Mapped[t.Optional[ObjectID]] + + delegated_user_id: Mapped[t.Optional[str]] + delegated_identity_pool_id: Mapped[t.Optional[str]] + + current_session_token: Mapped[t.Optional[str]] + + magic_client: Mapped[MagicClient] = sa.orm.relationship( + back_populates="auth_user", + uselist=False, + ) + linked_primary_auth_user = sa.orm.relationship( + "AuthUser", + remote_side=[id], + lazy="joined", + join_depth=1, + uselist=False, + ) + wallets: Mapped[t.List["AuthWallet"]] = sa.orm.relationship(back_populates="auth_user") + + @hybrid_property + def is_email_verified(self): + return self.email is not None and self.date_verified is not None + + @hybrid_property + def is_waiting_on_email_verification(self): + return self.email is not None and self.date_verified is None + + @hybrid_property + def is_new_signup(self): + return self.date_verified is None + + @hybrid_property + def has_linked_primary_auth_user(self): + return bool(self.linked_primary_auth_user_id) + + @hybrid_property + def is_magic_connect_user(self): + return self.global_auth_user_id is not None and self.user_type == EntityType.CONNECT.value + + +class AuthWallet(db.Model, SoftDeleteMixin, TimestampMixin): + __tablename__ = "auth_user" + + id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + auth_user_id: Mapped[ObjectID] = sa.orm.mapped_column(sa.ForeignKey("auth_user.id")) + wallet_type: Mapped[str] = sa.orm.mapped_column(default=WalletType.ETH.value) + management_type: Mapped[int] = sa.orm.mapped_column( + default=WalletManagementType.UNDELEGATED.value + ) + public_address: Mapped[t.Optional[str]] = sa.orm.mapped_column(index=True) + encrypted_private_address: Mapped[t.Optional[str]] + network: Mapped[str] + is_exported: Mapped[bool] = sa.orm.mapped_column(default=False) + + auth_user: Mapped[AuthUser] = sa.orm.relationship( + back_populates="auth_wallets", + uselist=False, + ) diff --git a/src/quart_sqlalchemy/sim/repo.py b/src/quart_sqlalchemy/sim/repo.py new file mode 100644 index 0000000..d257b5f --- /dev/null +++ b/src/quart_sqlalchemy/sim/repo.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import typing as t +from abc import ABCMeta +from abc import abstractmethod + +import sqlalchemy +import sqlalchemy.event +import sqlalchemy.exc +import sqlalchemy.orm +import sqlalchemy.sql + +from quart_sqlalchemy.types import ColumnExpr +from quart_sqlalchemy.types import EntityIdT +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import ORMOption +from quart_sqlalchemy.types import Selectable +from quart_sqlalchemy.types import SessionT + +from .builder import StatementBuilder + + +sa = sqlalchemy + + +class AbstractRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): + """A repository interface.""" + + # identity: t.Type[EntityIdT] + + # def __init__(self, model: t.Type[EntityT]): + # self.model = model + + @property + def model(self) -> t.Type[EntityT]: + return self.__orig_class__.__args__[0] # type: ignore + + @property + def identity(self) -> t.Type[EntityIdT]: + return self.__orig_class__.__args__[1] # type: ignore + + +class AbstractBulkRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): + """A repository interface for bulk operations. + + Note: this interface circumvents ORM internals, breaking commonly expected behavior in order + to gain performance benefits. Only use this class whenever absolutely necessary. + """ + + @property + def model(self) -> t.Type[EntityT]: + return self.__orig_class__.__args__[0] # type: ignore + + @property + def identity(self) -> t.Type[EntityIdT]: + return self.__orig_class__.__args__[1] # type: ignore + + + +class SQLAlchemyRepository( + AbstractRepository[EntityT, EntityIdT], + t.Generic[EntityT, EntityIdT], +): + """A repository that uses SQLAlchemy to persist data. + + The biggest change with this repository is that for methods returning multiple results, we + return the sa.ScalarResult so that the caller has maximum flexibility in how it's consumed. + + As a result, when calling a method such as get_by, you then need to decide how to fetch the + result. + + Methods of fetching results: + - .all() to return a list of results + - .first() to return the first result + - .one() to return the first result or raise an exception if there are no results + - .one_or_none() to return the first result or None if there are no results + - .partitions(n) to return a results as a list of n-sized sublists + + Additionally, there are methods for transforming the results prior to fetching. + + Methods of transforming results: + - .unique() to apply unique filtering to the result + + """ + + session: sa.orm.Session + builder: StatementBuilder + + def __init__(self, session: sa.orm.Session, **kwargs): + super().__init__(**kwargs) + self.session = session + self.builder = StatementBuilder(None) + + def insert(self, values: t.Dict[str, t.Any]) -> EntityT: + """Insert a new model into the database.""" + new = self.model(**values) + self.session.add(new) + self.session.flush() + self.session.refresh(new) + return new + + def update(self, id_: EntityIdT, values: t.Dict[str, t.Any]) -> EntityT: + """Update existing model with new values.""" + obj = self.session.get(self.model, id_) + if obj is None: + raise ValueError(f"Object with id {id_} not found") + for field, value in values.items(): + if getattr(obj, field) != value: + setattr(obj, field, value) + self.session.flush() + self.session.refresh(obj) + return obj + + def merge( + self, id_: EntityIdT, values: t.Dict[str, t.Any], for_update: bool = False + ) -> EntityT: + """Merge model in session/db having id_ with values.""" + self.session.get(self.model, id_) + values.update(id=id_) + merged = self.session.merge(self.model(**values)) + self.session.flush() + self.session.refresh(merged, with_for_update=for_update) # type: ignore + return merged + + def get( + self, + id_: EntityIdT, + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + for_update: bool = False, + include_inactive: bool = False, + ) -> t.Optional[EntityT]: + """Get object identified by id_ from the database. + + Note: It's a common misconception that session.get(Model, id) is akin to a shortcut for + a select(Model).where(Model.id == id) like statement. However this is not the case. + + Session.get is actually used for looking an object up in the sessions identity map. When + present it will be returned directly, when not, a database lookup will be performed. + + For use cases where this is what you actually want, you can still access the original get + method on self.session. For most uses cases, this behavior can introduce non-determinism + and because of that this method performs lookup using a select statement. Additionally, + to satisfy the expected interface's return type: Optional[EntityT], one_or_none is called + on the result before returning. + """ + execution_options = execution_options or {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + + statement = sa.select(self.model).where(self.model.id == id_).limit(1) # type: ignore + + for option in options: + statement = statement.options(option) + + if for_update: + statement = statement.with_for_update() + + return self.session.scalars(statement, execution_options=execution_options).one_or_none() + + def select( + self, + selectables: t.Sequence[Selectable] = (), + conditions: t.Sequence[ColumnExpr] = (), + group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + yield_by_chunk: t.Optional[int] = None, + ) -> t.Union[sa.ScalarResult[EntityT], t.Iterator[t.Sequence[EntityT]]]: + """Select from the database. + + Note: yield_by_chunk is not compatible with the subquery and joined loader strategies, use selectinload for eager loading. + """ + selectables = selectables or (self.model,) # type: ignore + + execution_options = execution_options or {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + if yield_by_chunk: + execution_options.setdefault("yield_per", yield_by_chunk) + + statement = self.builder.complex_select( + selectables, + conditions=conditions, + group_by=group_by, + order_by=order_by, + options=options, + execution_options=execution_options, + offset=offset, + limit=limit, + distinct=distinct, + for_update=for_update, + ) + + results = self.session.scalars(statement) + if yield_by_chunk: + results = results.partitions() + return results + + def delete(self, id_: EntityIdT, include_inactive: bool = False) -> None: + # if self.has_soft_delete: + # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") + + entity = self.get(id_, include_inactive=include_inactive) + if not entity: + raise RuntimeError(f"Entity with id {id_} not found.") + + self.session.delete(entity) + self.session.flush() + + def deactivate(self, id_: EntityIdT) -> EntityT: + # if not self.has_soft_delete: + # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") + + return self.update(id_, dict(is_active=False)) + + def reactivate(self, id_: EntityIdT) -> EntityT: + # if not self.has_soft_delete: + # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") + + return self.update(id_, dict(is_active=False)) + + def exists( + self, + conditions: t.Sequence[ColumnExpr] = (), + for_update: bool = False, + include_inactive: bool = False, + ) -> bool: + """Return whether an object matching conditions exists. + + Note: This performs better than simply trying to select an object since there is no + overhead in sending the selected object and deserializing it. + """ + selectable = sa.sql.literal(True) + + execution_options = {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + + statement = sa.select(selectable).where(*conditions) # type: ignore + + if for_update: + statement = statement.with_for_update() + + result = self.session.execute(statement, execution_options=execution_options).scalar() + + return bool(result) + + +class SQLAlchemyBulkRepository(AbstractBulkRepository, t.Generic[SessionT, EntityT, EntityIdT]): + def __init__(self, session: SessionT, **kwargs: t.Any): + super().__init__(**kwargs) + self.builder = StatementBuilder(self.model) + self.session = session + + def bulk_insert( + self, + values: t.Sequence[t.Dict[str, t.Any]] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Result[t.Any]: + statement = self.builder.bulk_insert(self.model, values) + return self.session.execute(statement, execution_options=execution_options or {}) + + def bulk_update( + self, + conditions: t.Sequence[ColumnExpr] = (), + values: t.Optional[t.Dict[str, t.Any]] = None, + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Result[t.Any]: + statement = self.builder.bulk_update(self.model, conditions, values) + return self.session.execute(statement, execution_options=execution_options or {}) + + def bulk_delete( + self, + conditions: t.Sequence[ColumnExpr] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + ) -> sa.Result[t.Any]: + statement = self.builder.bulk_delete(self.model, conditions) + return self.session.execute(statement, execution_options=execution_options or {}) diff --git a/src/quart_sqlalchemy/sim/repo_adapter.py b/src/quart_sqlalchemy/sim/repo_adapter.py new file mode 100644 index 0000000..bbd96b0 --- /dev/null +++ b/src/quart_sqlalchemy/sim/repo_adapter.py @@ -0,0 +1,309 @@ +import typing as t + +from pydantic import BaseModel +from sqlalchemy import ScalarResult +from sqlalchemy.orm import selectinload, Session +from sqlalchemy.sql.expression import func +from sqlalchemy.sql.expression import label +from quart_sqlalchemy.model import Base +from quart_sqlalchemy.types import ColumnExpr +from quart_sqlalchemy.types import EntityIdT +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import ORMOption +from quart_sqlalchemy.types import Selectable + +from .repo import SQLAlchemyRepository + + +class BaseModelSchema(BaseModel): + class Config: + from_orm = True + + +class BaseCreateSchema(BaseModelSchema): + pass + + +class BaseUpdateSchema(BaseModelSchema): + pass + + +ModelSchemaT = t.TypeVar("ModelSchemaT", bound=BaseModelSchema) +CreateSchemaT = t.TypeVar("CreateSchemaT", bound=BaseCreateSchema) +UpdateSchemaT = t.TypeVar("UpdateSchemaT", bound=BaseUpdateSchema) + + +class RepositoryLegacyAdapter(t.Generic[EntityT, EntityIdT]): + def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT], session: Session,): + self.model = model + self._identity = identity + self._session = session + self.repo = SQLAlchemyRepository[model, identity](session) + + def get_by( + self, + session: t.Optional[Session] = None, + filters=None, + allow_inactive=False, + join_list=None, + order_by_clause=None, + for_update=False, + offset=None, + limit=None, + ) -> t.Sequence[EntityT]: + if filters is None: + raise ValueError("Full table scans are prohibited. Please provide filters") + + join_list = join_list or () + + if order_by_clause is not None: + order_by_clause = (order_by_clause,) + else: + order_by_clause = () + + return self.repo.select( + conditions=filters, + options=[selectinload(getattr(self.model, attr)) for attr in join_list], + for_update=for_update, + order_by=order_by_clause, + offset=offset, + limit=limit, + include_inactive=allow_inactive, + ).all() + + def get_by_id( + self, + session = None, + model_id = None, + allow_inactive=False, + join_list=None, + for_update=False, + ) -> t.Optional[EntityT]: + if model_id is None: + raise ValueError("model_id is required") + join_list = join_list or () + return self.repo.get( + id_=model_id, + options=[selectinload(getattr(self.model, attr)) for attr in join_list], + for_update=for_update, + include_inactive=allow_inactive, + ) + + def one(self, session = None, filters=None, join_list=None, for_update=False, include_inactive=False) -> EntityT: + filters = filters or () + join_list = join_list or () + return self.repo.select( + conditions=filters, + options=[selectinload(getattr(self.model, attr)) for attr in join_list], + for_update=for_update, + include_inactive=include_inactive, + ).one() + + def count_by( + self, + session = None, + filters=None, + group_by=None, + distinct_column=None, + ): + if filters is None: + raise ValueError("Full table scans are prohibited. Please provide filters") + + group_by = group_by or () + + if distinct_column: + selectables = [label("count", func.count(func.distinct(distinct_column)))] + else: + selectables = [label("count", func.count(self.model.id))] + + for group in group_by: + selectables.append(group.expression) + + result = self.repo.select(selectables, conditions=filters, group_by=group_by) + + return result.all() + + def add(self, session = None, **kwargs) -> EntityT: + return self.repo.insert(kwargs) + + def update(self, session = None, model_id, **kwargs) -> EntityT: + return self.repo.update(id_=model_id, values=kwargs) + + def update_by(self, session = None, filters=None, **kwargs) -> EntityT: + if not filters: + raise ValueError("Full table scans are prohibited. Please provide filters") + + row = self.repo.select(conditions=filters, limit=2).one() + return self.repo.update(id_=row.id, values=kwargs) + + def delete_by_id(self, session = None, model_id) -> None: + self.repo.delete(id_=model_id, include_inactive=True) + + def delete_one_by(self, session = None, filters=None, optional=False) -> None: + filters = filters or () + result = self.repo.select(conditions=filters, limit=1) + + if optional: + row = result.one_or_none() + if row is None: + return + else: + row = result.one() + + self.repo.delete(id_=row.id) + + def exist(self, session = None, filters=None, allow_inactive=False) -> bool: + filters = filters or () + return self.repo.exists( + conditions=filters, + include_inactive=allow_inactive, + ) + + def yield_by_chunk( + self, session = None, chunk_size = 100, join_list=None, filters=None, allow_inactive=False + ): + filters = filters or () + join_list = join_list or () + results = self.repo.select( + conditions=filters, + options=[selectinload(getattr(self.model, attr)) for attr in join_list], + include_inactive=allow_inactive, + yield_by_chunk=chunk_size, + ) + for result in results: + yield result + + +class PydanticScalarResult(ScalarResult): + pydantic_schema: t.Type[ModelSchemaT] + + def __init__(self, scalar_result, pydantic_schema: t.Type[ModelSchemaT]): + for attribute in scalar_result.__slots__: + setattr(self, attribute, getattr(scalar_result, attribute)) + self.pydantic_schema = pydantic_schema + + def _translate_many(self, rows): + return [self.pydantic_schema.from_orm(row) for row in rows] + + def _translate_one(self, row): + if row is None: + return + return self.pydantic_schema.from_orm(row) + + def all(self): + return self._translate_many(super().all()) + + def fetchall(self): + return self._translate_many(super().fetchall()) + + def fetchmany(self, *args, **kwargs): + return self._translate_many(super().fetchmany(*args, **kwargs)) + + def first(self): + return self._translate_one(super().first()) + + def one(self): + return self._translate_one(super().one()) + + def one_or_none(self): + return self._translate_one(super().one_or_none()) + + def partitions(self, *args, **kwargs): + for partition in super().partitions(*args, **kwargs): + yield self._translate_many(partition) + + +class PydanticRepository(SQLAlchemyRepository, t.Generic[EntityT, EntityIdT, ModelSchemaT]): + @property + def schema(self) -> t.Type[ModelSchemaT]: + return self.__orig_class__.__args__[2] # type: ignore + + def insert( + self, + create_schema: CreateSchemaT, + sqla_model=False, + ): + create_data = create_schema.dict() + result = super().insert(create_data) + + if sqla_model: + return result + return self.schema.from_orm(result) + + def update( + self, + id_: EntityIdT, + update_schema: UpdateSchemaT, + sqla_model=False, + ): + existing = self.session.query(self.model).get(id_) + if existing is None: + raise ValueError("Model not found") + + update_data = update_schema.dict(exclude_unset=True) + for key, value in update_data.items(): + setattr(existing, key, value) + + self.session.add(existing) + self.session.flush() + self.session.refresh(existing) + if sqla_model: + return existing + return self.schema.from_orm(existing) + + def get( + self, + id_: EntityIdT, + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + for_update: bool = False, + include_inactive: bool = False, + sqla_model: bool = False, + ): + row = super().get( + id_, + options, + execution_options, + for_update, + include_inactive, + ) + if row is None: + return + + if sqla_model: + return row + return self.schema.from_orm(row) + + def select( + self, + selectables: t.Sequence[Selectable] = (), + conditions: t.Sequence[ColumnExpr] = (), + group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + yield_by_chunk: t.Optional[int] = None, + sqla_model: bool = False, + ): + result = super().select( + selectables, + conditions, + group_by, + order_by, + options, + execution_options, + offset, + limit, + distinct, + for_update, + include_inactive, + yield_by_chunk, + ) + if sqla_model: + return result + return PydanticScalarResult(result, self.schema) diff --git a/src/quart_sqlalchemy/sim/schema.py b/src/quart_sqlalchemy/sim/schema.py new file mode 100644 index 0000000..f47225a --- /dev/null +++ b/src/quart_sqlalchemy/sim/schema.py @@ -0,0 +1,14 @@ +from datetime import datetime + +from pydantic import BaseModel + +from .util import ObjectID + + +class BaseSchema(BaseModel): + class Config: + arbitrary_types_allowed = True + json_encoders = { + ObjectID: lambda v: v.encode(), + datetime: lambda dt: int(dt.timestamp()), + } diff --git a/src/quart_sqlalchemy/sim/signals.py b/src/quart_sqlalchemy/sim/signals.py new file mode 100644 index 0000000..1981cea --- /dev/null +++ b/src/quart_sqlalchemy/sim/signals.py @@ -0,0 +1,33 @@ +from blinker import Namespace + + +# Synchronous signals +_sync = Namespace() + +auth_user_duplicate = _sync.signal( + "auth_user_duplicate", + doc="""Called on discovery of at least one duplicate auth user. + + Handlers should have the following signature: + def handler( + current_app: Quart, + original_auth_user_id: ObjectID, + duplicate_auth_user_ids: List[ObjectID], + ) -> None: + ... + """, +) + +keys_rolled = _sync.signal( + "keys_rolled", + doc="""Called after api keys are rolled. + + Handlers should have the following signature: + def handler( + app: Quart, + deactivated_keys: Dict[str, Any], + redis_client: Redis, + ) -> None: + ... + """, +) diff --git a/src/quart_sqlalchemy/sim/util.py b/src/quart_sqlalchemy/sim/util.py new file mode 100644 index 0000000..463b350 --- /dev/null +++ b/src/quart_sqlalchemy/sim/util.py @@ -0,0 +1,101 @@ +import logging +import numbers + +from hashids import Hashids + + +logger = logging.getLogger(__name__) + + +class CryptographyError(Exception): + pass + + +class DecryptionError(CryptographyError): + pass + + +def one(input_list): + if len(input_list) != 1: + raise ValueError(f"Expected a list of length 1, got {len(input_list)}") + return input_list[0] + + +class ObjectID: + hashids = Hashids(min_length=12) + + def __init__(self, input_value): + if input_value is None: + raise ValueError("ObjectID cannot be None") + elif isinstance(input_value, ObjectID): + self._source_id = input_value._decoded_id + elif isinstance(input_value, str): + self._source_id = input_value + self._decode() + elif isinstance(input_value, numbers.Number): + try: + input_value = int(input_value) + except (ValueError, TypeError): + pass + + self._source_id = input_value + self._encode() + + @property + def _encoded_id(self): + return self._encode() + + @property + def _decoded_id(self): + return self._decode() + + def __eq__(self, other): + if isinstance(other, ObjectID): + return self._decoded_id == other._decoded_id and self._encoded_id == other._encoded_id + elif isinstance(other, int): + return self._decoded_id == other + elif isinstance(other, str): + return self._encoded_id == other + else: + return False + + def __lt__(self, other): + if isinstance(other, ObjectID): + return self._decoded_id < other._decoded_id + return False + + def __hash__(self): + return hash(tuple([self._encoded_id, self._decoded_id])) + + def __str__(self): + return "{encoded_id}".format(encoded_id=self._encoded_id) + + def __int__(self): + return self._decoded_id + + def __repr__(self): + return f"{type(self).__name__}({self._decoded_id})" + + def __json__(self): + return self.__str__() + + def _encode(self): + if isinstance(self._source_id, int): + return self.hashids.encode(self._source_id) + else: + return self._source_id + + def encode(self): + return self._encoded_id + + def _decode(self): + if isinstance(self._source_id, int): + return self._source_id + else: + return self.hashids.decode(self._source_id) + + def decode(self): + return self._decoded_id + + def decode_str(self): + return str(self._decoded_id) diff --git a/src/quart_sqlalchemy/sim/views/__init__.py b/src/quart_sqlalchemy/sim/views/__init__.py new file mode 100644 index 0000000..d110bc2 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/__init__.py @@ -0,0 +1,18 @@ +from quart import Blueprint +from quart import g + +from .auth_user import api as auth_user_api +from .auth_wallet import api as auth_wallet_api +from .magic_client import api as magic_client_api + + +api = Blueprint("api", __name__, url_prefix="api") + +api.register_blueprint(auth_user_api) +api.register_blueprint(auth_wallet_api) +api.register_blueprint(magic_client_api) + + +@api.before_request +def set_feature_owner(): + g.request_feature_owner = "magic" diff --git a/src/quart_sqlalchemy/sim/views/auth_user.py b/src/quart_sqlalchemy/sim/views/auth_user.py new file mode 100644 index 0000000..ee456e6 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/auth_user.py @@ -0,0 +1,7 @@ +import logging + +from quart import Blueprint + + +logger = logging.getLogger(__name__) +api = Blueprint("auth_user", __name__, url_prefix="auth_user") diff --git a/src/quart_sqlalchemy/sim/views/auth_wallet.py b/src/quart_sqlalchemy/sim/views/auth_wallet.py new file mode 100644 index 0000000..24fc9fc --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/auth_wallet.py @@ -0,0 +1,62 @@ +import logging +import typing as t + +from quart import Blueprint +from quart import g +from quart.utils import run_sync +from quart_schema.validation import validate + +from ..model import WalletManagementType +from ..model import WalletType +from ..schema import BaseSchema +from ..util import ObjectID +from .decorator import authorized_request + + +logger = logging.getLogger(__name__) +api = Blueprint("auth_wallet", __name__, url_prefix="auth_wallet") + + +@api.before_request +def set_feature_owner(): + g.request_feature_owner = "wallet" + + +class WalletSyncRequest(BaseSchema): + public_address: str + encrypted_private_address: str + wallet_type: str + hd_path: t.Optional[str] = None + encrypted_seed_phrase: t.Optional[str] = None + + +class WalletSyncResponse(BaseSchema): + wallet_id: ObjectID + auth_user_id: ObjectID + wallet_type: WalletType + public_address: str + encrypted_private_address: str + + +@authorized_request(authenticate_client=True, authenticate_user=True) +@validate(request=WalletSyncRequest, responses={200: (WalletSyncResponse, None)}) +@api.route("/sync", methods=["POST"]) +async def sync_auth_user_wallet(data: WalletSyncRequest): + try: + with g.bind.Session() as session: + wallet = await run_sync(g.h.AuthWallet(session).sync_auth_wallet)( + g.auth.user.id, + data.public_address, + data.encrypted_private_address, + WalletManagementType.DELEGATED.value, + ) + except RuntimeError: + raise RuntimeError("Unsupported wallet type or network") + + return WalletSyncResponse( + wallet_id=wallet.id, + auth_user_id=wallet.auth_user_id, + wallet_type=wallet.wallet_type, + public_address=wallet.public_address, + encrypted_private_address=wallet.encrypted_private_address, + ) diff --git a/src/quart_sqlalchemy/sim/views/decorator.py b/src/quart_sqlalchemy/sim/views/decorator.py new file mode 100644 index 0000000..357a819 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/decorator.py @@ -0,0 +1,131 @@ +# import inspect +# import typing as t +# from dataclasses import asdict +# from dataclasses import is_dataclass +from functools import wraps + +# from pydantic import BaseModel +# from pydantic import ValidationError +# from pydantic.dataclasses import dataclass as pydantic_dataclass +# from pydantic.schema import model_schema +from quart import current_app +from quart import g + + +# from quart import request +# from quart import ResponseReturnValue as QuartResponseReturnValue +# from quart_schema.typing import Model +# from quart_schema.typing import PydanticModel +# from quart_schema.typing import ResponseReturnValue +# from quart_schema.validation import _convert_headers +# from quart_schema.validation import DataSource +# from quart_schema.validation import QUART_SCHEMA_RESPONSE_ATTRIBUTE +# from quart_schema.validation import ResponseHeadersValidationError +# from quart_schema.validation import ResponseSchemaValidationError +# from quart_schema.validation import validate_headers +# from quart_schema.validation import validate_querystring +# from quart_schema.validation import validate_request + + +def authorized_request(authenticate_client: bool = False, authenticate_user: bool = False): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + if authenticate_client: + if not g.auth.client: + raise RuntimeError("Unable to authenticate client") + kwargs.update(client_id=g.auth.client.id) + + if authenticate_user: + if not g.auth.user: + raise RuntimeError("Unable to authenticate user") + kwargs.update(user_id=g.auth.user.id) + + return await current_app.ensure_async(func)(*args, **kwargs) + + return wrapper + + return decorator + + +# def validate_response() -> t.Callable: +# def decorator( +# func: t.Callable[..., ResponseReturnValue] +# ) -> t.Callable[..., QuartResponseReturnValue]: +# undecorated = func +# while hasattr(undecorated, "__wrapped__"): +# undecorated = undecorated.__wrapped__ + +# signature = inspect.signature(undecorated) +# derived_schema = signature.return_annotation or dict + +# schemas = getattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, {}) +# schemas[200] = (derived_schema, None) +# setattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, schemas) + +# @wraps(func) +# async def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: +# result = await current_app.ensure_async(func)(*args, **kwargs) + +# status_or_headers = None +# headers = None +# if isinstance(result, tuple): +# value, status_or_headers, headers = result + (None,) * (3 - len(result)) +# else: +# value = result + +# status = 200 +# if isinstance(status_or_headers, int): +# status = int(status_or_headers) + +# schemas = getattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, {200: dict}) +# model_class = schemas.get(status, dict) + +# try: +# if isinstance(value, dict): +# model_value = model_class(**value) +# elif type(value) == model_class: +# model_value = value +# elif is_dataclass(value): +# model_value = model_class(**asdict(value)) +# else: +# return result, status, headers + +# except ValidationError as error: +# raise ResponseHeadersValidationError(error) + +# headers_value = headers +# return model_value, status, headers_value + +# return wrapper + +# return decorator + + +# def validate( +# *, +# querystring: t.Optional[Model] = None, +# request: t.Optional[Model] = None, +# request_source: DataSource = DataSource.JSON, +# headers: t.Optional[Model] = None, +# responses: t.Dict[int, t.Tuple[Model, t.Optional[Model]]], +# ) -> t.Callable: +# """Validate the route. + +# This is a shorthand combination of of the validate_querystring, +# validate_request, validate_headers, and validate_response +# decorators. Please see the docstrings for those decorators. +# """ + +# def decorator(func: t.Callable) -> t.Callable: +# if querystring is not None: +# func = validate_querystring(querystring)(func) +# if request is not None: +# func = validate_request(request, source=request_source)(func) +# if headers is not None: +# func = validate_headers(headers)(func) +# for status, models in responses.items(): +# func = validate_response(models[0], status, models[1]) +# return func + +# return decorator diff --git a/src/quart_sqlalchemy/sim/views/magic_client.py b/src/quart_sqlalchemy/sim/views/magic_client.py new file mode 100644 index 0000000..a1d28e1 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/magic_client.py @@ -0,0 +1,7 @@ +import logging + +from quart import Blueprint + + +logger = logging.getLogger(__name__) +api = Blueprint("magic_client", __name__, url_prefix="magic_client") From fc9ae801a8cb5716894ab4ce0f1cd01b786b52ff Mon Sep 17 00:00:00 2001 From: Joe Black Date: Thu, 30 Mar 2023 18:34:08 -0400 Subject: [PATCH 02/11] clean --- src/quart_sqlalchemy/sim/handle.py | 8 +- src/quart_sqlalchemy/sim/legacy.py | 404 ----------------------------- 2 files changed, 1 insertion(+), 411 deletions(-) delete mode 100644 src/quart_sqlalchemy/sim/legacy.py diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py index 7076309..6ed473c 100644 --- a/src/quart_sqlalchemy/sim/handle.py +++ b/src/quart_sqlalchemy/sim/handle.py @@ -37,12 +37,6 @@ class APIKeySet(t.NamedTuple): secret_key: str -def get_session_proxy(): - from .app import db - - return db.bind.Session() - - class HandlerBase: logic: Logic session: Session @@ -51,7 +45,7 @@ class HandlerBase: """ def __init__(self, session: t.Optional[Session], logic: t.Optional[Logic] = None): - self.session = session or get_session_proxy() + self.session = session self.logic = logic or Logic() diff --git a/src/quart_sqlalchemy/sim/legacy.py b/src/quart_sqlalchemy/sim/legacy.py deleted file mode 100644 index 5d4662f..0000000 --- a/src/quart_sqlalchemy/sim/legacy.py +++ /dev/null @@ -1,404 +0,0 @@ -from sqlalchemy import exists -from sqlalchemy.orm import joinedload -from sqlalchemy.sql.expression import func -from sqlalchemy.sql.expression import label - - -def one(input_list): - (item,) = input_list - return item - - -class SQLRepository: - def __init__(self, model): - self._model = model - assert self._model is not None - - @property - def _has_is_active_field(self): - return bool(getattr(self._model, "is_active", None)) - - def get_by_id( - self, - session, - model_id, - allow_inactive=False, - join_list=None, - for_update=False, - ): - """SQL get interface to retrieve by model's id column. - - Args: - session: A database session object. - model_id: The id of the given model to be retrieved. - allow_inactive: Whether to include inactive or not. - join_list: A list of attributes to be joined in the same session for - given model. This is normally the attributes that have - relationship defined and referenced to other models. - for_update: Locks the table for update. - - Returns: - Data retrieved from the database for the model. - """ - query = session.query(self._model) - - if join_list: - for to_join in join_list: - query = query.options(joinedload(to_join)) - - if for_update: - query = query.with_for_update() - - row = query.get(model_id) - - if row is None: - return None - - if self._has_is_active_field and not row.is_active and not allow_inactive: - return None - - return row - - def get_by( - self, - session, - filters=None, - join_list=None, - order_by_clause=None, - for_update=False, - offset=None, - limit=None, - ): - """SQL get_by interface to retrieve model instances based on the given - filters. - - Args: - session: A database session object. - filters: A list of filters on the models. - join_list: A list of attributes to be joined in the same session for - given model. This is normally the attributes that have - relationship defined and referenced to other models. - order_by_clause: An order by clause. - for_update: Locks the table for update. - - Returns: - Modified rows. - - TODO(ajen#ch21549|2020-07-21): Filter out `is_active == False` row. This - will not be a trivial change as many places rely on this method and the - handlers/logics sometimes filter by in_active. Sometimes endpoints might - get affected. Proceed with caution. - """ - # If no filter is given, just return. Prevent table scan. - if filters is None: - return None - - query = session.query(self._model).filter(*filters).order_by(order_by_clause) - - if for_update: - query = query.with_for_update() - - if offset: - query = query.offset(offset) - - if limit: - query = query.limit(limit) - - # Prevent loading all the rows. - if limit == 0: - return [] - - if join_list: - for to_join in join_list: - query = query.options(joinedload(to_join)) - - return query.all() - - def count_by( - self, - session, - filters=None, - group_by=None, - distinct_column=None, - ): - """SQL count_by interface to retrieve model instance count based on the given - filters. - - Args: - session: A database session object. - filters (list): Required; a list of filters on the models. - group_by (list): A list of optional group by expressions. - Returns: - A list of counts of rows. - Raises: - ValueError: Returns a value error, when no filters are provided - """ - # Prevent table scans - if filters is None: - raise ValueError("Full table scans are prohibited. Please provide filters") - - select = [label("count", func.count(self._model.id))] - - if distinct_column: - select = [label("count", func.count(func.distinct(distinct_column)))] - - if group_by: - for group in group_by: - select.append(group.expression) - - query = session.query(*select).filter(*filters) - - if group_by: - query = query.group_by(*group_by) - - return query.all() - - def sum_by( - self, - session, - column, - filters=None, - group_by=None, - ): - """SQL sum_by interface to retrieve aggregate sum of column values for given - filters. - - Args: - session: A database session object. - column (sqlalchemy.Column): Required; the column to sum by. - filters (list): Required; a list of filters to apply to the query - group_by (list): A list of optional group by expressions. - - Returns: - A scalar value representing the sum or None. - - Raises: - ValueError: Returns a value error, when no filters are provided - """ - - # Prevent table scans - if filters is None: - raise ValueError("Full table scans are prohibited. Please provide filters") - - query = session.query(func.sum(column)).filter(*filters) - - if group_by: - query = query.group_by(*group_by) - - return query.scalar() - - def one(self, session, filters=None, join_list=None, for_update=False): - """SQL filtering interface to retrieve the single model instance matching - filter criteria. - - If there are more than one instances, an exception is raised. - - Args: - session: A database session object. - filters: A list of filters on the models. - for_update: Locks the table for update. - - Returns: - A model instance: If one row is found in the db. - None: If no result is found. - """ - row = self.get_by(session, filters=filters, join_list=join_list, for_update=for_update) - - if not row: - return None - - return one(row) - - def update(self, session, model_id, **kwargs): - """SQL update interface to modify data in a given model instance. - - Args: - session: A database session object. - model_id: The id of the given model to be modified. - kwargs: Any fields defined on the models. - - Returns: - Modified rows. - - Note: - We use session.flush() here to move the changes from the application - to SQL database. However, those changes will be in the pending changes - state. Meaning, it is in the queue to be inserted but yet to be done - so until session.commit() is called, which has been taken care of - in our ``with_db_session`` decorator or ``LogicComponent.begin`` - contextmanager. - """ - modified_row = session.query(self._model).get(model_id) - if modified_row is None: - return None - - for key, value in kwargs.items(): - setattr(modified_row, key, value) - - # Flush out our changes to DB transaction buffer but don't commit it yet. - # This is useful in the case when we want to rollback atomically on multiple - # sql operations in the same transaction which may or may not have - # dependencies. - session.flush() - - return modified_row - - def update_by(self, session, filters=None, **kwargs): - """SQL update_by interface to modify data for a given list of filters. - The filters should be provided so it can narrow down to one row. - - Args: - session: A database session object. - filters: A list of filters on the models. - kwargs: Any fields defined on the models. - - Returns: - Modified row. - - Raises: - sqlalchemy.orm.exc.NoResultFound - when no result is found. - sqlalchemy.orm.exc.MultipleResultsFound - when multiple result is found. - """ - # If no filter is given, just return. Prevent table scan. - if filters is None: - return None - - modified_row = session.query(self._model).filter(*filters).one() - for key, value in kwargs.items(): - setattr(modified_row, key, value) - - # Flush out our changes to DB transaction buffer but don't commit it yet. - # This is useful in the case when we want to rollback atomically on multiple - # sql operations in the same transaction which may or may not have - # dependencies. - session.flush() - - return modified_row - - def delete_one_by(self, session, filters=None, optional=False): - """SQL update_by interface to delete data for a given list of filters. - The filters should be provided so it can narrow down to one row. - - Note: Careful consideration should be had prior to using this function. - Always consider setting rows as inactive instead before choosing to use - this function. - - Args: - session: A database session object. - filters: A list of filters on the models. - optional: Whether deletion is optional; i.e. it's OK for the model not to exist - - Returns: - None. - - Raises: - sqlalchemy.orm.exc.NoResultFound - when no result is found and optional is False. - sqlalchemy.orm.exc.MultipleResultsFound - when multiple result is found. - """ - # If no filter is given, just return. Prevent table scan. - if filters is None: - return None - - if optional: - rows = session.query(self._model).filter(*filters).all() - - if not rows: - return None - - row = one(rows) - - else: - row = session.query(self._model).filter(*filters).one() - - session.delete(row) - - # Flush out our changes to DB transaction buffer but don't commit it yet. - # This is useful in the case when we want to rollback atomically on multiple - # sql operations in the same transaction which may or may not have - # dependencies. - session.flush() - - def delete_by_id(self, session, model_id): - return session.query(self._model).get(model_id).delete() - - def add(self, session, **kwargs): - """SQL add interface to insert data to the given model. - - Args: - session: A database session object. - kwargs: Any fields defined on the models. - - Returns: - Newly inserted rows. - - Note: - We use session.flush() here to move the changes from the application - to SQL database. However, those changes will be in the pending changes - state. Meaning, it is in the queue to be inserted but yet to be done - so until session.commit() is called, which has been taken care of - in our ``with_db_session`` decorator or ``LogicComponent.begin`` - contextmanager. - """ - new_row = self._model(**kwargs) - session.add(new_row) - - # Flush out our changes to DB transaction buffer but don't commit it yet. - # This is useful in the case when we want to rollback atomically on multiple - # sql operations in the same transaction which may or may not have - # dependencies. - session.flush() - - return new_row - - def exist(self, session, filters=None): - """SQL exist interface to check if any row exists at all for the given - filters. - - Args: - session: A database session object. - filters: A list of filters on the models. - - Returns: - A boolean. True if any row exists else False. - """ - exist_query = exists() - - for query_filter in filters: - exist_query = exist_query.where(query_filter) - - return session.query(exist_query).scalar() - - def yield_by_chunk(self, session, chunk_size, join_list=None, filters=None): - """This yields a batch of the model objects for the given chunk_size. - - Args: - session: A database session object. - chunk_size (int): The size of the chunk. - filters: A list of filters on the model. - join_list: A list of attributes to be joined in the same session for - given model. This is normally the attributes that have - relationship defined and referenced to other models. - - Returns: - A batch for the given chunk size. - """ - query = session.query(self._model) - - if filters is not None: - query = query.filter(*filters) - - if join_list: - for to_join in join_list: - query = query.options(joinedload(to_join)) - - start = 0 - - while True: - stop = start + chunk_size - model_objs = query.slice(start, stop).all() - if not model_objs: - break - - yield model_objs - - start += chunk_size From d660a7fa136082c1649520a3ff94c425e296acab Mon Sep 17 00:00:00 2001 From: Joe Black Date: Thu, 30 Mar 2023 19:16:56 -0400 Subject: [PATCH 03/11] fix circular dependencies --- src/quart_sqlalchemy/sim/app.py | 8 ++--- src/quart_sqlalchemy/sim/handle.py | 15 +++++----- src/quart_sqlalchemy/sim/logic.py | 38 ++++++++---------------- src/quart_sqlalchemy/sim/main.py | 8 ++--- src/quart_sqlalchemy/sim/model.py | 6 ++-- src/quart_sqlalchemy/sim/repo.py | 4 +-- src/quart_sqlalchemy/sim/repo_adapter.py | 2 +- src/quart_sqlalchemy/sim/schema.py | 2 +- src/quart_sqlalchemy/sqla.py | 8 +++++ 9 files changed, 41 insertions(+), 50 deletions(-) diff --git a/src/quart_sqlalchemy/sim/app.py b/src/quart_sqlalchemy/sim/app.py index 91b0ec8..e462fa0 100644 --- a/src/quart_sqlalchemy/sim/app.py +++ b/src/quart_sqlalchemy/sim/app.py @@ -12,10 +12,10 @@ from quart import Response from quart_schema import QuartSchema -from .. import Base -from .. import SQLAlchemyConfig -from ..framework import QuartSQLAlchemy -from .util import ObjectID +from quart_sqlalchemy import Base +from quart_sqlalchemy import SQLAlchemyConfig +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.sim.util import ObjectID AUTHORIZATION_PATTERN = re.compile(r"Bearer (?P.+)") diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py index 6ed473c..c81b71a 100644 --- a/src/quart_sqlalchemy/sim/handle.py +++ b/src/quart_sqlalchemy/sim/handle.py @@ -5,14 +5,13 @@ from sqlalchemy.orm import Session from quart_sqlalchemy import Bind - -from . import signals -from .logic import LogicComponent as Logic -from .model import AuthUser -from .model import AuthWallet -from .model import EntityType -from .model import WalletType -from .util import ObjectID +from quart_sqlalchemy.sim import signals +from quart_sqlalchemy.sim.logic import LogicComponent as Logic +from quart_sqlalchemy.sim.model import AuthUser +from quart_sqlalchemy.sim.model import AuthWallet +from quart_sqlalchemy.sim.model import EntityType +from quart_sqlalchemy.sim.model import WalletType +from quart_sqlalchemy.sim.util import ObjectID logger = logging.getLogger(__name__) diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py index dee3990..a13397d 100644 --- a/src/quart_sqlalchemy/sim/logic.py +++ b/src/quart_sqlalchemy/sim/logic.py @@ -3,36 +3,22 @@ from datetime import datetime from functools import wraps -from pydantic import BaseModel -from pydantic import Field from sqlalchemy import or_ -from sqlalchemy import ScalarResult from sqlalchemy.orm import contains_eager -from sqlalchemy.orm import DeclarativeBase -from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from sqlalchemy.sql.expression import func -from sqlalchemy.sql.expression import label - -from quart_sqlalchemy.model import Base -from quart_sqlalchemy.types import ColumnExpr -from quart_sqlalchemy.types import EntityIdT -from quart_sqlalchemy.types import EntityT -from quart_sqlalchemy.types import ORMOption -from quart_sqlalchemy.types import Selectable - -from . import signals -from .model import AuthUser as auth_user_model -from .model import AuthWallet as auth_wallet_model -from .model import ConnectInteropStatus -from .model import EntityType -from .model import MagicClient as magic_client_model -from .model import Provenance -from .model import WalletType -from .repo import SQLAlchemyRepository -from .repo_adapter import RepositoryLegacyAdapter -from .util import ObjectID -from .util import one + +from quart_sqlalchemy.sim import signals +from quart_sqlalchemy.sim.model import AuthUser as auth_user_model +from quart_sqlalchemy.sim.model import AuthWallet as auth_wallet_model +from quart_sqlalchemy.sim.model import ConnectInteropStatus +from quart_sqlalchemy.sim.model import EntityType +from quart_sqlalchemy.sim.model import MagicClient as magic_client_model +from quart_sqlalchemy.sim.model import Provenance +from quart_sqlalchemy.sim.model import WalletType +from quart_sqlalchemy.sim.repo_adapter import RepositoryLegacyAdapter +from quart_sqlalchemy.sim.util import ObjectID +from quart_sqlalchemy.sim.util import one logger = logging.getLogger(__name__) diff --git a/src/quart_sqlalchemy/sim/main.py b/src/quart_sqlalchemy/sim/main.py index 7bd73cd..2370b4f 100644 --- a/src/quart_sqlalchemy/sim/main.py +++ b/src/quart_sqlalchemy/sim/main.py @@ -1,9 +1,9 @@ -from .app import app -from .views import api +from quart_sqlalchemy.sim.app import app +from quart_sqlalchemy.sim.views import api -app.register_blueprint(api) +app.register_blueprint(api, url_prefix="/v1") if __name__ == "__main__": - app.run(port=8080) + app.run(port=8081) diff --git a/src/quart_sqlalchemy/sim/model.py b/src/quart_sqlalchemy/sim/model.py index efee04c..7a6566b 100644 --- a/src/quart_sqlalchemy/sim/model.py +++ b/src/quart_sqlalchemy/sim/model.py @@ -81,8 +81,8 @@ class MagicClient(db.Model, SoftDeleteMixin, TimestampMixin): is_signing_modal_enabled: Mapped[bool] = sa.orm.mapped_column(default=False) global_audience_enabled: Mapped[bool] = sa.orm.mapped_column(default=False) - public_api_key: Mapped[str] = sa.orm.mapped_column(default_factory=secrets.token_hex) - secret_api_key: Mapped[str] = sa.orm.mapped_column(default_factory=secrets.token_hex) + public_api_key: Mapped[str] = sa.orm.mapped_column(default=secrets.token_hex) + secret_api_key: Mapped[str] = sa.orm.mapped_column(default=secrets.token_hex) auth_users: Mapped[t.List["AuthUser"]] = sa.orm.relationship( back_populates="magic_client", @@ -144,7 +144,7 @@ def is_magic_connect_user(self): class AuthWallet(db.Model, SoftDeleteMixin, TimestampMixin): - __tablename__ = "auth_user" + __tablename__ = "auth_wallet" id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) auth_user_id: Mapped[ObjectID] = sa.orm.mapped_column(sa.ForeignKey("auth_user.id")) diff --git a/src/quart_sqlalchemy/sim/repo.py b/src/quart_sqlalchemy/sim/repo.py index d257b5f..9e41c70 100644 --- a/src/quart_sqlalchemy/sim/repo.py +++ b/src/quart_sqlalchemy/sim/repo.py @@ -10,6 +10,7 @@ import sqlalchemy.orm import sqlalchemy.sql +from quart_sqlalchemy.sim.builder import StatementBuilder from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT @@ -17,8 +18,6 @@ from quart_sqlalchemy.types import Selectable from quart_sqlalchemy.types import SessionT -from .builder import StatementBuilder - sa = sqlalchemy @@ -56,7 +55,6 @@ def identity(self) -> t.Type[EntityIdT]: return self.__orig_class__.__args__[1] # type: ignore - class SQLAlchemyRepository( AbstractRepository[EntityT, EntityIdT], t.Generic[EntityT, EntityIdT], diff --git a/src/quart_sqlalchemy/sim/repo_adapter.py b/src/quart_sqlalchemy/sim/repo_adapter.py index bbd96b0..89230b0 100644 --- a/src/quart_sqlalchemy/sim/repo_adapter.py +++ b/src/quart_sqlalchemy/sim/repo_adapter.py @@ -12,7 +12,7 @@ from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable -from .repo import SQLAlchemyRepository +from quart_sqlalchemy.sim.repo import SQLAlchemyRepository class BaseModelSchema(BaseModel): diff --git a/src/quart_sqlalchemy/sim/schema.py b/src/quart_sqlalchemy/sim/schema.py index f47225a..f2311ca 100644 --- a/src/quart_sqlalchemy/sim/schema.py +++ b/src/quart_sqlalchemy/sim/schema.py @@ -2,7 +2,7 @@ from pydantic import BaseModel -from .util import ObjectID +from quart_sqlalchemy.sim.util import ObjectID class BaseSchema(BaseModel): diff --git a/src/quart_sqlalchemy/sqla.py b/src/quart_sqlalchemy/sqla.py index 1df492a..56b8b60 100644 --- a/src/quart_sqlalchemy/sqla.py +++ b/src/quart_sqlalchemy/sqla.py @@ -38,6 +38,14 @@ def initialize(self): class Model(self.config.model_class, sa.orm.DeclarativeBase): pass + type_annotation_map = {} + for base_class in Model.__mro__[::-1]: + if base_class is Model: + continue + base_map = getattr(base_class, "type_annotation_map", {}).copy() + type_annotation_map.update(base_map) + + Model.registry.type_annotation_map.update(type_annotation_map) self.Model = Model self.binds = {} From b68fe9d0d3cd2532fceec04eec22eb2d3661b38b Mon Sep 17 00:00:00 2001 From: Joe Black Date: Fri, 31 Mar 2023 19:03:38 -0400 Subject: [PATCH 04/11] update sim --- src/quart_sqlalchemy/retry.py | 1 + src/quart_sqlalchemy/sim/app.py | 167 +++------ src/quart_sqlalchemy/sim/auth.py | 323 ++++++++++++++++++ src/quart_sqlalchemy/sim/db.py | 100 ++++++ src/quart_sqlalchemy/sim/handle.py | 182 +++++----- src/quart_sqlalchemy/sim/logic.py | 136 ++++---- src/quart_sqlalchemy/sim/main.py | 5 +- src/quart_sqlalchemy/sim/model.py | 12 +- src/quart_sqlalchemy/sim/repo.py | 78 +++-- src/quart_sqlalchemy/sim/repo_adapter.py | 82 +++-- src/quart_sqlalchemy/sim/schema.py | 31 +- src/quart_sqlalchemy/sim/testing.py | 13 + src/quart_sqlalchemy/sim/views/__init__.py | 4 +- src/quart_sqlalchemy/sim/views/auth_wallet.py | 53 ++- src/quart_sqlalchemy/sim/views/decorator.py | 131 ------- .../sim/views/util/__init__.py | 12 + .../sim/views/util/blueprint.py | 102 ++++++ .../sim/views/util/decorator.py | 120 +++++++ 18 files changed, 1066 insertions(+), 486 deletions(-) create mode 100644 src/quart_sqlalchemy/sim/auth.py create mode 100644 src/quart_sqlalchemy/sim/db.py create mode 100644 src/quart_sqlalchemy/sim/testing.py delete mode 100644 src/quart_sqlalchemy/sim/views/decorator.py create mode 100644 src/quart_sqlalchemy/sim/views/util/__init__.py create mode 100644 src/quart_sqlalchemy/sim/views/util/blueprint.py create mode 100644 src/quart_sqlalchemy/sim/views/util/decorator.py diff --git a/src/quart_sqlalchemy/retry.py b/src/quart_sqlalchemy/retry.py index 664dc66..8353dc3 100644 --- a/src/quart_sqlalchemy/retry.py +++ b/src/quart_sqlalchemy/retry.py @@ -99,6 +99,7 @@ async def add_user_post(db, user_id, post_values): import sqlalchemy.exc import sqlalchemy.orm import tenacity +from tenacity import RetryError sa = sqlalchemy diff --git a/src/quart_sqlalchemy/sim/app.py b/src/quart_sqlalchemy/sim/app.py index e462fa0..9f12d22 100644 --- a/src/quart_sqlalchemy/sim/app.py +++ b/src/quart_sqlalchemy/sim/app.py @@ -1,145 +1,88 @@ -import json import logging -import re import typing as t +from copy import deepcopy +from functools import wraps -import sqlalchemy as sa -from pydantic import BaseModel from quart import g from quart import Quart from quart import request -from quart import Request from quart import Response +from quart.typing import ResponseReturnValue +from quart_schema import APIKeySecurityScheme +from quart_schema import HttpSecurityScheme from quart_schema import QuartSchema +from werkzeug.utils import import_string -from quart_sqlalchemy import Base -from quart_sqlalchemy import SQLAlchemyConfig -from quart_sqlalchemy.framework import QuartSQLAlchemy -from quart_sqlalchemy.sim.util import ObjectID - -AUTHORIZATION_PATTERN = re.compile(r"Bearer (?P.+)") logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -class MyBase(Base): - type_annotation_map = {ObjectID: sa.Integer} - - -app = Quart(__name__) -db = QuartSQLAlchemy( - SQLAlchemyConfig.parse_obj( - { - "model_class": MyBase, - "binds": { - "default": { - "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, - "session": {"expire_on_commit": False}, - }, - "read-replica": { - "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, - "session": {"expire_on_commit": False}, - "read_only": True, - }, - "async": { - "engine": { - "url": "sqlite+aiosqlite:///file:mem.db?mode=memory&cache=shared&uri=true" - }, - "session": {"expire_on_commit": False}, - }, - }, - } - ) +BLUEPRINTS = ("quart_sqlalchemy.sim.views.api",) +EXTENSIONS = ( + "quart_sqlalchemy.sim.db.db", + "quart_sqlalchemy.sim.app.schema", + "quart_sqlalchemy.sim.auth.auth", ) -openapi = QuartSchema(app) - - -class RequestAuth(BaseModel): - client: t.Optional[t.Any] = None - user: t.Optional[t.Any] = None - - @property - def has_client(self): - return self.client is not None - - @property - def has_user(self): - return self.user is not None - - @property - def is_anonymous(self): - return all([self.has_client is False, self.has_user is False]) - -def get_request_client(request: Request): - api_key = request.headers.get("X-Public-API-Key") - if not api_key: - return +DEFAULT_CONFIG = { + "QUART_AUTH_SECURITY_SCHEMES": { + "public-api-key": APIKeySecurityScheme(in_="header", name="X-Public-API-Key"), + "session-token-bearer": HttpSecurityScheme(scheme="bearer", bearer_format="opaque"), + }, + "REGISTER_BLUEPRINTS": ["quart_sqlalchemy.sim.views.api"], +} - with g.bind.Session() as session: - try: - magic_client = g.h.MagicClient(session).get_by_public_api_key(api_key) - except ValueError: - return - else: - return magic_client +schema = QuartSchema(security_schemes=DEFAULT_CONFIG["QUART_AUTH_SECURITY_SCHEMES"]) -def get_request_user(request: Request): - auth_header = request.headers.get("Authorization") - if not auth_header: - return - m = AUTHORIZATION_PATTERN.match(auth_header) - if m is None: - raise RuntimeError("invalid authorization header") +def wrap_response(func: t.Callable) -> t.Callable: + @wraps(func) + async def decorator(result: ResponseReturnValue) -> Response: + # import pdb - auth_token = m.group("auth_token") + # pdb.set_trace() + return await func(result) - with g.bind.Session() as session: - try: - auth_user = g.h.AuthUser(session).get_by_session_token(auth_token) - except ValueError: - return - else: - return auth_user + return decorator -@app.before_request -def set_ethereum_network(): - g.request_network = request.headers.get("X-Fortmatic-Network", "GOERLI").upper() +def create_app( + override_config: t.Optional[t.Dict[str, t.Any]] = None, + extensions: t.Sequence[str] = EXTENSIONS, + blueprints: t.Sequence[str] = BLUEPRINTS, +): + override_config = override_config or {} + config = deepcopy(DEFAULT_CONFIG) + config.update(override_config) -@app.before_request -def set_bind_handlers_for_request(): - from quart_sqlalchemy.sim.handle import Handlers + app = Quart(__name__) + app.config.from_mapping(config) - g.db = db + for path in extensions: + extension = import_string(path) + extension.init_app(app) - method = request.method - if method in ["GET", "OPTIONS", "TRACE", "HEAD"]: - bind = "read-replica" - else: - bind = "default" + for path in blueprints: + bp = import_string(path) + app.register_blueprint(bp) - g.bind = db.get_bind(bind) - g.h = Handlers(g.bind) + @app.before_request + def set_ethereum_network(): + g.network = request.headers.get("X-Ethereum-Network", "GOERLI").upper() + # app.make_response = wrap_response(app.make_response) -@app.before_request -def set_request_auth(): - g.auth = RequestAuth( - client=get_request_client(request), - user=get_request_user(request), - ) + return app -@app.after_request -async def add_json_response_envelope(response: Response) -> Response: - if response.mimetype != "application/json": - return response - data = await response.get_json() - payload = dict(status="ok", message="", data=data) - response.set_data(json.dumps(payload)) - return response +# @app.after_request +# async def add_json_response_envelope(response: Response) -> Response: +# if response.mimetype != "application/json": +# return response +# data = await response.get_json() +# payload = dict(status="ok", message="", data=data) +# response.set_data(json.dumps(payload)) +# return response diff --git a/src/quart_sqlalchemy/sim/auth.py b/src/quart_sqlalchemy/sim/auth.py new file mode 100644 index 0000000..51c3962 --- /dev/null +++ b/src/quart_sqlalchemy/sim/auth.py @@ -0,0 +1,323 @@ +import logging +import re +import secrets +import typing as t + +import click +import sqlalchemy +import sqlalchemy.orm +import sqlalchemy.orm.exc +from quart import current_app +from quart import g +from quart import Quart +from quart import request +from quart import Request +from quart.cli import AppGroup +from quart.cli import pass_script_info +from quart.cli import ScriptInfo +from quart_schema.extension import QUART_SCHEMA_SECURITY_ATTRIBUTE +from quart_schema.extension import security_scheme +from quart_schema.openapi import APIKeySecurityScheme +from quart_schema.openapi import HttpSecurityScheme +from quart_schema.openapi import SecuritySchemeBase +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from .model import AuthUser +from .model import MagicClient +from .schema import BaseSchema +from .util import ObjectID + + +sa = sqlalchemy + +cli = AppGroup("auth") + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def authorized_request(security_schemes: t.Sequence[t.Dict[str, t.List[t.Any]]]): + def decorator(func): + return security_scheme(security_schemes)(func) + + return decorator + + +class MyRequest(Request): + @property + def ip_addr(self): + return self.remote_addr + + @property + def locale(self): + return self.accept_languages.best_match(["en"]) or "en" + + @property + def redirect_url(self): + return self.args.get("redirect_url") or self.headers.get("x-redirect-url") + + +class ValidatorError(RuntimeError): + pass + + +class SubjectNotFound(ValidatorError): + pass + + +class CredentialNotFound(ValidatorError): + pass + + +class Credential(BaseSchema): + scheme: SecuritySchemeBase + value: t.Optional[str] = None + subject: t.Union[MagicClient, AuthUser] + + +class AuthenticationValidator: + name: str + scheme: SecuritySchemeBase + + def extract(self, request: Request) -> str: + ... + + def lookup(self, value: str, session: Session) -> t.Any: + ... + + def authenticate(self, request: Request) -> Credential: + ... + + +class PublicAPIKeyValidator(AuthenticationValidator): + name = "public-api-key" + scheme = APIKeySecurityScheme(in_="header", name="X-Public-API-Key") + + def extract(self, request: Request) -> str: + if self.scheme.in_ == "header": + return request.headers.get(self.scheme.name, None) + elif self.scheme.in_ == "cookie": + return request.cookies.get(self.scheme.name, None) + elif self.scheme.in_ == "query": + return request.args.get(self.scheme.name, None) + else: + raise ValueError(f"No token found for {self.scheme}") + + def lookup(self, value: str, session: Session) -> t.Any: + statement = sa.select(MagicClient).where(MagicClient.public_api_key == value).limit(1) + + try: + result = session.scalars(statement).one() + except sa.orm.exc.NoResultFound: + raise SubjectNotFound(f"No MagicClient found for public_api_key {value}") + + return result + + def authenticate(self, request: Request, session: Session) -> Credential: + value = self.extract(request) + if value is None: + raise CredentialNotFound() + subject = self.lookup(value, session) + return Credential(scheme=self.scheme, value=value, subject=subject) + + +class SessionTokenValidator(AuthenticationValidator): + name = "session-token-bearer" + scheme = HttpSecurityScheme(scheme="bearer", bearer_format="opaque") + + AUTHORIZATION_PATTERN = re.compile(r"Bearer (?P.+)") + + def extract(self, request: Request) -> str: + if self.scheme.scheme != "bearer": + return + + value = request.headers.get("authorization") + m = self.AUTHORIZATION_PATTERN.match(value) + if m is None: + raise ValueError("Bearer token failed validation") + + return m.group("token") + + def lookup(self, value: str, session: Session) -> t.Any: + statement = sa.select(AuthUser).where(AuthUser.current_session_token == value).limit(1) + + try: + result = session.scalars(statement).one() + except sa.orm.exc.NoResultFound: + raise SubjectNotFound(f"No AuthUser found for session_token {value}") + + return result + + def authenticate(self, request: Request, session: Session) -> Credential: + value = self.extract(request) + if value is None: + raise CredentialNotFound() + subject = self.lookup(value, session) + return Credential(scheme=self.scheme, value=value, subject=subject) + + +class RequestAuthenticator: + validators = [PublicAPIKeyValidator(), SessionTokenValidator()] + validator_scheme_map = {v.name: v for v in validators} + + def enforce(self, security_schemes: t.Sequence[t.Dict[str, t.List[t.Any]]], session: Session): + passed, failed = [], [] + for scheme_credential in self.validate_security(security_schemes, session): + if all(scheme_credential.values()): + passed.append(scheme_credential) + else: + failed.append(scheme_credential) + if passed: + return passed + raise Forbidden() + + def validate_security( + self, security_schemes: t.Sequence[t.Dict[str, t.List[t.Any]]], session: Session + ): + if not security_schemes: + return + + for scheme in security_schemes: + scheme_credentials = {} + for name, _ in scheme.items(): + validator = self.validator_scheme_map[name] + credential = None + try: + credential = validator.authenticate(request, session) + except ValidatorError: + pass + except: + logger.exception(f"Unknown error while validating {name}") + raise + finally: + scheme_credentials[name] = credential + yield scheme_credentials + + +# def convert_model_result(func: t.Callable) -> t.Callable: +# @wraps(func) +# async def decorator(result: ResponseReturnValue) -> Response: +# status_or_headers = None +# headers = None +# if isinstance(result, tuple): +# value, status_or_headers, headers = result + (None,) * (3 - len(result)) +# else: +# value = result + +# was_model = False +# if is_dataclass(value): +# dict_or_value = asdict(value) +# was_model = True +# elif isinstance(value, BaseModel): +# dict_or_value = value.dict(by_alias=True) +# was_model = True +# else: +# dict_or_value = value + +# if was_model: +# dict_or_value = camelize(dict_or_value) + +# return await func((dict_or_value, status_or_headers, headers)) + +# return decorator + + +class QuartAuth: + authenticator = RequestAuthenticator() + + def __init__(self, app: t.Optional[Quart] = None): + if app is not None: + self.init_app(app) + + def init_app(self, app: Quart): + app.before_request(self.auth_endpoint_security) + + app.request_class = MyRequest + + self.security_schemes = app.config.get("QUART_AUTH_SECURITY_SCHEMES", {}) + app.cli.add_command(cli) + + def auth_endpoint_security(self): + db = current_app.extensions.get("sqlalchemy") + view_function = current_app.view_functions[request.endpoint] + security_schemes = getattr(view_function, QUART_SCHEMA_SECURITY_ATTRIBUTE, None) + if security_schemes is None: + g.authorized_credentials = {} + + with db.bind.Session() as session: + results = self.authenticator.enforce(security_schemes, session) + authorized_credentials = {} + for result in results: + authorized_credentials.update(result) + g.authorized_credentials = authorized_credentials + + +from .model import EntityType +from .model import MagicClient +from .model import Provenance + + +@cli.command("add-user") +@click.option( + "--email", + type=str, + default="default@none.com", + help="email", +) +@click.option( + "--user-type", + # type=click.Choice(list(EntityType.__members__)), + type=click.Choice(["FORTMATIC", "MAGIC", "CONNECT"]), + default="MAGIC", + help="user type", +) +@click.option( + "--client-id", + type=int, + required=True, + help="client id", +) +@pass_script_info +def add_user(info: ScriptInfo, email: str, user_type: str, client_id: int) -> None: + app = info.load_app() + db = app.extensions.get("sqlalchemy") + + with db.bind.Session() as s: + with s.begin(): + user = AuthUser( + email=email, + user_type=EntityType[user_type].value, + client_id=ObjectID(client_id), + provenance=Provenance.LINK, + current_session_token=secrets.token_hex(16), + ) + s.add(user) + s.flush() + s.refresh(user) + + click.echo(f"Created user {user.id} with session_token: {user.current_session_token}") + + +@cli.command("add-client") +@click.option( + "--name", + type=str, + default="My App", + help="app name", +) +@pass_script_info +def add_client(info: ScriptInfo, name: str) -> None: + app = info.load_app() + db = app.extensions.get("sqlalchemy") + with db.bind.Session() as s: + with s.begin(): + client = MagicClient(app_name=name, public_api_key=secrets.token_hex(16)) + s.add(client) + s.flush() + s.refresh(client) + + click.echo(f"Created client {client.id} with public_api_key: {client.public_api_key}") + + +auth = QuartAuth() diff --git a/src/quart_sqlalchemy/sim/db.py b/src/quart_sqlalchemy/sim/db.py new file mode 100644 index 0000000..e7677f6 --- /dev/null +++ b/src/quart_sqlalchemy/sim/db.py @@ -0,0 +1,100 @@ +import click +import sqlalchemy as sa +from quart import g +from quart import request +from quart.cli import AppGroup +from quart.cli import pass_script_info +from quart.cli import ScriptInfo +from sqlalchemy.types import Integer +from sqlalchemy.types import TypeDecorator + +from quart_sqlalchemy import Base +from quart_sqlalchemy import SQLAlchemyConfig +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.sim.util import ObjectID + + +cli = AppGroup("db-schema") + + +class ObjectIDType(TypeDecorator): + """A custom database column type that converts integer value to our ObjectID. + This allows us to pass around ObjectID type in the application for easy + frontend encoding and database decoding on the integer value. + + Note: all id db column type should use this type for its column. + """ + + impl = Integer + cache_ok = False + + def process_bind_param(self, value, dialect): + """Data going into to the database will be transformed by this method. + See ``ObjectID`` for the design and rational for this. + """ + if value is None: + return None + + return ObjectID(value).decode() + + def process_result_value(self, value, dialect): + """Data going out from the database will be explicitly casted to the + ``ObjectID``. + """ + if value is None: + return None + + return ObjectID(value) + + +class MyBase(Base): + type_annotation_map = {ObjectID: ObjectIDType} + + +class MyQuartSQLAlchemy(QuartSQLAlchemy): + def init_app(self, app): + super().init_app(app) + + @app.before_request + def set_bind(): + if request.method in ["GET", "OPTIONS", "HEAD", "TRACE"]: + g.bind = self.get_bind("read-replica") + else: + g.bind = self.get_bind("default") + + app.cli.add_command(cli) + + +@cli.command("load") +@pass_script_info +def schema_load(info: ScriptInfo) -> None: + app = info.load_app() + db = app.extensions.get("sqlalchemy") + db.create_all() + + click.echo(f"Initialized database schema for {db}") + + +# sqlite:///file:mem.db?mode=memory&cache=shared&uri=true +db = MyQuartSQLAlchemy( + SQLAlchemyConfig.parse_obj( + { + "model_class": MyBase, + "binds": { + "default": { + "engine": {"url": "sqlite:///file:sim.db?cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + "read-replica": { + "engine": {"url": "sqlite:///file:sim.db?cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + "read_only": True, + }, + "async": { + "engine": {"url": "sqlite+aiosqlite:///file:sim.db?cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + }, + } + ) +) diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py index c81b71a..e079da4 100644 --- a/src/quart_sqlalchemy/sim/handle.py +++ b/src/quart_sqlalchemy/sim/handle.py @@ -4,8 +4,6 @@ from sqlalchemy.orm import Session -from quart_sqlalchemy import Bind -from quart_sqlalchemy.sim import signals from quart_sqlalchemy.sim.logic import LogicComponent as Logic from quart_sqlalchemy.sim.model import AuthUser from quart_sqlalchemy.sim.model import AuthWallet @@ -69,6 +67,7 @@ def add( A ``MagicClient``. """ magic_clients_count = self.logic.MagicClientAPIUser.count_by_magic_api_user_id( + self.session, magic_api_user_id, ) @@ -83,7 +82,7 @@ def add( ) def get_by_public_api_key(self, public_api_key): - return self.logic.MagicClientAPIKey.get_by_public_api_key(public_api_key) + return self.logic.MagicClientAPIKey.get_by_public_api_key(self.session, public_api_key) def add_client( self, @@ -94,40 +93,43 @@ def add_client( ): live_api_key = APIKeySet(public_key="xxx", secret_key="yyy") - with self.logic.begin(ro=False) as session: - magic_client = self.logic.MagicClient._add( - session, - app_name=app_name, - ) - # self.logic.MagicClientAPIKey._add( - # session, - # magic_client.id, - # live_api_key_pair=live_api_key, - # ) - # self.logic.MagicClientAPIUser._add( - # session, - # magic_api_user_id, - # magic_client.id, - # ) - - # self.logic.MagicClientAuthMethods._add( - # session, - # magic_client_id=magic_client.id, - # is_magic_connect_enabled=is_magic_connect_enabled, - # is_metamask_wallet_enabled=(True if is_magic_connect_enabled else False), - # is_wallet_connect_enabled=(True if is_magic_connect_enabled else False), - # is_coinbase_wallet_enabled=(True if is_magic_connect_enabled else False), - # ) - - # self.logic.MagicClientTeam._add(session, magic_client.id, magic_team_id) - - return magic_client, live_api_key + # with self.logic.begin(ro=False) as session: + return self.logic.MagicClient._add( + self.session, + app_name=app_name, + ) + + # self.logic.MagicClientAPIKey._add( + # session, + # magic_client.id, + # live_api_key_pair=live_api_key, + # ) + # self.logic.MagicClientAPIUser._add( + # session, + # magic_api_user_id, + # magic_client.id, + # ) + + # self.logic.MagicClientAuthMethods._add( + # session, + # magic_client_id=magic_client.id, + # is_magic_connect_enabled=is_magic_connect_enabled, + # is_metamask_wallet_enabled=(True if is_magic_connect_enabled else False), + # is_wallet_connect_enabled=(True if is_magic_connect_enabled else False), + # is_coinbase_wallet_enabled=(True if is_magic_connect_enabled else False), + # ) + + # self.logic.MagicClientTeam._add(session, magic_client.id, magic_team_id) + + # return magic_client, live_api_key def get_magic_api_user_id_by_client_id(self, magic_client_id): - return self.logic.MagicClient.get_magic_api_user_id_by_client_id(magic_client_id) + return self.logic.MagicClient.get_magic_api_user_id_by_client_id( + self.session, magic_client_id + ) def get_by_id(self, magic_client_id): - return self.logic.MagicClient.get_by_id(magic_client_id) + return self.logic.MagicClient.get_by_id(self.session, magic_client_id) def update_app_name_by_id(self, magic_client_id, app_name): """ @@ -139,7 +141,9 @@ def update_app_name_by_id(self, magic_client_id, app_name): None if `magic_client_id` doesn't exist in the db app_name if update was successful """ - client = self.logic.MagicClient.update_by_id(magic_client_id, app_name=app_name) + client = self.logic.MagicClient.update_by_id( + self.session, magic_client_id, app_name=app_name + ) if not client: return None @@ -147,7 +151,7 @@ def update_app_name_by_id(self, magic_client_id, app_name): return client.app_name def update_by_id(self, magic_client_id, **kwargs): - client = self.logic.MagicClient.update_by_id(magic_client_id, **kwargs) + client = self.logic.MagicClient.update_by_id(self.session, magic_client_id, **kwargs) return client @@ -159,7 +163,7 @@ def set_inactive_by_id(self, magic_client_id): Returns: None """ - self.logic.MagicClient.update_by_id(magic_client_id, is_active=False) + self.logic.MagicClient.update_by_id(self.session, magic_client_id, is_active=False) def get_users_for_client( self, @@ -171,7 +175,7 @@ def get_users_for_client( """ Returns emails and signup timestamps for all auth users belonging to a given client """ - auth_user_handler = AuthUserHandler() + auth_user_handler = AuthUserHandler(session=self.session) product_type = get_product_type_by_client_id(magic_client_id) auth_users = auth_user_handler.get_by_client_id_and_user_type( magic_client_id, @@ -214,7 +218,7 @@ def get_users_for_client_v2( Returns emails, signup timestamps, provenance and MFA enablement for all auth users belonging to a given client. """ - auth_user_handler = AuthUserHandler() + auth_user_handler = AuthUserHandler(session=self.session) product_type = get_product_type_by_client_id(magic_client_id) auth_users = auth_user_handler.get_by_client_id_and_user_type( magic_client_id, @@ -260,7 +264,7 @@ def __init__(self, *args, auth_user_mfa_handler=None, **kwargs): # self.auth_user_mfa_handler = auth_user_mfa_handler or AuthUserMfaHandler() def get_by_session_token(self, session_token): - return self.logic.AuthUser.get_by_session_token(session_token) + return self.logic.AuthUser.get_by_session_token(self.session, session_token) def get_or_create_by_email_and_client_id( self, @@ -269,6 +273,7 @@ def get_or_create_by_email_and_client_id( user_type=EntityType.MAGIC.value, ): auth_user = self.logic.AuthUser.get_by_email_and_client_id( + self.session, email, client_id, user_type=user_type, @@ -292,6 +297,7 @@ def get_or_create_by_email_and_client_id( # raise EnhancedEmailValidation(error_message=str(e)) from e auth_user = self.logic.AuthUser.add_by_email_and_client_id( + self.session, client_id, email=email, user_type=user_type, @@ -300,7 +306,7 @@ def get_or_create_by_email_and_client_id( def get_by_id_and_validate_exists(self, auth_user_id): """This function helps formalize how a non-existent auth user should be handled.""" - auth_user = self.logic.AuthUser.get_by_id(auth_user_id) + auth_user = self.logic.AuthUser.get_by_id(self.session, auth_user_id) if auth_user is None: raise RuntimeError('resource_name="auth_user"') return auth_user @@ -314,18 +320,18 @@ def create_verified_user( email, user_type=EntityType.FORTMATIC.value, ): - with self.logic.begin(ro=False) as session: - auid = self.logic.AuthUser._add_by_email_and_client_id( - session, - client_id, - email, - user_type=user_type, - ).id - auth_user = self.logic.AuthUser._update_by_id( - session, - auid, - date_verified=datetime.utcnow(), - ) + # with self.logic.begin(ro=False) as session: + auid = self.logic.AuthUser._add_by_email_and_client_id( + self.session, + client_id, + email, + user_type=user_type, + ).id + auth_user = self.logic.AuthUser._update_by_id( + self.session, + auid, + date_verified=datetime.utcnow(), + ) return auth_user @@ -339,7 +345,7 @@ def create_verified_user( def get_by_id(self, auth_user_id, load_mfa_methods=False) -> AuthUser: # join_list = ["mfa_methods"] if load_mfa_methods else None - return self.logic.AuthUser.get_by_id(auth_user_id) + return self.logic.AuthUser.get_by_id(self.session, auth_user_id) def get_by_client_id_and_user_type( self, @@ -350,12 +356,14 @@ def get_by_client_id_and_user_type( ): if user_type == EntityType.CONNECT.value: return self.logic.AuthUser.get_by_client_id_for_connect( + self.session, client_id, offset=offset, limit=limit, ) else: return self.logic.AuthUser.get_by_client_id_and_user_type( + self.session, client_id, user_type, offset=offset, @@ -370,6 +378,7 @@ def get_by_client_ids_and_user_type( limit=None, ): return self.logic.AuthUser.get_by_client_ids_and_user_type( + self.session, client_ids, user_type, offset=offset, @@ -379,29 +388,33 @@ def get_by_client_ids_and_user_type( def get_user_count_by_client_id_and_user_type(self, client_id, user_type): if user_type == EntityType.CONNECT.value: return self.logic.AuthUser.get_user_count_by_client_id_for_connect( + self.session, client_id, ) else: return self.logic.AuthUser.get_user_count_by_client_id_and_user_type( + self.session, client_id, user_type, ) def exist_by_email_client_id_and_user_type(self, email, client_id, user_type): return self.logic.AuthUser.exist_by_email_and_client_id( + self.session, email, client_id, user_type=user_type, ) def update_email_by_id(self, model_id, email): - return self.logic.AuthUser.update_by_id(model_id, email=email) + return self.logic.AuthUser.update_by_id(self.session, model_id, email=email) def update_phone_number_by_id(self, model_id, phone_number): - return self.logic.AuthUser.update_by_id(model_id, phone_number=phone_number) + return self.logic.AuthUser.update_by_id(self.session, model_id, phone_number=phone_number) def get_by_email_client_id_and_user_type(self, email, client_id, user_type): return self.logic.AuthUser.get_by_email_and_client_id( + self.session, email, client_id, user_type, @@ -409,12 +422,14 @@ def get_by_email_client_id_and_user_type(self, email, client_id, user_type): def mark_date_verified_by_id(self, model_id): return self.logic.AuthUser.update_by_id( + self.session, model_id, date_verified=datetime.utcnow(), ) def set_role_by_email_magic_client_id(self, email, magic_client_id, role): auth_user = self.logic.AuthUser.get_by_email_and_client_id( + self.session, email, magic_client_id, EntityType.MAGIC.value, @@ -422,12 +437,13 @@ def set_role_by_email_magic_client_id(self, email, magic_client_id, role): if not auth_user: auth_user = self.logic.AuthUser.add_by_email_and_client_id( + self.session, magic_client_id, email, user_type=EntityType.MAGIC.value, ) - return self.logic.AuthUser.update_by_id(auth_user.id, **{role: True}) + return self.logic.AuthUser.update_by_id(self.session, auth_user.id, **{role: True}) def search_by_client_id_and_substring( self, @@ -443,6 +459,7 @@ def search_by_client_id_and_substring( raise InvalidSubstringError() auth_users = self.logic.AuthUser.get_by_client_id_with_substring_search( + self.session, client_id, substring, offset=offset, @@ -469,7 +486,7 @@ def is_magic_connect_enabled(self, auth_user_id=None, auth_user=None): return auth_user.user_type == EntityType.CONNECT.value def mark_as_inactive(self, auth_user_id): - self.logic.AuthUser.update_by_id(auth_user_id, is_active=False) + self.logic.AuthUser.update_by_id(self.session, auth_user_id, is_active=False) def get_by_email_and_wallet_type_for_interop(self, email, wallet_type, network): """ @@ -488,22 +505,22 @@ def get_magic_connect_auth_user(self, auth_user_id): return auth_user -@signals.auth_user_duplicate.connect -def handle_duplicate_auth_users( - current_app, - original_auth_user_id, - duplicate_auth_user_ids, - auth_user_handler: t.Optional[AuthUserHandler] = None, -) -> None: - logger.info(f"{len(duplicate_auth_user_ids)} dupe(s) found for {original_auth_user_id}") +# @signals.auth_user_duplicate.connect +# def handle_duplicate_auth_users( +# current_app, +# original_auth_user_id, +# duplicate_auth_user_ids, +# auth_user_handler: t.Optional[AuthUserHandler] = None, +# ) -> None: +# logger.info(f"{len(duplicate_auth_user_ids)} dupe(s) found for {original_auth_user_id}") - auth_user_handler = auth_user_handler or AuthUserHandler() +# auth_user_handler = auth_user_handler or AuthUserHandler() - for dupe_id in duplicate_auth_user_ids: - logger.info( - f"marking auth_user_id {dupe_id} as inactive, in favor of original {original_auth_user_id}", - ) - auth_user_handler.mark_as_inactive(dupe_id) +# for dupe_id in duplicate_auth_user_ids: +# logger.info( +# f"marking auth_user_id {dupe_id} as inactive, in favor of original {original_auth_user_id}", +# ) +# auth_user_handler.mark_as_inactive(dupe_id) class AuthWalletHandler(HandlerBase): @@ -515,10 +532,10 @@ def __init__(self, network, *args, wallet_type=WalletType.ETH, **kwargs): self.wallet_type = wallet_type def get_by_id(self, model_id): - return self.logic.AuthWallet.get_by_id(model_id) + return self.logic.AuthWallet.get_by_id(self.session, model_id) def get_by_public_address(self, public_address): - return self.logic.AuthWallet.get_by_public_address(public_address) + return self.logic.AuthWallet.get_by_public_address(self.session, public_address) def get_by_auth_user_id( self, @@ -528,6 +545,7 @@ def get_by_auth_user_id( **kwargs, ) -> t.List[AuthWallet]: auth_user = self.logic.AuthUser.get_by_id( + self.session, auth_user_id, join_list=["linked_primary_auth_user"], ) @@ -543,6 +561,7 @@ def get_by_auth_user_id( auth_user = auth_user.linked_primary_auth_user return self.logic.AuthWallet.get_by_auth_user_id( + self.session, auth_user.id, network=network, wallet_type=wallet_type, @@ -557,12 +576,14 @@ def sync_auth_wallet( wallet_management_type, ): existing_wallet = self.logic.AuthWallet.get_by_auth_user_id( + self.session, auth_user_id, ) if existing_wallet: raise RuntimeError("WalletExistsForNetworkAndWalletType") return self.logic.AuthWallet.add( + self.session, public_address, encrypted_private_address, self.wallet_type, @@ -570,18 +591,3 @@ def sync_auth_wallet( management_type=wallet_management_type, auth_user_id=auth_user_id, ) - - -class Handlers: - bind: Bind - - def __init__(self, bind: Bind): - self.bind = bind - - def __getattr__(self, name): - handlers = { - cls.__name__.replace("Handler", ""): cls for cls in HandlerBase.__subclasses__() - } - if name in handlers: - return handlers[name] - raise AttributeError() diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py index a13397d..ceb6ffd 100644 --- a/src/quart_sqlalchemy/sim/logic.py +++ b/src/quart_sqlalchemy/sim/logic.py @@ -86,28 +86,11 @@ def __getattr__(self, logic_name): ) -def with_db_session( - ro: t.Optional[bool] = None, - is_stale_tolerant: t.Optional[bool] = None, -): - """Stub decorator to ease transition with legacy code""" - - def wrapper(func): - @wraps(func) - def inner_wrapper(self, *args, **kwargs): - session = None - return func(self, None, *args, **kwargs) - - return inner_wrapper - - return wrapper - - class MagicClient(LogicComponent): - def __init__(self, session: Session): + def __init__(self): # self._repository = SQLAlchemyRepository[magic_client_model, ObjectID](session) - self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID, session) + self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID) def _add(self, session, app_name=None): return self._repository.add( @@ -115,9 +98,10 @@ def _add(self, session, app_name=None): app_name=app_name, ) - add = with_db_session(ro=False)(_add) + # add = with_db_session(ro=False)(_add) + add = _add - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_id( self, session, @@ -132,7 +116,7 @@ def get_by_id( join_list=join_list, ) - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_public_api_key( self, session, @@ -163,17 +147,17 @@ def get_by_public_api_key( # return client.magic_client_api_user.magic_api_user_id - @with_db_session(ro=False) + # @with_db_session(ro=False) def update_by_id(self, session, model_id, **update_params): modified_row = self._repository.update(session, model_id, **update_params) session.refresh(modified_row) return modified_row - @with_db_session(ro=True) + # @with_db_session(ro=True) def yield_all_clients_by_chunk(self, session, chunk_size): yield from self._repository.yield_by_chunk(session, chunk_size) - @with_db_session(ro=True) + # @with_db_session(ro=True) def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): yield from self._repository.yield_by_chunk( session, @@ -200,11 +184,11 @@ class MissingPhoneNumber(Exception): class AuthUser(LogicComponent): - def __init__(self, session: Session): + def __init__(self): # self._repository = SQLRepository(auth_user_model) - self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID, session) + self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID) - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_session_token( self, session, @@ -254,9 +238,10 @@ def _get_or_add_by_phone_number_and_client_id( return row - get_or_add_by_phone_number_and_client_id = with_db_session(ro=False)( - _get_or_add_by_phone_number_and_client_id, - ) + # get_or_add_by_phone_number_and_client_id = with_db_session(ro=False)( + # _get_or_add_by_phone_number_and_client_id, + # ) + get_or_add_by_phone_number_and_client_id = _get_or_add_by_phone_number_and_client_id def _add_by_email_and_client_id( self, @@ -299,7 +284,8 @@ def _add_by_email_and_client_id( return row - add_by_email_and_client_id = with_db_session(ro=False)(_add_by_email_and_client_id) + # add_by_email_and_client_id = with_db_session(ro=False)(_add_by_email_and_client_id) + add_by_email_and_client_id = _add_by_email_and_client_id def _add_by_client_id( self, @@ -327,7 +313,8 @@ def _add_by_client_id( return row - add_by_client_id = with_db_session(ro=False)(_add_by_client_id) + # add_by_client_id = with_db_session(ro=False)(_add_by_client_id) + add_by_client_id = _add_by_client_id def _get_by_active_identifier_and_client_id( self, @@ -364,7 +351,7 @@ def _get_by_active_identifier_and_client_id( return original - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_email_and_client_id( self, session, @@ -398,9 +385,10 @@ def _get_by_phone_number_and_client_id( user_type=user_type, ) - get_by_phone_number_and_client_id = with_db_session(ro=True)( - _get_by_phone_number_and_client_id, - ) + # get_by_phone_number_and_client_id = with_db_session(ro=True)( + # _get_by_phone_number_and_client_id, + # ) + get_by_phone_number_and_client_id = _get_by_phone_number_and_client_id def _exist_by_email_and_client_id( self, @@ -420,7 +408,8 @@ def _exist_by_email_and_client_id( ), ) - exist_by_email_and_client_id = with_db_session(ro=True)(_exist_by_email_and_client_id) + # exist_by_email_and_client_id = with_db_session(ro=True)(_exist_by_email_and_client_id) + exist_by_email_and_client_id = _exist_by_email_and_client_id def _get_by_id(self, session, model_id, join_list=None, for_update=False) -> auth_user_model: return self._repository.get_by_id( @@ -430,7 +419,8 @@ def _get_by_id(self, session, model_id, join_list=None, for_update=False) -> aut for_update=for_update, ) - get_by_id = with_db_session(ro=True)(_get_by_id) + get_by_id = _get_by_id + # get_by_id = with_db_session(ro=True)(_get_by_id) def _update_by_id(self, session, auth_user_id, **kwargs): modified_user = self._repository.update(session, auth_user_id, **kwargs) @@ -440,9 +430,10 @@ def _update_by_id(self, session, auth_user_id, **kwargs): return modified_user - update_by_id = with_db_session(ro=False)(_update_by_id) + # update_by_id = with_db_session(ro=False)(_update_by_id) + update_by_id = _update_by_id - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_user_count_by_client_id_and_user_type(self, session, client_id, user_type): query = ( session.query(auth_user_model) @@ -469,9 +460,10 @@ def _get_by_client_id_and_global_auth_user(self, session, client_id, global_auth ], ) - get_by_client_id_and_global_auth_user = with_db_session(ro=True)( - _get_by_client_id_and_global_auth_user, - ) + # get_by_client_id_and_global_auth_user = with_db_session(ro=True)( + # _get_by_client_id_and_global_auth_user, + # ) + get_by_client_id_and_global_auth_user = _get_by_client_id_and_global_auth_user # @with_db_session(ro=True) # def get_by_client_id_for_connect( @@ -542,7 +534,7 @@ def _get_by_client_id_and_global_auth_user(self, session, client_id, global_auth # return session.execute(query).scalar() - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_client_id_and_user_type( self, session, @@ -583,9 +575,10 @@ def _get_by_client_ids_and_user_type( order_by_clause=auth_user_model.id.desc(), ) - get_by_client_ids_and_user_type = with_db_session(ro=True)( - _get_by_client_ids_and_user_type, - ) + # get_by_client_ids_and_user_type = with_db_session(ro=True)( + # _get_by_client_ids_and_user_type, + # ) + get_by_client_ids_and_user_type = _get_by_client_ids_and_user_type def _get_by_client_id_with_substring_search( self, @@ -617,11 +610,12 @@ def _get_by_client_id_with_substring_search( join_list=join_list, ) - get_by_client_id_with_substring_search = with_db_session(ro=True)( - _get_by_client_id_with_substring_search, - ) + # get_by_client_id_with_substring_search = with_db_session(ro=True)( + # _get_by_client_id_with_substring_search, + # ) + get_by_client_id_with_substring_search = _get_by_client_id_with_substring_search - @with_db_session(ro=True) + # @with_db_session(ro=True) def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): yield from self._repository.yield_by_chunk( session, @@ -630,7 +624,7 @@ def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): join_list=join_list, ) - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_emails_and_client_id( self, session, @@ -663,12 +657,14 @@ def _get_by_email( join_list=join_list, ) - get_by_email = with_db_session(ro=True)(_get_by_email) + # get_by_email = with_db_session(ro=True)(_get_by_email) + get_by_email = _get_by_email def _add(self, session, **kwargs) -> ObjectID: return self._repository.add(session, **kwargs).id - add = with_db_session(ro=False)(_add) + # add = with_db_session(ro=False)(_add) + add = _add def _get_by_email_for_interop( self, @@ -676,7 +672,7 @@ def _get_by_email_for_interop( email: str, wallet_type: WalletType, network: str, - ) -> List[auth_user_model]: + ) -> t.List[auth_user_model]: """ Custom method for searching for users eligible for interop. Unfortunately, this can't be done with the current abstractions in our sql_repository, so this is a one-off bespoke method. @@ -720,9 +716,10 @@ def _get_by_email_for_interop( return query.all() - get_by_email_for_interop = with_db_session(ro=True)( - _get_by_email_for_interop, - ) + # get_by_email_for_interop = with_db_session(ro=True)( + # _get_by_email_for_interop, + # ) + get_by_email_for_interop = _get_by_email_for_interop def _get_linked_users(self, session, primary_auth_user_id, join_list, no_op=False): # TODO(magic-ravi#67899|2022-12-30): Re-enable account linked users for interop. Remove no_op flag. @@ -739,9 +736,10 @@ def _get_linked_users(self, session, primary_auth_user_id, join_list, no_op=Fals join_list=join_list, ) - get_linked_users = with_db_session(ro=True)(_get_linked_users) + # get_linked_users = with_db_session(ro=True)(_get_linked_users) + get_linked_users = _get_linked_users - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_phone_number(self, session, phone_number): return self._repository.get_by( session, @@ -752,9 +750,9 @@ def get_by_phone_number(self, session, phone_number): class AuthWallet(LogicComponent): - def __init__(self, session: Session): + def __init__(self): # self._repository = SQLAlchemyRepository[magic_client_model, ObjectID](session) - self._repository = RepositoryLegacyAdapter(auth_wallet_model, ObjectID, session) + self._repository = RepositoryLegacyAdapter(auth_wallet_model, ObjectID) def _add( self, @@ -778,9 +776,10 @@ def _add( return new_row - add = with_db_session(ro=False)(_add) + # add = with_db_session(ro=False)(_add) + add = _add - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_id(self, session, model_id, allow_inactive=False, join_list=None): return self._repository.get_by_id( session, @@ -789,7 +788,7 @@ def get_by_id(self, session, model_id, allow_inactive=False, join_list=None): join_list=join_list, ) - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_public_address(self, session, public_address, network=None, is_active=True): """Public address is unique in our system. In any case, we should only find one row for the given public address. @@ -819,7 +818,7 @@ def get_by_public_address(self, session, public_address, network=None, is_active return one(row) - @with_db_session(ro=True) + # @with_db_session(ro=True) def get_by_auth_user_id( self, session, @@ -866,4 +865,5 @@ def get_by_auth_user_id( def _update_by_id(self, session, model_id, **kwargs): self._repository.update(session, model_id, **kwargs) - update_by_id = with_db_session(ro=False)(_update_by_id) + # update_by_id = with_db_session(ro=False)(_update_by_id) + update_by_id = _update_by_id diff --git a/src/quart_sqlalchemy/sim/main.py b/src/quart_sqlalchemy/sim/main.py index 2370b4f..23d1e0c 100644 --- a/src/quart_sqlalchemy/sim/main.py +++ b/src/quart_sqlalchemy/sim/main.py @@ -1,8 +1,7 @@ -from quart_sqlalchemy.sim.app import app -from quart_sqlalchemy.sim.views import api +from quart_sqlalchemy.sim.app import create_app -app.register_blueprint(api, url_prefix="/v1") +app = create_app() if __name__ == "__main__": diff --git a/src/quart_sqlalchemy/sim/model.py b/src/quart_sqlalchemy/sim/model.py index 7a6566b..b0199ab 100644 --- a/src/quart_sqlalchemy/sim/model.py +++ b/src/quart_sqlalchemy/sim/model.py @@ -11,7 +11,7 @@ from quart_sqlalchemy.model import SoftDeleteMixin from quart_sqlalchemy.model import TimestampMixin -from quart_sqlalchemy.sim.app import db +from quart_sqlalchemy.sim.db import db from quart_sqlalchemy.sim.util import ObjectID @@ -100,8 +100,10 @@ class AuthUser(db.Model, SoftDeleteMixin, TimestampMixin): date_verified: Mapped[t.Optional[datetime]] provenance: Mapped[t.Optional[Provenance]] is_admin: Mapped[bool] = sa.orm.mapped_column(default=False) - client_id: Mapped[ObjectID] - linked_primary_auth_user_id: Mapped[t.Optional[ObjectID]] + client_id: Mapped[ObjectID] = sa.orm.mapped_column(sa.ForeignKey("magic_client.id")) + linked_primary_auth_user_id: Mapped[t.Optional[ObjectID]] = sa.orm.mapped_column( + sa.ForeignKey("auth_user.id"), default=None + ) global_auth_user_id: Mapped[t.Optional[ObjectID]] delegated_user_id: Mapped[t.Optional[str]] @@ -110,7 +112,7 @@ class AuthUser(db.Model, SoftDeleteMixin, TimestampMixin): current_session_token: Mapped[t.Optional[str]] magic_client: Mapped[MagicClient] = sa.orm.relationship( - back_populates="auth_user", + back_populates="auth_users", uselist=False, ) linked_primary_auth_user = sa.orm.relationship( @@ -158,6 +160,6 @@ class AuthWallet(db.Model, SoftDeleteMixin, TimestampMixin): is_exported: Mapped[bool] = sa.orm.mapped_column(default=False) auth_user: Mapped[AuthUser] = sa.orm.relationship( - back_populates="auth_wallets", + back_populates="wallets", uselist=False, ) diff --git a/src/quart_sqlalchemy/sim/repo.py b/src/quart_sqlalchemy/sim/repo.py index 9e41c70..5844b9d 100644 --- a/src/quart_sqlalchemy/sim/repo.py +++ b/src/quart_sqlalchemy/sim/repo.py @@ -2,7 +2,6 @@ import typing as t from abc import ABCMeta -from abc import abstractmethod import sqlalchemy import sqlalchemy.event @@ -19,6 +18,9 @@ from quart_sqlalchemy.types import SessionT +# from abc import abstractmethod + + sa = sqlalchemy @@ -81,47 +83,54 @@ class SQLAlchemyRepository( """ - session: sa.orm.Session + # session: sa.orm.Session builder: StatementBuilder - def __init__(self, session: sa.orm.Session, **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) - self.session = session + # self.session = session self.builder = StatementBuilder(None) - def insert(self, values: t.Dict[str, t.Any]) -> EntityT: + def insert(self, session: sa.orm.Session, values: t.Dict[str, t.Any]) -> EntityT: """Insert a new model into the database.""" new = self.model(**values) - self.session.add(new) - self.session.flush() - self.session.refresh(new) + session.add(new) + session.flush() + session.refresh(new) return new - def update(self, id_: EntityIdT, values: t.Dict[str, t.Any]) -> EntityT: + def update( + self, session: sa.orm.Session, id_: EntityIdT, values: t.Dict[str, t.Any] + ) -> EntityT: """Update existing model with new values.""" - obj = self.session.get(self.model, id_) + obj = session.get(self.model, id_) if obj is None: raise ValueError(f"Object with id {id_} not found") for field, value in values.items(): if getattr(obj, field) != value: setattr(obj, field, value) - self.session.flush() - self.session.refresh(obj) + session.flush() + session.refresh(obj) return obj def merge( - self, id_: EntityIdT, values: t.Dict[str, t.Any], for_update: bool = False + self, + session: sa.orm.Session, + id_: EntityIdT, + values: t.Dict[str, t.Any], + for_update: bool = False, ) -> EntityT: """Merge model in session/db having id_ with values.""" - self.session.get(self.model, id_) + session.get(self.model, id_) values.update(id=id_) - merged = self.session.merge(self.model(**values)) - self.session.flush() - self.session.refresh(merged, with_for_update=for_update) # type: ignore + merged = session.merge(self.model(**values)) + session.flush() + session.refresh(merged, with_for_update=for_update) # type: ignore return merged def get( self, + session: sa.orm.Session, id_: EntityIdT, options: t.Sequence[ORMOption] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -137,7 +146,7 @@ def get( present it will be returned directly, when not, a database lookup will be performed. For use cases where this is what you actually want, you can still access the original get - method on self.session. For most uses cases, this behavior can introduce non-determinism + method on session. For most uses cases, this behavior can introduce non-determinism and because of that this method performs lookup using a select statement. Additionally, to satisfy the expected interface's return type: Optional[EntityT], one_or_none is called on the result before returning. @@ -154,10 +163,11 @@ def get( if for_update: statement = statement.with_for_update() - return self.session.scalars(statement, execution_options=execution_options).one_or_none() + return session.scalars(statement, execution_options=execution_options).one_or_none() def select( self, + session: sa.orm.Session, selectables: t.Sequence[Selectable] = (), conditions: t.Sequence[ColumnExpr] = (), group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), @@ -196,12 +206,14 @@ def select( for_update=for_update, ) - results = self.session.scalars(statement) + results = session.scalars(statement) if yield_by_chunk: results = results.partitions() return results - def delete(self, id_: EntityIdT, include_inactive: bool = False) -> None: + def delete( + self, session: sa.orm.Session, id_: EntityIdT, include_inactive: bool = False + ) -> None: # if self.has_soft_delete: # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") @@ -209,16 +221,16 @@ def delete(self, id_: EntityIdT, include_inactive: bool = False) -> None: if not entity: raise RuntimeError(f"Entity with id {id_} not found.") - self.session.delete(entity) - self.session.flush() + session.delete(entity) + session.flush() - def deactivate(self, id_: EntityIdT) -> EntityT: + def deactivate(self, session: sa.orm.Session, id_: EntityIdT) -> EntityT: # if not self.has_soft_delete: # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") return self.update(id_, dict(is_active=False)) - def reactivate(self, id_: EntityIdT) -> EntityT: + def reactivate(self, session: sa.orm.Session, id_: EntityIdT) -> EntityT: # if not self.has_soft_delete: # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") @@ -226,6 +238,7 @@ def reactivate(self, id_: EntityIdT) -> EntityT: def exists( self, + session: sa.orm.Session, conditions: t.Sequence[ColumnExpr] = (), for_update: bool = False, include_inactive: bool = False, @@ -246,38 +259,41 @@ def exists( if for_update: statement = statement.with_for_update() - result = self.session.execute(statement, execution_options=execution_options).scalar() + result = session.execute(statement, execution_options=execution_options).scalar() return bool(result) class SQLAlchemyBulkRepository(AbstractBulkRepository, t.Generic[SessionT, EntityT, EntityIdT]): - def __init__(self, session: SessionT, **kwargs: t.Any): + def __init__(self, **kwargs: t.Any): super().__init__(**kwargs) self.builder = StatementBuilder(self.model) - self.session = session + # session = session def bulk_insert( self, + session: sa.orm.Session, values: t.Sequence[t.Dict[str, t.Any]] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: statement = self.builder.bulk_insert(self.model, values) - return self.session.execute(statement, execution_options=execution_options or {}) + return session.execute(statement, execution_options=execution_options or {}) def bulk_update( self, + session: sa.orm.Session, conditions: t.Sequence[ColumnExpr] = (), values: t.Optional[t.Dict[str, t.Any]] = None, execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: statement = self.builder.bulk_update(self.model, conditions, values) - return self.session.execute(statement, execution_options=execution_options or {}) + return session.execute(statement, execution_options=execution_options or {}) def bulk_delete( self, + session: sa.orm.Session, conditions: t.Sequence[ColumnExpr] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: statement = self.builder.bulk_delete(self.model, conditions) - return self.session.execute(statement, execution_options=execution_options or {}) + return session.execute(statement, execution_options=execution_options or {}) diff --git a/src/quart_sqlalchemy/sim/repo_adapter.py b/src/quart_sqlalchemy/sim/repo_adapter.py index 89230b0..3d6b48e 100644 --- a/src/quart_sqlalchemy/sim/repo_adapter.py +++ b/src/quart_sqlalchemy/sim/repo_adapter.py @@ -1,18 +1,24 @@ import typing as t +import sqlalchemy +import sqlalchemy.orm from pydantic import BaseModel from sqlalchemy import ScalarResult -from sqlalchemy.orm import selectinload, Session +from sqlalchemy.orm import selectinload +from sqlalchemy.orm import Session from sqlalchemy.sql.expression import func from sqlalchemy.sql.expression import label + from quart_sqlalchemy.model import Base +from quart_sqlalchemy.sim.repo import SQLAlchemyRepository from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable -from quart_sqlalchemy.sim.repo import SQLAlchemyRepository + +sa = sqlalchemy class BaseModelSchema(BaseModel): @@ -34,11 +40,16 @@ class BaseUpdateSchema(BaseModelSchema): class RepositoryLegacyAdapter(t.Generic[EntityT, EntityIdT]): - def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT], session: Session,): + def __init__( + self, + model: t.Type[EntityT], + identity: t.Type[EntityIdT], + # session: Session, + ): self.model = model self._identity = identity - self._session = session - self.repo = SQLAlchemyRepository[model, identity](session) + # self._session = session + self.repo = SQLAlchemyRepository[model, identity]() def get_by( self, @@ -62,6 +73,7 @@ def get_by( order_by_clause = () return self.repo.select( + session, conditions=filters, options=[selectinload(getattr(self.model, attr)) for attr in join_list], for_update=for_update, @@ -73,8 +85,8 @@ def get_by( def get_by_id( self, - session = None, - model_id = None, + session=None, + model_id=None, allow_inactive=False, join_list=None, for_update=False, @@ -83,16 +95,20 @@ def get_by_id( raise ValueError("model_id is required") join_list = join_list or () return self.repo.get( + session, id_=model_id, options=[selectinload(getattr(self.model, attr)) for attr in join_list], for_update=for_update, include_inactive=allow_inactive, ) - def one(self, session = None, filters=None, join_list=None, for_update=False, include_inactive=False) -> EntityT: + def one( + self, session=None, filters=None, join_list=None, for_update=False, include_inactive=False + ) -> EntityT: filters = filters or () join_list = join_list or () return self.repo.select( + session, conditions=filters, options=[selectinload(getattr(self.model, attr)) for attr in join_list], for_update=for_update, @@ -101,7 +117,7 @@ def one(self, session = None, filters=None, join_list=None, for_update=False, in def count_by( self, - session = None, + session=None, filters=None, group_by=None, distinct_column=None, @@ -119,29 +135,29 @@ def count_by( for group in group_by: selectables.append(group.expression) - result = self.repo.select(selectables, conditions=filters, group_by=group_by) + result = self.repo.select(session, selectables, conditions=filters, group_by=group_by) return result.all() - def add(self, session = None, **kwargs) -> EntityT: - return self.repo.insert(kwargs) + def add(self, session=None, **kwargs) -> EntityT: + return self.repo.insert(session, kwargs) - def update(self, session = None, model_id, **kwargs) -> EntityT: - return self.repo.update(id_=model_id, values=kwargs) + def update(self, session=None, model_id=None, **kwargs) -> EntityT: + return self.repo.update(session, id_=model_id, values=kwargs) - def update_by(self, session = None, filters=None, **kwargs) -> EntityT: + def update_by(self, session=None, filters=None, **kwargs) -> EntityT: if not filters: raise ValueError("Full table scans are prohibited. Please provide filters") - row = self.repo.select(conditions=filters, limit=2).one() - return self.repo.update(id_=row.id, values=kwargs) + row = self.repo.select(session, conditions=filters, limit=2).one() + return self.repo.update(session, id_=row.id, values=kwargs) - def delete_by_id(self, session = None, model_id) -> None: - self.repo.delete(id_=model_id, include_inactive=True) + def delete_by_id(self, session=None, model_id=None) -> None: + self.repo.delete(session, id_=model_id, include_inactive=True) - def delete_one_by(self, session = None, filters=None, optional=False) -> None: + def delete_one_by(self, session=None, filters=None, optional=False) -> None: filters = filters or () - result = self.repo.select(conditions=filters, limit=1) + result = self.repo.select(session, conditions=filters, limit=1) if optional: row = result.one_or_none() @@ -150,21 +166,23 @@ def delete_one_by(self, session = None, filters=None, optional=False) -> None: else: row = result.one() - self.repo.delete(id_=row.id) + self.repo.delete(session, id_=row.id) - def exist(self, session = None, filters=None, allow_inactive=False) -> bool: + def exist(self, session=None, filters=None, allow_inactive=False) -> bool: filters = filters or () return self.repo.exists( + session, conditions=filters, include_inactive=allow_inactive, ) def yield_by_chunk( - self, session = None, chunk_size = 100, join_list=None, filters=None, allow_inactive=False + self, session=None, chunk_size=100, join_list=None, filters=None, allow_inactive=False ): filters = filters or () join_list = join_list or () results = self.repo.select( + session, conditions=filters, options=[selectinload(getattr(self.model, attr)) for attr in join_list], include_inactive=allow_inactive, @@ -220,11 +238,12 @@ def schema(self) -> t.Type[ModelSchemaT]: def insert( self, + session: sa.orm.Session, create_schema: CreateSchemaT, sqla_model=False, ): create_data = create_schema.dict() - result = super().insert(create_data) + result = super().insert(session, create_data) if sqla_model: return result @@ -232,11 +251,12 @@ def insert( def update( self, + session: sa.orm.Session, id_: EntityIdT, update_schema: UpdateSchemaT, sqla_model=False, ): - existing = self.session.query(self.model).get(id_) + existing = session.query(self.model).get(id_) if existing is None: raise ValueError("Model not found") @@ -244,15 +264,16 @@ def update( for key, value in update_data.items(): setattr(existing, key, value) - self.session.add(existing) - self.session.flush() - self.session.refresh(existing) + session.add(existing) + session.flush() + session.refresh(existing) if sqla_model: return existing return self.schema.from_orm(existing) def get( self, + session: sa.orm.Session, id_: EntityIdT, options: t.Sequence[ORMOption] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -261,6 +282,7 @@ def get( sqla_model: bool = False, ): row = super().get( + session, id_, options, execution_options, @@ -276,6 +298,7 @@ def get( def select( self, + session: sa.orm.Session, selectables: t.Sequence[Selectable] = (), conditions: t.Sequence[ColumnExpr] = (), group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), @@ -291,6 +314,7 @@ def select( sqla_model: bool = False, ): result = super().select( + session, selectables, conditions, group_by, diff --git a/src/quart_sqlalchemy/sim/schema.py b/src/quart_sqlalchemy/sim/schema.py index f2311ca..e0304b7 100644 --- a/src/quart_sqlalchemy/sim/schema.py +++ b/src/quart_sqlalchemy/sim/schema.py @@ -1,8 +1,11 @@ +import typing as t from datetime import datetime from pydantic import BaseModel +from pydantic import Field +from pydantic import validator -from quart_sqlalchemy.sim.util import ObjectID +from .util import ObjectID class BaseSchema(BaseModel): @@ -12,3 +15,29 @@ class Config: ObjectID: lambda v: v.encode(), datetime: lambda dt: int(dt.timestamp()), } + + @classmethod + def _get_value(cls, v: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: + if hasattr(v, "__serialize__"): + return v.__serialize__() + for type_, converter in cls.__config__.json_encoders.items(): + if isinstance(v, type_): + return converter(v) + + return super()._get_value(v, *args, **kwargs) + + +class ResponseWrapper(BaseSchema): + """Generic response wrapper""" + + error_code: str = "" + status: str = "" + message: str = "" + data: t.Any = Field(default_factory=dict) + + @validator("status") + def set_status_by_error_code(cls, v, values): + error_code = values.get("error_code") + if error_code: + return "failed" + return "ok" diff --git a/src/quart_sqlalchemy/sim/testing.py b/src/quart_sqlalchemy/sim/testing.py new file mode 100644 index 0000000..347b422 --- /dev/null +++ b/src/quart_sqlalchemy/sim/testing.py @@ -0,0 +1,13 @@ +from contextlib import contextmanager + +from quart import g +from quart import signals + + +@contextmanager +def user_set(app, user): + def handler(sender, **kwargs): + g.user = user + + with signals.appcontext_pushed.connected_to(handler, app): + yield diff --git a/src/quart_sqlalchemy/sim/views/__init__.py b/src/quart_sqlalchemy/sim/views/__init__.py index d110bc2..656f265 100644 --- a/src/quart_sqlalchemy/sim/views/__init__.py +++ b/src/quart_sqlalchemy/sim/views/__init__.py @@ -6,7 +6,7 @@ from .magic_client import api as magic_client_api -api = Blueprint("api", __name__, url_prefix="api") +api = Blueprint("api", __name__, url_prefix="/api") api.register_blueprint(auth_user_api) api.register_blueprint(auth_wallet_api) @@ -15,4 +15,4 @@ @api.before_request def set_feature_owner(): - g.request_feature_owner = "magic" + g.request_feature_owner = "auth-team" diff --git a/src/quart_sqlalchemy/sim/views/auth_wallet.py b/src/quart_sqlalchemy/sim/views/auth_wallet.py index 24fc9fc..9c6f052 100644 --- a/src/quart_sqlalchemy/sim/views/auth_wallet.py +++ b/src/quart_sqlalchemy/sim/views/auth_wallet.py @@ -1,25 +1,28 @@ import logging import typing as t -from quart import Blueprint from quart import g from quart.utils import run_sync -from quart_schema.validation import validate +from quart_sqlalchemy.retry import retry_context +from quart_sqlalchemy.retry import RetryError + +from ..auth import authorized_request +from ..handle import AuthWalletHandler from ..model import WalletManagementType from ..model import WalletType from ..schema import BaseSchema from ..util import ObjectID -from .decorator import authorized_request +from .util import APIBlueprint logger = logging.getLogger(__name__) -api = Blueprint("auth_wallet", __name__, url_prefix="auth_wallet") +api = APIBlueprint("auth_wallet", __name__, url_prefix="auth_wallet") @api.before_request def set_feature_owner(): - g.request_feature_owner = "wallet" + g.request_feature_owner = "wallet-team" class WalletSyncRequest(BaseSchema): @@ -38,18 +41,36 @@ class WalletSyncResponse(BaseSchema): encrypted_private_address: str -@authorized_request(authenticate_client=True, authenticate_user=True) -@validate(request=WalletSyncRequest, responses={200: (WalletSyncResponse, None)}) -@api.route("/sync", methods=["POST"]) -async def sync_auth_user_wallet(data: WalletSyncRequest): +@api.post( + "/sync", + authorizer=authorized_request( + [ + # We use the OpenAPI security scheme metadata to know which kind of authorization to enforce. + # + # Together in the same dict implies logical AND requirement so both public-api-key and + # session-token will be enforced + { + "public-api-key": [], + "session-token-bearer": [], + } + ], + ), +) +async def sync(data: WalletSyncRequest) -> WalletSyncResponse: + user_credential = g.authorized_credentials.get("session-token-bearer") + try: - with g.bind.Session() as session: - wallet = await run_sync(g.h.AuthWallet(session).sync_auth_wallet)( - g.auth.user.id, - data.public_address, - data.encrypted_private_address, - WalletManagementType.DELEGATED.value, - ) + for attempt in retry_context: + with attempt: + with g.bind.Session() as session: + wallet = AuthWalletHandler(g.network, session).sync_auth_wallet( + user_credential.subject.id, + data.public_address, + data.encrypted_private_address, + WalletManagementType.DELEGATED.value, + ) + except RetryError: + pass except RuntimeError: raise RuntimeError("Unsupported wallet type or network") diff --git a/src/quart_sqlalchemy/sim/views/decorator.py b/src/quart_sqlalchemy/sim/views/decorator.py deleted file mode 100644 index 357a819..0000000 --- a/src/quart_sqlalchemy/sim/views/decorator.py +++ /dev/null @@ -1,131 +0,0 @@ -# import inspect -# import typing as t -# from dataclasses import asdict -# from dataclasses import is_dataclass -from functools import wraps - -# from pydantic import BaseModel -# from pydantic import ValidationError -# from pydantic.dataclasses import dataclass as pydantic_dataclass -# from pydantic.schema import model_schema -from quart import current_app -from quart import g - - -# from quart import request -# from quart import ResponseReturnValue as QuartResponseReturnValue -# from quart_schema.typing import Model -# from quart_schema.typing import PydanticModel -# from quart_schema.typing import ResponseReturnValue -# from quart_schema.validation import _convert_headers -# from quart_schema.validation import DataSource -# from quart_schema.validation import QUART_SCHEMA_RESPONSE_ATTRIBUTE -# from quart_schema.validation import ResponseHeadersValidationError -# from quart_schema.validation import ResponseSchemaValidationError -# from quart_schema.validation import validate_headers -# from quart_schema.validation import validate_querystring -# from quart_schema.validation import validate_request - - -def authorized_request(authenticate_client: bool = False, authenticate_user: bool = False): - def decorator(func): - @wraps(func) - async def wrapper(*args, **kwargs): - if authenticate_client: - if not g.auth.client: - raise RuntimeError("Unable to authenticate client") - kwargs.update(client_id=g.auth.client.id) - - if authenticate_user: - if not g.auth.user: - raise RuntimeError("Unable to authenticate user") - kwargs.update(user_id=g.auth.user.id) - - return await current_app.ensure_async(func)(*args, **kwargs) - - return wrapper - - return decorator - - -# def validate_response() -> t.Callable: -# def decorator( -# func: t.Callable[..., ResponseReturnValue] -# ) -> t.Callable[..., QuartResponseReturnValue]: -# undecorated = func -# while hasattr(undecorated, "__wrapped__"): -# undecorated = undecorated.__wrapped__ - -# signature = inspect.signature(undecorated) -# derived_schema = signature.return_annotation or dict - -# schemas = getattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, {}) -# schemas[200] = (derived_schema, None) -# setattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, schemas) - -# @wraps(func) -# async def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: -# result = await current_app.ensure_async(func)(*args, **kwargs) - -# status_or_headers = None -# headers = None -# if isinstance(result, tuple): -# value, status_or_headers, headers = result + (None,) * (3 - len(result)) -# else: -# value = result - -# status = 200 -# if isinstance(status_or_headers, int): -# status = int(status_or_headers) - -# schemas = getattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, {200: dict}) -# model_class = schemas.get(status, dict) - -# try: -# if isinstance(value, dict): -# model_value = model_class(**value) -# elif type(value) == model_class: -# model_value = value -# elif is_dataclass(value): -# model_value = model_class(**asdict(value)) -# else: -# return result, status, headers - -# except ValidationError as error: -# raise ResponseHeadersValidationError(error) - -# headers_value = headers -# return model_value, status, headers_value - -# return wrapper - -# return decorator - - -# def validate( -# *, -# querystring: t.Optional[Model] = None, -# request: t.Optional[Model] = None, -# request_source: DataSource = DataSource.JSON, -# headers: t.Optional[Model] = None, -# responses: t.Dict[int, t.Tuple[Model, t.Optional[Model]]], -# ) -> t.Callable: -# """Validate the route. - -# This is a shorthand combination of of the validate_querystring, -# validate_request, validate_headers, and validate_response -# decorators. Please see the docstrings for those decorators. -# """ - -# def decorator(func: t.Callable) -> t.Callable: -# if querystring is not None: -# func = validate_querystring(querystring)(func) -# if request is not None: -# func = validate_request(request, source=request_source)(func) -# if headers is not None: -# func = validate_headers(headers)(func) -# for status, models in responses.items(): -# func = validate_response(models[0], status, models[1]) -# return func - -# return decorator diff --git a/src/quart_sqlalchemy/sim/views/util/__init__.py b/src/quart_sqlalchemy/sim/views/util/__init__.py new file mode 100644 index 0000000..e677b31 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/util/__init__.py @@ -0,0 +1,12 @@ +from .blueprint import APIBlueprint +from .decorator import inject_request +from .decorator import validate_request +from .decorator import validate_response + + +__all__ = ( + "APIBlueprint", + "inject_request", + "validate_request", + "validate_response", +) diff --git a/src/quart_sqlalchemy/sim/views/util/blueprint.py b/src/quart_sqlalchemy/sim/views/util/blueprint.py new file mode 100644 index 0000000..35927ec --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/util/blueprint.py @@ -0,0 +1,102 @@ +import inspect +import typing as t + +from quart import Blueprint +from quart import Request +from quart_schema.validation import validate_headers +from quart_schema.validation import validate_querystring + +from ...schema import BaseSchema +from .decorator import inject_request +from .decorator import validate_request +from .decorator import validate_response + + +class APIBlueprint(Blueprint): + def _endpoint( + self, + uri: str, + methods: t.Optional[t.Sequence[str]] = ("GET",), + authorizer: t.Optional[t.Callable] = None, + **route_kwargs, + ): + def decorator(func): + sig = inspect.signature(func) + + param_annotation_map = { + name: param.annotation for name, param in sig.parameters.items() + } + has_request_schema = "data" in sig.parameters and issubclass( + param_annotation_map["data"], + BaseSchema, + ) + has_query_schema = "query_args" in sig.parameters and issubclass( + param_annotation_map["query_args"], + BaseSchema, + ) + has_headers_schema = "headers" in sig.parameters and issubclass( + param_annotation_map["headers"], + BaseSchema, + ) + + has_response_schema = isinstance(sig.return_annotation, BaseSchema) + + should_inject_request, request_param_name = False, None + for name in param_annotation_map: + if isinstance(param_annotation_map[name], Request): + should_inject_request, request_param_name = True, name + break + + decorated = func + + if should_inject_request: + decorated = inject_request(request_param_name)(decorated) + + if has_query_schema: + decorated = validate_querystring(param_annotation_map["query_args"])(decorated) + + if has_headers_schema: + decorated = validate_headers(param_annotation_map["headers"])(decorated) + + if has_request_schema: + decorated = validate_request(param_annotation_map["data"])(decorated) + + if has_response_schema: + decorated = validate_response(sig.return_annotation)(decorated) + + if authorizer: + decorated = authorizer(decorated) + + return self.route(uri, t.cast(t.List[str], methods), **route_kwargs)(decorated) + + return decorator + + def get(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["GET"], **kwargs) + + def post(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["POST"], **kwargs) + + def put(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["PUT"], **kwargs) + + def patch(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["PATCH"], **kwargs) + + def delete(self, *args, **kwargs): + if "methods" in kwargs: + del kwargs["methods"] + + return self._endpoint(*args, methods=["DELETE"], **kwargs) diff --git a/src/quart_sqlalchemy/sim/views/util/decorator.py b/src/quart_sqlalchemy/sim/views/util/decorator.py new file mode 100644 index 0000000..55323d2 --- /dev/null +++ b/src/quart_sqlalchemy/sim/views/util/decorator.py @@ -0,0 +1,120 @@ +import typing as t +from dataclasses import asdict +from dataclasses import is_dataclass +from functools import wraps + +from humps import camelize +from humps import decamelize +from pydantic import BaseModel +from pydantic import ValidationError +from quart import current_app +from quart import request +from quart import Response +from quart_schema.typing import Model +from quart_schema.typing import ResponseReturnValue +from quart_schema.validation import QUART_SCHEMA_REQUEST_ATTRIBUTE +from quart_schema.validation import QUART_SCHEMA_RESPONSE_ATTRIBUTE +from quart_schema.validation import RequestSchemaValidationError +from quart_schema.validation import ResponseSchemaValidationError + + +def convert_model_result(func: t.Callable) -> t.Callable: + @wraps(func) + async def decorator(result: ResponseReturnValue) -> Response: + status_or_headers = None + headers = None + if isinstance(result, tuple): + value, status_or_headers, headers = result + (None,) * (3 - len(result)) + else: + value = result + + was_model = False + if is_dataclass(value): + dict_or_value = asdict(value) + was_model = True + elif isinstance(value, BaseModel): + dict_or_value = value.dict(by_alias=True) + was_model = True + else: + dict_or_value = value + + if was_model: + dict_or_value = camelize(dict_or_value) + + return await func((dict_or_value, status_or_headers, headers)) + + return decorator + + +def validate_request(model_class: Model) -> t.Callable: + def decorator(func: t.Callable) -> t.Callable: + setattr(func, QUART_SCHEMA_REQUEST_ATTRIBUTE, (model_class, None)) + + @wraps(func) + async def wrapper(*args, **kwargs): + data = await request.get_json() + data = decamelize(data) + + try: + model = model_class(**data) + except (TypeError, ValidationError) as error: + raise RequestSchemaValidationError(error) + else: + return await current_app.ensure_async(func)(*args, data=model, **kwargs) + + return wrapper + + return decorator + + +def validate_response(model_class: Model, status_code: int = 200) -> t.Callable: + def decorator(func): + schemas = getattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, {}) + schemas[status_code] = (model_class, None) + setattr(func, QUART_SCHEMA_RESPONSE_ATTRIBUTE, schemas) + + @wraps(func) + async def wrapper(*args, **kwargs): + result = await current_app.ensure_async(func)(*args, **kwargs) + + status_or_headers = None + headers = None + if isinstance(result, tuple): + value, status_or_headers, headers = result + (None,) * (3 - len(result)) + else: + value = result + + status = 200 + if isinstance(status_or_headers, int): + status = int(status_or_headers) + + if status == status_code: + try: + if isinstance(value, dict): + model_value = model_class(**value) + elif type(value) == model_class: + model_value = value + else: + raise ResponseSchemaValidationError() + except ValidationError as error: + raise ResponseSchemaValidationError(error) + + return model_value, status, headers + else: + return result + + return wrapper + + return decorator + + +def inject_request(key: str): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + kwargs[key] = request._get_current_object() + return await current_app.ensure_async(func)(*args, **kwargs) + + return wrapper + + return decorator From bc9f565699f2cc746b4e18e8da91c9973a289591 Mon Sep 17 00:00:00 2001 From: Joe Black Date: Fri, 31 Mar 2023 19:06:50 -0400 Subject: [PATCH 05/11] fix --- src/quart_sqlalchemy/sim/auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/quart_sqlalchemy/sim/auth.py b/src/quart_sqlalchemy/sim/auth.py index 51c3962..80a908a 100644 --- a/src/quart_sqlalchemy/sim/auth.py +++ b/src/quart_sqlalchemy/sim/auth.py @@ -274,12 +274,12 @@ def auth_endpoint_security(self): ) @click.option( "--client-id", - type=int, + type=str, required=True, help="client id", ) @pass_script_info -def add_user(info: ScriptInfo, email: str, user_type: str, client_id: int) -> None: +def add_user(info: ScriptInfo, email: str, user_type: str, client_id: str) -> None: app = info.load_app() db = app.extensions.get("sqlalchemy") From 51ccb5cd412d7bf97dd13fdf6a83de965f1ccc1e Mon Sep 17 00:00:00 2001 From: Joe Black Date: Fri, 31 Mar 2023 19:14:40 -0400 Subject: [PATCH 06/11] fix --- src/quart_sqlalchemy/sim/util.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/quart_sqlalchemy/sim/util.py b/src/quart_sqlalchemy/sim/util.py index 463b350..d4d33d1 100644 --- a/src/quart_sqlalchemy/sim/util.py +++ b/src/quart_sqlalchemy/sim/util.py @@ -30,8 +30,7 @@ def __init__(self, input_value): elif isinstance(input_value, ObjectID): self._source_id = input_value._decoded_id elif isinstance(input_value, str): - self._source_id = input_value - self._decode() + self._source_id = self._decode(input_value) elif isinstance(input_value, numbers.Number): try: input_value = int(input_value) @@ -47,7 +46,7 @@ def _encoded_id(self): @property def _decoded_id(self): - return self._decode() + return self._source_id def __eq__(self, other): if isinstance(other, ObjectID): @@ -88,11 +87,11 @@ def _encode(self): def encode(self): return self._encoded_id - def _decode(self): - if isinstance(self._source_id, int): - return self._source_id + def _decode(self, value): + if isinstance(value, int): + return value else: - return self.hashids.decode(self._source_id) + return self.hashids.decode(value)[0] def decode(self): return self._decoded_id From c7e453dc1571530da00bf93961955b9373419147 Mon Sep 17 00:00:00 2001 From: Joe Black Date: Fri, 31 Mar 2023 19:41:41 -0400 Subject: [PATCH 07/11] added sim extra deps --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 8e2348c..eb2e1be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,9 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] +sim = [ + "quart-schema", "hashids" +] tests = [ "pytest", # "pytest-asyncio~=0.20.3", From 21e2c2447a77daa9805ceaaa0a5692996fdd73df Mon Sep 17 00:00:00 2001 From: Joe Black Date: Wed, 12 Apr 2023 16:31:51 -0400 Subject: [PATCH 08/11] Sim POC --- .env | 3 + docs/Simulation.md | 159 ++++++ docs/usage.md | 300 +++++++++++ examples/decorators/provide_session.py | 57 ++ examples/repository/base.py | 60 ++- examples/repository/sqla.py | 48 +- examples/usrsrv/component/__init__.py | 23 + examples/usrsrv/component/app.py | 10 + examples/usrsrv/component/command.py | 43 ++ examples/usrsrv/component/entity.py | 39 ++ examples/usrsrv/component/event.py | 45 ++ examples/usrsrv/component/exception.py | 2 + ...5_ddd_component_unitofwork_18fd763a02a4.py | 0 examples/usrsrv/component/repository.py | 63 +++ examples/usrsrv/component/service.py | 68 +++ pyproject.toml | 7 +- setup.cfg | 1 + src/quart_sqlalchemy/__init__.py | 2 +- src/quart_sqlalchemy/bind.py | 189 +++++-- src/quart_sqlalchemy/config.py | 193 ++++--- src/quart_sqlalchemy/framework/cli.py | 69 ++- src/quart_sqlalchemy/framework/extension.py | 14 +- src/quart_sqlalchemy/model/__init__.py | 11 +- src/quart_sqlalchemy/model/columns.py | 14 +- src/quart_sqlalchemy/model/mixins.py | 93 +++- src/quart_sqlalchemy/model/model.py | 43 +- src/quart_sqlalchemy/session.py | 53 ++ src/quart_sqlalchemy/signals.py | 31 ++ src/quart_sqlalchemy/sim/app.py | 69 +-- src/quart_sqlalchemy/sim/auth.py | 62 +-- src/quart_sqlalchemy/sim/builder.py | 6 +- src/quart_sqlalchemy/sim/commands.py | 52 ++ src/quart_sqlalchemy/sim/config.py | 55 ++ src/quart_sqlalchemy/sim/container.py | 57 ++ src/quart_sqlalchemy/sim/db.py | 48 +- src/quart_sqlalchemy/sim/handle.py | 454 +++++----------- src/quart_sqlalchemy/sim/logic.py | 501 ++++++------------ src/quart_sqlalchemy/sim/main.py | 2 + src/quart_sqlalchemy/sim/model.py | 8 +- src/quart_sqlalchemy/sim/repo.py | 150 +++--- src/quart_sqlalchemy/sim/repo_adapter.py | 120 ++--- src/quart_sqlalchemy/sim/schema.py | 68 ++- src/quart_sqlalchemy/sim/signals.py | 1 + src/quart_sqlalchemy/sim/testing.py | 7 +- src/quart_sqlalchemy/sim/views/auth_user.py | 85 ++- src/quart_sqlalchemy/sim/views/auth_wallet.py | 62 ++- .../sim/views/magic_client.py | 87 ++- src/quart_sqlalchemy/sim/web3.py | 153 ++++++ src/quart_sqlalchemy/sqla.py | 294 ++++++++-- src/quart_sqlalchemy/testing/fake.py | 0 src/quart_sqlalchemy/testing/signals.py | 40 ++ src/quart_sqlalchemy/testing/transaction.py | 10 +- src/quart_sqlalchemy/types.py | 15 +- tests/base.py | 204 ++++++- tests/conftest.py | 49 -- tests/constants.py | 20 +- tests/integration/concurrency/__init__.py | 0 .../concurrency/with_for_update.py | 103 ++++ tests/integration/framework/smoke_test.py | 14 +- tests/integration/model/mixins_test.py | 133 ++++- tests/integration/model/model_test.py | 59 +++ tests/integration/retry_test.py | 4 +- workspace.code-workspace | 3 +- 63 files changed, 3351 insertions(+), 1284 deletions(-) create mode 100644 .env create mode 100644 docs/Simulation.md create mode 100644 docs/usage.md create mode 100644 examples/decorators/provide_session.py create mode 100644 examples/usrsrv/component/__init__.py create mode 100644 examples/usrsrv/component/app.py create mode 100644 examples/usrsrv/component/command.py create mode 100644 examples/usrsrv/component/entity.py create mode 100644 examples/usrsrv/component/event.py create mode 100644 examples/usrsrv/component/exception.py create mode 100644 examples/usrsrv/component/migrations/2020-04-15_ddd_component_unitofwork_18fd763a02a4.py create mode 100644 examples/usrsrv/component/repository.py create mode 100644 examples/usrsrv/component/service.py create mode 100644 src/quart_sqlalchemy/sim/commands.py create mode 100644 src/quart_sqlalchemy/sim/config.py create mode 100644 src/quart_sqlalchemy/sim/container.py create mode 100644 src/quart_sqlalchemy/sim/web3.py create mode 100644 src/quart_sqlalchemy/testing/fake.py create mode 100644 src/quart_sqlalchemy/testing/signals.py create mode 100644 tests/integration/concurrency/__init__.py create mode 100644 tests/integration/concurrency/with_for_update.py create mode 100644 tests/integration/model/model_test.py diff --git a/.env b/.env new file mode 100644 index 0000000..f9ea633 --- /dev/null +++ b/.env @@ -0,0 +1,3 @@ +WEB3_HTTPS_PROVIDER_URI=https://eth-mainnet.g.alchemy.com/v2/422IpViRAru0Uu1SANhySuOStpaIK3AG +ALCHEMY_API_KEY=xxx +QUART_APP=quart_sqlalchemy.sim.main:app diff --git a/docs/Simulation.md b/docs/Simulation.md new file mode 100644 index 0000000..656d8df --- /dev/null +++ b/docs/Simulation.md @@ -0,0 +1,159 @@ +# Simulation Docs + +# initialize database +```shell +quart db create +``` +``` +Initialized database schema for +``` + +# add first client to the database (Using CLI) +```shell +quart auth add-client +``` +``` +Created client 2VolejRejNmG with public_api_key: 5f794cf72d0cef2dd008be2c0b7a632b +``` + +Use the `public_api_key` returned for the value of the `X-Public-API-Key` header when making API requests. + + +# Create new auth_user via api +```shell +curl -X POST localhost:8081/api/auth_user/ \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Content-Type: application/json' \ + --data '{"email": "joe2@joe.com"}' +``` +```json +{ + "data": { + "auth_user": { + "client_id": "2VolejRejNmG", + "current_session_token": "69ee9af5b9296a09f90be5b71c1dda38", + "date_verified": 1681344793, + "delegated_identity_pool_id": null, + "delegated_user_id": null, + "email": "joe2@joe.com", + "global_auth_user_id": null, + "id": "GWpmbk5ezJn4", + "is_admin": false, + "linked_primary_auth_user_id": null, + "phone_number": null, + "provenance": null, + "user_type": 2 + } + }, + "error_code": "", + "message": "", + "status": "" +} +``` + +Use the `current_session_token` returned for the value of `Authorization: Bearer {token}` header when making API Requests requiring a user. + +# get AuthUser corresponding to provided bearer session token +```shell +curl -X GET localhost:8081/api/auth_user/ \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Authorization: Bearer 69ee9af5b9296a09f90be5b71c1dda38' \ + -H 'Content-Type: application/json' +``` +```json +{ + "data": { + "client_id": "2VolejRejNmG", + "current_session_token": "69ee9af5b9296a09f90be5b71c1dda38", + "date_verified": 1681344793, + "delegated_identity_pool_id": null, + "delegated_user_id": null, + "email": "joe2@joe.com", + "global_auth_user_id": null, + "id": "GWpmbk5ezJn4", + "is_admin": false, + "linked_primary_auth_user_id": null, + "phone_number": null, + "provenance": null, + "user_type": 2 + }, + "error_code": "", + "message": "", + "status": "" +} +``` + + +# AuthWallet Sync +```shell +curl -X POST localhost:8081/api/auth_wallet/sync \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Authorization: Bearer 69ee9af5b9296a09f90be5b71c1dda38' \ + -H 'Content-Type: application/json' \ + --data '{"public_address": "xxx", "encrypted_private_address": "xxx", "wallet_type": "ETH"}' +``` +```json +{ + "data": { + "auth_user_id": "GWpmbk5ezJn4", + "encrypted_private_address": "xxx", + "public_address": "xxx", + "wallet_id": "GWpmbk5ezJn4", + "wallet_type": "ETH" + }, + "error_code": "", + "message": "", + "status": "" +} +``` + +# get magic client corresponding to provided public api key +```shell +curl -X GET localhost:8081/api/magic_client/ \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Content-Type: application/json' +``` +```json +{ + "data": { + "app_name": "My App", + "connect_interop": null, + "global_audience_enabled": false, + "id": "2VolejRejNmG", + "is_signing_modal_enabled": false, + "public_api_key": "5f794cf72d0cef2dd008be2c0b7a632b", + "rate_limit_tier": null, + "secret_api_key": "c6ecbced505b35505751c862ed0fb10ffb623d24095019433e0d4d94e240e508" + }, + "error_code": "", + "message": "", + "status": "" +} +``` + +# Create new magic client +```shell +curl -X POST localhost:8081/api/magic_client/ \ + -H 'X-Public-API-Key: 5f794cf72d0cef2dd008be2c0b7a632b' \ + -H 'Content-Type: application/json' \ + --data '{"app_name": "New App"}' +``` +```json +{ + "data": { + "magic_client": { + "app_name": "New App", + "connect_interop": null, + "global_audience_enabled": false, + "id": "GWpmbk5ezJn4", + "is_signing_modal_enabled": false, + "public_api_key": "fb7e0466e2e09387b93af7da49bb1386", + "rate_limit_tier": null, + "secret_api_key": "2ac56a6068d0d4b2ce911ba08401c7bf4acdb03db957550c260bd317c6c49a76" + } + }, + "error_code": "", + "message": "", + "status": "" +} +``` \ No newline at end of file diff --git a/docs/usage.md b/docs/usage.md new file mode 100644 index 0000000..f0bfdbe --- /dev/null +++ b/docs/usage.md @@ -0,0 +1,300 @@ +# API + +## SQLAlchemy +### `quart_sqlalchemy.sqla.SQLAlchemy` + +### Conventions +This manager class keeps things very simple by using a few configuration conventions: + +* Configuration has been simplified down to base_class and binds. +* Everything related to ORM mapping, DeclarativeBase, registry, MetaData, etc should be configured by passing the a custom DeclarativeBase class as the base_class configuration parameter. +* Everything related to engine/session configuration should be configured by passing a dictionary mapping string names to BindConfigs as the `binds` configuration parameter. +* the bind named `default` is the canonical bind, and to be used unless something more specific has been requested + +### Configuration +BindConfig can be as simple as a dictionary containing a url key like so: +```python +bind_config = { + "default": {"url": "sqlite://"} +} +``` + +But most use cases will require more than just a connection url, and divide core/engine configuration from orm/session configuration which looks more like this: +```python +bind_config = { + "default": { + "engine": { + "url": "sqlite://" + }, + "session": { + "expire_on_commit": False + } + } +} +``` + +It helps to think of the bind configuration as being the options dictionary used to build the main core and orm factory objects. +* For SQLAlchemy core, the configuration under the key `engine` will be used by `sa.engine_from_config` to build the `sa.Engine` object which acts as a factory for `sa.Connection` objects. + ```python + engine = sa.engine_from_config(config.engine, prefix="") + ``` +* For SQLAlchemy orm, the configuration under the key `session` will be used to build the `sa.orm.sessionmaker` session factory which acts as a factory for `sa.orm.Session` objects. + ```python + session_factory = sa.orm.sessionmaker(bind=engine, **config.session) + ``` + +#### Usage Examples +SQLAlchemyConfig is to be passed to SQLAlchemy or QuartSQLAlchemy as the first parameter when initializing. + +```python +db = SQLAlchemy( + SQLAlchemyConfig( + binds=dict( + default=dict( + url="sqlite://" + ) + ) + ) +) +``` + +When nothing is provided to SQLAlchemyConfig directly, it is instantiated with the following defaults + +```python +db = SQLAlchemy(SQLAlchemyConfig()) +``` + +For `QuartSQLAlchemy` configuration can also be provided via Quart configuration. +```python +from quart_sqlalchemy.framework import QuartSQLAlchemy + +app = Quart(__name__) +app.config.from_mapping( + { + "SQLALCHEMY_BINDS": { + "default": { + "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + } + }, + "SQLALCHEMY_BASE_CLASS": Base, + } +) +db = QuartSQLAlchemy(app=app) +``` + + + + +A typical configuration containing engine and session config both: +```python +config = SQLAlchemyConfig( + binds=dict( + default=dict( + engine=dict( + url="sqlite://" + ), + session=dict( + expire_on_commit=False + ) + ) + ) +) +``` + +Async first configuration +```python +config = SQLAlchemyConfig( + binds=dict( + default=dict( + engine=dict( + url="sqlite+aiosqlite:///file:mem.db?mode=memory&cache=shared&uri=true" + ), + session=dict( + expire_on_commit=False + ) + ) + ) +) +``` + +More complex configuration having two additional binds based on default, one for a read-replica and the second having an async driver + +```python +config = { + "SQLALCHEMY_BINDS": { + "default": { + "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + "read-replica": { + "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + "read_only": True, + }, + "async": { + "engine": {"url": "sqlite+aiosqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, + "session": {"expire_on_commit": False}, + }, + }, + "SQLALCHEMY_BASE_CLASS": Base, +} +``` + + + Once instantiated, operations targetting all of the binds, aka metadata, like + `metadata.create_all` should be called from this class. Operations specific to a bind + should be called from that bind. This class has a few ways to get a specific bind. + + * To get a Bind, you can call `.get_bind(name)` on this class. The default bind can be + referenced at `.bind`. + + * To define an ORM model using the Base class attached to this class, simply inherit + from `.Base` + + db = SQLAlchemy(SQLAlchemyConfig()) + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + * You can also decouple Base from SQLAlchemy with some dependency inversion: + from quart_sqlalchemy.model.mixins import DynamicArgsMixin, ReprMixin, TableNameMixin + + class Base(DynamicArgsMixin, ReprMixin, TableNameMixin): + __abstract__ = True + + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db = SQLAlchemy(SQLAlchemyConfig(bind_class=Base)) + + db.create_all() + + + Declarative Mapping using registry based decorator: + + db = SQLAlchemy(SQLAlchemyConfig()) + + @db.registry.mapped + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Declarative with Imperative Table (Hybrid Declarative): + + class User(db.Base): + __table__ = sa.Table( + "user", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.String, default="Joe"), + ) + + + Declarative using reflection to automatically build the table object: + + class User(db.Base): + __table__ = sa.Table( + "user", + db.metadata, + autoload_with=db.bind.engine, + ) + + + Declarative Dataclass Mapping: + + from quart_sqlalchemy.model import Base as Base_ + + class Base(sa.orm.MappedAsDataclass, Base_): + pass + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + class User(db.Base): + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Declarative Dataclass Mapping (using decorator): + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + @db.registry.mapped_as_dataclass + class User: + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Alternate Dataclass Provider Pattern: + + from pydantic.dataclasses import dataclass + from quart_sqlalchemy.model import Base as Base_ + + class Base(sa.orm.MappedAsDataclass, Base_, dataclass_callable=dataclass): + pass + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + class User(db.Base): + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + Imperative style Mapping + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + user_table = sa.Table( + "user", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.String, default="Joe"), + ) + + post_table = sa.Table( + "post", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("title", sa.String, default="My post"), + sa.Column("user_id", sa.ForeignKey("user.id"), nullable=False), + ) + + class User: + pass + + class Post: + pass + + db.registry.map_imperatively( + User, + user_table, + properties={ + "posts": sa.orm.relationship(Post, back_populates="user") + } + ) + db.registry.map_imperatively( + Post, + post_table, + properties={ + "user": sa.orm.relationship(User, back_populates="posts", uselist=False) + } + ) \ No newline at end of file diff --git a/examples/decorators/provide_session.py b/examples/decorators/provide_session.py new file mode 100644 index 0000000..f67f2b0 --- /dev/null +++ b/examples/decorators/provide_session.py @@ -0,0 +1,57 @@ +import inspect +import typing as t +from contextlib import contextmanager +from functools import wraps + + +RT = t.TypeVar("RT") + + +@contextmanager +def create_session(bind): + """Contextmanager that will create and teardown a session.""" + session = bind.Session() + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + +def provide_session(bind_name: str = "default"): + """ + Function decorator that provides a session if it isn't provided. + If you want to reuse a session or run the function as part of a + database transaction, you pass it to the function, if not this wrapper + will create one and close it for you. + """ + + def decorator(func: t.Callable[..., RT]) -> t.Callable[..., RT]: + from quart_sqlalchemy import Bind + + func_params = inspect.signature(func).parameters + try: + # func_params is an ordered dict -- this is the "recommended" way of getting the position + session_args_idx = tuple(func_params).index("session") + except ValueError: + raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None + + # We don't need this anymore -- ensure we don't keep a reference to it by mistake + del func_params + + @wraps(func) + def wrapper(*args, **kwargs) -> RT: + if "session" in kwargs or session_args_idx < len(args): + return func(*args, **kwargs) + else: + bind = Bind.get_instance(bind_name) + + with create_session(bind) as session: + return func(*args, session=session, **kwargs) + + return wrapper + + return decorator diff --git a/examples/repository/base.py b/examples/repository/base.py index ec70e39..3ae0d2e 100644 --- a/examples/repository/base.py +++ b/examples/repository/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator import typing as t from abc import ABCMeta from abc import abstractmethod @@ -14,42 +15,49 @@ from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import Operator from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable +from quart_sqlalchemy.types import SessionT sa = sqlalchemy -class AbstractRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): +class AbstractRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): """A repository interface.""" - identity: t.Type[EntityIdT] + # entity: t.Type[EntityT] - # def __init__(self, model: t.Type[EntityT]): - # self.model = model + # def __init__(self, entity: t.Type[EntityT]): + # self.entity = entity @property - def model(self) -> EntityT: + def entity(self) -> EntityT: return self.__orig_class__.__args__[0] @abstractmethod - def insert(self, values: t.Dict[str, t.Any]) -> EntityT: + def insert(self, session: SessionT, values: t.Dict[str, t.Any]) -> EntityT: """Add `values` to the collection.""" @abstractmethod - def update(self, id_: EntityIdT, values: t.Dict[str, t.Any]) -> EntityT: + def update(self, session: SessionT, id_: EntityIdT, values: t.Dict[str, t.Any]) -> EntityT: """Update model with model_id using values.""" @abstractmethod def merge( - self, id_: EntityIdT, values: t.Dict[str, t.Any], for_update: bool = False + self, + session: SessionT, + id_: EntityIdT, + values: t.Dict[str, t.Any], + for_update: bool = False, ) -> EntityT: """Merge model with model_id using values.""" @abstractmethod def get( self, + session: SessionT, id_: EntityIdT, options: t.Sequence[ORMOption] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -58,9 +66,28 @@ def get( ) -> t.Optional[EntityT]: """Get model with model_id.""" + @abstractmethod + def get_by_field( + self, + session: SessionT, + field: t.Union[ColumnExpr, str], + value: t.Any, + op: Operator = operator.eq, + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + ) -> sa.ScalarResult[EntityT]: + """Select models where field is equal to value.""" + @abstractmethod def select( self, + session: SessionT, selectables: t.Sequence[Selectable] = (), conditions: t.Sequence[ColumnExpr] = (), group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), @@ -77,12 +104,13 @@ def select( """Select models matching conditions.""" @abstractmethod - def delete(self, id_: EntityIdT) -> None: + def delete(self, session: SessionT, id_: EntityIdT) -> None: """Delete model with id_.""" @abstractmethod def exists( self, + session: SessionT, conditions: t.Sequence[ColumnExpr] = (), for_update: bool = False, include_inactive: bool = False, @@ -90,27 +118,31 @@ def exists( """Return the existence of an object matching conditions.""" @abstractmethod - def deactivate(self, id_: EntityIdT) -> EntityT: + def deactivate(self, session: SessionT, id_: EntityIdT) -> EntityT: """Soft-Delete model with id_.""" @abstractmethod - def reactivate(self, id_: EntityIdT) -> EntityT: + def reactivate(self, session: SessionT, id_: EntityIdT) -> EntityT: """Soft-Delete model with id_.""" -class AbstractBulkRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): +class AbstractBulkRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): """A repository interface for bulk operations. Note: this interface circumvents ORM internals, breaking commonly expected behavior in order to gain performance benefits. Only use this class whenever absolutely necessary. """ - model: t.Type[EntityT] builder: StatementBuilder + @property + def entity(self) -> EntityT: + return self.__orig_class__.__args__[0] + @abstractmethod def bulk_insert( self, + session: SessionT, values: t.Sequence[t.Dict[str, t.Any]] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: @@ -119,6 +151,7 @@ def bulk_insert( @abstractmethod def bulk_update( self, + session: SessionT, conditions: t.Sequence[ColumnExpr] = (), values: t.Optional[t.Dict[str, t.Any]] = None, execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -128,6 +161,7 @@ def bulk_update( @abstractmethod def bulk_delete( self, + session: SessionT, conditions: t.Sequence[ColumnExpr] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: diff --git a/examples/repository/sqla.py b/examples/repository/sqla.py index 417ddaf..5f77dae 100644 --- a/examples/repository/sqla.py +++ b/examples/repository/sqla.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator import typing as t import sqlalchemy @@ -15,6 +16,7 @@ from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import Operator from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable from quart_sqlalchemy.types import SessionT @@ -25,8 +27,8 @@ class SQLAlchemyRepository( TableMetadataMixin, - AbstractRepository[EntityT, EntityIdT], - t.Generic[EntityT, EntityIdT], + AbstractRepository[EntityT, EntityIdT, SessionT], + t.Generic[EntityT, EntityIdT, SessionT], ): """A repository that uses SQLAlchemy to persist data. @@ -53,7 +55,7 @@ class SQLAlchemyRepository( session: sa.orm.Session builder: StatementBuilder - def __init__(self, session: sa.orm.Session, **kwargs): + def __init__(self, model: sa.orm.Session, **kwargs): super().__init__(**kwargs) self.session = session self.builder = StatementBuilder(None) @@ -125,6 +127,46 @@ def get( return self.session.scalars(statement, execution_options=execution_options).one_or_none() + def get_by_field( + self, + field: t.Union[ColumnExpr, str], + value: t.Any, + op: Operator = operator.eq, + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + ) -> sa.ScalarResult[EntityT]: + """Select models where field is equal to value.""" + selectables = (self.model,) # type: ignore + + execution_options = execution_options or {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + + if isinstance(field, str): + field = getattr(self.model, field) + + conditions = [t.cast(ColumnExpr, op(field, value))] + + statement = self.builder.complex_select( + selectables, + conditions=conditions, + order_by=order_by, + options=options, + execution_options=execution_options, + offset=offset, + limit=limit, + distinct=distinct, + for_update=for_update, + ) + + return self.session.scalars(statement) + def select( self, selectables: t.Sequence[Selectable] = (), diff --git a/examples/usrsrv/component/__init__.py b/examples/usrsrv/component/__init__.py new file mode 100644 index 0000000..21b27e9 --- /dev/null +++ b/examples/usrsrv/component/__init__.py @@ -0,0 +1,23 @@ +from . import commands +from . import events +from . import exceptions +from .app import handler +from .entity import EntityID +from .service import CommandHandler +from .service import Listener + + +handle = handler.handle +register = handler.register +unregister = handler.unregister + +__all__ = [ + "commands", + "events", + "exceptions", + "EntityID", + "CommandHandler", + "handle", + "register", + "unregister", +] diff --git a/examples/usrsrv/component/app.py b/examples/usrsrv/component/app.py new file mode 100644 index 0000000..b18baee --- /dev/null +++ b/examples/usrsrv/component/app.py @@ -0,0 +1,10 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from .repository import ORMRepository +from .service import CommandHandler + + +some_engine = create_engine("sqlite:///") +Session = sessionmaker(bind=some_engine) +handler = CommandHandler(ORMRepository(Session())) diff --git a/examples/usrsrv/component/command.py b/examples/usrsrv/component/command.py new file mode 100644 index 0000000..47f59b0 --- /dev/null +++ b/examples/usrsrv/component/command.py @@ -0,0 +1,43 @@ +""" +Commands +======== +A command is always DTO and as specific, as it can be from a domain perspective. I aim to create +separate classes for commands so I can just dispatch handlers by command class. + +```python +@dataclass +class Create(Command): + command_id: CommandID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) +``` +""" + +from abc import ABC +from dataclasses import dataclass, field +from datetime import datetime +from typing import Text +from uuid import UUID, uuid1 + +from .entity import EntityID + +CommandID = UUID + + +class Command(ABC): + entity_id: EntityID + command_id: CommandID + timestamp: datetime + + +@dataclass +class Create(Command): + command_id: CommandID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) + + +@dataclass +class UpdateValue(Command): + entity_id: EntityID + value: Text + command_id: CommandID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) diff --git a/examples/usrsrv/component/entity.py b/examples/usrsrv/component/entity.py new file mode 100644 index 0000000..a910ad2 --- /dev/null +++ b/examples/usrsrv/component/entity.py @@ -0,0 +1,39 @@ +from typing import NewType +from typing import Optional +from typing import Text +from uuid import UUID +from uuid import uuid1 + + +EntityID = NewType("EntityID", UUID) + + +class EntityDTO: + id: EntityID + value: Optional[Text] + + +class Entity: + id: EntityID + dto: EntityDTO + + class Event: + pass + + class Updated(Event): + pass + + def __init__(self, dto: EntityDTO) -> None: + self.id = dto.id + self.dto = dto + + @classmethod + def create(cls) -> "Entity": + dto = EntityDTO() + dto.id = EntityID(uuid1()) + dto.value = None + return Entity(dto) + + def update(self, value: Text) -> Updated: + self.dto.value = value + return self.Updated() diff --git a/examples/usrsrv/component/event.py b/examples/usrsrv/component/event.py new file mode 100644 index 0000000..4dd2895 --- /dev/null +++ b/examples/usrsrv/component/event.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from dataclasses import field +from datetime import datetime +from functools import singledispatch +from uuid import UUID +from uuid import uuid1 + +from .command import Command +from .command import CommandID +from .entity import Entity +from .entity import EntityID + + +EventID = UUID + + +class Event: + command_id: CommandID + event_id: EventID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) + + +@dataclass +class Created(Event): + command_id: CommandID + uow_id: EntityID + event_id: EventID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) + + +@dataclass +class Updated(Event): + command_id: CommandID + event_id: EventID = field(default_factory=uuid1) + timestamp: datetime = field(default_factory=datetime.utcnow) + + +@singledispatch +def app_event(event: Entity.Event, command: Command) -> Event: + raise NotImplementedError + + +@app_event.register(Entity.Updated) +def _(event: Entity.Updated, command: Command) -> Updated: + return Updated(command.command_id) diff --git a/examples/usrsrv/component/exception.py b/examples/usrsrv/component/exception.py new file mode 100644 index 0000000..f88096f --- /dev/null +++ b/examples/usrsrv/component/exception.py @@ -0,0 +1,2 @@ +class NotFound(Exception): + pass diff --git a/examples/usrsrv/component/migrations/2020-04-15_ddd_component_unitofwork_18fd763a02a4.py b/examples/usrsrv/component/migrations/2020-04-15_ddd_component_unitofwork_18fd763a02a4.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/usrsrv/component/repository.py b/examples/usrsrv/component/repository.py new file mode 100644 index 0000000..44dffcd --- /dev/null +++ b/examples/usrsrv/component/repository.py @@ -0,0 +1,63 @@ +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.orm import registry +from sqlalchemy.orm import Session + +from . import EntityID +from .entity import Entity +from .entity import EntityDTO +from .exception import NotFound +from .service import Repository + + +metadata = MetaData() +mapper_registry = registry(metadata=metadata) + + +entities_table = Table( + "entities", + metadata, + Column("id", Integer, primary_key=True, autoincrement=True), + Column("uuid", String, unique=True, index=True), + Column("value", String, nullable=True), +) + +# EntityMapper = mapper( +# EntityDTO, +# entities_table, +# properties={ +# "id": entities_table.c.uuid, +# "value": entities_table.c.value, +# }, +# column_prefix="_db_column_", +# ) + +EntityMapper = mapper_registry.map_imperatively( + EntityDTO, + entities_table, + properties={ + "id": entities_table.c.uuid, + "value": entities_table.c.value, + }, + column_prefix="_db_column_", +) + + +class ORMRepository(Repository): + def __init__(self, session: Session): + self._session = session + self._query = select(EntityMapper) + + def get(self, entity_id: EntityID) -> Entity: + dto = self._session.scalars(self._query.filter_by(uuid=entity_id)).one_or_none() + if not dto: + raise NotFound(entity_id) + return Entity(dto) + + def save(self, entity: Entity) -> None: + self._session.add(entity.dto) + self._session.flush() diff --git a/examples/usrsrv/component/service.py b/examples/usrsrv/component/service.py new file mode 100644 index 0000000..9846736 --- /dev/null +++ b/examples/usrsrv/component/service.py @@ -0,0 +1,68 @@ +from abc import ABC +from abc import abstractmethod +from functools import singledispatch +from typing import Callable +from typing import List +from typing import Optional + +from .command import Command +from .command import Create +from .command import UpdateValue +from .entity import Entity +from .entity import EntityID +from .event import app_event +from .event import Created +from .event import Event + + +Listener = Callable[[Event], None] + + +class Repository(ABC): + @abstractmethod + def get(self, entity_id: EntityID) -> Entity: + raise NotImplementedError + + @abstractmethod + def save(self, entity: Entity) -> None: + raise NotImplementedError + + +class CommandHandler: + def __init__(self, repository: Repository) -> None: + self._repository = repository + self._listeners: List[Listener] = [] + super().__init__() + + def register(self, listener: Listener) -> None: + if listener not in self._listeners: + self._listeners.append(listener) + + def unregister(self, listener: Listener) -> None: + if listener in self._listeners: + self._listeners.remove(listener) + + @singledispatch + def handle(self, command: Command) -> Optional[Event]: + entity: Entity = self._repository.get(command.entity_id) + + event: Event = app_event(self._handle(command, entity), command) + for listener in self._listeners: + listener(event) + + self._repository.save(entity) + return event + + @handle.register(Create) + def create(self, command: Create) -> Event: + entity = Entity.create() + self._repository.save(entity) + return Created(command.command_id, entity.id) + + @singledispatch + def _handle(self, c: Command, u: Entity) -> Entity.Event: + raise NotImplementedError + + @_handle.register(UpdateValue) + def _(self, command: UpdateValue, entity: Entity) -> Entity.Event: + return entity.update(command.value) diff --git a/pyproject.toml b/pyproject.toml index eb2e1be..17e8c4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,8 @@ dependencies = [ "pydantic", "tenacity", "sqlapagination", - "exceptiongroup" + "exceptiongroup", + "python-ulid" ] requires-python = ">=3.7" readme = "README.rst" @@ -31,7 +32,7 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] sim = [ - "quart-schema", "hashids" + "quart-schema", "hashids", "web3", "dependency-injector", ] tests = [ "pytest", @@ -121,7 +122,7 @@ ignore_missing_imports = true [tool.pylint.messages_control] max-line-length = 100 -disable = ["missing-docstring", "protected-access"] +disable = ["invalid-name", "missing-docstring", "protected-access"] [tool.flakeheaven] baseline = ".flakeheaven_baseline" diff --git a/setup.cfg b/setup.cfg index ee9c148..0c8f6eb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,6 +27,7 @@ ignore = WPS463 allowed-domain-names = + db value val vals diff --git a/src/quart_sqlalchemy/__init__.py b/src/quart_sqlalchemy/__init__.py index b85cd50..28059e4 100644 --- a/src/quart_sqlalchemy/__init__.py +++ b/src/quart_sqlalchemy/__init__.py @@ -1,5 +1,5 @@ __version__ = "3.0.2" - +from . import util from .bind import AsyncBind from .bind import Bind from .bind import BindContext diff --git a/src/quart_sqlalchemy/bind.py b/src/quart_sqlalchemy/bind.py index 8a69eaa..dce08c2 100644 --- a/src/quart_sqlalchemy/bind.py +++ b/src/quart_sqlalchemy/bind.py @@ -1,8 +1,12 @@ from __future__ import annotations import os +import threading import typing as t +from contextlib import asynccontextmanager from contextlib import contextmanager +from contextlib import ExitStack +from weakref import WeakValueDictionary import sqlalchemy import sqlalchemy.event @@ -11,6 +15,7 @@ import sqlalchemy.ext.asyncio import sqlalchemy.orm import sqlalchemy.util +import typing_extensions as tx from . import signals from .config import BindConfig @@ -21,8 +26,16 @@ sa = sqlalchemy +SqlAMode = tx.Literal["orm", "core"] + + +class BindNotInitialized(RuntimeError): + """ "Bind not initialized yet.""" + class BindBase: + name: t.Optional[str] + url: sa.URL config: BindConfig metadata: sa.MetaData engine: sa.Engine @@ -30,66 +43,157 @@ class BindBase: def __init__( self, - config: BindConfig, - metadata: sa.MetaData, + name: t.Optional[str] = None, + url: t.Union[sa.URL, str] = "sqlite://", + config: t.Optional[BindConfig] = None, + metadata: t.Optional[sa.MetaData] = None, ): - self.config = config - self.metadata = metadata - - @property - def url(self) -> str: - if not hasattr(self, "engine"): - raise RuntimeError("Database not initialized yet. Call initialize() first.") - return str(self.engine.url) + self.name = name + self.url = sa.make_url(url) + self.config = config or BindConfig.default() + self.metadata = metadata or sa.MetaData() @property def is_async(self) -> bool: - if not hasattr(self, "engine"): - raise RuntimeError("Database not initialized yet. Call initialize() first.") - return self.engine.url.get_dialect().is_async + return self.url.get_dialect().is_async @property - def is_read_only(self): + def is_read_only(self) -> bool: return self.config.read_only + def __repr__(self) -> str: + parts = [type(self).__name__] + if self.name: + parts.append(self.name) + if self.url: + parts.append(str(self.url)) + if self.is_read_only: + parts.append("[read-only]") + + return f"<{' '.join(parts)}>" + class BindContext(BindBase): pass class Bind(BindBase): + lock: threading.Lock + _instances: WeakValueDictionary = WeakValueDictionary() + def __init__( self, - config: BindConfig, - metadata: sa.MetaData, + name: t.Optional[str] = None, + url: t.Union[sa.URL, str] = "sqlite://", + config: t.Optional[BindConfig] = None, + metadata: t.Optional[sa.MetaData] = None, initialize: bool = True, + track_instance: bool = False, ): - self.config = config - self.metadata = metadata + super().__init__(name, url, config, metadata) + self._initialization_lock = threading.Lock() + + if track_instance: + self._track_instance(name) if initialize: self.initialize() - def initialize(self): - if hasattr(self, "engine"): - self.engine.dispose() + self._session_stack = [] - self.engine = self.create_engine( - self.config.engine.dict(exclude_unset=True, exclude_none=True), - prefix="", - ) - self.Session = self.create_session_factory( - self.config.session.dict(exclude_unset=True, exclude_none=True), - ) + def initialize(self) -> tx.Self: + with self._initialization_lock: + if hasattr(self, "engine"): + self.engine.dispose() + + engine_config = self.config.engine.dict(exclude_unset=True, exclude_none=True) + engine_config.setdefault("url", self.url) + self.engine = self.create_engine(engine_config, prefix="") + + session_options = self.config.session.dict(exclude_unset=True, exclude_none=True) + self.Session = self.create_session_factory(session_options) return self + def _track_instance(self, name): + if name is None: + return + + if name in Bind._instances: + raise ValueError("Bind instance `{name}` already exists, use another name.") + else: + Bind._instances[name] = self + + @classmethod + def get_instance(cls, name: str = "default") -> Bind: + """Get the singleton instance having `name`. + + This enables some really cool patterns similar to how logging allows getting an already + initialized logger from anywhere without importing it directly. Features like this are + most useful when working in web frameworks like flask and quart that are more prone to + circular dependency issues. + + Example: + app/db.py: + from quart_sqlalchemy import Bind + + default = Bind("default", url="sqlite://") + + with default.Session() as session: + with session.begin(): + session.add(User()) + + + app/views/v1/user/login.py + from quart_sqlalchemy import Bind + + # get the same `default` bind already instantiated in app/db.py + default = Bind.get_instance("default") + + with default.Session() as session: + with session.begin(): + session.add(User()) + ... + """ + try: + return Bind._instances[name]() + except KeyError as err: + raise ValueError(f"Bind instance `{name}` does not exist.") from err + + @t.overload + @contextmanager + def transaction(self, mode: SqlAMode = "orm") -> t.Generator[sa.orm.Session, None, None]: + ... + + @t.overload + @contextmanager + def transaction(self, mode: SqlAMode = "core") -> t.Generator[sa.Connection, None, None]: + ... + + @contextmanager + def transaction( + self, mode: SqlAMode = "orm" + ) -> t.Generator[t.Union[sa.orm.Session, sa.Connection], None, None]: + if mode == "orm": + with self.Session() as session: + with session.begin(): + yield session + elif mode == "core": + with self.engine.connect() as connection: + with connection.begin(): + yield connection + else: + raise ValueError(f"Invalid transaction mode `{mode}`") + + def test_transaction(self, savepoint: bool = False) -> TestTransaction: + return TestTransaction(self, savepoint=savepoint) + @contextmanager def context( self, engine_execution_options: t.Optional[t.Dict[str, t.Any]] = None, session_execution__options: t.Optional[t.Dict[str, t.Any]] = None, ) -> t.Generator[BindContext, None, None]: - context = BindContext(self.config, self.metadata) + context = BindContext(f"{self.name}-context", self.url, self.config, self.metadata) context.engine = self.engine.execution_options(**engine_execution_options or {}) context.Session = self.create_session_factory(session_execution__options or {}) context.Session.configure(bind=context.engine) @@ -110,7 +214,7 @@ def context( ) def create_session_factory( - self, options: dict[str, t.Any] + self, options: t.Dict[str, t.Any] ) -> sa.orm.sessionmaker[sa.orm.Session]: signals.before_bind_session_factory_created.send(self, options=options) session_factory = sa.orm.sessionmaker(bind=self.engine, **options) @@ -125,9 +229,6 @@ def create_engine(self, config: t.Dict[str, t.Any], prefix: str = "") -> sa.Engi signals.after_bind_engine_created.send(self, config=config, prefix=prefix, engine=engine) return engine - def test_transaction(self, savepoint: bool = False): - return TestTransaction(self, savepoint=savepoint) - def _call_metadata(self, method: str): with self.engine.connect() as conn: with conn.begin(): @@ -142,14 +243,27 @@ def drop_all(self): def reflect(self): return self._call_metadata("reflect") - def __repr__(self) -> str: - return f"<{type(self).__name__} {self.engine.url}>" - class AsyncBind(Bind): engine: sa.ext.asyncio.AsyncEngine Session: sa.ext.asyncio.async_sessionmaker + @asynccontextmanager + async def transaction(self, mode: SqlAMode = "orm"): + if mode == "orm": + async with self.Session() as session: + async with session.begin(): + yield session + elif mode == "core": + async with self.engine.connect() as connection: + async with connection.begin(): + yield connection + else: + raise ValueError(f"Invalid transaction mode `{mode}`") + + def test_transaction(self, savepoint: bool = False): + return AsyncTestTransaction(self, savepoint=savepoint) + def create_session_factory( self, options: dict[str, t.Any] ) -> sa.ext.asyncio.async_sessionmaker[sa.ext.asyncio.AsyncSession]: @@ -182,9 +296,6 @@ def create_engine( signals.after_bind_engine_created.send(self, config=config, prefix=prefix, engine=engine) return engine - def test_transaction(self, savepoint: bool = False): - return AsyncTestTransaction(self, savepoint=savepoint) - async def _call_metadata(self, method: str): async with self.engine.connect() as conn: async with conn.begin(): diff --git a/src/quart_sqlalchemy/config.py b/src/quart_sqlalchemy/config.py index 5caa0ec..0190160 100644 --- a/src/quart_sqlalchemy/config.py +++ b/src/quart_sqlalchemy/config.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import os import types import typing as t @@ -11,6 +10,7 @@ import sqlalchemy.ext import sqlalchemy.ext.asyncio import sqlalchemy.orm +import sqlalchemy.sql.sqltypes import sqlalchemy.util import typing_extensions as tx from pydantic import BaseModel @@ -22,6 +22,8 @@ from .model import Base from .types import BoundParamStyle from .types import DMLStrategy +from .types import Empty +from .types import EmptyType from .types import SessionBind from .types import SessionBindKey from .types import SynchronizeSession @@ -63,9 +65,9 @@ class ConfigBase(BaseModel): class Config: arbitrary_types_allowed = True - @classmethod - def default(cls): - return cls() + @root_validator + def scrub_empty(cls, values): + return {key: val for key, val in values.items() if val not in [Empty, {}]} class CoreExecutionOptions(ConfigBase): @@ -73,15 +75,15 @@ class CoreExecutionOptions(ConfigBase): https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.Connection.execution_options """ - isolation_level: t.Optional[TransactionIsolationLevel] = None - compiled_cache: t.Optional[t.Dict[t.Any, Compiled]] = Field(default_factory=dict) - logging_token: t.Optional[str] = None - no_parameters: bool = False - stream_results: bool = False - max_row_buffer: int = 1000 - yield_per: t.Optional[int] = None - insertmanyvalues_page_size: int = 1000 - schema_translate_map: t.Optional[t.Dict[str, str]] = None + isolation_level: t.Union[TransactionIsolationLevel, EmptyType] = Empty + compiled_cache: t.Union[t.Dict[t.Any, Compiled], None, EmptyType] = Empty + logging_token: t.Union[str, None, EmptyType] = Empty + no_parameters: t.Union[bool, EmptyType] = Empty + stream_results: t.Union[bool, EmptyType] = Empty + max_row_buffer: t.Union[int, EmptyType] = Empty + yield_per: t.Union[int, None, EmptyType] = Empty + insertmanyvalues_page_size: t.Union[int, EmptyType] = Empty + schema_translate_map: t.Union[t.Dict[str, str], None, EmptyType] = Empty class ORMExecutionOptions(ConfigBase): @@ -89,14 +91,21 @@ class ORMExecutionOptions(ConfigBase): https://docs.sqlalchemy.org/en/20/orm/queryguide/api.html#orm-queryguide-execution-options """ - isolation_level: t.Optional[TransactionIsolationLevel] = None - stream_results: bool = False - yield_per: t.Optional[int] = None - populate_existing: bool = False - autoflush: bool = True - identity_token: t.Optional[str] = None - synchronize_session: SynchronizeSession = "auto" - dml_strategy: DMLStrategy = "auto" + isolation_level: t.Union[TransactionIsolationLevel, EmptyType] = Empty + stream_results: t.Union[bool, EmptyType] = Empty + yield_per: t.Union[int, None, EmptyType] = Empty + populate_existing: t.Union[bool, EmptyType] = Empty + autoflush: t.Union[bool, EmptyType] = Empty + identity_token: t.Union[str, None, EmptyType] = Empty + synchronize_session: t.Union[SynchronizeSession, None, EmptyType] = Empty + dml_strategy: t.Union[DMLStrategy, None, EmptyType] = Empty + + +# connect_args: +# mysql: +# connect_timeout: +# postgres: +# connect_timeout: class EngineConfig(ConfigBase): @@ -104,42 +113,55 @@ class EngineConfig(ConfigBase): https://docs.sqlalchemy.org/en/20/core/engines.html#sqlalchemy.create_engine """ - url: t.Union[sa.URL, str] = "sqlite://" - echo: bool = False - echo_pool: bool = False - connect_args: t.Dict[str, t.Any] = Field(default_factory=dict) + url: t.Union[sa.URL, str, EmptyType] = Empty + echo: t.Union[bool, EmptyType] = Empty + echo_pool: t.Union[bool, EmptyType] = Empty + connect_args: t.Union[t.Dict[str, t.Any], EmptyType] = Empty execution_options: CoreExecutionOptions = Field(default_factory=CoreExecutionOptions) - enable_from_linting: bool = True - hide_parameters: bool = False - insertmanyvalues_page_size: int = 1000 - isolation_level: t.Optional[TransactionIsolationLevel] = None - json_deserializer: t.Callable[[str], t.Any] = json.loads - json_serializer: t.Callable[[t.Any], str] = json.dumps - label_length: t.Optional[int] = None - logging_name: t.Optional[str] = None - max_identifier_length: t.Optional[int] = None - max_overflow: int = 10 - module: t.Optional[types.ModuleType] = None - paramstyle: t.Optional[BoundParamStyle] = None - pool: t.Optional[sa.Pool] = None - poolclass: t.Optional[t.Type[sa.Pool]] = None - pool_logging_name: t.Optional[str] = None - pool_pre_ping: bool = False - pool_size: int = 5 - pool_recycle: int = -1 - pool_reset_on_return: t.Optional[tx.Literal["values", "rollback"]] = None - pool_timeout: int = 40 - pool_use_lifo: bool = False - plugins: t.Sequence[str] = Field(default_factory=list) - query_cache_size: int = 500 - user_insertmanyvalues: bool = True + enable_from_linting: t.Union[bool, EmptyType] = Empty + hide_parameters: t.Union[bool, EmptyType] = Empty + insertmanyvalues_page_size: t.Union[int, EmptyType] = Empty + isolation_level: t.Union[TransactionIsolationLevel, EmptyType] = Empty + json_deserializer: t.Union[t.Callable[[str], t.Any], EmptyType] = Empty + json_serializer: t.Union[t.Callable[[t.Any], str], EmptyType] = Empty + label_length: t.Union[int, None, EmptyType] = Empty + logging_name: t.Union[str, None, EmptyType] = Empty + max_identifier_length: t.Union[int, None, EmptyType] = Empty + max_overflow: t.Union[int, EmptyType] = Empty + module: t.Union[types.ModuleType, None, EmptyType] = Empty + paramstyle: t.Union[BoundParamStyle, None, EmptyType] = Empty + pool: t.Union[sa.Pool, None, EmptyType] = Empty + poolclass: t.Union[t.Type[sa.Pool], None, EmptyType] = Empty + pool_logging_name: t.Union[str, None, EmptyType] = Empty + pool_pre_ping: t.Union[bool, EmptyType] = Empty + pool_size: t.Union[int, EmptyType] = Empty + pool_recycle: t.Union[int, EmptyType] = Empty + pool_reset_on_return: t.Union[tx.Literal["values", "rollback"], None, EmptyType] = Empty + pool_timeout: t.Union[int, EmptyType] = Empty + pool_use_lifo: t.Union[bool, EmptyType] = Empty + plugins: t.Union[t.Sequence[str], EmptyType] = Empty + query_cache_size: t.Union[int, EmptyType] = Empty + user_insertmanyvalues: t.Union[bool, EmptyType] = Empty - @classmethod - def default(cls): - return cls(url="sqlite://") + @root_validator + def scrub_execution_options(cls, values): + if "execution_options" in values: + execute_options = values["execution_options"].dict(exclude_defaults=True) + if execute_options: + values["execution_options"] = execute_options + return values + + @root_validator + def set_defaults(cls, values): + values.setdefault("url", "sqlite://") + return values @root_validator def apply_driver_defaults(cls, values): + # values["execution_options"] = values["execution_options"].dict(exclude_defaults=True) + # values = {key: val for key, val in values.items() if val not in [Empty, {}]} + # values.setdefault("url", "sqlite://") + url = sa.make_url(values["url"]) driver = url.drivername @@ -177,19 +199,31 @@ def apply_driver_defaults(cls, values): return values +class AsyncEngineConfig(EngineConfig): + @root_validator + def set_defaults(cls, values): + values.setdefault("url", "sqlite+aiosqlite://") + return values + + class SessionOptions(ConfigBase): """ https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.Session """ - autoflush: bool = True - autobegin: bool = True - expire_on_commit: bool = False - bind: t.Optional[SessionBind] = None - binds: t.Optional[t.Dict[SessionBindKey, SessionBind]] = None - twophase: bool = False - info: t.Optional[t.Dict[t.Any, t.Any]] = None - join_transaction_mode: JoinTransactionMode = "conditional_savepoint" + autoflush: t.Union[bool, EmptyType] = Empty + autobegin: t.Union[bool, EmptyType] = Empty + expire_on_commit: t.Union[bool, EmptyType] = Empty + bind: t.Union[SessionBind, None, EmptyType] = Empty + binds: t.Union[t.Dict[SessionBindKey, SessionBind], None, EmptyType] = Empty + twophase: t.Union[bool, EmptyType] = Empty + info: t.Union[t.Dict[t.Any, t.Any], None, EmptyType] = Empty + join_transaction_mode: t.Union[JoinTransactionMode, EmptyType] = Empty + + @root_validator + def set_defaults(cls, values): + values.setdefault("expire_on_commit", False) + return values class SessionmakerOptions(SessionOptions): @@ -197,7 +231,7 @@ class SessionmakerOptions(SessionOptions): https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.sessionmaker """ - class_: t.Type[sa.orm.Session] = sa.orm.Session + class_: t.Union[t.Type[sa.orm.Session], EmptyType] = Empty class AsyncSessionOptions(SessionOptions): @@ -205,7 +239,7 @@ class AsyncSessionOptions(SessionOptions): https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.AsyncSession """ - sync_session_class: t.Type[sa.orm.Session] = sa.orm.Session + sync_session_class: t.Union[t.Type[sa.orm.Session], EmptyType] = Empty class AsyncSessionmakerOptions(AsyncSessionOptions): @@ -213,13 +247,14 @@ class AsyncSessionmakerOptions(AsyncSessionOptions): https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#sqlalchemy.ext.asyncio.async_sessionmaker """ - class_: t.Type[sa.ext.asyncio.AsyncSession] = sa.ext.asyncio.AsyncSession + class_: t.Union[t.Type[sa.ext.asyncio.AsyncSession], EmptyType] = Empty class BindConfig(ConfigBase): read_only: bool = False - session: SessionmakerOptions = Field(default_factory=SessionmakerOptions.default) - engine: EngineConfig = Field(default_factory=EngineConfig.default) + session: SessionmakerOptions = Field(default_factory=SessionmakerOptions) + engine: EngineConfig = Field(default_factory=EngineConfig) + track_instance: bool = False @root_validator def validate_dialect(cls, values): @@ -227,30 +262,24 @@ def validate_dialect(cls, values): class AsyncBindConfig(BindConfig): - session: AsyncSessionmakerOptions = Field(default_factory=AsyncSessionmakerOptions.default) + session: AsyncSessionmakerOptions = Field(default_factory=AsyncSessionmakerOptions) + engine: AsyncEngineConfig = Field(default_factory=AsyncEngineConfig) @root_validator def validate_dialect(cls, values): return validate_dialect(cls, values, "async") -def default(): - dict(default=dict()) - - class SQLAlchemyConfig(ConfigBase): - class Meta: - web_config_field_map = { - "SQLALCHEMY_MODEL_CLASS": "model_class", - "SQLALCHEMY_BINDS": "binds", - } + base_class: t.Type[t.Any] = Base + binds: t.Dict[str, t.Union[BindConfig, AsyncBindConfig]] = Field(default_factory=dict) - model_class: t.Type[t.Any] = Base - binds: t.Dict[str, t.Union[AsyncBindConfig, BindConfig]] = Field( - default_factory=lambda: dict(default=BindConfig()) - ) + @root_validator + def set_default_bind(cls, values): + values.setdefault("binds", dict(default=BindConfig())) + return values @classmethod - def from_framework(cls, values: t.Dict[str, t.Any]): - key_map = cls.Meta.web_config_field_map - return cls(**{key_map.get(key, key): val for key, val in values.items()}) + def from_framework(cls, framework_config): + config = framework_config.get_namespace("SQLALCHEMY_") + return cls.parse_obj(config or {}) diff --git a/src/quart_sqlalchemy/framework/cli.py b/src/quart_sqlalchemy/framework/cli.py index 13ac0af..6f48455 100644 --- a/src/quart_sqlalchemy/framework/cli.py +++ b/src/quart_sqlalchemy/framework/cli.py @@ -1,28 +1,83 @@ import json import sys +import typing as t import urllib.parse import click -from quart import current_app from quart.cli import AppGroup +from quart.cli import pass_script_info +from quart.cli import ScriptInfo + +from quart_sqlalchemy import signals + + +if t.TYPE_CHECKING: + from quart_sqlalchemy.framework import QuartSQLAlchemy db_cli = AppGroup("db") +fixtures_cli = AppGroup("fixtures") -@db_cli.command("info", with_appcontext=True) +@db_cli.command("info") +@pass_script_info @click.option("--uri-only", is_flag=True, default=False, help="Only output the connection uri") -def db_info(uri_only=False): - db = current_app.extensions["sqlalchemy"].db - uri = urllib.parse.unquote(str(db.engine.url)) - db_info = dict(db.engine.url._asdict()) +def db_info(info: ScriptInfo, uri_only=False): + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + uri = urllib.parse.unquote(str(db.bind.url)) + info = dict(db.bind.url._asdict()) if uri_only: click.echo(uri) sys.exit(0) click.echo("Database Connection Info") - click.echo(json.dumps(db_info, indent=2)) + click.echo(json.dumps(info, indent=2)) click.echo("\n") click.echo("Connection URI") click.echo(uri) + + +@db_cli.command("create") +@pass_script_info +def create(info: ScriptInfo) -> None: + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + db.create_all() + + click.echo(f"Initialized database schema for {db}") + + +@db_cli.command("drop") +@pass_script_info +def drop(info: ScriptInfo) -> None: + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + db.drop_all() + + click.echo(f"Dropped database schema for {db}") + + +@db_cli.command("recreate") +@pass_script_info +def recreate(info: ScriptInfo) -> None: + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + db.drop_all() + db.create_all() + + click.echo(f"Recreated database schema for {db}") + + +@fixtures_cli.command("load") +@pass_script_info +def load(info: ScriptInfo) -> None: + app = info.load_app() + db: "QuartSQLAlchemy" = app.extensions["sqlalchemy"] + signals.framework_extension_load_fixtures.send(sender=db, app=app) + + click.echo(f"Loaded database fixtures for {db}") + + +db_cli.add_command(fixtures_cli) diff --git a/src/quart_sqlalchemy/framework/extension.py b/src/quart_sqlalchemy/framework/extension.py index 4d29d74..70f0737 100644 --- a/src/quart_sqlalchemy/framework/extension.py +++ b/src/quart_sqlalchemy/framework/extension.py @@ -11,10 +11,12 @@ class QuartSQLAlchemy(SQLAlchemy): def __init__( self, - config: SQLAlchemyConfig, + config: t.Optional[SQLAlchemyConfig] = None, app: t.Optional[Quart] = None, ): - super().__init__(config) + initialize = False if config is None else True + super().__init__(config, initialize=initialize) + if app is not None: self.init_app(app) @@ -24,15 +26,21 @@ def init_app(self, app: Quart) -> None: f"A {type(self).__name__} instance has already been registered on this app" ) + if self.config is None: + self.config = SQLAlchemyConfig.from_framework(app.config) + self.initialize() + signals.before_framework_extension_initialization.send(self, app=app) app.extensions["sqlalchemy"] = self @app.shell_context_processor def export_sqlalchemy_objects(): + nonlocal self + return dict( db=self, - **{m.class_.__name__: m.class_ for m in self.Model._sa_registry.mappers}, + **{m.class_.__name__: m.class_ for m in self.Base.registry.mappers}, ) app.cli.add_command(db_cli) diff --git a/src/quart_sqlalchemy/model/__init__.py b/src/quart_sqlalchemy/model/__init__.py index cf7a4ea..fce0fbf 100644 --- a/src/quart_sqlalchemy/model/__init__.py +++ b/src/quart_sqlalchemy/model/__init__.py @@ -1,7 +1,9 @@ -from .columns import CreatedTimestamp +from .columns import Created +from .columns import IntPK from .columns import Json -from .columns import PrimaryKey -from .columns import UpdatedTimestamp +from .columns import ULID +from .columns import Updated +from .columns import UUID from .custom_types import PydanticType from .custom_types import TZDateTime from .mixins import DynamicArgsMixin @@ -15,3 +17,6 @@ from .mixins import TimestampMixin from .mixins import VersionMixin from .model import Base +from .model import BaseMixins +from .model import default_metadata_naming_convention +from .model import default_type_annotation_map diff --git a/src/quart_sqlalchemy/model/columns.py b/src/quart_sqlalchemy/model/columns.py index c00d9eb..d1bfd94 100644 --- a/src/quart_sqlalchemy/model/columns.py +++ b/src/quart_sqlalchemy/model/columns.py @@ -2,6 +2,8 @@ import typing as t from datetime import datetime +from uuid import UUID +from uuid import uuid4 import sqlalchemy import sqlalchemy.event @@ -12,21 +14,24 @@ import sqlalchemy.util import sqlalchemy_utils import typing_extensions as tx +from ulid import ULID sa = sqlalchemy sau = sqlalchemy_utils +IntPK = tx.Annotated[int, sa.orm.mapped_column(primary_key=True, autoincrement=True)] +UUID = tx.Annotated[UUID, sa.orm.mapped_column(default=uuid4)] +ULID = tx.Annotated[ULID, sa.orm.mapped_column(default=ULID)] -PrimaryKey = tx.Annotated[int, sa.orm.mapped_column(sa.Identity(), primary_key=True)] -CreatedTimestamp = tx.Annotated[ +Created = tx.Annotated[ datetime, sa.orm.mapped_column( default=sa.func.now(), server_default=sa.FetchedValue(), ), ] -UpdatedTimestamp = tx.Annotated[ +Updated = tx.Annotated[ datetime, sa.orm.mapped_column( default=sa.func.now(), @@ -35,7 +40,8 @@ server_onupdate=sa.FetchedValue(), ), ] + Json = tx.Annotated[ t.Dict[t.Any, t.Any], - sa.orm.mapped_column(sau.JSONType, default_factory=dict), + sa.orm.mapped_column(sau.JSONType, default=dict), ] diff --git a/src/quart_sqlalchemy/model/mixins.py b/src/quart_sqlalchemy/model/mixins.py index 3dc3834..d08f446 100644 --- a/src/quart_sqlalchemy/model/mixins.py +++ b/src/quart_sqlalchemy/model/mixins.py @@ -10,6 +10,7 @@ import sqlalchemy.ext.asyncio import sqlalchemy.orm import sqlalchemy.util +import typing_extensions as tx from sqlalchemy.orm import Mapped from ..util import camel_to_snake_case @@ -18,14 +19,37 @@ sa = sqlalchemy +class ORMModel(tx.Protocol): + __table__: sa.Table + + +class SerializingModel(ORMModel): + __table__: sa.Table + + def to_dict( + self: ORMModel, + obj: t.Optional[t.Any] = None, + max_depth: int = 3, + _children_seen: t.Optional[set] = None, + _relations_seen: t.Optional[set] = None, + ) -> t.Dict[str, t.Any]: + ... + + class TableNameMixin: + __abstract__ = True + __table__: sa.Table + @sa.orm.declared_attr.directive - def __tablename__(cls) -> str: + def __tablename__(cls: t.Type[ORMModel]) -> str: return camel_to_snake_case(cls.__name__) class ReprMixin: - def __repr__(self) -> str: + __abstract__ = True + __table__: sa.Table + + def __repr__(self: ORMModel) -> str: state = sa.inspect(self) if state is None: return super().__repr__() @@ -41,7 +65,10 @@ def __repr__(self) -> str: class ComparableMixin: - def __eq__(self, other): + __abstract__ = True + __table__: sa.Table + + def __eq__(self: ORMModel, other: ORMModel) -> bool: if type(self).__name__ != type(other).__name__: return False @@ -55,37 +82,38 @@ def __eq__(self, other): class TotalOrderMixin: - def __lt__(self, other): - if type(self).__name__ != type(other).__name__: - return False + __abstract__ = True + __table__: sa.Table - for key, column in sa.inspect(type(self)).columns.items(): - if column.primary_key: - continue + def __lt__(self: ORMModel, other: ORMModel) -> bool: + if type(self).__name__ != type(other).__name__: + raise NotImplemented - if not (getattr(self, key) == getattr(other, key)): - return False - return True + primary_keys = sa.inspect(type(self)).primary_key + self_keys = [getattr(self, col.name) for col in primary_keys] + other_keys = [getattr(other, col.name) for col in primary_keys] + return self_keys < other_keys class SimpleDictMixin: __abstract__ = True __table__: sa.Table - def to_dict(self): + def to_dict(self) -> t.Dict[str, t.Any]: return {c.name: getattr(self, c.name) for c in self.__table__.columns} class RecursiveDictMixin: __abstract__ = True + __table__: sa.Table def to_dict( - self, + self: tx.Self, obj: t.Optional[t.Any] = None, - max_depth: int = 3, + max_depth: int = 1, _children_seen: t.Optional[set] = None, _relations_seen: t.Optional[set] = None, - ): + ) -> t.Dict[str, t.Any]: """Convert model to python dict, with recursion. Args: @@ -106,11 +134,7 @@ def to_dict( mapper = sa.inspect(obj).mapper columns = [column.key for column in mapper.columns] - get_key_value = ( - lambda c: (c, getattr(obj, c).isoformat()) - if isinstance(getattr(obj, c), datetime) - else (c, getattr(obj, c)) - ) + get_key_value = lambda c: (c, getattr(obj, c)) data = dict(map(get_key_value, columns)) if max_depth > 0: @@ -125,10 +149,12 @@ def to_dict( if relationship_children is not None: if relation.uselist: children = [] - for child in (c for c in relationship_children if c not in _children_seen): - _children_seen.add(child) + for child in ( + c for c in relationship_children if repr(c) not in _children_seen + ): + _children_seen.add(repr(child)) children.append( - self.model_to_dict( + self.to_dict( child, max_depth=max_depth - 1, _children_seen=_children_seen, @@ -137,7 +163,7 @@ def to_dict( ) data[name] = children else: - data[name] = self.model_to_dict( + data[name] = self.to_dict( relationship_children, max_depth=max_depth - 1, _children_seen=_children_seen, @@ -148,6 +174,9 @@ def to_dict( class IdentityMixin: + __abstract__ = True + __table__: sa.Table + id: Mapped[int] = sa.orm.mapped_column(sa.Identity(), primary_key=True, autoincrement=True) @@ -191,21 +220,29 @@ class User(db.Model, SoftDeleteMixin): """ __abstract__ = True + __table__: sa.Table is_active: Mapped[bool] = sa.orm.mapped_column(default=True) class TimestampMixin: __abstract__ = True + __table__: sa.Table - created_at: Mapped[datetime] = sa.orm.mapped_column(default=sa.func.now()) + created_at: Mapped[datetime] = sa.orm.mapped_column( + default=sa.func.now(), server_default=sa.FetchedValue() + ) updated_at: Mapped[datetime] = sa.orm.mapped_column( - default=sa.func.now(), onupdate=sa.func.now() + default=sa.func.now(), + onupdate=sa.func.now(), + server_default=sa.FetchedValue(), + server_onupdate=sa.FetchedValue(), ) class VersionMixin: __abstract__ = True + __table__: sa.Table version_id: Mapped[int] = sa.orm.mapped_column(nullable=False) @@ -222,6 +259,7 @@ class EagerDefaultsMixin: """ __abstract__ = True + __table__: sa.Table @sa.orm.declared_attr.directive def __mapper_args__(cls) -> dict[str, t.Any]: @@ -289,6 +327,7 @@ def accumulate_tuples_with_mapping(class_, attribute) -> t.Sequence[t.Any]: class DynamicArgsMixin: __abstract__ = True + __table__: sa.Table @sa.orm.declared_attr.directive def __mapper_args__(cls) -> t.Dict[str, t.Any]: diff --git a/src/quart_sqlalchemy/model/model.py b/src/quart_sqlalchemy/model/model.py index bbcdc21..313b6b3 100644 --- a/src/quart_sqlalchemy/model/model.py +++ b/src/quart_sqlalchemy/model/model.py @@ -1,6 +1,7 @@ from __future__ import annotations import enum +import uuid import sqlalchemy import sqlalchemy.event @@ -10,21 +11,49 @@ import sqlalchemy.orm import sqlalchemy.util import typing_extensions as tx +from sqlalchemy_utils import JSONType from .mixins import ComparableMixin from .mixins import DynamicArgsMixin +from .mixins import EagerDefaultsMixin +from .mixins import RecursiveDictMixin from .mixins import ReprMixin -from .mixins import SimpleDictMixin from .mixins import TableNameMixin +from .mixins import TotalOrderMixin sa = sqlalchemy - -class Base(DynamicArgsMixin, ReprMixin, SimpleDictMixin, ComparableMixin, TableNameMixin): +default_metadata_naming_convention = { + "ix": "ix_%(column_0_label)s", # INDEX + "uq": "uq_%(table_name)s_%(column_0_N_name)s", # UNIQUE + "ck": "ck_%(table_name)s_%(constraint_name)s", # CHECK + "fk": "fk_%(table_name)s_%(column_0_N_name)s_%(referred_table_name)s", # FOREIGN KEY + "pk": "pk_%(table_name)s", # PRIMARY KEY +} + +default_type_annotation_map = { + enum.Enum: sa.Enum(enum.Enum, native_enum=False, validate_strings=True), + tx.Literal: sa.Enum(enum.Enum, native_enum=False, validate_strings=True), + uuid.UUID: sa.Uuid, + dict: JSONType, +} + + +class BaseMixins( + DynamicArgsMixin, + EagerDefaultsMixin, + ReprMixin, + RecursiveDictMixin, + TotalOrderMixin, + ComparableMixin, + TableNameMixin, +): __abstract__ = True + __table__: sa.Table + - type_annotation_map = { - enum.Enum: sa.Enum(enum.Enum, native_enum=False, validate_strings=True), - tx.Literal: sa.Enum(enum.Enum, native_enum=False, validate_strings=True), - } +class Base(BaseMixins, sa.orm.DeclarativeBase): + __abstract__ = True + metadata = sa.MetaData(naming_convention=default_metadata_naming_convention) + type_annotation_map = default_type_annotation_map diff --git a/src/quart_sqlalchemy/session.py b/src/quart_sqlalchemy/session.py index 4833a5c..36db6c3 100644 --- a/src/quart_sqlalchemy/session.py +++ b/src/quart_sqlalchemy/session.py @@ -1,6 +1,9 @@ from __future__ import annotations import typing as t +from contextlib import contextmanager +from contextvars import ContextVar +from functools import wraps import sqlalchemy import sqlalchemy.exc @@ -18,6 +21,56 @@ sa = sqlalchemy +""" +Requirements: + * a global context var session + * a context manager that sets the session value and manages its lifetime + * a factory that will always return the current session value + * a decorator that will inject the current session value +""" + +_global_contextual_session = ContextVar("_global_contextual_session") + + +@contextmanager +def set_global_contextual_session(session, bind=None): + token = _global_contextual_session.set(session) + try: + yield + finally: + _global_contextual_session.reset(token) + + +def provide_global_contextual_session(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + session_in_args = any( + [isinstance(arg, (sa.orm.Session, sa.ext.asyncio.AsyncSession)) for arg in args] + ) + session_in_kwargs = "session" in kwargs + session_provided = session_in_args or session_in_kwargs + + if session_provided: + return func(self, *args, **kwargs) + else: + session = session_proxy() + + return func(self, session, *args, **kwargs) + + return wrapper + + +class SessionProxy: + def __call__(self) -> t.Union[sa.orm.Session, sa.ext.asyncio.AsyncSession]: + return _global_contextual_session.get() + + def __getattr__(self, name): + return getattr(self(), name) + + +session_proxy = SessionProxy() + + class Session(sa.orm.Session, t.Generic[EntityT, EntityIdT]): """A SQLAlchemy :class:`~sqlalchemy.orm.Session` class. diff --git a/src/quart_sqlalchemy/signals.py b/src/quart_sqlalchemy/signals.py index 3015b16..ecef762 100644 --- a/src/quart_sqlalchemy/signals.py +++ b/src/quart_sqlalchemy/signals.py @@ -113,3 +113,34 @@ def handle(sender: QuartSQLAlchemy, app: Quart): ... """, ) + + +framework_extension_load_fixtures = sync_signals.signal( + "quart-sqlalchemy.framework.extension.fixtures.load", + doc="""Fired to load fixtures into a fresh database. + + No default signal handlers exist for this signal as the logic is very application dependent. + This signal handler is typically triggered using the CLI: + + $ quart db fixtures load + + Example: + + @signals.framework_extension_load_fixtures.connect + def handle(sender: QuartSQLAlchemy, app: Quart): + db = sender.get_bind("default") + with db.Session() as session: + with session.begin(): + session.add_all( + [ + models.User(username="user1"), + models.User(username="user2"), + ] + ) + session.commit() + + Handler signature: + def handle(sender: QuartSQLAlchemy, app: Quart): + ... + """, +) diff --git a/src/quart_sqlalchemy/sim/app.py b/src/quart_sqlalchemy/sim/app.py index 9f12d22..c138f42 100644 --- a/src/quart_sqlalchemy/sim/app.py +++ b/src/quart_sqlalchemy/sim/app.py @@ -1,88 +1,41 @@ import logging import typing as t from copy import deepcopy -from functools import wraps -from quart import g from quart import Quart -from quart import request -from quart import Response -from quart.typing import ResponseReturnValue -from quart_schema import APIKeySecurityScheme -from quart_schema import HttpSecurityScheme from quart_schema import QuartSchema from werkzeug.utils import import_string +from .config import settings +from .container import Container + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -BLUEPRINTS = ("quart_sqlalchemy.sim.views.api",) -EXTENSIONS = ( - "quart_sqlalchemy.sim.db.db", - "quart_sqlalchemy.sim.app.schema", - "quart_sqlalchemy.sim.auth.auth", -) - -DEFAULT_CONFIG = { - "QUART_AUTH_SECURITY_SCHEMES": { - "public-api-key": APIKeySecurityScheme(in_="header", name="X-Public-API-Key"), - "session-token-bearer": HttpSecurityScheme(scheme="bearer", bearer_format="opaque"), - }, - "REGISTER_BLUEPRINTS": ["quart_sqlalchemy.sim.views.api"], -} - - -schema = QuartSchema(security_schemes=DEFAULT_CONFIG["QUART_AUTH_SECURITY_SCHEMES"]) - - -def wrap_response(func: t.Callable) -> t.Callable: - @wraps(func) - async def decorator(result: ResponseReturnValue) -> Response: - # import pdb +schema = QuartSchema(security_schemes=settings.SECURITY_SCHEMES) - # pdb.set_trace() - return await func(result) - return decorator - - -def create_app( - override_config: t.Optional[t.Dict[str, t.Any]] = None, - extensions: t.Sequence[str] = EXTENSIONS, - blueprints: t.Sequence[str] = BLUEPRINTS, -): +def create_app(override_config: t.Optional[t.Dict[str, t.Any]] = None): override_config = override_config or {} - config = deepcopy(DEFAULT_CONFIG) + config = deepcopy(settings.dict()) config.update(override_config) app = Quart(__name__) app.config.from_mapping(config) + app.config.from_prefixed_env() - for path in extensions: + for path in app.config["LOAD_EXTENSIONS"]: extension = import_string(path) extension.init_app(app) - for path in blueprints: + for path in app.config["LOAD_BLUEPRINTS"]: bp = import_string(path) app.register_blueprint(bp) - @app.before_request - def set_ethereum_network(): - g.network = request.headers.get("X-Ethereum-Network", "GOERLI").upper() - - # app.make_response = wrap_response(app.make_response) + container = Container(app=app) + app.container = container return app - - -# @app.after_request -# async def add_json_response_envelope(response: Response) -> Response: -# if response.mimetype != "application/json": -# return response -# data = await response.get_json() -# payload = dict(status="ok", message="", data=data) -# response.set_data(json.dumps(payload)) -# return response diff --git a/src/quart_sqlalchemy/sim/auth.py b/src/quart_sqlalchemy/sim/auth.py index 80a908a..dece2ba 100644 --- a/src/quart_sqlalchemy/sim/auth.py +++ b/src/quart_sqlalchemy/sim/auth.py @@ -24,7 +24,9 @@ from werkzeug.exceptions import Forbidden from .model import AuthUser +from .model import EntityType from .model import MagicClient +from .model import Provenance from .schema import BaseSchema from .util import ObjectID @@ -195,38 +197,12 @@ def validate_security( yield scheme_credentials -# def convert_model_result(func: t.Callable) -> t.Callable: -# @wraps(func) -# async def decorator(result: ResponseReturnValue) -> Response: -# status_or_headers = None -# headers = None -# if isinstance(result, tuple): -# value, status_or_headers, headers = result + (None,) * (3 - len(result)) -# else: -# value = result - -# was_model = False -# if is_dataclass(value): -# dict_or_value = asdict(value) -# was_model = True -# elif isinstance(value, BaseModel): -# dict_or_value = value.dict(by_alias=True) -# was_model = True -# else: -# dict_or_value = value - -# if was_model: -# dict_or_value = camelize(dict_or_value) - -# return await func((dict_or_value, status_or_headers, headers)) - -# return decorator - - class QuartAuth: authenticator = RequestAuthenticator() - def __init__(self, app: t.Optional[Quart] = None): + def __init__(self, app: t.Optional[Quart] = None, bind_name: str = "default"): + self.bind_name = bind_name + if app is not None: self.init_app(app) @@ -238,6 +214,8 @@ def init_app(self, app: Quart): self.security_schemes = app.config.get("QUART_AUTH_SECURITY_SCHEMES", {}) app.cli.add_command(cli) + app.extensions["auth"] = self + def auth_endpoint_security(self): db = current_app.extensions.get("sqlalchemy") view_function = current_app.view_functions[request.endpoint] @@ -245,7 +223,8 @@ def auth_endpoint_security(self): if security_schemes is None: g.authorized_credentials = {} - with db.bind.Session() as session: + bind = db.get_bind(self.bind_name) + with bind.Session() as session: results = self.authenticator.enforce(security_schemes, session) authorized_credentials = {} for result in results: @@ -253,9 +232,17 @@ def auth_endpoint_security(self): g.authorized_credentials = authorized_credentials -from .model import EntityType -from .model import MagicClient -from .model import Provenance +class RequestCredentials: + def __init__(self, request): + self.request = request + + @property + def current_user(self): + return g.authorized_credentials.get("session-token-bearer") + + @property + def current_client(self): + return g.authorized_credentials.get("public-api-key") @cli.command("add-user") @@ -282,8 +269,9 @@ def auth_endpoint_security(self): def add_user(info: ScriptInfo, email: str, user_type: str, client_id: str) -> None: app = info.load_app() db = app.extensions.get("sqlalchemy") - - with db.bind.Session() as s: + auth = app.extensions.get("auth") + bind = db.get_bind(auth.bind_name) + with bind.Session() as s: with s.begin(): user = AuthUser( email=email, @@ -310,7 +298,9 @@ def add_user(info: ScriptInfo, email: str, user_type: str, client_id: str) -> No def add_client(info: ScriptInfo, name: str) -> None: app = info.load_app() db = app.extensions.get("sqlalchemy") - with db.bind.Session() as s: + auth = app.extensions.get("auth") + bind = db.get_bind(auth.bind_name) + with bind.Session() as s: with s.begin(): client = MagicClient(app_name=name, public_api_key=secrets.token_hex(16)) s.add(client) diff --git a/src/quart_sqlalchemy/sim/builder.py b/src/quart_sqlalchemy/sim/builder.py index ccaffeb..5aa96d6 100644 --- a/src/quart_sqlalchemy/sim/builder.py +++ b/src/quart_sqlalchemy/sim/builder.py @@ -19,12 +19,12 @@ class StatementBuilder(t.Generic[EntityT]): - model: t.Type[EntityT] + model: t.Optional[t.Type[EntityT]] - def __init__(self, model: t.Type[EntityT]): + def __init__(self, model: t.Optional[t.Type[EntityT]] = None): self.model = model - def complex_select( + def select( self, selectables: t.Sequence[Selectable] = (), conditions: t.Sequence[ColumnExpr] = (), diff --git a/src/quart_sqlalchemy/sim/commands.py b/src/quart_sqlalchemy/sim/commands.py new file mode 100644 index 0000000..c2a6567 --- /dev/null +++ b/src/quart_sqlalchemy/sim/commands.py @@ -0,0 +1,52 @@ +import asyncio +import sys + +import click +import IPython +from IPython.terminal.ipapp import load_default_config +from quart import current_app + + +def attach(app): + app.shell_context_processor(app_env) + app.cli.command( + with_appcontext=True, + context_settings=dict( + ignore_unknown_options=True, + ), + )(ishell) + + +def app_env(): + app = current_app + return dict(container=app.container) + + +@click.argument("ipython_args", nargs=-1, type=click.UNPROCESSED) +def ishell(ipython_args): + import nest_asyncio + + nest_asyncio.apply() + + config = load_default_config() + + asyncio.run(current_app.startup()) + + context = current_app.make_shell_context() + + config.TerminalInteractiveShell.banner1 = """Python %s on %s +IPython: %s +App: %s [%s] +""" % ( + sys.version, + sys.platform, + IPython.__version__, + current_app.import_name, + current_app.env, + ) + + IPython.start_ipython( + argv=ipython_args, + user_ns=context, + config=config, + ) diff --git a/src/quart_sqlalchemy/sim/config.py b/src/quart_sqlalchemy/sim/config.py new file mode 100644 index 0000000..d743ee7 --- /dev/null +++ b/src/quart_sqlalchemy/sim/config.py @@ -0,0 +1,55 @@ +import typing as t + +import sqlalchemy +from pydantic import BaseSettings +from pydantic import Field +from pydantic import PyObject +from quart_schema import APIKeySecurityScheme +from quart_schema import HttpSecurityScheme +from quart_schema.openapi import SecuritySchemeBase + +from quart_sqlalchemy import AsyncBindConfig +from quart_sqlalchemy import BindConfig +from quart_sqlalchemy.sim.db import MyBase + + +sa = sqlalchemy + + +class AppSettings(BaseSettings): + class Config: + env_file = ".env", ".secrets.env" + + LOAD_BLUEPRINTS: t.List[str] = Field( + default_factory=lambda: list(("quart_sqlalchemy.sim.views.api",)) + ) + LOAD_EXTENSIONS: t.List[str] = Field( + default_factory=lambda: list( + ( + "quart_sqlalchemy.sim.db.db", + "quart_sqlalchemy.sim.app.schema", + "quart_sqlalchemy.sim.auth.auth", + ) + ) + ) + SECURITY_SCHEMES: t.Dict[str, SecuritySchemeBase] = Field( + default_factory=lambda: { + "public-api-key": APIKeySecurityScheme(in_="header", name="X-Public-API-Key"), + "session-token-bearer": HttpSecurityScheme(scheme="bearer", bearer_format="opaque"), + } + ) + + SQLALCHEMY_BINDS: t.Dict[str, t.Union[AsyncBindConfig, BindConfig]] = Field( + default_factory=lambda: dict(default=BindConfig(engine=dict(url="sqlite:///app.db"))) + ) + SQLALCHEMY_BASE_CLASS: t.Type[t.Any] = Field(default=MyBase) + + WEB3_DEFAULT_CHAIN: str = Field(default="ethereum") + WEB3_DEFAULT_NETWORK: str = Field(default="goerli") + + WEB3_PROVIDER_CLASS: PyObject = Field("web3.providers.HTTPProvider", env="WEB3_PROVIDER_CLASS") + ALCHEMY_API_KEY: str = Field(env="ALCHEMY_API_KEY") + WEB3_HTTPS_PROVIDER_URI: str = Field(env="WEB3_HTTPS_PROVIDER_URI") + + +settings = AppSettings() diff --git a/src/quart_sqlalchemy/sim/container.py b/src/quart_sqlalchemy/sim/container.py new file mode 100644 index 0000000..91bb358 --- /dev/null +++ b/src/quart_sqlalchemy/sim/container.py @@ -0,0 +1,57 @@ +import typing as t + +import sqlalchemy.orm +from dependency_injector import containers +from dependency_injector import providers +from quart import request + +from quart_sqlalchemy.session import SessionProxy +from quart_sqlalchemy.sim.auth import RequestCredentials +from quart_sqlalchemy.sim.handle import AuthUserHandler +from quart_sqlalchemy.sim.handle import AuthWalletHandler +from quart_sqlalchemy.sim.handle import MagicClientHandler +from quart_sqlalchemy.sim.logic import LogicComponent + +from .config import AppSettings +from .web3 import Web3 +from .web3 import web3_node_factory + + +sa = sqlalchemy + + +def get_db_from_app(app): + return app.extensions["sqlalchemy"] + + +class Container(containers.DeclarativeContainer): + wiring_config = containers.WiringConfiguration( + modules=[ + "quart_sqlalchemy.sim.views", + "quart_sqlalchemy.sim.logic", + "quart_sqlalchemy.sim.handle", + "quart_sqlalchemy.sim.views.auth_wallet", + "quart_sqlalchemy.sim.views.auth_user", + "quart_sqlalchemy.sim.views.magic_client", + ] + ) + config = providers.Configuration(pydantic_settings=[AppSettings()]) + app = providers.Object() + db = providers.Singleton(get_db_from_app, app=app) + + session_factory = providers.Singleton(SessionProxy) + logic = providers.Singleton(LogicComponent) + + AuthUserHandler = providers.Singleton(AuthUserHandler) + MagicClientHandler = providers.Singleton(MagicClientHandler) + AuthWalletHandler = providers.Singleton(AuthWalletHandler) + + web3_node = providers.Singleton(web3_node_factory, config=config) + web3 = providers.Singleton( + Web3, + node=web3_node, + default_network=config.WEB3_DEFAULT_NETWORK, + default_chain=config.WEB3_DEFAULT_CHAIN, + ) + current_request = providers.Factory(lambda: request) + request_credentials = providers.Singleton(RequestCredentials, request=current_request) diff --git a/src/quart_sqlalchemy/sim/db.py b/src/quart_sqlalchemy/sim/db.py index e7677f6..9634815 100644 --- a/src/quart_sqlalchemy/sim/db.py +++ b/src/quart_sqlalchemy/sim/db.py @@ -1,22 +1,41 @@ import click -import sqlalchemy as sa -from quart import g -from quart import request +import sqlalchemy +import sqlalchemy.orm from quart.cli import AppGroup from quart.cli import pass_script_info from quart.cli import ScriptInfo from sqlalchemy.types import Integer from sqlalchemy.types import TypeDecorator -from quart_sqlalchemy import Base from quart_sqlalchemy import SQLAlchemyConfig from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.model import BaseMixins from quart_sqlalchemy.sim.util import ObjectID +sa = sqlalchemy cli = AppGroup("db-schema") +def init_fixtures(session): + """Initialize the database with some fixtures.""" + from quart_sqlalchemy.sim.model import AuthUser + from quart_sqlalchemy.sim.model import MagicClient + + client = MagicClient( + app_name="My App", + public_api_key="4700aed5ee9f76f7be6398cd4b00b586", + auth_users=[ + AuthUser( + email="joe@magic.link", + current_session_token="97ee741d53e11a490460927c8a2ce4a3", + ), + ], + ) + session.add(client) + session.flush() + + class ObjectIDType(TypeDecorator): """A custom database column type that converts integer value to our ObjectID. This allows us to pass around ObjectID type in the application for easy @@ -47,24 +66,11 @@ def process_result_value(self, value, dialect): return ObjectID(value) -class MyBase(Base): +class MyBase(BaseMixins, sa.orm.DeclarativeBase): + __abstract__ = True type_annotation_map = {ObjectID: ObjectIDType} -class MyQuartSQLAlchemy(QuartSQLAlchemy): - def init_app(self, app): - super().init_app(app) - - @app.before_request - def set_bind(): - if request.method in ["GET", "OPTIONS", "HEAD", "TRACE"]: - g.bind = self.get_bind("read-replica") - else: - g.bind = self.get_bind("default") - - app.cli.add_command(cli) - - @cli.command("load") @pass_script_info def schema_load(info: ScriptInfo) -> None: @@ -76,10 +82,10 @@ def schema_load(info: ScriptInfo) -> None: # sqlite:///file:mem.db?mode=memory&cache=shared&uri=true -db = MyQuartSQLAlchemy( +db = QuartSQLAlchemy( SQLAlchemyConfig.parse_obj( { - "model_class": MyBase, + "base_class": MyBase, "binds": { "default": { "engine": {"url": "sqlite:///file:sim.db?cache=shared&uri=true"}, diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py index e079da4..47d0cf3 100644 --- a/src/quart_sqlalchemy/sim/handle.py +++ b/src/quart_sqlalchemy/sim/handle.py @@ -1,10 +1,17 @@ +from __future__ import annotations + import logging +import secrets import typing as t from datetime import datetime -from sqlalchemy.orm import Session +import sqlalchemy +from dependency_injector.wiring import Provide +from quart import Quart -from quart_sqlalchemy.sim.logic import LogicComponent as Logic +from quart_sqlalchemy.session import SessionProxy +from quart_sqlalchemy.sim import signals +from quart_sqlalchemy.sim.logic import LogicComponent from quart_sqlalchemy.sim.model import AuthUser from quart_sqlalchemy.sim.model import AuthWallet from quart_sqlalchemy.sim.model import EntityType @@ -12,11 +19,17 @@ from quart_sqlalchemy.sim.util import ObjectID +sa = sqlalchemy + logger = logging.getLogger(__name__) CLIENTS_PER_API_USER_LIMIT = 50 +def get_product_type_by_client_id(_): + return EntityType.MAGIC.value + + class MaxClientsExceeded(Exception): pass @@ -29,34 +42,21 @@ class InvalidSubstringError(AuthUserBaseError): pass -class APIKeySet(t.NamedTuple): - public_key: str - secret_key: str - - class HandlerBase: - logic: Logic - session: Session - """The base class for all handler classes. It provides handler with access - to our logic object. - """ - - def __init__(self, session: t.Optional[Session], logic: t.Optional[Logic] = None): - self.session = session - self.logic = logic or Logic() - - -def get_product_type_by_client_id(client_id): - return EntityType.MAGIC.value + logic: LogicComponent = Provide["logic"] + session_factory = SessionProxy() class MagicClientHandler(HandlerBase): + auth_user_handler: AuthUserHandler = Provide["AuthUserHandler"] + def add( self, - magic_api_user_id, - magic_team_id, app_name=None, - is_magic_connect_enabled=False, + rate_limit_tier=None, + connect_interop=None, + is_signing_modal_enabled=False, + global_audience_enabled=False, ): """Registers a new client. @@ -66,70 +66,21 @@ def add( Returns: A ``MagicClient``. """ - magic_clients_count = self.logic.MagicClientAPIUser.count_by_magic_api_user_id( - self.session, - magic_api_user_id, - ) - - if magic_clients_count >= CLIENTS_PER_API_USER_LIMIT: - raise MaxClientsExceeded() - - return self.add_client( - magic_api_user_id, - magic_team_id, - app_name, - is_magic_connect_enabled, - ) - - def get_by_public_api_key(self, public_api_key): - return self.logic.MagicClientAPIKey.get_by_public_api_key(self.session, public_api_key) - - def add_client( - self, - magic_api_user_id, - magic_team_id, - app_name=None, - is_magic_connect_enabled=False, - ): - live_api_key = APIKeySet(public_key="xxx", secret_key="yyy") - # with self.logic.begin(ro=False) as session: - return self.logic.MagicClient._add( - self.session, + return self.logic.MagicClient.add( + self.session_factory(), app_name=app_name, + rate_limit_tier=rate_limit_tier, + connect_interop=connect_interop, + is_signing_modal_enabled=is_signing_modal_enabled, + global_audience_enabled=global_audience_enabled, ) - # self.logic.MagicClientAPIKey._add( - # session, - # magic_client.id, - # live_api_key_pair=live_api_key, - # ) - # self.logic.MagicClientAPIUser._add( - # session, - # magic_api_user_id, - # magic_client.id, - # ) - - # self.logic.MagicClientAuthMethods._add( - # session, - # magic_client_id=magic_client.id, - # is_magic_connect_enabled=is_magic_connect_enabled, - # is_metamask_wallet_enabled=(True if is_magic_connect_enabled else False), - # is_wallet_connect_enabled=(True if is_magic_connect_enabled else False), - # is_coinbase_wallet_enabled=(True if is_magic_connect_enabled else False), - # ) - - # self.logic.MagicClientTeam._add(session, magic_client.id, magic_team_id) - - # return magic_client, live_api_key - - def get_magic_api_user_id_by_client_id(self, magic_client_id): - return self.logic.MagicClient.get_magic_api_user_id_by_client_id( - self.session, magic_client_id - ) + def get_by_public_api_key(self, public_api_key): + return self.logic.MagicClient.get_by_public_api_key(self.session_factory(), public_api_key) def get_by_id(self, magic_client_id): - return self.logic.MagicClient.get_by_id(self.session, magic_client_id) + return self.logic.MagicClient.get_by_id(self.session_factory(), magic_client_id) def update_app_name_by_id(self, magic_client_id, app_name): """ @@ -142,7 +93,7 @@ def update_app_name_by_id(self, magic_client_id, app_name): app_name if update was successful """ client = self.logic.MagicClient.update_by_id( - self.session, magic_client_id, app_name=app_name + self.session_factory(), magic_client_id, app_name=app_name ) if not client: @@ -151,7 +102,9 @@ def update_app_name_by_id(self, magic_client_id, app_name): return client.app_name def update_by_id(self, magic_client_id, **kwargs): - client = self.logic.MagicClient.update_by_id(self.session, magic_client_id, **kwargs) + client = self.logic.MagicClient.update_by_id( + self.session_factory(), magic_client_id, **kwargs + ) return client @@ -163,108 +116,38 @@ def set_inactive_by_id(self, magic_client_id): Returns: None """ - self.logic.MagicClient.update_by_id(self.session, magic_client_id, is_active=False) + self.logic.MagicClient.update_by_id( + self.session_factory(), magic_client_id, is_active=False + ) def get_users_for_client( self, magic_client_id, offset=None, limit=None, - include_count=False, ): """ Returns emails and signup timestamps for all auth users belonging to a given client """ - auth_user_handler = AuthUserHandler(session=self.session) - product_type = get_product_type_by_client_id(magic_client_id) - auth_users = auth_user_handler.get_by_client_id_and_user_type( - magic_client_id, - product_type, - offset=offset, - limit=limit, - ) - - # Here we blindly load from oauth users table because we only provide - # two login methods right now. If not email link then it is oauth. - # TODO(ajen#ch22926|2020-08-14): rely on the `login_method` column to - # deterministically load from correct source (oauth, webauthn, etc.). - # emails_from_oauth = OAuthUserHandler().get_emails_by_auth_user_ids( - # [auth_user.id for auth_user in auth_users if auth_user.email is None], - # ) - - data = { - "users": [ - dict(email=u.email or "none", signup_ts=int(datetime.timestamp(u.time_created))) - for u in auth_users - ] - } - - if include_count: - data["count"] = auth_user_handler.get_user_count_by_client_id_and_user_type( - magic_client_id, - product_type, - ) - - return data - - def get_users_for_client_v2( - self, - magic_client_id, - offset=None, - limit=None, - include_count=False, - ): - """ - Returns emails, signup timestamps, provenance and MFA enablement for all auth users - belonging to a given client. - """ - auth_user_handler = AuthUserHandler(session=self.session) product_type = get_product_type_by_client_id(magic_client_id) - auth_users = auth_user_handler.get_by_client_id_and_user_type( + auth_users = self.auth_user_handler.get_by_client_id_and_user_type( magic_client_id, product_type, offset=offset, limit=limit, ) - data = { + return { "users": [ dict(email=u.email or "none", signup_ts=int(datetime.timestamp(u.time_created))) for u in auth_users ] } - if include_count: - data["count"] = auth_user_handler.get_user_count_by_client_id_and_user_type( - magic_client_id, - product_type, - ) - - return data - - # def get_user_logins_for_client(self, magic_client_id, limit=None): - # logins = AuthUserLoginHandler().get_logins_by_magic_client_id( - # magic_client_id, - # limit=limit or 20, - # ) - # user_logins = get_user_logins_response(logins) - - # return sorted( - # user_logins, - # key=lambda x: x["login_ts"], - # reverse=True, - # )[:limit] - class AuthUserHandler(HandlerBase): - # auth_user_mfa_handler: AuthUserMfaHandler - - def __init__(self, *args, auth_user_mfa_handler=None, **kwargs): - super().__init__(*args, **kwargs) - # self.auth_user_mfa_handler = auth_user_mfa_handler or AuthUserMfaHandler() - def get_by_session_token(self, session_token): - return self.logic.AuthUser.get_by_session_token(self.session, session_token) + return self.logic.AuthUser.get_by_session_token(self.session_factory(), session_token) def get_or_create_by_email_and_client_id( self, @@ -272,41 +155,27 @@ def get_or_create_by_email_and_client_id( client_id, user_type=EntityType.MAGIC.value, ): - auth_user = self.logic.AuthUser.get_by_email_and_client_id( - self.session, - email, - client_id, - user_type=user_type, - ) - if not auth_user: - # try: - # email = enhanced_email_validation( - # email, - # source=MAGIC, - # # So we don't affect sign-up. - # silence_network_error=True, - # ) - # except ( - # EnhanceEmailValidationError, - # EnhanceEmailSuggestionError, - # ) as e: - # logger.warning( - # "Email Start Attempt.", - # exc_info=True, - # ) - # raise EnhancedEmailValidation(error_message=str(e)) from e - - auth_user = self.logic.AuthUser.add_by_email_and_client_id( - self.session, + session = self.session_factory() + with session.begin_nested(): + auth_user = self.logic.AuthUser.get_by_email_and_client_id( + session, + email, client_id, - email=email, user_type=user_type, + for_update=True, ) + if not auth_user: + auth_user = self.logic.AuthUser.add_by_email_and_client_id( + session, + client_id, + email=email, + user_type=user_type, + ) return auth_user def get_by_id_and_validate_exists(self, auth_user_id): """This function helps formalize how a non-existent auth user should be handled.""" - auth_user = self.logic.AuthUser.get_by_id(self.session, auth_user_id) + auth_user = self.logic.AuthUser.get_by_id(self.session_factory(), auth_user_id) if auth_user is None: raise RuntimeError('resource_name="auth_user"') return auth_user @@ -319,33 +188,33 @@ def create_verified_user( client_id, email, user_type=EntityType.FORTMATIC.value, + **kwargs, ): + # with self.session_factory() as session: # with self.logic.begin(ro=False) as session: - auid = self.logic.AuthUser._add_by_email_and_client_id( - self.session, - client_id, - email, - user_type=user_type, - ).id - auth_user = self.logic.AuthUser._update_by_id( - self.session, - auid, - date_verified=datetime.utcnow(), - ) - - return auth_user + session = self.session_factory() + with session.begin_nested(): + auid = self.logic.AuthUser.add_by_email_and_client_id( + session, + client_id, + email, + user_type=user_type, + **kwargs, + ).id - # def get_auth_user_from_public_address(self, public_address): - # wallet = self.logic.AuthWallet.get_by_public_address(public_address) + session.flush() - # if not wallet: - # return None + auth_user = self.logic.AuthUser.update_by_id( + session, + auid, + date_verified=datetime.utcnow(), + current_session_token=secrets.token_hex(16), + ) - # return self.logic.AuthUser.get_by_id(wallet.auth_user_id) + return auth_user - def get_by_id(self, auth_user_id, load_mfa_methods=False) -> AuthUser: - # join_list = ["mfa_methods"] if load_mfa_methods else None - return self.logic.AuthUser.get_by_id(self.session, auth_user_id) + def get_by_id(self, auth_user_id) -> AuthUser: + return self.logic.AuthUser.get_by_id(self.session_factory(), auth_user_id) def get_by_client_id_and_user_type( self, @@ -354,21 +223,13 @@ def get_by_client_id_and_user_type( offset=None, limit=None, ): - if user_type == EntityType.CONNECT.value: - return self.logic.AuthUser.get_by_client_id_for_connect( - self.session, - client_id, - offset=offset, - limit=limit, - ) - else: - return self.logic.AuthUser.get_by_client_id_and_user_type( - self.session, - client_id, - user_type, - offset=offset, - limit=limit, - ) + return self.logic.AuthUser.get_by_client_id_and_user_type( + self.session_factory(), + client_id, + user_type, + offset=offset, + limit=limit, + ) def get_by_client_ids_and_user_type( self, @@ -378,43 +239,32 @@ def get_by_client_ids_and_user_type( limit=None, ): return self.logic.AuthUser.get_by_client_ids_and_user_type( - self.session, + self.session_factory(), client_ids, user_type, offset=offset, limit=limit, ) - def get_user_count_by_client_id_and_user_type(self, client_id, user_type): - if user_type == EntityType.CONNECT.value: - return self.logic.AuthUser.get_user_count_by_client_id_for_connect( - self.session, - client_id, - ) - else: - return self.logic.AuthUser.get_user_count_by_client_id_and_user_type( - self.session, - client_id, - user_type, - ) - def exist_by_email_client_id_and_user_type(self, email, client_id, user_type): return self.logic.AuthUser.exist_by_email_and_client_id( - self.session, + self.session_factory(), email, client_id, user_type=user_type, ) def update_email_by_id(self, model_id, email): - return self.logic.AuthUser.update_by_id(self.session, model_id, email=email) + return self.logic.AuthUser.update_by_id(self.session_factory(), model_id, email=email) def update_phone_number_by_id(self, model_id, phone_number): - return self.logic.AuthUser.update_by_id(self.session, model_id, phone_number=phone_number) + return self.logic.AuthUser.update_by_id( + self.session_factory(), model_id, phone_number=phone_number + ) def get_by_email_client_id_and_user_type(self, email, client_id, user_type): return self.logic.AuthUser.get_by_email_and_client_id( - self.session, + self.session_factory(), email, client_id, user_type, @@ -422,28 +272,32 @@ def get_by_email_client_id_and_user_type(self, email, client_id, user_type): def mark_date_verified_by_id(self, model_id): return self.logic.AuthUser.update_by_id( - self.session, + self.session_factory(), model_id, date_verified=datetime.utcnow(), ) def set_role_by_email_magic_client_id(self, email, magic_client_id, role): + session = self.session_factory() auth_user = self.logic.AuthUser.get_by_email_and_client_id( - self.session, + session, email, magic_client_id, EntityType.MAGIC.value, + for_update=True, ) if not auth_user: auth_user = self.logic.AuthUser.add_by_email_and_client_id( - self.session, + session, magic_client_id, email, user_type=EntityType.MAGIC.value, ) - return self.logic.AuthUser.update_by_id(self.session, auth_user.id, **{role: True}) + session.flush() + + return self.logic.AuthUser.update_by_id(session, auth_user.id, **{role: True}) def search_by_client_id_and_substring( self, @@ -451,29 +305,18 @@ def search_by_client_id_and_substring( substring, offset=None, limit=10, - load_mfa_methods=False, ): - # join_list = ["mfa_methods"] if load_mfa_methods is True else None - if not isinstance(substring, str) or len(substring) < 3: raise InvalidSubstringError() auth_users = self.logic.AuthUser.get_by_client_id_with_substring_search( - self.session, + self.session_factory(), client_id, substring, offset=offset, limit=limit, - # join_list=join_list, ) - # mfa_enablements = self.auth_user_mfa_handler.is_active_batch( - # [auth_user.id for auth_user in auth_users], - # ) - # for auth_user in auth_users: - # if mfa_enablements[auth_user.id] is False: - # auth_user.mfa_methods = [] - return auth_users def is_magic_connect_enabled(self, auth_user_id=None, auth_user=None): @@ -486,13 +329,14 @@ def is_magic_connect_enabled(self, auth_user_id=None, auth_user=None): return auth_user.user_type == EntityType.CONNECT.value def mark_as_inactive(self, auth_user_id): - self.logic.AuthUser.update_by_id(self.session, auth_user_id, is_active=False) + self.logic.AuthUser.update_by_id(self.session_factory(), auth_user_id, is_active=False) def get_by_email_and_wallet_type_for_interop(self, email, wallet_type, network): """ Opinionated method for fetching AuthWallets by email address, wallet_type and network. """ return self.logic.AuthUser.get_by_email_for_interop( + self.session_factory(), email=email, wallet_type=wallet_type, network=network, @@ -505,37 +349,27 @@ def get_magic_connect_auth_user(self, auth_user_id): return auth_user -# @signals.auth_user_duplicate.connect -# def handle_duplicate_auth_users( -# current_app, -# original_auth_user_id, -# duplicate_auth_user_ids, -# auth_user_handler: t.Optional[AuthUserHandler] = None, -# ) -> None: -# logger.info(f"{len(duplicate_auth_user_ids)} dupe(s) found for {original_auth_user_id}") +@signals.auth_user_duplicate.connect +def handle_duplicate_auth_users( + app: Quart, + original_auth_user_id: ObjectID, + duplicate_auth_user_ids: t.Sequence[ObjectID], +) -> None: + logger.info(f"{len(duplicate_auth_user_ids)} dupe(s) found for {original_auth_user_id}") -# auth_user_handler = auth_user_handler or AuthUserHandler() - -# for dupe_id in duplicate_auth_user_ids: -# logger.info( -# f"marking auth_user_id {dupe_id} as inactive, in favor of original {original_auth_user_id}", -# ) -# auth_user_handler.mark_as_inactive(dupe_id) + for dupe_id in duplicate_auth_user_ids: + logger.info( + f"marking auth_user_id {dupe_id} as inactive, in favor of original {original_auth_user_id}", + ) + app.container.logic().AuthUser.update_by_id(dupe_id, is_active=False) class AuthWalletHandler(HandlerBase): - # account_linking_feature = LDFeatureFlag("is-account-linking-enabled", anonymous_user=True) - - def __init__(self, network, *args, wallet_type=WalletType.ETH, **kwargs): - super().__init__(*args, **kwargs) - self.wallet_network = network - self.wallet_type = wallet_type - def get_by_id(self, model_id): - return self.logic.AuthWallet.get_by_id(self.session, model_id) + return self.logic.AuthWallet.get_by_id(self.session_factory(), model_id) def get_by_public_address(self, public_address): - return self.logic.AuthWallet.get_by_public_address(self.session, public_address) + return self.logic.AuthWallet().get_by_public_address(self.session_factory(), public_address) def get_by_auth_user_id( self, @@ -544,25 +378,9 @@ def get_by_auth_user_id( wallet_type: t.Optional[WalletType] = None, **kwargs, ) -> t.List[AuthWallet]: - auth_user = self.logic.AuthUser.get_by_id( - self.session, - auth_user_id, - join_list=["linked_primary_auth_user"], - ) - - if auth_user.has_linked_primary_auth_user: - logger.info( - "Linked primary_auth_user found for wallet delegation", - extra=dict( - auth_user_id=auth_user.id, - delegated_to=auth_user.linked_primary_auth_user_id, - ), - ) - auth_user = auth_user.linked_primary_auth_user - return self.logic.AuthWallet.get_by_auth_user_id( - self.session, - auth_user.id, + self.session_factory(), + auth_user_id, network=network, wallet_type=wallet_type, **kwargs, @@ -574,20 +392,24 @@ def sync_auth_wallet( public_address, encrypted_private_address, wallet_management_type, + network: t.Optional[str] = None, + wallet_type: t.Optional[WalletType] = None, ): - existing_wallet = self.logic.AuthWallet.get_by_auth_user_id( - self.session, - auth_user_id, - ) - if existing_wallet: - raise RuntimeError("WalletExistsForNetworkAndWalletType") - - return self.logic.AuthWallet.add( - self.session, - public_address, - encrypted_private_address, - self.wallet_type, - self.wallet_network, - management_type=wallet_management_type, - auth_user_id=auth_user_id, - ) + session = self.session_factory() + with session.begin_nested(): + existing_wallet = self.logic.AuthWallet.get_by_auth_user_id( + session, + auth_user_id, + ) + if existing_wallet: + raise RuntimeError("WalletExistsForNetworkAndWalletType") + + return self.logic.AuthWallet.add( + session, + public_address, + encrypted_private_address, + wallet_type, + network, + management_type=wallet_management_type, + auth_user_id=auth_user_id, + ) diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py index ceb6ffd..7b969bc 100644 --- a/src/quart_sqlalchemy/sim/logic.py +++ b/src/quart_sqlalchemy/sim/logic.py @@ -1,13 +1,14 @@ +import inspect import logging +import secrets import typing as t from datetime import datetime -from functools import wraps -from sqlalchemy import or_ -from sqlalchemy.orm import contains_eager -from sqlalchemy.orm import Session -from sqlalchemy.sql.expression import func +import sqlalchemy +import sqlalchemy.orm +from quart import current_app +from quart_sqlalchemy.session import provide_global_contextual_session from quart_sqlalchemy.sim import signals from quart_sqlalchemy.sim.model import AuthUser as auth_user_model from quart_sqlalchemy.sim.model import AuthWallet as auth_wallet_model @@ -19,89 +20,59 @@ from quart_sqlalchemy.sim.repo_adapter import RepositoryLegacyAdapter from quart_sqlalchemy.sim.util import ObjectID from quart_sqlalchemy.sim.util import one +from quart_sqlalchemy.types import EntityIdT +from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import SessionT logger = logging.getLogger(__name__) +sa = sqlalchemy class LogicMeta(type): - """This is metaclass provides registry pattern where all the available - logics will be accessible through any instantiated logic object. - - Note: - Don't use this metaclass at another places. This is only intended to be - used by LogicComponent. If you want your own registry, please create - your own. - """ + _ignore = {"LegacyLogicComponent"} def __init__(cls, name, bases, cls_dict): if not hasattr(cls, "_registry"): cls._registry = {} else: - cls._registry[name] = cls() - - super().__init__(name, bases, cls_dict) - - -class LogicComponent(metaclass=LogicMeta): - """This is the base class for any logic class. This overrides the getattr - method for registry lookup. + if cls.__name__ not in cls._ignore: + model = getattr(cls, "model", None) + if model is not None: + name = model.__name__ - Example: + cls._registry[name] = cls() - ``` - class TrollGoat(LogicComponent): - - def add(x): - print(x) - ``` - - Once you have a logic object, you can directly do something like: - - ``` - logic.TrollGoat.add('troll_goat') - ``` + super().__init__(name, bases, cls_dict) - Note: - You will have to explicitly import your newly created logic in - ``fortmatic.logic.__init__.py``. When the logic is imported, it is created - the first time; hence, it is then registered. If this is unclear to you, - read https://blog.ionelmc.ro/2015/02/09/understanding-python-metaclasses/ - It has all the info you need to understand. For example, everything in - python is an object :P. Enjoy. - """ +class LogicComponent(t.Generic[EntityT, EntityIdT, SessionT], metaclass=LogicMeta): def __dir__(self): return super().__dir__() + list(self._registry.keys()) - def __getattr__(self, logic_name): - if logic_name in self._registry: - return self._registry[logic_name] + def __getattr__(self, name): + if name in self._registry: + return self._registry[name] else: - raise AttributeError( - "{object_name} has no attribute '{logic_name}'".format( - object_name=self.__class__.__name__, - logic_name=logic_name, - ), - ) + raise AttributeError(f"{type(self).__name__} has no attribute '{name}'") -class MagicClient(LogicComponent): - def __init__(self): - # self._repository = SQLAlchemyRepository[magic_client_model, ObjectID](session) +class MagicClient(LogicComponent[magic_client_model, ObjectID, sa.orm.Session]): + model = magic_client_model + identity = ObjectID + _repository = RepositoryLegacyAdapter(model, identity) - self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID) - - def _add(self, session, app_name=None): + @provide_global_contextual_session + def add(self, session, app_name=None, **kwargs): + public_api_key = secrets.token_hex(16) return self._repository.add( session, app_name=app_name, + **kwargs, + public_api_key=public_api_key, ) - # add = with_db_session(ro=False)(_add) - add = _add - - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_id( self, session, @@ -116,7 +87,7 @@ def get_by_id( join_list=join_list, ) - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_public_api_key( self, session, @@ -130,34 +101,17 @@ def get_by_public_api_key( ) ) - # @with_db_session(ro=True) - # def get_magic_api_user_id_by_client_id(self, session, magic_client_id): - # client = self._repository.get_by_id( - # session, - # magic_client_id, - # allow_inactive=False, - # join_list=None, - # ) - - # if client is None: - # return None - - # if client.magic_client_api_user is None: - # return None - - # return client.magic_client_api_user.magic_api_user_id - - # @with_db_session(ro=False) + @provide_global_contextual_session def update_by_id(self, session, model_id, **update_params): modified_row = self._repository.update(session, model_id, **update_params) session.refresh(modified_row) return modified_row - # @with_db_session(ro=True) + @provide_global_contextual_session def yield_all_clients_by_chunk(self, session, chunk_size): yield from self._repository.yield_by_chunk(session, chunk_size) - # @with_db_session(ro=True) + @provide_global_contextual_session def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): yield from self._repository.yield_by_chunk( session, @@ -183,67 +137,17 @@ class MissingPhoneNumber(Exception): pass -class AuthUser(LogicComponent): - def __init__(self): - # self._repository = SQLRepository(auth_user_model) - self._repository = RepositoryLegacyAdapter(magic_client_model, ObjectID) - - # @with_db_session(ro=True) - def get_by_session_token( - self, - session, - session_token, - ): - return one( - self._repository.get_by( - session, - filters=[auth_user_model.current_session_token == session_token], - limit=1, - ) - ) - - def _get_or_add_by_phone_number_and_client_id( - self, - session, - client_id, - phone_number, - user_type=EntityType.FORTMATIC.value, - ): - if phone_number is None: - raise MissingPhoneNumber() - - row = self._get_by_phone_number_and_client_id( - session=session, - phone_number=phone_number, - client_id=client_id, - user_type=user_type, - ) - - if row: - return row - - row = self._repository.add( - session=session, - phone_number=phone_number, - client_id=client_id, - user_type=user_type, - provenance=Provenance.SMS, - ) - logger.info( - "New auth user (id: {}) created by phone number (client_id: {})".format( - row.id, - client_id, - ), - ) - - return row +class AuthUser(LogicComponent[auth_user_model, ObjectID, sa.orm.Session]): + model = auth_user_model + identity = ObjectID + _repository = RepositoryLegacyAdapter(model, identity) - # get_or_add_by_phone_number_and_client_id = with_db_session(ro=False)( - # _get_or_add_by_phone_number_and_client_id, - # ) - get_or_add_by_phone_number_and_client_id = _get_or_add_by_phone_number_and_client_id + @provide_global_contextual_session + def add(self, session, **kwargs) -> auth_user_model: + return self._repository.add(session, **kwargs) - def _add_by_email_and_client_id( + @provide_global_contextual_session + def add_by_email_and_client_id( self, session, client_id, @@ -254,7 +158,7 @@ def _add_by_email_and_client_id( if email is None: raise MissingEmail() - if self._exist_by_email_and_client_id( + if self.exist_by_email_and_client_id( session, email, client_id, @@ -284,10 +188,8 @@ def _add_by_email_and_client_id( return row - # add_by_email_and_client_id = with_db_session(ro=False)(_add_by_email_and_client_id) - add_by_email_and_client_id = _add_by_email_and_client_id - - def _add_by_client_id( + @provide_global_contextual_session + def add_by_client_id( self, session, client_id, @@ -305,7 +207,55 @@ def _add_by_client_id( date_verified=datetime.utcnow() if is_verified else None, ) logger.info( - "New auth user (id: {}) created by (client_id: {})".format( + "New auth user (id: {}) created by (client_id: {})".format(row.id, client_id), + ) + + return row + + @provide_global_contextual_session + def get_by_session_token( + self, + session, + session_token, + ): + return one( + self._repository.get_by( + session, + filters=[auth_user_model.current_session_token == session_token], + limit=1, + ) + ) + + @provide_global_contextual_session + def get_or_add_by_phone_number_and_client_id( + self, + session, + client_id, + phone_number, + user_type=EntityType.FORTMATIC.value, + ): + if phone_number is None: + raise MissingPhoneNumber() + + row = self.get_by_phone_number_and_client_id( + session=session, + phone_number=phone_number, + client_id=client_id, + user_type=user_type, + ) + + if row: + return row + + row = self._repository.add( + session=session, + phone_number=phone_number, + client_id=client_id, + user_type=user_type, + provenance=Provenance.SMS, + ) + logger.info( + "New auth user (id: {}) created by phone number (client_id: {})".format( row.id, client_id, ), @@ -313,29 +263,28 @@ def _add_by_client_id( return row - # add_by_client_id = with_db_session(ro=False)(_add_by_client_id) - add_by_client_id = _add_by_client_id - - def _get_by_active_identifier_and_client_id( + @provide_global_contextual_session + def get_by_active_identifier_and_client_id( self, session, identifier_field, identifier_value, client_id, user_type, - ) -> auth_user_model: + for_update=False, + ) -> t.Optional[auth_user_model]: """There should only be one active identifier where all the parameters match for a given client ID. In the case of multiple results, the subsequent entries / "dupes" will be marked as inactive.""" filters = [ identifier_field == identifier_value, auth_user_model.client_id == client_id, auth_user_model.user_type == user_type, - # auth_user_model.is_active == True, # noqa: E712 ] results = self._repository.get_by( session, filters=filters, order_by_clause=auth_user_model.id.asc(), + for_update=for_update, ) if not results: @@ -345,29 +294,33 @@ def _get_by_active_identifier_and_client_id( if duplicates: signals.auth_user_duplicate.send( + current_app, original_auth_user_id=original.id, duplicate_auth_user_ids=[dupe.id for dupe in duplicates], ) return original - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_email_and_client_id( self, session, email, client_id, user_type=EntityType.FORTMATIC.value, + for_update=False, ): - return self._get_by_active_identifier_and_client_id( + return self.get_by_active_identifier_and_client_id( session=session, identifier_field=auth_user_model.email, identifier_value=email, client_id=client_id, user_type=user_type, + for_update=for_update, ) - def _get_by_phone_number_and_client_id( + @provide_global_contextual_session + def get_by_phone_number_and_client_id( self, session, phone_number, @@ -377,7 +330,7 @@ def _get_by_phone_number_and_client_id( if phone_number is None: raise MissingPhoneNumber() - return self._get_by_active_identifier_and_client_id( + return self.get_by_active_identifier_and_client_id( session=session, identifier_field=auth_user_model.phone_number, identifier_value=phone_number, @@ -385,12 +338,8 @@ def _get_by_phone_number_and_client_id( user_type=user_type, ) - # get_by_phone_number_and_client_id = with_db_session(ro=True)( - # _get_by_phone_number_and_client_id, - # ) - get_by_phone_number_and_client_id = _get_by_phone_number_and_client_id - - def _exist_by_email_and_client_id( + @provide_global_contextual_session + def exist_by_email_and_client_id( self, session, email, @@ -408,10 +357,10 @@ def _exist_by_email_and_client_id( ), ) - # exist_by_email_and_client_id = with_db_session(ro=True)(_exist_by_email_and_client_id) - exist_by_email_and_client_id = _exist_by_email_and_client_id - - def _get_by_id(self, session, model_id, join_list=None, for_update=False) -> auth_user_model: + @provide_global_contextual_session + def get_by_id( + self, session, model_id, join_list=None, for_update=False + ) -> t.Optional[auth_user_model]: return self._repository.get_by_id( session, model_id, @@ -419,10 +368,8 @@ def _get_by_id(self, session, model_id, join_list=None, for_update=False) -> aut for_update=for_update, ) - get_by_id = _get_by_id - # get_by_id = with_db_session(ro=True)(_get_by_id) - - def _update_by_id(self, session, auth_user_id, **kwargs): + @provide_global_contextual_session + def update_by_id(self, session, auth_user_id, **kwargs): modified_user = self._repository.update(session, auth_user_id, **kwargs) if modified_user is None: @@ -430,111 +377,33 @@ def _update_by_id(self, session, auth_user_id, **kwargs): return modified_user - # update_by_id = with_db_session(ro=False)(_update_by_id) - update_by_id = _update_by_id - - # @with_db_session(ro=True) + @provide_global_contextual_session def get_user_count_by_client_id_and_user_type(self, session, client_id, user_type): query = ( session.query(auth_user_model) .filter( auth_user_model.client_id == client_id, auth_user_model.user_type == user_type, - # auth_user_model.is_active == True, # noqa: E712 auth_user_model.date_verified.is_not(None), ) - .statement.with_only_columns([func.count()]) + .statement.with_only_columns(sa.func.count()) .order_by(None) ) return session.execute(query).scalar() - def _get_by_client_id_and_global_auth_user(self, session, client_id, global_auth_user_id): + @provide_global_contextual_session + def get_by_client_id_and_global_auth_user(self, session, client_id, global_auth_user_id): return self._repository.get_by( session=session, filters=[ auth_user_model.client_id == client_id, auth_user_model.user_type == EntityType.CONNECT.value, - # auth_user_model.is_active == True, # noqa: E712 auth_user_model.global_auth_user_id == global_auth_user_id, ], ) - # get_by_client_id_and_global_auth_user = with_db_session(ro=True)( - # _get_by_client_id_and_global_auth_user, - # ) - get_by_client_id_and_global_auth_user = _get_by_client_id_and_global_auth_user - - # @with_db_session(ro=True) - # def get_by_client_id_for_connect( - # self, - # session, - # client_id, - # offset=None, - # limit=None, - # ): - # # TODO(thomas|2022-07-12): Determine where/if is the right place to split - # # connect/magic logic based on user type as part of https://app.shortcut.com/magic-labs/story/53323. - # # See https://github.com/fortmatic/fortmatic/pull/6173#discussion_r919529540. - # return ( - # session.query(auth_user_model) - # .join( - # identifier_model, - # auth_user_model.global_auth_user_id == identifier_model.global_auth_user_id, - # ) - # .filter( - # auth_user_model.client_id == client_id, - # auth_user_model.user_type == EntityType.CONNECT.value, - # auth_user_model.is_active == True, # noqa: E712, - # auth_user_model.provenance == Provenance.IDENTIFIER, - # or_( - # identifier_model.identifier_type.in_( - # GlobalAuthUserIdentifierType.get_public_address_enums(), - # ), - # identifier_model.date_verified != None, - # ), - # ) - # .order_by(auth_user_model.id.desc()) - # .limit(limit) - # .offset(offset) - # ).all() - - # @with_db_session(ro=True) - # def get_user_count_by_client_id_for_connect( - # self, - # session, - # client_id, - # ): - # # TODO(thomas|2022-07-12): Determine where/if is the right place to split - # # connect/magic logic based on user type as part of https://app.shortcut.com/magic-labs/story/53323. - # # See https://github.com/fortmatic/fortmatic/pull/6173#discussion_r919529540. - # query = ( - # session.query(auth_user_model) - # .join( - # identifier_model, - # auth_user_model.global_auth_user_id == identifier_model.global_auth_user_id, - # ) - # .filter( - # auth_user_model.client_id == client_id, - # auth_user_model.user_type == EntityType.CONNECT.value, - # auth_user_model.is_active == True, # noqa: E712, - # auth_user_model.provenance == Provenance.IDENTIFIER, - # or_( - # identifier_model.identifier_type.in_( - # GlobalAuthUserIdentifierType.get_public_address_enums(), - # ), - # identifier_model.date_verified != None, - # ), - # ) - # .statement.with_only_columns( - # [func.count(distinct(auth_user_model.global_auth_user_id))], - # ) - # .order_by(None) - # ) - - # return session.execute(query).scalar() - - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_client_id_and_user_type( self, session, @@ -543,7 +412,7 @@ def get_by_client_id_and_user_type( offset=None, limit=None, ): - return self._get_by_client_ids_and_user_type( + return self.get_by_client_ids_and_user_type( session, [client_id], user_type, @@ -551,7 +420,8 @@ def get_by_client_id_and_user_type( limit=limit, ) - def _get_by_client_ids_and_user_type( + @provide_global_contextual_session + def get_by_client_ids_and_user_type( self, session, client_ids, @@ -567,7 +437,6 @@ def _get_by_client_ids_and_user_type( filters=[ auth_user_model.client_id.in_(client_ids), auth_user_model.user_type == user_type, - # auth_user_model.is_active == True, # noqa: E712, auth_user_model.date_verified != None, ], offset=offset, @@ -575,12 +444,8 @@ def _get_by_client_ids_and_user_type( order_by_clause=auth_user_model.id.desc(), ) - # get_by_client_ids_and_user_type = with_db_session(ro=True)( - # _get_by_client_ids_and_user_type, - # ) - get_by_client_ids_and_user_type = _get_by_client_ids_and_user_type - - def _get_by_client_id_with_substring_search( + @provide_global_contextual_session + def get_by_client_id_with_substring_search( self, session, client_id, @@ -594,12 +459,12 @@ def _get_by_client_id_with_substring_search( filters=[ auth_user_model.client_id == client_id, auth_user_model.user_type == EntityType.MAGIC.value, - or_( + sa.or_( auth_user_model.provenance == Provenance.SMS, auth_user_model.provenance == Provenance.LINK, auth_user_model.provenance == None, # noqa: E711 ), - or_( + sa.or_( auth_user_model.phone_number.contains(substring), auth_user_model.email.contains(substring), ), @@ -610,12 +475,7 @@ def _get_by_client_id_with_substring_search( join_list=join_list, ) - # get_by_client_id_with_substring_search = with_db_session(ro=True)( - # _get_by_client_id_with_substring_search, - # ) - get_by_client_id_with_substring_search = _get_by_client_id_with_substring_search - - # @with_db_session(ro=True) + @provide_global_contextual_session def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): yield from self._repository.yield_by_chunk( session, @@ -624,7 +484,7 @@ def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): join_list=join_list, ) - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_emails_and_client_id( self, session, @@ -639,7 +499,8 @@ def get_by_emails_and_client_id( ], ) - def _get_by_email( + @provide_global_contextual_session + def get_by_email( self, session, email: str, @@ -657,16 +518,8 @@ def _get_by_email( join_list=join_list, ) - # get_by_email = with_db_session(ro=True)(_get_by_email) - get_by_email = _get_by_email - - def _add(self, session, **kwargs) -> ObjectID: - return self._repository.add(session, **kwargs).id - - # add = with_db_session(ro=False)(_add) - add = _add - - def _get_by_email_for_interop( + @provide_global_contextual_session + def get_by_email_for_interop( self, session, email: str, @@ -686,21 +539,14 @@ def _get_by_email_for_interop( auth_user_model.wallets.and_( auth_wallet_model.wallet_type == str(wallet_type) ).and_(auth_wallet_model.network == network) - # .and_(auth_wallet_model.is_active == 1), ) - .options(contains_eager(auth_user_model.wallets)) + .options(sa.orm.contains_eager(auth_user_model.wallets)) .join( auth_user_model.magic_client.and_( magic_client_model.connect_interop == ConnectInteropStatus.ENABLED, ), ) - .options(contains_eager(auth_user_model.magic_client)) - # TODO(magic-ravi#67899|2022-12-30): Uncomment to allow account-linked users to use interop - # .options( - # joinedload( - # auth_user_model.linked_primary_auth_user, - # ).joinedload("auth_wallets"), - # ) + .options(sa.orm.contains_eager(auth_user_model.magic_client)) .filter( auth_wallet_model.wallet_type == wallet_type, auth_wallet_model.network == network, @@ -708,20 +554,14 @@ def _get_by_email_for_interop( .filter( auth_user_model.email == email, auth_user_model.user_type == EntityType.MAGIC.value, - # auth_user_model.is_active == 1, - auth_user_model.linked_primary_auth_user_id == None, # noqa: E711 ) .populate_existing() ) return query.all() - # get_by_email_for_interop = with_db_session(ro=True)( - # _get_by_email_for_interop, - # ) - get_by_email_for_interop = _get_by_email_for_interop - - def _get_linked_users(self, session, primary_auth_user_id, join_list, no_op=False): + @provide_global_contextual_session + def get_linked_users(self, session, primary_auth_user_id, join_list, no_op=False): # TODO(magic-ravi#67899|2022-12-30): Re-enable account linked users for interop. Remove no_op flag. if no_op: return [] @@ -729,17 +569,13 @@ def _get_linked_users(self, session, primary_auth_user_id, join_list, no_op=Fals return self._repository.get_by( session, filters=[ - # auth_user_model.is_active == True, # noqa: E712 auth_user_model.user_type == EntityType.MAGIC.value, auth_user_model.linked_primary_auth_user_id == primary_auth_user_id, ], join_list=join_list, ) - # get_linked_users = with_db_session(ro=True)(_get_linked_users) - get_linked_users = _get_linked_users - - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_phone_number(self, session, phone_number): return self._repository.get_by( session, @@ -749,12 +585,13 @@ def get_by_phone_number(self, session, phone_number): ) -class AuthWallet(LogicComponent): - def __init__(self): - # self._repository = SQLAlchemyRepository[magic_client_model, ObjectID](session) - self._repository = RepositoryLegacyAdapter(auth_wallet_model, ObjectID) +class AuthWallet(LogicComponent[auth_wallet_model, ObjectID, sa.orm.Session]): + model = auth_wallet_model + identity = ObjectID + _repository = RepositoryLegacyAdapter(model, identity) - def _add( + @provide_global_contextual_session + def add( self, session, public_address, @@ -776,10 +613,7 @@ def _add( return new_row - # add = with_db_session(ro=False)(_add) - add = _add - - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_id(self, session, model_id, allow_inactive=False, join_list=None): return self._repository.get_by_id( session, @@ -788,24 +622,10 @@ def get_by_id(self, session, model_id, allow_inactive=False, join_list=None): join_list=join_list, ) - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_public_address(self, session, public_address, network=None, is_active=True): - """Public address is unique in our system. In any case, we should only - find one row for the given public address. - - Args: - session: A database session object. - public_address (str): A public address. - network (str): A network name. - is_active (boolean): A boolean value to denote if the query should - retrieve active or inactive rows. - - Returns: - A formatted row, either in presenter form or raw db row. - """ filters = [ auth_wallet_model.public_address == public_address, - # auth_wallet_model.is_active == is_active, ] if network: @@ -818,7 +638,7 @@ def get_by_public_address(self, session, public_address, network=None, is_active return one(row) - # @with_db_session(ro=True) + @provide_global_contextual_session def get_by_auth_user_id( self, session, @@ -828,23 +648,8 @@ def get_by_auth_user_id( is_active=True, join_list=None, ): - """Return all the associated wallets for the given user id. - - Args: - session: A database session object. - auth_user_id (ObjectID): A auth_user id. - network (str|None): A network name. - wallet_type (str|None): a wallet type like ETH or BTC - is_active (boolean): A boolean value to denote if the query should - retrieve active or inactive rows. - join_list (None|List): Table you wish to join. - - Returns: - An empty list or a list of wallets. - """ filters = [ auth_wallet_model.auth_user_id == auth_user_id, - # auth_wallet_model.is_active == is_active, ] if network: @@ -862,8 +667,6 @@ def get_by_auth_user_id( return rows - def _update_by_id(self, session, model_id, **kwargs): + @provide_global_contextual_session + def update_by_id(self, session, model_id, **kwargs): self._repository.update(session, model_id, **kwargs) - - # update_by_id = with_db_session(ro=False)(_update_by_id) - update_by_id = _update_by_id diff --git a/src/quart_sqlalchemy/sim/main.py b/src/quart_sqlalchemy/sim/main.py index 23d1e0c..8350c68 100644 --- a/src/quart_sqlalchemy/sim/main.py +++ b/src/quart_sqlalchemy/sim/main.py @@ -1,8 +1,10 @@ +from quart_sqlalchemy.sim import commands from quart_sqlalchemy.sim.app import create_app app = create_app() +commands.attach(app) if __name__ == "__main__": app.run(port=8081) diff --git a/src/quart_sqlalchemy/sim/model.py b/src/quart_sqlalchemy/sim/model.py index b0199ab..9df6fe9 100644 --- a/src/quart_sqlalchemy/sim/model.py +++ b/src/quart_sqlalchemy/sim/model.py @@ -11,7 +11,7 @@ from quart_sqlalchemy.model import SoftDeleteMixin from quart_sqlalchemy.model import TimestampMixin -from quart_sqlalchemy.sim.db import db +from quart_sqlalchemy.sim.db import MyBase from quart_sqlalchemy.sim.util import ObjectID @@ -71,7 +71,7 @@ class WalletType(StrEnum): HEDERA = "HEDERA" -class MagicClient(db.Model, SoftDeleteMixin, TimestampMixin): +class MagicClient(MyBase, SoftDeleteMixin, TimestampMixin): __tablename__ = "magic_client" id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) @@ -90,7 +90,7 @@ class MagicClient(db.Model, SoftDeleteMixin, TimestampMixin): ) -class AuthUser(db.Model, SoftDeleteMixin, TimestampMixin): +class AuthUser(MyBase, SoftDeleteMixin, TimestampMixin): __tablename__ = "auth_user" id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) @@ -145,7 +145,7 @@ def is_magic_connect_user(self): return self.global_auth_user_id is not None and self.user_type == EntityType.CONNECT.value -class AuthWallet(db.Model, SoftDeleteMixin, TimestampMixin): +class AuthWallet(MyBase, SoftDeleteMixin, TimestampMixin): __tablename__ = "auth_wallet" id: Mapped[ObjectID] = sa.orm.mapped_column(primary_key=True, autoincrement=True) diff --git a/src/quart_sqlalchemy/sim/repo.py b/src/quart_sqlalchemy/sim/repo.py index 5844b9d..959efb2 100644 --- a/src/quart_sqlalchemy/sim/repo.py +++ b/src/quart_sqlalchemy/sim/repo.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator import typing as t from abc import ABCMeta @@ -13,53 +14,35 @@ from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT +from quart_sqlalchemy.types import Operator from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable from quart_sqlalchemy.types import SessionT -# from abc import abstractmethod - - sa = sqlalchemy -class AbstractRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): +class AbstractRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): """A repository interface.""" - # identity: t.Type[EntityIdT] - - # def __init__(self, model: t.Type[EntityT]): - # self.model = model - - @property - def model(self) -> t.Type[EntityT]: - return self.__orig_class__.__args__[0] # type: ignore + model: t.Type[EntityT] + identity: t.Type[EntityIdT] - @property - def identity(self) -> t.Type[EntityIdT]: - return self.__orig_class__.__args__[1] # type: ignore - -class AbstractBulkRepository(t.Generic[EntityT, EntityIdT], metaclass=ABCMeta): +class AbstractBulkRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): """A repository interface for bulk operations. Note: this interface circumvents ORM internals, breaking commonly expected behavior in order to gain performance benefits. Only use this class whenever absolutely necessary. """ - @property - def model(self) -> t.Type[EntityT]: - return self.__orig_class__.__args__[0] # type: ignore - - @property - def identity(self) -> t.Type[EntityIdT]: - return self.__orig_class__.__args__[1] # type: ignore + model: t.Type[EntityT] + identity: t.Type[EntityIdT] class SQLAlchemyRepository( - AbstractRepository[EntityT, EntityIdT], - t.Generic[EntityT, EntityIdT], + AbstractRepository[EntityT, EntityIdT, SessionT], t.Generic[EntityT, EntityIdT, SessionT] ): """A repository that uses SQLAlchemy to persist data. @@ -83,20 +66,21 @@ class SQLAlchemyRepository( """ - # session: sa.orm.Session builder: StatementBuilder - def __init__(self, **kwargs): - super().__init__(**kwargs) - # self.session = session - self.builder = StatementBuilder(None) + def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT]): + self.model = model + self.identity = identity + self.builder = StatementBuilder(self.model) def insert(self, session: sa.orm.Session, values: t.Dict[str, t.Any]) -> EntityT: """Insert a new model into the database.""" new = self.model(**values) + session.add(new) session.flush() session.refresh(new) + return new def update( @@ -109,8 +93,10 @@ def update( for field, value in values.items(): if getattr(obj, field) != value: setattr(obj, field, value) + session.flush() session.refresh(obj) + return obj def merge( @@ -123,9 +109,11 @@ def merge( """Merge model in session/db having id_ with values.""" session.get(self.model, id_) values.update(id=id_) + merged = session.merge(self.model(**values)) session.flush() session.refresh(merged, with_for_update=for_update) # type: ignore + return merged def get( @@ -151,19 +139,61 @@ def get( to satisfy the expected interface's return type: Optional[EntityT], one_or_none is called on the result before returning. """ + selectables = (self.model,) + execution_options = execution_options or {} if include_inactive: execution_options.setdefault("include_inactive", include_inactive) - statement = sa.select(self.model).where(self.model.id == id_).limit(1) # type: ignore + statement = self.builder.select( + selectables, # type: ignore + conditions=[self.model.id == id_], + options=options, + limit=1, + for_update=for_update, + ) - for option in options: - statement = statement.options(option) + return session.scalars(statement, execution_options=execution_options).one_or_none() - if for_update: - statement = statement.with_for_update() + def get_by_field( + self, + session: sa.orm.Session, + field: t.Union[ColumnExpr, str], + value: t.Any, + op: Operator = operator.eq, + order_by: t.Sequence[t.Union[ColumnExpr, str]] = (), + options: t.Sequence[ORMOption] = (), + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + offset: t.Optional[int] = None, + limit: t.Optional[int] = None, + distinct: bool = False, + for_update: bool = False, + include_inactive: bool = False, + ) -> sa.ScalarResult[EntityT]: + """Select models where field is equal to value.""" + selectables = (self.model,) - return session.scalars(statement, execution_options=execution_options).one_or_none() + execution_options = execution_options or {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + + if isinstance(field, str): + field = getattr(self.model, field) + + conditions = [t.cast(ColumnExpr, op(field, value))] + + statement = self.builder.select( + selectables, # type: ignore + conditions=conditions, + order_by=order_by, + options=options, + offset=offset, + limit=limit, + distinct=distinct, + for_update=for_update, + ) + + return session.scalars(statement, execution_options=execution_options) def select( self, @@ -193,7 +223,7 @@ def select( if yield_by_chunk: execution_options.setdefault("yield_per", yield_by_chunk) - statement = self.builder.complex_select( + statement = self.builder.select( selectables, conditions=conditions, group_by=group_by, @@ -214,10 +244,7 @@ def select( def delete( self, session: sa.orm.Session, id_: EntityIdT, include_inactive: bool = False ) -> None: - # if self.has_soft_delete: - # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") - - entity = self.get(id_, include_inactive=include_inactive) + entity = self.get(session, id_, include_inactive=include_inactive) if not entity: raise RuntimeError(f"Entity with id {id_} not found.") @@ -225,16 +252,10 @@ def delete( session.flush() def deactivate(self, session: sa.orm.Session, id_: EntityIdT) -> EntityT: - # if not self.has_soft_delete: - # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") - - return self.update(id_, dict(is_active=False)) + return self.update(session, id_, dict(is_active=False)) def reactivate(self, session: sa.orm.Session, id_: EntityIdT) -> EntityT: - # if not self.has_soft_delete: - # raise RuntimeError("Can't delete entity that uses soft-delete semantics.") - - return self.update(id_, dict(is_active=False)) + return self.update(session, id_, dict(is_active=False)) def exists( self, @@ -248,31 +269,36 @@ def exists( Note: This performs better than simply trying to select an object since there is no overhead in sending the selected object and deserializing it. """ - selectable = sa.sql.literal(True) + selectables = (sa.sql.literal(True),) execution_options = {} if include_inactive: execution_options.setdefault("include_inactive", include_inactive) - statement = sa.select(selectable).where(*conditions) # type: ignore - - if for_update: - statement = statement.with_for_update() + statement = self.builder.select( + selectables, + conditions=conditions, + limit=1, + for_update=for_update, + ) result = session.execute(statement, execution_options=execution_options).scalar() return bool(result) -class SQLAlchemyBulkRepository(AbstractBulkRepository, t.Generic[SessionT, EntityT, EntityIdT]): - def __init__(self, **kwargs: t.Any): +class SQLAlchemyBulkRepository( + AbstractBulkRepository[EntityT, EntityIdT, SessionT], t.Generic[EntityT, EntityIdT, SessionT] +): + builder: StatementBuilder + + def __init__(self, **kwargs): super().__init__(**kwargs) - self.builder = StatementBuilder(self.model) - # session = session + self.builder = StatementBuilder(None) def bulk_insert( self, - session: sa.orm.Session, + session: SessionT, values: t.Sequence[t.Dict[str, t.Any]] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: @@ -281,7 +307,7 @@ def bulk_insert( def bulk_update( self, - session: sa.orm.Session, + session: SessionT, conditions: t.Sequence[ColumnExpr] = (), values: t.Optional[t.Dict[str, t.Any]] = None, execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -291,7 +317,7 @@ def bulk_update( def bulk_delete( self, - session: sa.orm.Session, + session: SessionT, conditions: t.Sequence[ColumnExpr] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, ) -> sa.Result[t.Any]: diff --git a/src/quart_sqlalchemy/sim/repo_adapter.py b/src/quart_sqlalchemy/sim/repo_adapter.py index 3d6b48e..8e0a43d 100644 --- a/src/quart_sqlalchemy/sim/repo_adapter.py +++ b/src/quart_sqlalchemy/sim/repo_adapter.py @@ -3,19 +3,14 @@ import sqlalchemy import sqlalchemy.orm from pydantic import BaseModel -from sqlalchemy import ScalarResult -from sqlalchemy.orm import selectinload -from sqlalchemy.orm import Session -from sqlalchemy.sql.expression import func -from sqlalchemy.sql.expression import label -from quart_sqlalchemy.model import Base from quart_sqlalchemy.sim.repo import SQLAlchemyRepository from quart_sqlalchemy.types import ColumnExpr from quart_sqlalchemy.types import EntityIdT from quart_sqlalchemy.types import EntityT from quart_sqlalchemy.types import ORMOption from quart_sqlalchemy.types import Selectable +from quart_sqlalchemy.types import SessionT sa = sqlalchemy @@ -39,21 +34,18 @@ class BaseUpdateSchema(BaseModelSchema): UpdateSchemaT = t.TypeVar("UpdateSchemaT", bound=BaseUpdateSchema) -class RepositoryLegacyAdapter(t.Generic[EntityT, EntityIdT]): - def __init__( - self, - model: t.Type[EntityT], - identity: t.Type[EntityIdT], - # session: Session, - ): +class RepositoryLegacyAdapter(t.Generic[EntityT, EntityIdT, SessionT]): + model: t.Type[EntityT] + identity: t.Type[EntityIdT] + + def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT]): self.model = model - self._identity = identity - # self._session = session - self.repo = SQLAlchemyRepository[model, identity]() + self.identity = identity + self._adapted = SQLAlchemyRepository(model, identity) def get_by( self, - session: t.Optional[Session] = None, + session: SessionT, filters=None, allow_inactive=False, join_list=None, @@ -72,10 +64,10 @@ def get_by( else: order_by_clause = () - return self.repo.select( + return self._adapted.select( session, conditions=filters, - options=[selectinload(getattr(self.model, attr)) for attr in join_list], + options=[sa.orm.selectinload(getattr(self.model, attr)) for attr in join_list], for_update=for_update, order_by=order_by_clause, offset=offset, @@ -85,7 +77,7 @@ def get_by( def get_by_id( self, - session=None, + session: SessionT, model_id=None, allow_inactive=False, join_list=None, @@ -94,30 +86,35 @@ def get_by_id( if model_id is None: raise ValueError("model_id is required") join_list = join_list or () - return self.repo.get( + return self._adapted.get( session, id_=model_id, - options=[selectinload(getattr(self.model, attr)) for attr in join_list], + options=[sa.orm.selectinload(getattr(self.model, attr)) for attr in join_list], for_update=for_update, include_inactive=allow_inactive, ) def one( - self, session=None, filters=None, join_list=None, for_update=False, include_inactive=False + self, + session: SessionT, + filters=None, + join_list=None, + for_update=False, + include_inactive=False, ) -> EntityT: filters = filters or () join_list = join_list or () - return self.repo.select( + return self._adapted.select( session, conditions=filters, - options=[selectinload(getattr(self.model, attr)) for attr in join_list], + options=[sa.orm.selectinload(getattr(self.model, attr)) for attr in join_list], for_update=for_update, include_inactive=include_inactive, ).one() def count_by( self, - session=None, + session: SessionT, filters=None, group_by=None, distinct_column=None, @@ -128,36 +125,36 @@ def count_by( group_by = group_by or () if distinct_column: - selectables = [label("count", func.count(func.distinct(distinct_column)))] + selectables = [sa.label("count", sa.func.count(sa.func.distinct(distinct_column)))] else: - selectables = [label("count", func.count(self.model.id))] + selectables = [sa.label("count", sa.func.count(self.model.id))] for group in group_by: selectables.append(group.expression) - result = self.repo.select(session, selectables, conditions=filters, group_by=group_by) + result = self._adapted.select(session, selectables, conditions=filters, group_by=group_by) return result.all() - def add(self, session=None, **kwargs) -> EntityT: - return self.repo.insert(session, kwargs) + def add(self, session: SessionT, **kwargs) -> EntityT: + return self._adapted.insert(session, kwargs) - def update(self, session=None, model_id=None, **kwargs) -> EntityT: - return self.repo.update(session, id_=model_id, values=kwargs) + def update(self, session: SessionT, model_id=None, **kwargs) -> EntityT: + return self._adapted.update(session, id_=model_id, values=kwargs) - def update_by(self, session=None, filters=None, **kwargs) -> EntityT: + def update_by(self, session: SessionT, filters=None, **kwargs) -> EntityT: if not filters: raise ValueError("Full table scans are prohibited. Please provide filters") - row = self.repo.select(session, conditions=filters, limit=2).one() - return self.repo.update(session, id_=row.id, values=kwargs) + row = self._adapted.select(session, conditions=filters, limit=2).one() + return self._adapted.update(session, id_=row.id, values=kwargs) - def delete_by_id(self, session=None, model_id=None) -> None: - self.repo.delete(session, id_=model_id, include_inactive=True) + def delete_by_id(self, session: SessionT, model_id=None) -> None: + self._adapted.delete(session, id_=model_id, include_inactive=True) - def delete_one_by(self, session=None, filters=None, optional=False) -> None: + def delete_one_by(self, session: SessionT, filters=None, optional=False) -> None: filters = filters or () - result = self.repo.select(session, conditions=filters, limit=1) + result = self._adapted.select(session, conditions=filters, limit=1) if optional: row = result.one_or_none() @@ -166,25 +163,25 @@ def delete_one_by(self, session=None, filters=None, optional=False) -> None: else: row = result.one() - self.repo.delete(session, id_=row.id) + self._adapted.delete(session, id_=row.id) - def exist(self, session=None, filters=None, allow_inactive=False) -> bool: + def exist(self, session: SessionT, filters=None, allow_inactive=False) -> bool: filters = filters or () - return self.repo.exists( + return self._adapted.exists( session, conditions=filters, include_inactive=allow_inactive, ) def yield_by_chunk( - self, session=None, chunk_size=100, join_list=None, filters=None, allow_inactive=False + self, session: SessionT, chunk_size=100, join_list=None, filters=None, allow_inactive=False ): filters = filters or () join_list = join_list or () - results = self.repo.select( + results = self._adapted.select( session, conditions=filters, - options=[selectinload(getattr(self.model, attr)) for attr in join_list], + options=[sa.orm.selectinload(getattr(self.model, attr)) for attr in join_list], include_inactive=allow_inactive, yield_by_chunk=chunk_size, ) @@ -192,10 +189,10 @@ def yield_by_chunk( yield result -class PydanticScalarResult(ScalarResult): +class PydanticScalarResult(sa.ScalarResult, t.Generic[ModelSchemaT]): pydantic_schema: t.Type[ModelSchemaT] - def __init__(self, scalar_result, pydantic_schema: t.Type[ModelSchemaT]): + def __init__(self, scalar_result: t.Any, pydantic_schema: t.Type[ModelSchemaT]): for attribute in scalar_result.__slots__: setattr(self, attribute, getattr(scalar_result, attribute)) self.pydantic_schema = pydantic_schema @@ -231,14 +228,15 @@ def partitions(self, *args, **kwargs): yield self._translate_many(partition) -class PydanticRepository(SQLAlchemyRepository, t.Generic[EntityT, EntityIdT, ModelSchemaT]): - @property - def schema(self) -> t.Type[ModelSchemaT]: - return self.__orig_class__.__args__[2] # type: ignore +class PydanticRepository( + SQLAlchemyRepository[EntityT, EntityIdT, SessionT], + t.Generic[EntityT, EntityIdT, SessionT, ModelSchemaT, CreateSchemaT, UpdateSchemaT], +): + model_schema: t.Type[ModelSchemaT] def insert( self, - session: sa.orm.Session, + session: SessionT, create_schema: CreateSchemaT, sqla_model=False, ): @@ -247,11 +245,11 @@ def insert( if sqla_model: return result - return self.schema.from_orm(result) + return self.model_schema.from_orm(result) def update( self, - session: sa.orm.Session, + session: SessionT, id_: EntityIdT, update_schema: UpdateSchemaT, sqla_model=False, @@ -267,13 +265,14 @@ def update( session.add(existing) session.flush() session.refresh(existing) + if sqla_model: return existing - return self.schema.from_orm(existing) + return self.model_schema.from_orm(existing) def get( self, - session: sa.orm.Session, + session: SessionT, id_: EntityIdT, options: t.Sequence[ORMOption] = (), execution_options: t.Optional[t.Dict[str, t.Any]] = None, @@ -294,11 +293,11 @@ def get( if sqla_model: return row - return self.schema.from_orm(row) + return self.model_schema.from_orm(row) def select( self, - session: sa.orm.Session, + session: SessionT, selectables: t.Sequence[Selectable] = (), conditions: t.Sequence[ColumnExpr] = (), group_by: t.Sequence[t.Union[ColumnExpr, str]] = (), @@ -328,6 +327,7 @@ def select( include_inactive, yield_by_chunk, ) + if sqla_model: return result - return PydanticScalarResult(result, self.schema) + return PydanticScalarResult[self.model_schema](result, self.model_schema) diff --git a/src/quart_sqlalchemy/sim/schema.py b/src/quart_sqlalchemy/sim/schema.py index e0304b7..e679d61 100644 --- a/src/quart_sqlalchemy/sim/schema.py +++ b/src/quart_sqlalchemy/sim/schema.py @@ -1,20 +1,34 @@ import typing as t from datetime import datetime +from enum import Enum from pydantic import BaseModel from pydantic import Field from pydantic import validator +from pydantic.generics import GenericModel +from .model import ConnectInteropStatus +from .model import EntityType +from .model import Provenance +from .model import WalletManagementType +from .model import WalletType from .util import ObjectID +DataT = t.TypeVar("DataT") + +json_encoders = { + ObjectID: lambda v: v.encode(), + datetime: lambda dt: int(dt.timestamp()), + Enum: lambda e: e.value, +} + + class BaseSchema(BaseModel): class Config: arbitrary_types_allowed = True - json_encoders = { - ObjectID: lambda v: v.encode(), - datetime: lambda dt: int(dt.timestamp()), - } + json_encoders = dict(json_encoders) + orm_mode = True @classmethod def _get_value(cls, v: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: @@ -27,13 +41,17 @@ def _get_value(cls, v: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: return super()._get_value(v, *args, **kwargs) -class ResponseWrapper(BaseSchema): +class ResponseWrapper(GenericModel, t.Generic[DataT]): """Generic response wrapper""" + class Config: + arbitrary_types_allowed = True + json_encoders = dict(json_encoders) + error_code: str = "" status: str = "" message: str = "" - data: t.Any = Field(default_factory=dict) + data: DataT = Field(default_factory=dict) @validator("status") def set_status_by_error_code(cls, v, values): @@ -41,3 +59,41 @@ def set_status_by_error_code(cls, v, values): if error_code: return "failed" return "ok" + + +class MagicClientSchema(BaseSchema): + id: ObjectID + app_name: str + rate_limit_tier: t.Optional[str] = None + connect_interop: t.Optional[ConnectInteropStatus] = None + is_signing_modal_enabled: bool + global_audience_enabled: bool + public_api_key: str + secret_api_key: str + + +class AuthUserSchema(BaseSchema): + id: ObjectID + client_id: ObjectID + email: str + phone_number: t.Optional[str] = None + user_type: EntityType = EntityType.MAGIC + provenance: t.Optional[Provenance] = None + date_verified: t.Optional[datetime] = None + is_admin: bool = False + linked_primary_auth_user_id: t.Optional[ObjectID] = None + global_auth_user_id: t.Optional[ObjectID] = None + delegated_user_id: t.Optional[str] = None + delegated_identity_pool_id: t.Optional[str] = None + current_session_token: t.Optional[str] = None + + +class AuthWalletSchema(BaseSchema): + id: ObjectID + auth_user_id: ObjectID + wallet_type: WalletType + management_type: WalletManagementType + public_address: str + encrypted_private_address: str + network: str + is_exported: bool diff --git a/src/quart_sqlalchemy/sim/signals.py b/src/quart_sqlalchemy/sim/signals.py index 1981cea..510be83 100644 --- a/src/quart_sqlalchemy/sim/signals.py +++ b/src/quart_sqlalchemy/sim/signals.py @@ -13,6 +13,7 @@ def handler( current_app: Quart, original_auth_user_id: ObjectID, duplicate_auth_user_ids: List[ObjectID], + session: sa.orm.Session, ) -> None: ... """, diff --git a/src/quart_sqlalchemy/sim/testing.py b/src/quart_sqlalchemy/sim/testing.py index 347b422..fd84547 100644 --- a/src/quart_sqlalchemy/sim/testing.py +++ b/src/quart_sqlalchemy/sim/testing.py @@ -1,13 +1,16 @@ from contextlib import contextmanager from quart import g +from quart import Quart from quart import signals +from quart_sqlalchemy import Bind + @contextmanager -def user_set(app, user): +def global_bind(app: Quart, bind: Bind): def handler(sender, **kwargs): - g.user = user + g.bind = bind with signals.appcontext_pushed.connected_to(handler, app): yield diff --git a/src/quart_sqlalchemy/sim/views/auth_user.py b/src/quart_sqlalchemy/sim/views/auth_user.py index ee456e6..6c8664b 100644 --- a/src/quart_sqlalchemy/sim/views/auth_user.py +++ b/src/quart_sqlalchemy/sim/views/auth_user.py @@ -1,7 +1,88 @@ import logging +import typing as t -from quart import Blueprint +import sqlalchemy.orm +from dependency_injector.wiring import inject +from dependency_injector.wiring import Provide +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.session import set_global_contextual_session + +from ..auth import authorized_request +from ..auth import RequestCredentials +from ..container import Container +from ..handle import AuthUserHandler +from ..model import EntityType +from ..schema import AuthUserSchema +from ..schema import BaseSchema +from ..schema import ResponseWrapper +from .util import APIBlueprint + + +sa = sqlalchemy logger = logging.getLogger(__name__) -api = Blueprint("auth_user", __name__, url_prefix="auth_user") +api = APIBlueprint("auth_user", __name__, url_prefix="/auth_user") + + +class CreateAuthUserRequest(BaseSchema): + email: str + + +class CreateAuthUserResponse(BaseSchema): + auth_user: AuthUserSchema + + +@api.get( + "/", + authorizer=authorized_request( + [ + { + "public-api-key": [], + "session-token-bearer": [], + } + ], + ), +) +@inject +def get_auth_user( + auth_user_handler: AuthUserHandler = Provide["AuthUserHandler"], + db: QuartSQLAlchemy = Provide["db"], + credentials: RequestCredentials = Provide["request_credentials"], +) -> ResponseWrapper[AuthUserSchema]: + with db.bind.Session() as session: + with set_global_contextual_session(session): + auth_user = auth_user_handler.get_by_session_token(credentials.current_user.value) + + return ResponseWrapper[AuthUserSchema](data=AuthUserSchema.from_orm(auth_user)) + + +@api.post( + "/", + authorizer=authorized_request( + [ + { + "public-api-key": [], + } + ], + ), +) +@inject +def create_auth_user( + data: CreateAuthUserRequest, + auth_user_handler: AuthUserHandler = Provide["AuthUserHandler"], + db: QuartSQLAlchemy = Provide["db"], + credentials: RequestCredentials = Provide[Container.request_credentials], +) -> ResponseWrapper[CreateAuthUserResponse]: + with db.bind.Session() as session: + with session.begin(): + with set_global_contextual_session(session): + client = auth_user_handler.create_verified_user( + email=data.email, + client_id=credentials.current_client.subject.id, + user_type=EntityType.MAGIC.value, + ) + + return ResponseWrapper[CreateAuthUserResponse]( + data=dict(auth_user=AuthUserSchema.from_orm(client)) # type: ignore + ) diff --git a/src/quart_sqlalchemy/sim/views/auth_wallet.py b/src/quart_sqlalchemy/sim/views/auth_wallet.py index 9c6f052..79c3be0 100644 --- a/src/quart_sqlalchemy/sim/views/auth_wallet.py +++ b/src/quart_sqlalchemy/sim/views/auth_wallet.py @@ -1,23 +1,27 @@ import logging import typing as t +from dependency_injector.wiring import inject +from dependency_injector.wiring import Provide from quart import g -from quart.utils import run_sync -from quart_sqlalchemy.retry import retry_context -from quart_sqlalchemy.retry import RetryError +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.session import set_global_contextual_session from ..auth import authorized_request +from ..auth import RequestCredentials from ..handle import AuthWalletHandler from ..model import WalletManagementType from ..model import WalletType from ..schema import BaseSchema +from ..schema import ResponseWrapper from ..util import ObjectID +from ..web3 import Web3 from .util import APIBlueprint logger = logging.getLogger(__name__) -api = APIBlueprint("auth_wallet", __name__, url_prefix="auth_wallet") +api = APIBlueprint("auth_wallet", __name__, url_prefix="/auth_wallet") @api.before_request @@ -56,28 +60,32 @@ class WalletSyncResponse(BaseSchema): ], ), ) -async def sync(data: WalletSyncRequest) -> WalletSyncResponse: - user_credential = g.authorized_credentials.get("session-token-bearer") +@inject +def sync( + data: WalletSyncRequest, + auth_wallet_handler: AuthWalletHandler = Provide["AuthWalletHandler"], + web3: Web3 = Provide["web3"], + db: QuartSQLAlchemy = Provide["db"], + credentials: RequestCredentials = Provide["request_credentials"], +) -> ResponseWrapper[WalletSyncResponse]: + with db.bind.Session() as session: + with session.begin(): + with set_global_contextual_session(session): + wallet = auth_wallet_handler.sync_auth_wallet( + credentials.current_user.subject.id, + data.public_address, + data.encrypted_private_address, + WalletManagementType.DELEGATED.value, + network=web3.network, + wallet_type=data.wallet_type, + ) - try: - for attempt in retry_context: - with attempt: - with g.bind.Session() as session: - wallet = AuthWalletHandler(g.network, session).sync_auth_wallet( - user_credential.subject.id, - data.public_address, - data.encrypted_private_address, - WalletManagementType.DELEGATED.value, - ) - except RetryError: - pass - except RuntimeError: - raise RuntimeError("Unsupported wallet type or network") - - return WalletSyncResponse( - wallet_id=wallet.id, - auth_user_id=wallet.auth_user_id, - wallet_type=wallet.wallet_type, - public_address=wallet.public_address, - encrypted_private_address=wallet.encrypted_private_address, + return ResponseWrapper[WalletSyncResponse]( + data=dict( + wallet_id=wallet.id, + auth_user_id=wallet.auth_user_id, + wallet_type=wallet.wallet_type, + public_address=wallet.public_address, + encrypted_private_address=wallet.encrypted_private_address, + ) # type: ignore ) diff --git a/src/quart_sqlalchemy/sim/views/magic_client.py b/src/quart_sqlalchemy/sim/views/magic_client.py index a1d28e1..09d889b 100644 --- a/src/quart_sqlalchemy/sim/views/magic_client.py +++ b/src/quart_sqlalchemy/sim/views/magic_client.py @@ -1,7 +1,90 @@ import logging +import typing as t -from quart import Blueprint +from dependency_injector.wiring import inject +from dependency_injector.wiring import Provide + +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.session import set_global_contextual_session + +from ..auth import authorized_request +from ..auth import RequestCredentials +from ..handle import MagicClientHandler +from ..model import ConnectInteropStatus +from ..schema import BaseSchema +from ..schema import MagicClientSchema +from ..schema import ResponseWrapper +from .util import APIBlueprint logger = logging.getLogger(__name__) -api = Blueprint("magic_client", __name__, url_prefix="magic_client") +api = APIBlueprint("magic_client", __name__, url_prefix="/magic_client") + + +class CreateMagicClientRequest(BaseSchema): + app_name: str + rate_limit_tier: t.Optional[str] = None + connect_interop: t.Optional[ConnectInteropStatus] = None + is_signing_modal_enabled: bool = False + global_audience_enabled: bool = False + + +class CreateMagicClientResponse(BaseSchema): + magic_client: MagicClientSchema + + +@api.post( + "/", + authorizer=authorized_request( + [ + { + "public-api-key": [], + } + ], + ), +) +@inject +def create_magic_client( + data: CreateMagicClientRequest, + magic_client_handler: MagicClientHandler = Provide["MagicClientHandler"], + db: QuartSQLAlchemy = Provide["db"], +) -> ResponseWrapper[CreateMagicClientResponse]: + with db.bind.Session() as session: + with session.begin(): + with set_global_contextual_session(session): + client = magic_client_handler.add( + app_name=data.app_name, + rate_limit_tier=data.rate_limit_tier, + connect_interop=data.connect_interop, + is_signing_modal_enabled=data.is_signing_modal_enabled, + global_audience_enabled=data.global_audience_enabled, + ) + + return ResponseWrapper[CreateMagicClientResponse]( + data=dict(magic_client=MagicClientSchema.from_orm(client)) # type: ignore + ) + + +@api.get( + "/", + authorizer=authorized_request( + [ + { + "public-api-key": [], + } + ], + ), +) +@inject +def get_magic_client( + magic_client_handler: MagicClientHandler = Provide["MagicClientHandler"], + credentials: RequestCredentials = Provide["request_credentials"], + db: QuartSQLAlchemy = Provide["db"], +) -> ResponseWrapper[MagicClientSchema]: + with db.bind.Session() as session: + with set_global_contextual_session(session): + client = magic_client_handler.get_by_public_api_key(credentials.current_client.value) + + return ResponseWrapper[MagicClientSchema]( + data=MagicClientSchema.from_orm(client) # type: ignore + ) diff --git a/src/quart_sqlalchemy/sim/web3.py b/src/quart_sqlalchemy/sim/web3.py new file mode 100644 index 0000000..e1a088b --- /dev/null +++ b/src/quart_sqlalchemy/sim/web3.py @@ -0,0 +1,153 @@ +import typing as t +from decimal import Decimal + +import typing_extensions as tx +import web3.providers +from ens import ENS +from eth_typing import AnyAddress +from eth_typing import ChecksumAddress +from eth_typing import HexStr +from eth_typing import Primitives +from eth_typing.abi import TypeStr +from quart import request +from quart.ctx import has_request_context +from web3.eth import Eth +from web3.geth import Geth +from web3.main import BaseWeb3 +from web3.module import Module +from web3.net import Net +from web3.providers import BaseProvider +from web3.types import Wei + + +""" +generate new key address pairing + +```zsh +python -c "from web3 import Web3; w3 = Web3(); acc = w3.eth.account.create(); print(f'private key={w3.to_hex(acc.key)}, account={acc.address}')" +``` +""" + + +class Web3Node(tx.Protocol): + eth: Eth + net: Net + geth: Geth + provider: BaseProvider + ens: ENS + + def is_connected(self) -> bool: + ... + + @staticmethod + def to_bytes( + primitive: t.Optional[Primitives] = None, + hexstr: t.Optional[HexStr] = None, + text: t.Optional[str] = None, + ) -> bytes: + ... + + @staticmethod + def to_int( + primitive: t.Optional[Primitives] = None, + hexstr: t.Optional[HexStr] = None, + text: t.Optional[str] = None, + ) -> int: + ... + + @staticmethod + def to_hex( + primitive: t.Optional[Primitives] = None, + hexstr: t.Optional[HexStr] = None, + text: t.Optional[str] = None, + ) -> HexStr: + ... + + @staticmethod + def to_text( + primitive: t.Optional[Primitives] = None, + hexstr: t.Optional[HexStr] = None, + text: t.Optional[str] = None, + ) -> str: + ... + + @staticmethod + def to_json(obj: t.Dict[t.Any, t.Any]) -> str: + ... + + @staticmethod + def to_wei(number: t.Union[int, float, str, Decimal], unit: str) -> Wei: + ... + + @staticmethod + def from_wei(number: int, unit: str) -> t.Union[int, Decimal]: + ... + + @staticmethod + def is_address(value: t.Any) -> bool: + ... + + @staticmethod + def is_checksum_address(value: t.Any) -> bool: + ... + + @staticmethod + def to_checksum_address(value: t.Union[AnyAddress, str, bytes]) -> ChecksumAddress: + ... + + @property + def api(self) -> str: + ... + + @staticmethod + def keccak( + primitive: t.Optional[Primitives] = None, + text: t.Optional[str] = None, + hexstr: t.Optional[HexStr] = None, + ) -> bytes: + ... + + @classmethod + def normalize_values( + cls, _w3: BaseWeb3, abi_types: t.List[TypeStr], values: t.List[t.Any] + ) -> t.List[t.Any]: + ... + + @classmethod + def solidity_keccak(cls, abi_types: t.List[TypeStr], values: t.List[t.Any]) -> bytes: + ... + + def attach_modules( + self, modules: t.Optional[t.Dict[str, t.Union[t.Type[Module], t.Sequence[t.Any]]]] + ) -> None: + ... + + def is_encodable(self, _type: TypeStr, value: t.Any) -> bool: + ... + + +def web3_node_factory(config): + if config["WEB3_PROVIDER_CLASS"] is web3.providers.HTTPProvider: + provider = config["WEB3_PROVIDER_CLASS"](config["WEB3_HTTPS_PROVIDER_URI"]) + return web3.Web3(provider) + + +class Web3: + node: Web3Node + + def __init__(self, node: Web3Node, default_network: str, default_chain: str): + self.node = node + self.default_network = default_network + self.default_chain = default_chain + + @property + def chain(self) -> str: + if has_request_context(): + return request.headers.get("x-web3-chain", self.default_chain).upper() + return self.default_chain + + @property + def network(self) -> str: + if has_request_context(): + return request.headers.get("x-web3-network", self.default_network).upper() + return self.default_network diff --git a/src/quart_sqlalchemy/sqla.py b/src/quart_sqlalchemy/sqla.py index 56b8b60..b437625 100644 --- a/src/quart_sqlalchemy/sqla.py +++ b/src/quart_sqlalchemy/sqla.py @@ -10,55 +10,280 @@ import sqlalchemy.orm import sqlalchemy.util -from .bind import AsyncBind -from .bind import Bind -from .config import AsyncBindConfig -from .config import SQLAlchemyConfig +from quart_sqlalchemy.bind import AsyncBind +from quart_sqlalchemy.bind import Bind +from quart_sqlalchemy.config import AsyncBindConfig +from quart_sqlalchemy.config import SQLAlchemyConfig sa = sqlalchemy class SQLAlchemy: + """ + This manager class keeps things very simple by using a few configuration conventions. + + Configuration has been simplified down to base_class and binds. + + * Everything related to ORM mapping, DeclarativeBase, registry, MetaData, etc should be + configured by passing the a custom DeclarativeBase class as the base_class configuration + parameter. + + * Everything related to engine/session configuration should be configured by passing a + dictionary mapping string names to BindConfigs as the `binds` configuration parameter. + + BindConfig can be as simple as a dictionary containing a url key like so: + + bind_config = { + "default": {"url": "sqlite://"} + } + + But most use cases will require more than just a connection url, and divide core/engine + configuration from orm/session configuration which looks more like this: + + bind_config = { + "default": { + "engine": { + "url": "sqlite://" + }, + "session": { + "expire_on_commit": False + } + } + } + + Everything under `engine` will then be passed to `sqlalchemy.create_engine_from_config` and + everything under `session` will be passed to `sqlalchemy.orm.sessionmaker`. + + engine = sa.create_engine_from_config(bind_config.engine) + Session = sa.orm.sessionmaker(bind=engine, **bind_config.session) + + Config Examples: + + Simple URL: + db = SQLAlchemy( + SQLAlchemyConfig( + binds=dict( + default=dict( + url="sqlite://" + ) + ) + ) + ) + + Shortcut for the above: + db = SQLAlchemy(SQLAlchemyConfig()) + + More complex configuration for engine and session both: + db = SQLAlchemy( + SQLAlchemyConfig( + binds=dict( + default=dict( + engine=dict( + url="sqlite://" + ), + session=dict( + expire_on_commit=False + ) + ) + ) + ) + ) + + Once instantiated, operations targetting all of the binds, aka metadata, like + `metadata.create_all` should be called from this class. Operations specific to a bind + should be called from that bind. This class has a few ways to get a specific bind. + + * To get a Bind, you can call `.get_bind(name)` on this class. The default bind can be + referenced at `.bind`. + + * To define an ORM model using the Base class attached to this class, simply inherit + from `.Base` + + db = SQLAlchemy(SQLAlchemyConfig()) + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + * You can also decouple Base from SQLAlchemy with some dependency inversion: + from quart_sqlalchemy.model.mixins import DynamicArgsMixin, ReprMixin, TableNameMixin + + class Base(DynamicArgsMixin, ReprMixin, TableNameMixin): + __abstract__ = True + + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db = SQLAlchemy(SQLAlchemyConfig(bind_class=Base)) + + db.create_all() + + + Declarative Mapping using registry based decorator: + + db = SQLAlchemy(SQLAlchemyConfig()) + + @db.registry.mapped + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Declarative with Imperative Table (Hybrid Declarative): + + class User(db.Base): + __table__ = sa.Table( + "user", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.String, default="Joe"), + ) + + + Declarative using reflection to automatically build the table object: + + class User(db.Base): + __table__ = sa.Table( + "user", + db.metadata, + autoload_with=db.bind.engine, + ) + + + Declarative Dataclass Mapping: + + from quart_sqlalchemy.model import Base as Base_ + + class Base(sa.orm.MappedAsDataclass, Base_): + pass + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + class User(db.Base): + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Declarative Dataclass Mapping (using decorator): + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + @db.registry.mapped_as_dataclass + class User: + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + + Alternate Dataclass Provider Pattern: + + from pydantic.dataclasses import dataclass + from quart_sqlalchemy.model import Base as Base_ + + class Base(sa.orm.MappedAsDataclass, Base_, dataclass_callable=dataclass): + pass + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + class User(db.Base): + __tablename__ = "user" + + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = sa.orm.mapped_column(default="Joe") + + db.create_all() + + Imperative style Mapping + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + user_table = sa.Table( + "user", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.String, default="Joe"), + ) + + post_table = sa.Table( + "post", + db.metadata, + sa.Column("id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("title", sa.String, default="My post"), + sa.Column("user_id", sa.ForeignKey("user.id"), nullable=False), + ) + + class User: + pass + + class Post: + pass + + db.registry.map_imperatively( + User, + user_table, + properties={ + "posts": sa.orm.relationship(Post, back_populates="user") + } + ) + db.registry.map_imperatively( + Post, + post_table, + properties={ + "user": sa.orm.relationship(User, back_populates="posts", uselist=False) + } + ) + """ + config: SQLAlchemyConfig binds: t.Dict[str, t.Union[Bind, AsyncBind]] - Model: t.Type[sa.orm.DeclarativeBase] + Base: t.Type[sa.orm.DeclarativeBase] - def __init__(self, config: SQLAlchemyConfig, initialize: bool = True): + def __init__( + self, + config: t.Optional[SQLAlchemyConfig] = None, + initialize: bool = True, + ): self.config = config if initialize: self.initialize() - def initialize(self): - if issubclass(self.config.model_class, sa.orm.DeclarativeBase): - Model = self.config.model_class # type: ignore - else: - - class Model(self.config.model_class, sa.orm.DeclarativeBase): - pass + def initialize(self, config: t.Optional[SQLAlchemyConfig] = None): + if config is not None: + self.config = config + if self.config is None: + self.config = SQLAlchemyConfig.default() - type_annotation_map = {} - for base_class in Model.__mro__[::-1]: - if base_class is Model: - continue - base_map = getattr(base_class, "type_annotation_map", {}).copy() - type_annotation_map.update(base_map) + if issubclass(self.config.base_class, sa.orm.DeclarativeBase): + Base = self.config.base_class # type: ignore + else: + Base = type("Base", (self.config.base_class, sa.orm.DeclarativeBase), {}) - Model.registry.type_annotation_map.update(type_annotation_map) - self.Model = Model + self.Base = Base - self.binds = {} - for name, bind_config in self.config.binds.items(): - is_async = isinstance(bind_config, AsyncBindConfig) - if is_async: - self.binds[name] = AsyncBind(bind_config, self.metadata) - else: - self.binds[name] = Bind(bind_config, self.metadata) + if not hasattr(self, "binds"): + self.binds = {} + for name, bind_config in self.config.binds.items(): + is_async = isinstance(bind_config, AsyncBindConfig) + factory = AsyncBind if is_async else Bind + self.binds[name] = factory(name, bind_config.engine.url, bind_config, self.metadata) - @classmethod - def default(cls): - return cls(SQLAlchemyConfig()) + def get_bind(self, bind: str = "default"): + return self.binds[bind] @property def bind(self) -> Bind: @@ -66,10 +291,11 @@ def bind(self) -> Bind: @property def metadata(self) -> sa.MetaData: - return self.Model.metadata + return self.Base.metadata - def get_bind(self, bind: str = "default"): - return self.binds[bind] + @property + def registry(self) -> sa.orm.registry: + return self.Base.registry def create_all(self, bind: str = "default"): return self.binds[bind].create_all() diff --git a/src/quart_sqlalchemy/testing/fake.py b/src/quart_sqlalchemy/testing/fake.py new file mode 100644 index 0000000..e69de29 diff --git a/src/quart_sqlalchemy/testing/signals.py b/src/quart_sqlalchemy/testing/signals.py new file mode 100644 index 0000000..d9c1cb9 --- /dev/null +++ b/src/quart_sqlalchemy/testing/signals.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import sqlalchemy +import sqlalchemy.orm +from blinker import Namespace +from quart.signals import AsyncNamespace + + +sa = sqlalchemy + +sync_signals = Namespace() +async_signals = AsyncNamespace() + + +load_test_fixtures = sync_signals.signal( + "quart-sqlalchemy.testing.fixtures.load.sync", + doc="""Fired to load test fixtures into a freshly instantiated test database. + + No default signal handlers exist for this signal as the logic is very application dependent. + + Example: + + @signals.framework_extension_load_fixtures.connect + def handle(sender: QuartSQLAlchemy, app: Quart): + bind = sender.get_bind("default") + with bind.Session() as session: + with session.begin(): + session.add_all( + [ + models.User(username="user1"), + models.User(username="user2"), + ] + ) + session.commit() + + Handler signature: + def handle(sender: QuartSQLAlchemy, app: Quart): + ... + """, +) diff --git a/src/quart_sqlalchemy/testing/transaction.py b/src/quart_sqlalchemy/testing/transaction.py index 507a963..d7ff873 100644 --- a/src/quart_sqlalchemy/testing/transaction.py +++ b/src/quart_sqlalchemy/testing/transaction.py @@ -23,13 +23,13 @@ def __init__(self, bind: "Bind", savepoint: bool = False): self.savepoint = savepoint self.bind = bind - def Session(self, **options): + def Session(self, **options: t.Any) -> sa.orm.Session: options.update(bind=self.connection) if self.savepoint: options.update(join_transaction_mode="create_savepoint") return self.bind.Session(**options) - def begin(self): + def open(self) -> None: self.connection = self.bind.engine.connect() self.trans = self.connection.begin() @@ -59,7 +59,7 @@ def close(self, exc: t.Optional[Exception] = None) -> None: ) def __enter__(self): - self.begin() + self.open() return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -82,7 +82,7 @@ class AsyncTestTransaction(TestTransaction): def __init__(self, bind: "AsyncBind", savepoint: bool = False): super().__init__(bind, savepoint=savepoint) - async def begin(self): + async def open(self): self.connection = await self.bind.engine.connect() self.trans = await self.connection.begin() @@ -112,7 +112,7 @@ async def close(self, exc: t.Optional[Exception] = None) -> None: ) async def __aenter__(self): - await self.begin() + await self.open() return self async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/src/quart_sqlalchemy/types.py b/src/quart_sqlalchemy/types.py index 2c5b96f..73e6c6f 100644 --- a/src/quart_sqlalchemy/types.py +++ b/src/quart_sqlalchemy/types.py @@ -7,6 +7,7 @@ import sqlalchemy.orm import sqlalchemy.sql import typing_extensions as tx +from sqlalchemy import SQLColumnExpression from sqlalchemy.orm.interfaces import ORMOption as _ORMOption from sqlalchemy.sql._typing import _ColumnExpressionArgument from sqlalchemy.sql._typing import _ColumnsClauseArgument @@ -15,11 +16,18 @@ sa = sqlalchemy + +class Empty: + pass + + +EmptyType = t.Type[Empty] + SessionT = t.TypeVar("SessionT", bound=sa.orm.Session) EntityT = t.TypeVar("EntityT", bound=sa.orm.DeclarativeBase) EntityIdT = t.TypeVar("EntityIdT", bound=t.Any) -ColumnExpr = _ColumnExpressionArgument +ColumnExpr = SQLColumnExpression Selectable = _ColumnsClauseArgument DMLTable = _DMLTableArgument ORMOption = _ORMOption @@ -40,3 +48,8 @@ SABind = t.Union[ sa.Engine, sa.Connection, sa.ext.asyncio.AsyncEngine, sa.ext.asyncio.AsyncConnection ] + + +class Operator(tx.Protocol): + def __call__(self, __a: object, __b: object) -> object: + ... diff --git a/tests/base.py b/tests/base.py index 1ff8756..8afa3fb 100644 --- a/tests/base.py +++ b/tests/base.py @@ -2,6 +2,7 @@ import random import typing as t +from copy import deepcopy from datetime import datetime import pytest @@ -10,8 +11,12 @@ from quart import Quart from sqlalchemy.orm import Mapped -from quart_sqlalchemy import SQLAlchemyConfig +from quart_sqlalchemy import Base from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.model.mixins import ComparableMixin +from quart_sqlalchemy.model.mixins import DynamicArgsMixin +from quart_sqlalchemy.model.mixins import EagerDefaultsMixin +from quart_sqlalchemy.model.mixins import TableNameMixin from . import constants @@ -21,27 +26,156 @@ class SimpleTestBase: @pytest.fixture(scope="class") - def app(self, request): + def Base(self) -> t.Type[t.Any]: + return Base + + @pytest.fixture(scope="class") + def app_config(self, Base): + config = deepcopy(constants.simple_config) + config.update(SQLALCHEMY_BASE_CLASS=Base) + return config + + @pytest.fixture(scope="class") + def app(self, app_config, request): app = Quart(request.module.__name__) - app.config.from_mapping({"TESTING": True}) + app.config.from_mapping(app_config) + app.config["TESTING"] = True return app @pytest.fixture(scope="class") - def sqlalchemy_config(self): - return SQLAlchemyConfig.parse_obj(constants.simple_mapping_config) + def db(self, app: Quart) -> t.Generator[QuartSQLAlchemy, None, None]: + db = QuartSQLAlchemy(app=app) - @pytest.fixture(scope="class") - def db(self, sqlalchemy_config, app: Quart) -> QuartSQLAlchemy: - return QuartSQLAlchemy(sqlalchemy_config, app) - # yield db - # db.drop_all() + yield db @pytest.fixture(scope="class") - def models(self, app: Quart, db: QuartSQLAlchemy) -> t.Mapping[str, t.Type[t.Any]]: - class Todo(db.Model): + def models( + self, app: Quart, db: QuartSQLAlchemy + ) -> t.Generator[t.Mapping[str, t.Type[t.Any]], None, None]: + class Todo(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") + user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) + + user: Mapped[t.Optional["User"]] = sa.orm.relationship( + back_populates="todos", lazy="noload", uselist=False + ) + + class User(db.Base): id: Mapped[int] = sa.orm.mapped_column( - sa.Identity(), primary_key=True, autoincrement=True + primary_key=True, + autoincrement=True, ) + name: Mapped[str] = sa.orm.mapped_column(default="default") + + created_at: Mapped[datetime] = sa.orm.mapped_column( + default=sa.func.now(), + server_default=sa.FetchedValue(), + ) + + time_updated: Mapped[datetime] = sa.orm.mapped_column( + default=sa.func.now(), + onupdate=sa.func.now(), + server_default=sa.FetchedValue(), + server_onupdate=sa.FetchedValue(), + ) + + todos: Mapped[t.List[Todo]] = sa.orm.relationship(lazy="noload", back_populates="user") + + yield dict(todo=Todo, user=User) + # We need to cleanup these objects that like to retain state beyond the fixture scope lifecycle + Base.registry.dispose() + Base.metadata.clear() + + @pytest.fixture(scope="class", autouse=True) + def create_drop_all(self, db: QuartSQLAlchemy, models): + db.create_all() + yield + db.drop_all() + + @pytest.fixture(scope="class") + def Todo(self, models: t.Mapping[str, t.Type[t.Any]]) -> t.Type[sa.orm.DeclarativeBase]: + return models["todo"] + + @pytest.fixture(scope="class") + def User(self, models: t.Mapping[str, t.Type[t.Any]]) -> t.Type[sa.orm.DeclarativeBase]: + return models["user"] + + @pytest.fixture(scope="class") + def _user_fixtures(self, User: t.Type[t.Any], Todo: t.Type[t.Any]): + users = [] + for i in range(5): + user = User(name=f"user: {i}") + for j in range(random.randint(0, 6)): + todo = Todo(title=f"todo: {j}") + user.todos.append(todo) + users.append(user) + return users + + @pytest.fixture(scope="class") + def _add_fixtures( + self, db: QuartSQLAlchemy, User: t.Type[t.Any], Todo: t.Type[t.Any], _user_fixtures + ) -> None: + with db.bind.Session() as s: + with s.begin(): + s.add_all(_user_fixtures) + + @pytest.fixture(scope="class", autouse=True) + def db_fixtures( + self, db: QuartSQLAlchemy, User: t.Type[t.Any], Todo: t.Type[t.Any], _add_fixtures + ) -> t.Dict[t.Type[t.Any], t.Sequence[t.Any]]: + with db.bind.Session() as s: + users = s.scalars(sa.select(User).options(sa.orm.selectinload(User.todos))).all() + todos = s.scalars(sa.select(Todo).options(sa.orm.selectinload(Todo.user))).all() + + return {User: users, Todo: todos} + + +class MixinTestBase: + default_mixins = ( + DynamicArgsMixin, + EagerDefaultsMixin, + TableNameMixin, + ComparableMixin, + ) + extra_mixins = () + + @pytest.fixture(scope="class") + def Base(self) -> t.Type[t.Any]: + return type( + "Base", + tuple(self.extra_mixins + self.default_mixins), + {"__abstract__": True}, + ) + + @pytest.fixture(scope="class") + def app_config(self, Base): + config = deepcopy(constants.simple_config) + config.update(SQLALCHEMY_BASE_CLASS=Base) + return config + + @pytest.fixture(scope="class") + def app(self, app_config, request): + app = Quart(request.module.__name__) + app.config.from_mapping(app_config) + app.config["TESTING"] = True + return app + + @pytest.fixture(scope="class") + def db(self, app: Quart) -> t.Generator[QuartSQLAlchemy, None, None]: + db = QuartSQLAlchemy(app=app) + + yield db + + # It's very important to clear the class _instances dict before recreating binds with the same name. + # Bind._instances.clear() + + @pytest.fixture(scope="class") + def models( + self, app: Quart, db: QuartSQLAlchemy + ) -> t.Generator[t.Mapping[str, t.Type[t.Any]], None, None]: + class Todo(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) title: Mapped[str] = sa.orm.mapped_column(default="default") user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) @@ -49,9 +183,8 @@ class Todo(db.Model): back_populates="todos", lazy="noload", uselist=False ) - class User(db.Model): + class User(db.Base): id: Mapped[int] = sa.orm.mapped_column( - sa.Identity(), primary_key=True, autoincrement=True, ) @@ -71,7 +204,10 @@ class User(db.Model): todos: Mapped[t.List[Todo]] = sa.orm.relationship(lazy="noload", back_populates="user") - return dict(todo=Todo, user=User) + yield dict(todo=Todo, user=User) + # We need to cleanup these objects that like to retain state beyond the fixture scope lifecycle + Base.registry.dispose() + Base.metadata.clear() @pytest.fixture(scope="class", autouse=True) def create_drop_all(self, db: QuartSQLAlchemy, models): @@ -119,8 +255,10 @@ def db_fixtures( class AsyncTestBase(SimpleTestBase): @pytest.fixture(scope="class") - def sqlalchemy_config(self): - return SQLAlchemyConfig.parse_obj(constants.async_mapping_config) + def app_config(self, Base): + config = deepcopy(constants.async_config) + config.update(SQLALCHEMY_BASE_CLASS=Base) + return config @pytest.fixture(scope="class", autouse=True) async def create_drop_all(self, db: QuartSQLAlchemy, models) -> t.AsyncGenerator[None, None]: @@ -151,5 +289,31 @@ async def db_fixtures( class ComplexTestBase(SimpleTestBase): @pytest.fixture(scope="class") - def sqlalchemy_config(self): - return SQLAlchemyConfig.parse_obj(constants.complex_mapping_config) + def app_config(self, Base): + config = deepcopy(constants.complex_config) + config.update(SQLALCHEMY_BASE_CLASS=Base) + return config + + +# class CustomMixinTestBase(SimpleTestBase): +# default_mixins = ( +# DynamicArgsMixin, +# EagerDefaultsMixin, +# TableNameMixin, +# ) +# additional_mixins = () + +# @pytest.fixture(scope="class") +# def Base(self) -> t.Type[t.Any]: +# return type( +# "Base", +# tuple(self.additional_mixins + self.default_mixins), +# {"__abstract__": True}, +# ) + +# @pytest.fixture(scope="class") +# def app_config(self, Base): +# config = deepcopy(constants.simple_config) +# config["SQLALCHEMY_BASE_CLASS"] = Base +# config["TESTING"] = True +# return config diff --git a/tests/conftest.py b/tests/conftest.py index eefd7d7..e69de29 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,49 +0,0 @@ -from __future__ import annotations - -import typing as t - -import pytest -import sqlalchemy -import sqlalchemy.orm -from quart import Quart -from sqlalchemy.orm import Mapped - -from quart_sqlalchemy import SQLAlchemyConfig -from quart_sqlalchemy.framework import QuartSQLAlchemy - -from . import constants - - -sa = sqlalchemy - - -@pytest.fixture(scope="session") -def app(request: pytest.FixtureRequest) -> Quart: - app = Quart(request.module.__name__) - app.config.from_mapping({"TESTING": True}) - return app - - -@pytest.fixture(scope="session") -def sqlalchemy_config(): - return SQLAlchemyConfig.parse_obj(constants.simple_mapping_config) - - -@pytest.fixture(scope="session") -def db(sqlalchemy_config, app: Quart) -> QuartSQLAlchemy: - return QuartSQLAlchemy(sqlalchemy_config, app) - - -@pytest.fixture(name="Todo", scope="session") -def _todo_fixture( - app: Quart, db: QuartSQLAlchemy -) -> t.Generator[t.Type[sa.orm.DeclarativeBase], None, None]: - class Todo(db.Model): - id: Mapped[int] = sa.orm.mapped_column(sa.Identity(), primary_key=True, autoincrement=True) - title: Mapped[str] = sa.orm.mapped_column(default="default") - - db.create_all() - - yield Todo - - db.drop_all() diff --git a/tests/constants.py b/tests/constants.py index 4fb8240..6a2277c 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,19 +1,18 @@ from quart_sqlalchemy import Base -simple_mapping_config = { - "model_class": Base, - "binds": { +simple_config = { + "SQLALCHEMY_BINDS": { "default": { "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, "session": {"expire_on_commit": False}, } }, + "SQLALCHEMY_BASE_CLASS": Base, } -complex_mapping_config = { - "model_class": Base, - "binds": { +complex_config = { + "SQLALCHEMY_BINDS": { "default": { "engine": {"url": "sqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, "session": {"expire_on_commit": False}, @@ -28,14 +27,15 @@ "session": {"expire_on_commit": False}, }, }, + "SQLALCHEMY_BASE_CLASS": Base, } -async_mapping_config = { - "model_class": Base, - "binds": { +async_config = { + "SQLALCHEMY_BINDS": { "default": { "engine": {"url": "sqlite+aiosqlite:///file:mem.db?mode=memory&cache=shared&uri=true"}, "session": {"expire_on_commit": False}, - } + }, }, + "SQLALCHEMY_BASE_CLASS": Base, } diff --git a/tests/integration/concurrency/__init__.py b/tests/integration/concurrency/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/concurrency/with_for_update.py b/tests/integration/concurrency/with_for_update.py new file mode 100644 index 0000000..a9fa552 --- /dev/null +++ b/tests/integration/concurrency/with_for_update.py @@ -0,0 +1,103 @@ +import logging +import threading +import time + +import pytest +import sqlalchemy +import sqlalchemy.orm + + +sa = sqlalchemy + +logging.basicConfig(level=logging.DEBUG) +logging.getLogger("sqlalchemy").setLevel(logging.INFO) + +log = logging.getLogger(__name__) + + +class Base(sa.orm.DeclarativeBase): + pass + + +class Thing(Base): + __tablename__ = "things" + + id = sa.Column(sa.Integer, primary_key=True) + status = sa.Column(sa.String) + + +@pytest.fixture(scope="module") +def engine(): + engine = sa.create_engine("sqlite:///") + # engine = sa.create_engine("postgresql+psycopg2://spikes:sesame@localhost/spikes") + Base.metadata.create_all(engine) + + yield engine + + Base.metadata.drop_all(engine) + + +@pytest.fixture(scope="module") +def connection(engine): + with engine.connect() as conn: + yield conn + + +@pytest.fixture +def db(connection): + transaction = connection.begin() + session = sa.orm.Session(bind=connection) + + # now we can even `.commit()` such session + yield session + + session.close() + transaction.rollback() + + +def test_select_for_update(engine): + # scoped_db = scoped_session(sessionmaker(bind=connection)) + scoped_db = sa.orm.scoped_session(sa.orm.sessionmaker(bind=engine)) + db = scoped_db() + db.add(Thing(status="old")) + db.commit() + + def first(event, sess_factory, status): + sess = sess_factory() + # thing = sess.query(Thing).get(1) + thing = sess.query(Thing).with_for_update().get(1) + event.set() # poke second thread + log.debug("Make him wait for a while") + time.sleep(0.263) + thing.status = status + sess.commit() + log.debug("Done!") + # it is always better to explicitly `.remove()` scoped sessions, but + # in this case it is not necessary because it will be garbage-collected + # sess_factory.remove() + + def second(event, sess_factory, status): + event.wait() # ensure we are called in the right moment + sess = sess_factory() + # thing = sess.query(Thing).get(1) + thing = sess.query(Thing).with_for_update().get(1) + thing.status = status + sess.commit() + + event = threading.Event() + th1 = threading.Thread(target=first, args=(event, scoped_db, "new")) + th2 = threading.Thread(target=second, args=(event, scoped_db, "brand_new")) + + th1.start() + th2.start() + + th1.join() + th2.join() + + # assert db.query(Thing).filter_by(id=1).one().status == 'new' + t = db.query(Thing).get(1) + # it is only mandatory to remove session here, seems like it is not + # garbage-collected becasue it is in `assert` statement (not sure about that) + scoped_db.remove() + + assert t.status == "brand_new" diff --git a/tests/integration/framework/smoke_test.py b/tests/integration/framework/smoke_test.py index 682d0e4..7154fd7 100644 --- a/tests/integration/framework/smoke_test.py +++ b/tests/integration/framework/smoke_test.py @@ -40,7 +40,7 @@ def test_simple_transactional_orm_flow(self, db: QuartSQLAlchemy, Todo: t.Any): def test_simple_transactional_core_flow(self, db: QuartSQLAlchemy, Todo: t.Any): with db.bind.engine.connect() as conn: with conn.begin(): - result = conn.execute(sa.insert(Todo)) + result = conn.execute(sa.insert(Todo).values(title="default")) insert_row = result.inserted_primary_key select_row = conn.execute(sa.select(Todo).where(Todo.id == insert_row.id)).one() @@ -56,15 +56,3 @@ def test_simple_transactional_core_flow(self, db: QuartSQLAlchemy, Todo: t.Any): with db.bind.engine.connect() as conn: with pytest.raises(sa.exc.NoResultFound): conn.execute(sa.select(Todo).where(Todo.id == insert_row.id)).one() - - def test_orm_models_comparable(self, db: QuartSQLAlchemy, Todo: t.Any): - with db.bind.Session() as s: - with s.begin(): - todo = Todo() - s.add(todo) - s.flush() - s.refresh(todo) - - with db.bind.Session() as s: - select_todo = s.scalars(sa.select(Todo).where(Todo.id == todo.id)).one() - assert todo == select_todo diff --git a/tests/integration/model/mixins_test.py b/tests/integration/model/mixins_test.py index a88b4d9..43928e6 100644 --- a/tests/integration/model/mixins_test.py +++ b/tests/integration/model/mixins_test.py @@ -8,21 +8,26 @@ from sqlalchemy.orm import Mapped from quart_sqlalchemy import SQLAlchemy -from quart_sqlalchemy.model import Base -from quart_sqlalchemy.model import SoftDeleteMixin +from quart_sqlalchemy.framework import QuartSQLAlchemy +from quart_sqlalchemy.model.mixins import ComparableMixin +from quart_sqlalchemy.model.mixins import RecursiveDictMixin +from quart_sqlalchemy.model.mixins import ReprMixin +from quart_sqlalchemy.model.mixins import SimpleDictMixin +from quart_sqlalchemy.model.mixins import SoftDeleteMixin +from quart_sqlalchemy.model.mixins import TotalOrderMixin -from ...base import SimpleTestBase +from ... import base sa = sqlalchemy -class TestSoftDeleteFeature(SimpleTestBase): - @pytest.fixture - def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[Base], None, None]: - class Post(SoftDeleteMixin, db.Model): - id: Mapped[int] = sa.orm.mapped_column(primary_key=True) - title: Mapped[str] = sa.orm.mapped_column() +class TestSoftDeleteFeature(base.MixinTestBase): + @pytest.fixture(scope="class") + def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[t.Any], None, None]: + class Post(SoftDeleteMixin, db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) user: Mapped[t.Optional[User]] = sa.orm.relationship(backref="posts") @@ -53,3 +58,113 @@ def test_inactive_filtered(self, db: SQLAlchemy, Post: t.Type[t.Any]): assert select_post.id == post.id assert select_post.is_active is False + + +class TestComparableMixin(base.MixinTestBase): + extra_mixins = (TotalOrderMixin,) + + def test_orm_models_comparable(self, db: QuartSQLAlchemy, Todo: t.Any): + assert ComparableMixin in self.default_mixins + + with db.bind.Session() as s: + with s.begin(): + todos = [Todo() for _ in range(5)] + s.add_all(todos) + + with db.bind.Session() as s: + todos = s.scalars(sa.select(Todo).order_by(Todo.id)).all() + + todo1, todo2, *_ = todos + assert todo1 < todo2 + + +class TestReprMixin(base.MixinTestBase): + @pytest.fixture(scope="class") + def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[t.Any], None, None]: + class Post(ReprMixin, db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") + user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) + + user: Mapped[t.Optional[User]] = sa.orm.relationship(backref="posts") + + db.create_all() + yield Post + + def test_mixin_generates_repr(self, db: QuartSQLAlchemy, Post: t.Any): + with db.bind.Session() as s: + with s.begin(): + post = Post() + s.add(post) + s.flush() + s.refresh(post) + + assert repr(post) == f"<{type(post).__name__} {post.id}>" + + +class TestSimpleDictMixin(base.MixinTestBase): + extra_mixins = (SimpleDictMixin,) + + @pytest.fixture(scope="class") + def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[t.Any], None, None]: + class Post(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") + user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) + + user: Mapped[t.Optional[User]] = sa.orm.relationship(backref="posts") + + db.create_all() + yield Post + + def test_mixin_converts_model_to_dict(self, db: QuartSQLAlchemy, Post: t.Any, User: t.Any): + with db.bind.Session() as s: + with s.begin(): + user = s.scalars(sa.select(User)).first() + post = Post(user=user) + s.add(post) + s.flush() + s.refresh(post.user) + + with db.bind.Session() as s: + with s.begin(): + user = s.scalars(sa.select(User).options(sa.orm.selectinload(User.posts))).first() + + data = user.to_dict() + + for field in data: + assert data[field] == getattr(user, field) + + +class TestRecursiveMixin(base.MixinTestBase): + extra_mixins = (RecursiveDictMixin,) + + @pytest.fixture(scope="class") + def Post(self, db: SQLAlchemy, User: t.Type[t.Any]) -> t.Generator[t.Type[t.Any], None, None]: + class Post(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + title: Mapped[str] = sa.orm.mapped_column(default="default") + user_id: Mapped[t.Optional[int]] = sa.orm.mapped_column(sa.ForeignKey("user.id")) + + user: Mapped[t.Optional[User]] = sa.orm.relationship(backref="posts") + + db.create_all() + yield Post + + def test_mixin_converts_model_to_dict(self, db: QuartSQLAlchemy, Post: t.Any, User: t.Any): + with db.bind.Session() as s: + with s.begin(): + user = s.scalars(sa.select(User)).first() + post = Post(user=user) + s.add(post) + s.flush() + s.refresh(post.user) + + with db.bind.Session() as s: + with s.begin(): + user = s.scalars(sa.select(User).options(sa.orm.selectinload(User.posts))).first() + + data = user.to_dict() + + for col in sa.inspect(user).mapper.columns: + assert data[col.name] == getattr(user, col.name) diff --git a/tests/integration/model/model_test.py b/tests/integration/model/model_test.py new file mode 100644 index 0000000..651d908 --- /dev/null +++ b/tests/integration/model/model_test.py @@ -0,0 +1,59 @@ +from datetime import datetime + +import pytest +import sqlalchemy +import sqlalchemy.event +import sqlalchemy.exc +import sqlalchemy.ext +import sqlalchemy.ext.asyncio +import sqlalchemy.orm +import sqlalchemy.util +from sqlalchemy.orm import Mapped + +from quart_sqlalchemy import Base +from quart_sqlalchemy import Bind +from quart_sqlalchemy import SQLAlchemy +from quart_sqlalchemy import SQLAlchemyConfig +from quart_sqlalchemy.model.model import BaseMixins + + +sa = sqlalchemy + + +class TestSQLAlchemyWithCustomModelClass: + def test_base_class_with_declarative_preserves_class_and_table_metadata(self): + """This is nice to have as it decouples quart and quart_sqlalchemy from the data + models themselves. + """ + + class User(Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + + db = SQLAlchemy(SQLAlchemyConfig(base_class=Base)) + + db.create_all() + + with db.bind.Session() as s: + with s.begin(): + user = User() + s.add(user) + s.flush() + s.refresh(user) + + Base.registry.dispose() + Bind._instances.clear() + + def test_sqla_class_adds_declarative_base_when_missing_from_base_class(self): + db = SQLAlchemy(SQLAlchemyConfig(base_class=BaseMixins)) + + class User(db.Base): + id: Mapped[int] = sa.orm.mapped_column(primary_key=True, autoincrement=True) + + db.create_all() + + with db.bind.Session() as s: + with s.begin(): + user = User() + s.add(user) + s.flush() + s.refresh(user) diff --git a/tests/integration/retry_test.py b/tests/integration/retry_test.py index 4e4d60b..a4bb6a6 100644 --- a/tests/integration/retry_test.py +++ b/tests/integration/retry_test.py @@ -46,7 +46,7 @@ def test_retrying_session(self, db: SQLAlchemy, Todo: t.Type[t.Any], mocker): # s.commit() def test_retrying_session_class(self, db: SQLAlchemy, Todo: t.Type[t.Any], mocker): - class Unique(db.Model): + class Unique(db.Base): id: Mapped[int] = sa.orm.mapped_column( sa.Identity(), primary_key=True, autoincrement=True ) @@ -57,7 +57,7 @@ class Unique(db.Model): db.create_all() with retrying_session(db.bind) as s: - todo = Todo(title="hello") + todo = Unique(name="hello") s.add(todo) diff --git a/workspace.code-workspace b/workspace.code-workspace index 10a11a5..aa9e12e 100644 --- a/workspace.code-workspace +++ b/workspace.code-workspace @@ -13,6 +13,7 @@ "source.organizeImports": true } }, - "esbonio.sphinx.confDir": "" + "esbonio.sphinx.confDir": "", + "python.linting.pylintEnabled": false } } From bcc5a9b9f87d6234ea1a4cfcae95758c682fbc1a Mon Sep 17 00:00:00 2001 From: Joe Black Date: Wed, 12 Apr 2023 17:59:47 -0400 Subject: [PATCH 09/11] clean up --- src/quart_sqlalchemy/sim/handle.py | 83 ------------- src/quart_sqlalchemy/sim/logic.py | 188 ----------------------------- src/quart_sqlalchemy/sim/repo.py | 66 +++++++--- 3 files changed, 49 insertions(+), 288 deletions(-) diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py index 47d0cf3..51ea99d 100644 --- a/src/quart_sqlalchemy/sim/handle.py +++ b/src/quart_sqlalchemy/sim/handle.py @@ -173,16 +173,6 @@ def get_or_create_by_email_and_client_id( ) return auth_user - def get_by_id_and_validate_exists(self, auth_user_id): - """This function helps formalize how a non-existent auth user should be handled.""" - auth_user = self.logic.AuthUser.get_by_id(self.session_factory(), auth_user_id) - if auth_user is None: - raise RuntimeError('resource_name="auth_user"') - return auth_user - - # This function is reserved for consolidating into a canonical user. Do not - # call this function under other circumstances as it will automatically set - # the user as verified. See ch-25343 for additional details. def create_verified_user( self, client_id, @@ -190,8 +180,6 @@ def create_verified_user( user_type=EntityType.FORTMATIC.value, **kwargs, ): - # with self.session_factory() as session: - # with self.logic.begin(ro=False) as session: session = self.session_factory() with session.begin_nested(): auid = self.logic.AuthUser.add_by_email_and_client_id( @@ -231,21 +219,6 @@ def get_by_client_id_and_user_type( limit=limit, ) - def get_by_client_ids_and_user_type( - self, - client_ids, - user_type, - offset=None, - limit=None, - ): - return self.logic.AuthUser.get_by_client_ids_and_user_type( - self.session_factory(), - client_ids, - user_type, - offset=offset, - limit=limit, - ) - def exist_by_email_client_id_and_user_type(self, email, client_id, user_type): return self.logic.AuthUser.exist_by_email_and_client_id( self.session_factory(), @@ -257,11 +230,6 @@ def exist_by_email_client_id_and_user_type(self, email, client_id, user_type): def update_email_by_id(self, model_id, email): return self.logic.AuthUser.update_by_id(self.session_factory(), model_id, email=email) - def update_phone_number_by_id(self, model_id, phone_number): - return self.logic.AuthUser.update_by_id( - self.session_factory(), model_id, phone_number=phone_number - ) - def get_by_email_client_id_and_user_type(self, email, client_id, user_type): return self.logic.AuthUser.get_by_email_and_client_id( self.session_factory(), @@ -299,55 +267,9 @@ def set_role_by_email_magic_client_id(self, email, magic_client_id, role): return self.logic.AuthUser.update_by_id(session, auth_user.id, **{role: True}) - def search_by_client_id_and_substring( - self, - client_id, - substring, - offset=None, - limit=10, - ): - if not isinstance(substring, str) or len(substring) < 3: - raise InvalidSubstringError() - - auth_users = self.logic.AuthUser.get_by_client_id_with_substring_search( - self.session_factory(), - client_id, - substring, - offset=offset, - limit=limit, - ) - - return auth_users - - def is_magic_connect_enabled(self, auth_user_id=None, auth_user=None): - if auth_user is None and auth_user_id is None: - raise Exception("At least one argument needed: auth_user_id or auth_user.") - - if auth_user is None: - auth_user = self.get_by_id(auth_user_id) - - return auth_user.user_type == EntityType.CONNECT.value - def mark_as_inactive(self, auth_user_id): self.logic.AuthUser.update_by_id(self.session_factory(), auth_user_id, is_active=False) - def get_by_email_and_wallet_type_for_interop(self, email, wallet_type, network): - """ - Opinionated method for fetching AuthWallets by email address, wallet_type and network. - """ - return self.logic.AuthUser.get_by_email_for_interop( - self.session_factory(), - email=email, - wallet_type=wallet_type, - network=network, - ) - - def get_magic_connect_auth_user(self, auth_user_id): - auth_user = self.get_by_id_and_validate_exists(auth_user_id) - if not auth_user.is_magic_connect_user: - raise RuntimeError("RequestForbidden") - return auth_user - @signals.auth_user_duplicate.connect def handle_duplicate_auth_users( @@ -355,12 +277,7 @@ def handle_duplicate_auth_users( original_auth_user_id: ObjectID, duplicate_auth_user_ids: t.Sequence[ObjectID], ) -> None: - logger.info(f"{len(duplicate_auth_user_ids)} dupe(s) found for {original_auth_user_id}") - for dupe_id in duplicate_auth_user_ids: - logger.info( - f"marking auth_user_id {dupe_id} as inactive, in favor of original {original_auth_user_id}", - ) app.container.logic().AuthUser.update_by_id(dupe_id, is_active=False) diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py index 7b969bc..38c58a4 100644 --- a/src/quart_sqlalchemy/sim/logic.py +++ b/src/quart_sqlalchemy/sim/logic.py @@ -226,43 +226,6 @@ def get_by_session_token( ) ) - @provide_global_contextual_session - def get_or_add_by_phone_number_and_client_id( - self, - session, - client_id, - phone_number, - user_type=EntityType.FORTMATIC.value, - ): - if phone_number is None: - raise MissingPhoneNumber() - - row = self.get_by_phone_number_and_client_id( - session=session, - phone_number=phone_number, - client_id=client_id, - user_type=user_type, - ) - - if row: - return row - - row = self._repository.add( - session=session, - phone_number=phone_number, - client_id=client_id, - user_type=user_type, - provenance=Provenance.SMS, - ) - logger.info( - "New auth user (id: {}) created by phone number (client_id: {})".format( - row.id, - client_id, - ), - ) - - return row - @provide_global_contextual_session def get_by_active_identifier_and_client_id( self, @@ -319,25 +282,6 @@ def get_by_email_and_client_id( for_update=for_update, ) - @provide_global_contextual_session - def get_by_phone_number_and_client_id( - self, - session, - phone_number, - client_id, - user_type=EntityType.FORTMATIC.value, - ): - if phone_number is None: - raise MissingPhoneNumber() - - return self.get_by_active_identifier_and_client_id( - session=session, - identifier_field=auth_user_model.phone_number, - identifier_value=phone_number, - client_id=client_id, - user_type=user_type, - ) - @provide_global_contextual_session def exist_by_email_and_client_id( self, @@ -392,17 +336,6 @@ def get_user_count_by_client_id_and_user_type(self, session, client_id, user_typ return session.execute(query).scalar() - @provide_global_contextual_session - def get_by_client_id_and_global_auth_user(self, session, client_id, global_auth_user_id): - return self._repository.get_by( - session=session, - filters=[ - auth_user_model.client_id == client_id, - auth_user_model.user_type == EntityType.CONNECT.value, - auth_user_model.global_auth_user_id == global_auth_user_id, - ], - ) - @provide_global_contextual_session def get_by_client_id_and_user_type( self, @@ -420,61 +353,6 @@ def get_by_client_id_and_user_type( limit=limit, ) - @provide_global_contextual_session - def get_by_client_ids_and_user_type( - self, - session, - client_ids, - user_type, - offset=None, - limit=None, - ): - if not client_ids: - return [] - - return self._repository.get_by( - session, - filters=[ - auth_user_model.client_id.in_(client_ids), - auth_user_model.user_type == user_type, - auth_user_model.date_verified != None, - ], - offset=offset, - limit=limit, - order_by_clause=auth_user_model.id.desc(), - ) - - @provide_global_contextual_session - def get_by_client_id_with_substring_search( - self, - session, - client_id, - substring, - offset=None, - limit=10, - join_list=None, - ): - return self._repository.get_by( - session, - filters=[ - auth_user_model.client_id == client_id, - auth_user_model.user_type == EntityType.MAGIC.value, - sa.or_( - auth_user_model.provenance == Provenance.SMS, - auth_user_model.provenance == Provenance.LINK, - auth_user_model.provenance == None, # noqa: E711 - ), - sa.or_( - auth_user_model.phone_number.contains(substring), - auth_user_model.email.contains(substring), - ), - ], - offset=offset, - limit=limit, - order_by_clause=auth_user_model.id.desc(), - join_list=join_list, - ) - @provide_global_contextual_session def yield_by_chunk(self, session, chunk_size, filters=None, join_list=None): yield from self._repository.yield_by_chunk( @@ -518,72 +396,6 @@ def get_by_email( join_list=join_list, ) - @provide_global_contextual_session - def get_by_email_for_interop( - self, - session, - email: str, - wallet_type: WalletType, - network: str, - ) -> t.List[auth_user_model]: - """ - Custom method for searching for users eligible for interop. Unfortunately, this can't be done with the current - abstractions in our sql_repository, so this is a one-off bespoke method. - If we need to add more similar queries involving eager loading and multiple joins, we can add an abstraction - inside the repository. - """ - - query = ( - session.query(auth_user_model) - .join( - auth_user_model.wallets.and_( - auth_wallet_model.wallet_type == str(wallet_type) - ).and_(auth_wallet_model.network == network) - ) - .options(sa.orm.contains_eager(auth_user_model.wallets)) - .join( - auth_user_model.magic_client.and_( - magic_client_model.connect_interop == ConnectInteropStatus.ENABLED, - ), - ) - .options(sa.orm.contains_eager(auth_user_model.magic_client)) - .filter( - auth_wallet_model.wallet_type == wallet_type, - auth_wallet_model.network == network, - ) - .filter( - auth_user_model.email == email, - auth_user_model.user_type == EntityType.MAGIC.value, - ) - .populate_existing() - ) - - return query.all() - - @provide_global_contextual_session - def get_linked_users(self, session, primary_auth_user_id, join_list, no_op=False): - # TODO(magic-ravi#67899|2022-12-30): Re-enable account linked users for interop. Remove no_op flag. - if no_op: - return [] - else: - return self._repository.get_by( - session, - filters=[ - auth_user_model.user_type == EntityType.MAGIC.value, - auth_user_model.linked_primary_auth_user_id == primary_auth_user_id, - ], - join_list=join_list, - ) - - @provide_global_contextual_session - def get_by_phone_number(self, session, phone_number): - return self._repository.get_by( - session, - filters=[ - auth_user_model.phone_number == phone_number, - ], - ) - class AuthWallet(LogicComponent[auth_wallet_model, ObjectID, sa.orm.Session]): model = auth_wallet_model diff --git a/src/quart_sqlalchemy/sim/repo.py b/src/quart_sqlalchemy/sim/repo.py index 959efb2..618af04 100644 --- a/src/quart_sqlalchemy/sim/repo.py +++ b/src/quart_sqlalchemy/sim/repo.py @@ -29,6 +29,10 @@ class AbstractRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCM model: t.Type[EntityT] identity: t.Type[EntityIdT] + def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT]): + self.model = model + self.identity = identity + class AbstractBulkRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass=ABCMeta): """A repository interface for bulk operations. @@ -40,6 +44,10 @@ class AbstractBulkRepository(t.Generic[EntityT, EntityIdT, SessionT], metaclass= model: t.Type[EntityT] identity: t.Type[EntityIdT] + def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT]): + self.model = model + self.identity = identity + class SQLAlchemyRepository( AbstractRepository[EntityT, EntityIdT, SessionT], t.Generic[EntityT, EntityIdT, SessionT] @@ -68,11 +76,23 @@ class SQLAlchemyRepository( builder: StatementBuilder - def __init__(self, model: t.Type[EntityT], identity: t.Type[EntityIdT]): - self.model = model - self.identity = identity + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.builder = StatementBuilder(self.model) + def _build_execution_options( + self, + execution_options: t.Optional[t.Dict[str, t.Any]] = None, + include_inactive: bool = False, + yield_by_chunk: bool = False, + ): + execution_options = execution_options or {} + if include_inactive: + execution_options.setdefault("include_inactive", include_inactive) + if yield_by_chunk: + execution_options.setdefault("yield_per", yield_by_chunk) + return execution_options + def insert(self, session: sa.orm.Session, values: t.Dict[str, t.Any]) -> EntityT: """Insert a new model into the database.""" new = self.model(**values) @@ -141,9 +161,12 @@ def get( """ selectables = (self.model,) - execution_options = execution_options or {} - if include_inactive: - execution_options.setdefault("include_inactive", include_inactive) + execution_options = self._build_execution_options( + execution_options, include_inactive=include_inactive + ) + # execution_options = execution_options or {} + # if include_inactive: + # execution_options.setdefault("include_inactive", include_inactive) statement = self.builder.select( selectables, # type: ignore @@ -173,9 +196,12 @@ def get_by_field( """Select models where field is equal to value.""" selectables = (self.model,) - execution_options = execution_options or {} - if include_inactive: - execution_options.setdefault("include_inactive", include_inactive) + execution_options = self._build_execution_options( + execution_options, include_inactive=include_inactive + ) + # execution_options = execution_options or {} + # if include_inactive: + # execution_options.setdefault("include_inactive", include_inactive) if isinstance(field, str): field = getattr(self.model, field) @@ -217,11 +243,16 @@ def select( """ selectables = selectables or (self.model,) # type: ignore - execution_options = execution_options or {} - if include_inactive: - execution_options.setdefault("include_inactive", include_inactive) - if yield_by_chunk: - execution_options.setdefault("yield_per", yield_by_chunk) + execution_options = self._build_execution_options( + execution_options, + include_inactive=include_inactive, + yield_by_chunk=yield_by_chunk, + ) + # execution_options = execution_options or {} + # if include_inactive: + # execution_options.setdefault("include_inactive", include_inactive) + # if yield_by_chunk: + # execution_options.setdefault("yield_per", yield_by_chunk) statement = self.builder.select( selectables, @@ -271,9 +302,10 @@ def exists( """ selectables = (sa.sql.literal(True),) - execution_options = {} - if include_inactive: - execution_options.setdefault("include_inactive", include_inactive) + execution_options = self._build_execution_options(None, include_inactive=include_inactive) + # execution_options = {} + # if include_inactive: + # execution_options.setdefault("include_inactive", include_inactive) statement = self.builder.select( selectables, From 671edeb377b6b94ee23952f9dad862846e7cc0d5 Mon Sep 17 00:00:00 2001 From: Joe Black Date: Wed, 12 Apr 2023 18:14:50 -0400 Subject: [PATCH 10/11] delete global session --- src/quart_sqlalchemy/sim/handle.py | 115 +++++++++++------- src/quart_sqlalchemy/sim/logic.py | 4 - src/quart_sqlalchemy/sim/views/auth_user.py | 16 ++- src/quart_sqlalchemy/sim/views/auth_wallet.py | 19 ++- .../sim/views/magic_client.py | 22 ++-- 5 files changed, 98 insertions(+), 78 deletions(-) diff --git a/src/quart_sqlalchemy/sim/handle.py b/src/quart_sqlalchemy/sim/handle.py index 51ea99d..bd3bc12 100644 --- a/src/quart_sqlalchemy/sim/handle.py +++ b/src/quart_sqlalchemy/sim/handle.py @@ -6,10 +6,11 @@ from datetime import datetime import sqlalchemy +import sqlalchemy.orm from dependency_injector.wiring import Provide from quart import Quart -from quart_sqlalchemy.session import SessionProxy +from quart_sqlalchemy.session import provide_global_contextual_session from quart_sqlalchemy.sim import signals from quart_sqlalchemy.sim.logic import LogicComponent from quart_sqlalchemy.sim.model import AuthUser @@ -44,14 +45,15 @@ class InvalidSubstringError(AuthUserBaseError): class HandlerBase: logic: LogicComponent = Provide["logic"] - session_factory = SessionProxy() class MagicClientHandler(HandlerBase): auth_user_handler: AuthUserHandler = Provide["AuthUserHandler"] + @provide_global_contextual_session def add( self, + session: sa.orm.Session, app_name=None, rate_limit_tier=None, connect_interop=None, @@ -68,7 +70,7 @@ def add( """ return self.logic.MagicClient.add( - self.session_factory(), + session, app_name=app_name, rate_limit_tier=rate_limit_tier, connect_interop=connect_interop, @@ -76,13 +78,16 @@ def add( global_audience_enabled=global_audience_enabled, ) - def get_by_public_api_key(self, public_api_key): - return self.logic.MagicClient.get_by_public_api_key(self.session_factory(), public_api_key) + @provide_global_contextual_session + def get_by_public_api_key(self, session: sa.orm.Session, public_api_key): + return self.logic.MagicClient.get_by_public_api_key(session, public_api_key) - def get_by_id(self, magic_client_id): - return self.logic.MagicClient.get_by_id(self.session_factory(), magic_client_id) + @provide_global_contextual_session + def get_by_id(self, session: sa.orm.Session, magic_client_id): + return self.logic.MagicClient.get_by_id(session, magic_client_id) - def update_app_name_by_id(self, magic_client_id, app_name): + @provide_global_contextual_session + def update_app_name_by_id(self, session: sa.orm.Session, magic_client_id, app_name): """ Args: magic_client_id (ObjectID|int|str): self explanatory. @@ -92,23 +97,21 @@ def update_app_name_by_id(self, magic_client_id, app_name): None if `magic_client_id` doesn't exist in the db app_name if update was successful """ - client = self.logic.MagicClient.update_by_id( - self.session_factory(), magic_client_id, app_name=app_name - ) + client = self.logic.MagicClient.update_by_id(session, magic_client_id, app_name=app_name) if not client: return None return client.app_name - def update_by_id(self, magic_client_id, **kwargs): - client = self.logic.MagicClient.update_by_id( - self.session_factory(), magic_client_id, **kwargs - ) + @provide_global_contextual_session + def update_by_id(self, session: sa.orm.Session, magic_client_id, **kwargs): + client = self.logic.MagicClient.update_by_id(session, magic_client_id, **kwargs) return client - def set_inactive_by_id(self, magic_client_id): + @provide_global_contextual_session + def set_inactive_by_id(self, session: sa.orm.Session, magic_client_id): """ Args: magic_client_id (ObjectID|int|str): self explanatory. @@ -116,12 +119,12 @@ def set_inactive_by_id(self, magic_client_id): Returns: None """ - self.logic.MagicClient.update_by_id( - self.session_factory(), magic_client_id, is_active=False - ) + self.logic.MagicClient.update_by_id(session, magic_client_id, is_active=False) + @provide_global_contextual_session def get_users_for_client( self, + session: sa.orm.Session, magic_client_id, offset=None, limit=None, @@ -131,6 +134,7 @@ def get_users_for_client( """ product_type = get_product_type_by_client_id(magic_client_id) auth_users = self.auth_user_handler.get_by_client_id_and_user_type( + session, magic_client_id, product_type, offset=offset, @@ -146,16 +150,18 @@ def get_users_for_client( class AuthUserHandler(HandlerBase): - def get_by_session_token(self, session_token): - return self.logic.AuthUser.get_by_session_token(self.session_factory(), session_token) + @provide_global_contextual_session + def get_by_session_token(self, session: sa.orm.Session, session_token): + return self.logic.AuthUser.get_by_session_token(session, session_token) + @provide_global_contextual_session def get_or_create_by_email_and_client_id( self, + session: sa.orm.Session, email, client_id, user_type=EntityType.MAGIC.value, ): - session = self.session_factory() with session.begin_nested(): auth_user = self.logic.AuthUser.get_by_email_and_client_id( session, @@ -173,14 +179,15 @@ def get_or_create_by_email_and_client_id( ) return auth_user + @provide_global_contextual_session def create_verified_user( self, + session: sa.orm.Session, client_id, email, user_type=EntityType.FORTMATIC.value, **kwargs, ): - session = self.session_factory() with session.begin_nested(): auid = self.logic.AuthUser.add_by_email_and_client_id( session, @@ -201,52 +208,66 @@ def create_verified_user( return auth_user - def get_by_id(self, auth_user_id) -> AuthUser: - return self.logic.AuthUser.get_by_id(self.session_factory(), auth_user_id) + @provide_global_contextual_session + def get_by_id(self, session: sa.orm.Session, auth_user_id) -> AuthUser: + return self.logic.AuthUser.get_by_id(session, auth_user_id) + @provide_global_contextual_session def get_by_client_id_and_user_type( self, + session: sa.orm.Session, client_id, user_type, offset=None, limit=None, ): return self.logic.AuthUser.get_by_client_id_and_user_type( - self.session_factory(), + session, client_id, user_type, offset=offset, limit=limit, ) - def exist_by_email_client_id_and_user_type(self, email, client_id, user_type): + @provide_global_contextual_session + def exist_by_email_client_id_and_user_type( + self, session: sa.orm.Session, email, client_id, user_type + ): return self.logic.AuthUser.exist_by_email_and_client_id( - self.session_factory(), + session, email, client_id, user_type=user_type, ) - def update_email_by_id(self, model_id, email): - return self.logic.AuthUser.update_by_id(self.session_factory(), model_id, email=email) + @provide_global_contextual_session + def update_email_by_id(self, session: sa.orm.Session, model_id, email): + return self.logic.AuthUser.update_by_id(session, model_id, email=email) - def get_by_email_client_id_and_user_type(self, email, client_id, user_type): + @provide_global_contextual_session + def get_by_email_client_id_and_user_type( + self, session: sa.orm.Session, email, client_id, user_type + ): return self.logic.AuthUser.get_by_email_and_client_id( - self.session_factory(), + session, email, client_id, user_type, ) - def mark_date_verified_by_id(self, model_id): + @provide_global_contextual_session + def mark_date_verified_by_id(self, session: sa.orm.Session, model_id): return self.logic.AuthUser.update_by_id( - self.session_factory(), + session, model_id, date_verified=datetime.utcnow(), ) - def set_role_by_email_magic_client_id(self, email, magic_client_id, role): - session = self.session_factory() + @provide_global_contextual_session + def set_role_by_email_magic_client_id( + self, session: sa.orm.Session, email, magic_client_id, role + ): + session = session auth_user = self.logic.AuthUser.get_by_email_and_client_id( session, email, @@ -267,8 +288,9 @@ def set_role_by_email_magic_client_id(self, email, magic_client_id, role): return self.logic.AuthUser.update_by_id(session, auth_user.id, **{role: True}) - def mark_as_inactive(self, auth_user_id): - self.logic.AuthUser.update_by_id(self.session_factory(), auth_user_id, is_active=False) + @provide_global_contextual_session + def mark_as_inactive(self, session: sa.orm.Session, auth_user_id): + self.logic.AuthUser.update_by_id(session, auth_user_id, is_active=False) @signals.auth_user_duplicate.connect @@ -282,29 +304,35 @@ def handle_duplicate_auth_users( class AuthWalletHandler(HandlerBase): - def get_by_id(self, model_id): - return self.logic.AuthWallet.get_by_id(self.session_factory(), model_id) + @provide_global_contextual_session + def get_by_id(self, session: sa.orm.Session, model_id): + return self.logic.AuthWallet.get_by_id(session, model_id) - def get_by_public_address(self, public_address): - return self.logic.AuthWallet().get_by_public_address(self.session_factory(), public_address) + @provide_global_contextual_session + def get_by_public_address(self, session: sa.orm.Session, public_address): + return self.logic.AuthWallet().get_by_public_address(session, public_address) + @provide_global_contextual_session def get_by_auth_user_id( self, + session: sa.orm.Session, auth_user_id: ObjectID, network: t.Optional[str] = None, wallet_type: t.Optional[WalletType] = None, **kwargs, ) -> t.List[AuthWallet]: return self.logic.AuthWallet.get_by_auth_user_id( - self.session_factory(), + session, auth_user_id, network=network, wallet_type=wallet_type, **kwargs, ) + @provide_global_contextual_session def sync_auth_wallet( self, + session: sa.orm.Session, auth_user_id, public_address, encrypted_private_address, @@ -312,7 +340,6 @@ def sync_auth_wallet( network: t.Optional[str] = None, wallet_type: t.Optional[WalletType] = None, ): - session = self.session_factory() with session.begin_nested(): existing_wallet = self.logic.AuthWallet.get_by_auth_user_id( session, diff --git a/src/quart_sqlalchemy/sim/logic.py b/src/quart_sqlalchemy/sim/logic.py index 38c58a4..626adbf 100644 --- a/src/quart_sqlalchemy/sim/logic.py +++ b/src/quart_sqlalchemy/sim/logic.py @@ -1,4 +1,3 @@ -import inspect import logging import secrets import typing as t @@ -12,11 +11,8 @@ from quart_sqlalchemy.sim import signals from quart_sqlalchemy.sim.model import AuthUser as auth_user_model from quart_sqlalchemy.sim.model import AuthWallet as auth_wallet_model -from quart_sqlalchemy.sim.model import ConnectInteropStatus from quart_sqlalchemy.sim.model import EntityType from quart_sqlalchemy.sim.model import MagicClient as magic_client_model -from quart_sqlalchemy.sim.model import Provenance -from quart_sqlalchemy.sim.model import WalletType from quart_sqlalchemy.sim.repo_adapter import RepositoryLegacyAdapter from quart_sqlalchemy.sim.util import ObjectID from quart_sqlalchemy.sim.util import one diff --git a/src/quart_sqlalchemy/sim/views/auth_user.py b/src/quart_sqlalchemy/sim/views/auth_user.py index 6c8664b..cf1d366 100644 --- a/src/quart_sqlalchemy/sim/views/auth_user.py +++ b/src/quart_sqlalchemy/sim/views/auth_user.py @@ -6,7 +6,6 @@ from dependency_injector.wiring import Provide from quart_sqlalchemy.framework import QuartSQLAlchemy -from quart_sqlalchemy.session import set_global_contextual_session from ..auth import authorized_request from ..auth import RequestCredentials @@ -51,8 +50,7 @@ def get_auth_user( credentials: RequestCredentials = Provide["request_credentials"], ) -> ResponseWrapper[AuthUserSchema]: with db.bind.Session() as session: - with set_global_contextual_session(session): - auth_user = auth_user_handler.get_by_session_token(credentials.current_user.value) + auth_user = auth_user_handler.get_by_session_token(session, credentials.current_user.value) return ResponseWrapper[AuthUserSchema](data=AuthUserSchema.from_orm(auth_user)) @@ -76,12 +74,12 @@ def create_auth_user( ) -> ResponseWrapper[CreateAuthUserResponse]: with db.bind.Session() as session: with session.begin(): - with set_global_contextual_session(session): - client = auth_user_handler.create_verified_user( - email=data.email, - client_id=credentials.current_client.subject.id, - user_type=EntityType.MAGIC.value, - ) + client = auth_user_handler.create_verified_user( + session, + email=data.email, + client_id=credentials.current_client.subject.id, + user_type=EntityType.MAGIC.value, + ) return ResponseWrapper[CreateAuthUserResponse]( data=dict(auth_user=AuthUserSchema.from_orm(client)) # type: ignore diff --git a/src/quart_sqlalchemy/sim/views/auth_wallet.py b/src/quart_sqlalchemy/sim/views/auth_wallet.py index 79c3be0..9275999 100644 --- a/src/quart_sqlalchemy/sim/views/auth_wallet.py +++ b/src/quart_sqlalchemy/sim/views/auth_wallet.py @@ -6,7 +6,6 @@ from quart import g from quart_sqlalchemy.framework import QuartSQLAlchemy -from quart_sqlalchemy.session import set_global_contextual_session from ..auth import authorized_request from ..auth import RequestCredentials @@ -70,15 +69,15 @@ def sync( ) -> ResponseWrapper[WalletSyncResponse]: with db.bind.Session() as session: with session.begin(): - with set_global_contextual_session(session): - wallet = auth_wallet_handler.sync_auth_wallet( - credentials.current_user.subject.id, - data.public_address, - data.encrypted_private_address, - WalletManagementType.DELEGATED.value, - network=web3.network, - wallet_type=data.wallet_type, - ) + wallet = auth_wallet_handler.sync_auth_wallet( + session, + credentials.current_user.subject.id, + data.public_address, + data.encrypted_private_address, + WalletManagementType.DELEGATED.value, + network=web3.network, + wallet_type=data.wallet_type, + ) return ResponseWrapper[WalletSyncResponse]( data=dict( diff --git a/src/quart_sqlalchemy/sim/views/magic_client.py b/src/quart_sqlalchemy/sim/views/magic_client.py index 09d889b..dc506a2 100644 --- a/src/quart_sqlalchemy/sim/views/magic_client.py +++ b/src/quart_sqlalchemy/sim/views/magic_client.py @@ -5,7 +5,6 @@ from dependency_injector.wiring import Provide from quart_sqlalchemy.framework import QuartSQLAlchemy -from quart_sqlalchemy.session import set_global_contextual_session from ..auth import authorized_request from ..auth import RequestCredentials @@ -51,14 +50,14 @@ def create_magic_client( ) -> ResponseWrapper[CreateMagicClientResponse]: with db.bind.Session() as session: with session.begin(): - with set_global_contextual_session(session): - client = magic_client_handler.add( - app_name=data.app_name, - rate_limit_tier=data.rate_limit_tier, - connect_interop=data.connect_interop, - is_signing_modal_enabled=data.is_signing_modal_enabled, - global_audience_enabled=data.global_audience_enabled, - ) + client = magic_client_handler.add( + session, + app_name=data.app_name, + rate_limit_tier=data.rate_limit_tier, + connect_interop=data.connect_interop, + is_signing_modal_enabled=data.is_signing_modal_enabled, + global_audience_enabled=data.global_audience_enabled, + ) return ResponseWrapper[CreateMagicClientResponse]( data=dict(magic_client=MagicClientSchema.from_orm(client)) # type: ignore @@ -82,8 +81,9 @@ def get_magic_client( db: QuartSQLAlchemy = Provide["db"], ) -> ResponseWrapper[MagicClientSchema]: with db.bind.Session() as session: - with set_global_contextual_session(session): - client = magic_client_handler.get_by_public_api_key(credentials.current_client.value) + client = magic_client_handler.get_by_public_api_key( + session, credentials.current_client.value + ) return ResponseWrapper[MagicClientSchema]( data=MagicClientSchema.from_orm(client) # type: ignore From 38cf5a1adde49979672685a66c110e6e455f2ad7 Mon Sep 17 00:00:00 2001 From: Joe Black Date: Sun, 28 Sep 2025 19:18:12 -0400 Subject: [PATCH 11/11] Update project configuration and dependencies; remove .flake8 and .setup.cfg, add new settings in pyproject.toml, and adjust .editorconfig for improved code style. Update Python version to 3.12 and version to 4.0.0. Enhance signal handling in signals.py and implement encryption/decryption functions in util.py. --- .editorconfig | 25 +++- .flake8 | 19 --- .gitignore | 2 +- .python-version | 2 +- pyproject.toml | 221 ++++++++++++++++---------------- setup.cfg | 73 ----------- src/quart_sqlalchemy/signals.py | 23 ++-- src/quart_sqlalchemy/util.py | 43 ++++++- tox.ini | 23 +--- workspace.code-workspace | 24 ++-- 10 files changed, 205 insertions(+), 250 deletions(-) delete mode 100644 .flake8 delete mode 100644 setup.cfg diff --git a/.editorconfig b/.editorconfig index e32c802..1f877e4 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,5 +1,11 @@ +# EditorConfig is awesome: https://editorconfig.org + +# top-most EditorConfig file root = true +max_line_length = 100 + +# Unix-style newlines with a newline ending every file [*] indent_style = space indent_size = 4 @@ -7,7 +13,22 @@ insert_final_newline = true trim_trailing_whitespace = true end_of_line = lf charset = utf-8 -max_line_length = 88 -[*.{yml,yaml,json,js,css,html}] +[*.md] +trim_trailing_whitespace = false + + +# 2 space indentation +[*.{yml,yaml,css,html}] indent_size = 2 + + +[*.py] +indent_style = space +indent_size = 4 + +# Tab indentation (no size specified) +[Makefile] +indent_style = tab +tab_width = 4 + diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 45a33e1..0000000 --- a/.flake8 +++ /dev/null @@ -1,19 +0,0 @@ -[flake8] -# B = bugbear -# E = pycodestyle errors -# F = flake8 pyflakes -# W = pycodestyle warnings -# B9 = bugbear opinions -# ISC = implicit-str-concat -select = B, E, F, W, B9, ISC -ignore = - # slice notation whitespace, invalid - E203 - # line length, handled by bugbear B950 - E501 - # bare except, handled by bugbear B001 - E722 - # bin op line break, invalid - W503 -# up to 88 allowed by bugbear B950 -max-line-length = 88 diff --git a/.gitignore b/.gitignore index e231971..dae162d 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,4 @@ pytest-plugin-work old examples/two/ *.sqlite -*.db \ No newline at end of file +*.db diff --git a/.python-version b/.python-version index 214b521..8531a3b 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.7.13 +3.12.2 diff --git a/pyproject.toml b/pyproject.toml index 17e8c4b..c0e5624 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "quart-sqlalchemy" -version = "3.0.2" +version = "4.0.0" description = "SQLAlchemy for humans, with framework adapter for Quart." authors = [ {name = "Joe Black", email = "me@joeblack.nyc"}, @@ -11,13 +11,17 @@ dependencies = [ "SQLAlchemy-Utils", "anyio", "aiosqlite", + "bases", + "blinker", + "cryptography", "pydantic", "tenacity", + "reedsolo", "sqlapagination", - "exceptiongroup", + "simonspeckciphers", "python-ulid" ] -requires-python = ">=3.7" +requires-python = ">=3.12" readme = "README.rst" license = {text = "MIT"} @@ -29,6 +33,8 @@ license = {text = "MIT"} requires = ["setuptools", "setuptools-scm"] build-backend = "setuptools.build_meta" +[tool.setuptools] +py-modules = ["src"] [project.optional-dependencies] sim = [ @@ -37,7 +43,6 @@ sim = [ tests = [ "pytest", # "pytest-asyncio~=0.20.3", - # "pytest-asyncio @ https://github.com/joeblackwaslike/pytest-asyncio/releases/download/v0.20.4.dev42/pytest_asyncio-0.20.4.dev42-py3-none-any.whl", "pytest-asyncio @ https://github.com/joeblackwaslike/pytest-asyncio/releases/download/v0.20.4.dev43/pytest_asyncio-0.20.4.dev43-py3-none-any.whl", "pytest-mock", "pytest-cov", @@ -46,12 +51,11 @@ tests = [ dev = [ "pre-commit", "tox", - "tox-pdm", "mypy", - "wemake-python-styleguide", - "IPython", - "black", + "ipython", + "ruff", ] + docs = [ "sphinx", "pallets-sphinx-themes", @@ -59,22 +63,10 @@ docs = [ "sphinxcontrib-log-cabinet", ] -[tool.pdm.build] -source-includes = [ - "docs/", - # "examples/", - "tests/", - "CHANGES.rst", - "pdm.lock", - "tox.ini", -] -excludes = [ - "docs/_build", -] - [tool.pytest.ini_options] addopts = "-rsx --tb=short --loop-scope session" testpaths = ["tests"] +pythonpath = ["src"] filterwarnings = ["error"] asyncio_mode = "auto" py311_task = true @@ -85,23 +77,104 @@ branch = true source = ["src", "tests"] [tool.coverage.paths] -source = ["src", "*/site-packages"] - -[tool.isort] -profile = "black" -src_paths = ["src", "tests", "examples"] -force_single_line = true -use_parentheses = true -atomic = true -lines_after_imports = 2 -line_length = 100 -order_by_type = false -known_first_party = ["quart_sqlalchemy", "tests"] +source = ["src"] + + +[tool.ruff] +# Ruff config: https://docs.astral.sh/ruff/settings +target-version = "py312" +line-length = 100 +fix = true + +[tool.ruff.lint] +select = [ + "A", # flake8-builtins + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "C90", # maccabe + "COM", # flake8-commas + "D", # pydocstyle + "DTZ", # flake8-datetimez + "E", # pycodestyle + "ERA", # flake8-eradicate + "EXE", # flake8-executable + "F", # pyflakes + "FBT", # flake8-boolean-trap + "FLY", # pyflint + "FURB", # refurb + "G", # flake8-logging-format + "I", # isort + "ICN", # flake8-import-conventions + "ISC", # flake8-implicit-str-concat + "LOG", # flake8-logging + "N", # pep8-naming + "PERF", # perflint + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-raise + "RUF", # ruff + "S", # flake8-bandit + "SIM", # flake8-simpify + "SLF", # flake8-self + "SLOT", # flake8-slots + "T100", # flake8-debugger + "TRY", # tryceratops + "UP", # pyupgrade + "W", # pycodestyle + "YTT", # flake8-2020 +] +ignore = [ + "A005", # allow to shadow stdlib and builtin module names + "COM812", # trailing comma, conflicts with `ruff format` + # Different doc rules that we don't really care about: + "D100", + "D104", + "D106", + "D203", + "D212", + "D401", + "D404", + "D405", + "E501", # line too long + "E731", # do not assign lambda + "ISC001", # implicit string concat conflicts with `ruff format` + "ISC003", # prefer explicit string concat over implicit concat + "PLR09", # we have our own complexity rules + "PLR2004", # do not report magic numbers + "PLR6301", # do not require classmethod / staticmethod when self not used + "TRY003", # long exception messages from `tryceratops` +] + +# Plugin configs: +flake8-import-conventions.banned-from = [ "ast" ] +flake8-quotes.inline-quotes = "double" +mccabe.max-complexity = 6 +pydocstyle.convention = "google" + +[tool.ruff.lint.per-file-ignores] +"tests/*.py" = [ + "S101", # asserts + "S105", # hardcoded passwords + "S404", # subprocess calls are for tests + "S603", # do not require `shell=True` + "S607", # partial executable paths + "D103", # docstrings on public functions +] + +[tool.ruff.format] +preview = true +quote-style = "double" +indent-style = "space" +docstring-code-format = false [tool.mypy] -python_version = "3.7" +python_version = "3.12" plugins = ["pydantic.mypy", "sqlalchemy.ext.mypy.plugin"] -files = ["src/quart_sqlalchemy", "tests"] +files = ["src", "tests"] show_error_codes = true pretty = true strict = true @@ -120,81 +193,3 @@ module = [ ] ignore_missing_imports = true -[tool.pylint.messages_control] -max-line-length = 100 -disable = ["invalid-name", "missing-docstring", "protected-access"] - -[tool.flakeheaven] -baseline = ".flakeheaven_baseline" -exclude = ["W503"] - -[tool.flakeheaven.plugins] -"flake8-*" = ["+*"] -"flake8-docstrings" = [ - "+*", - "-D100", - "-D101", - "-D102", - "-D103", - "-D106", - "-D107", - "-D401", -] -"flake8-quotes" = [ - "+*", - "-Q000", -] -"flake8-isort" = [ - "+*", - "-I001", - "-I003", - "-I005", -] -"flake8-bandit" = [ - "+*", - "-S101", -] -"mccabe" = ["+*"] -"pycodestyle" = ["+*"] -"pyflakes" = [ - "+*", -] -"wemake-python-styleguide" = [ - "+*", - "-WPS110", - "-WPS111", - "-WPS115", - "-WPS118", - "-WPS120", # allow variables with trailing underscore - "-WPS201", - "-WPS204", - "-WPS210", - "-WPS211", - "-WPS214", - "-WPS221", - "-WPS224", - "-WPS225", # allow multiple except in try block - "-WPS226", - "-WPS230", - "-WPS231", - "-WPS232", - "-WPS238", # allow multiple raises in function - "-WPS305", # allow f-strings - "-WPS306", - "-WPS326", - "-WPS337", # allow multi-line conditionals - "-WPS338", - "-WPS420", # allow pass keyword - "-WPS429", - "-WPS430", # allow nested functions - "-WPS431", - "-WPS432", - "-WPS433", - "-WPS435", - "-WPS437", - "-WPS463", # Unsure what it means "Found a getter without a return value" - "-WPS473", - "-WPS503", - "-WPS505", - "-WPS604", # allow pass inside 'class' body -] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 0c8f6eb..0000000 --- a/setup.cfg +++ /dev/null @@ -1,73 +0,0 @@ -# Config goes in pyproject.toml unless a tool doesn't support that. - -[flake8] -# B = bugbear -# E = pycodestyle errors -# F = flake8 pyflakes -# W = pycodestyle warnings -# B9 = bugbear opinions -# ISC = implicit-str-concat - -show-source = true -max-line-length = 100 -min-name-length = 2 -max-name-length = 20 -max-methods = 12 - -nested-classes-whitelist = - Meta - Params - Config - Defaults - -ignore = - # allow f strings - WPS305 - WPS430 - WPS463 - -allowed-domain-names = - db - value - val - vals - values - result - results - -exclude = - .git - .github - .mypy_cache - .pytest_cache - __pycache__ - __pypackages__ - venv - .venv - artwork - build - dist - docs - examples - old - -extend-select = - # bugbear - B - # bugbear opinions - B9 - # implicit str concat - ISC - -extend-ignore = - # slice notation whitespace, invalid - E203 - # line length, handled by bugbear B950 - E501 - # bare except, handled by bugbear B001 - E722 - # zip with strict=, requires python >= 3.10 - B905 - # string formatting opinion, B028 renamed to B907 - B028 - B907 diff --git a/src/quart_sqlalchemy/signals.py b/src/quart_sqlalchemy/signals.py index ecef762..cc52336 100644 --- a/src/quart_sqlalchemy/signals.py +++ b/src/quart_sqlalchemy/signals.py @@ -3,16 +3,13 @@ import sqlalchemy import sqlalchemy.orm from blinker import Namespace -from quart.signals import AsyncNamespace - sa = sqlalchemy -sync_signals = Namespace() -async_signals = AsyncNamespace() +signals = Namespace() -before_bind_engine_created = sync_signals.signal( +before_bind_engine_created = signals.signal( "quart-sqlalchemy.bind.engine.created.before", doc="""Called before a bind creates an engine. @@ -25,7 +22,7 @@ def handler( ... """, ) -after_bind_engine_created = sync_signals.signal( +after_bind_engine_created = signals.signal( "quart-sqlalchemy.bind.engine.created.after", doc="""Called after a bind creates an engine. @@ -40,7 +37,7 @@ def handler( """, ) -before_bind_session_factory_created = sync_signals.signal( +before_bind_session_factory_created = signals.signal( "quart-sqlalchemy.bind.session_factory.created.before", doc="""Called before a bind creates a session_factory. @@ -49,7 +46,7 @@ def handler(sender: t.Union[Bind, AsyncBind], options: Dict[str, Any]) -> None: ... """, ) -after_bind_session_factory_created = sync_signals.signal( +after_bind_session_factory_created = signals.signal( "quart-sqlalchemy.bind.session_factory.created.after", doc="""Called after a bind creates a session_factory. @@ -64,7 +61,7 @@ def handler( ) -bind_context_entered = sync_signals.signal( +bind_context_entered = signals.signal( "quart-sqlalchemy.bind.context.entered", doc="""Called when a bind context is entered. @@ -79,7 +76,7 @@ def handler( """, ) -bind_context_exited = sync_signals.signal( +bind_context_exited = signals.signal( "quart-sqlalchemy.bind.context.exited", doc="""Called when a bind context is exited. @@ -95,7 +92,7 @@ def handler( ) -before_framework_extension_initialization = sync_signals.signal( +before_framework_extension_initialization = signals.signal( "quart-sqlalchemy.framework.extension.initialization.before", doc="""Fired before SQLAlchemy.init_app(app) is called. @@ -104,7 +101,7 @@ def handle(sender: QuartSQLAlchemy, app: Quart): ... """, ) -after_framework_extension_initialization = sync_signals.signal( +after_framework_extension_initialization = signals.signal( "quart-sqlalchemy.framework.extension.initialization.after", doc="""Fired after SQLAlchemy.init_app(app) is called. @@ -115,7 +112,7 @@ def handle(sender: QuartSQLAlchemy, app: Quart): ) -framework_extension_load_fixtures = sync_signals.signal( +framework_extension_load_fixtures = signals.signal( "quart-sqlalchemy.framework.extension.fixtures.load", doc="""Fired to load fixtures into a fresh database. diff --git a/src/quart_sqlalchemy/util.py b/src/quart_sqlalchemy/util.py index 9544bb7..8cdf512 100644 --- a/src/quart_sqlalchemy/util.py +++ b/src/quart_sqlalchemy/util.py @@ -11,6 +11,11 @@ import sqlalchemy.ext.asyncio import sqlalchemy.orm import sqlalchemy.util +from bases import alphabet +from bases import encoding +from reedsolo import ReedSolomonError +from reedsolo import RSCodec +from speck import SpeckCipher sa = sqlalchemy @@ -18,6 +23,38 @@ T = t.TypeVar("T") +alphabet.register(base62=alphabet.string_alphabet.StringAlphabet(alphabet.base64.chars[:-2])) +b62 = alphabet.get("base62") +b62enc = encoding.make(b62, kind="simple-enc", case_sensitive=True, name="base62") + +speck = SpeckCipher(0x123456789ABCDEF00FEDCBA987654321) +rsc = RSCodec(10) + + +def encrypt_id(id: int) -> str: + cipher_text = speck.encrypt(id) + cipher_bytes = bytearray(cipher_text.to_bytes(length=16, byteorder="big")) + checksummed = rsc.encode(cipher_bytes) + encoded = b62enc.encode(checksummed) + return encoded + + +def decrypt_id(id: str) -> int: + decoded = b62enc.decode(id) + try: + decoded = rsc.decode(decoded) + except ReedSolomonError: + raise ValueError("Invalid checksum") + cipher_text = int.from_bytes(decoded, byteorder="big") + return speck.decrypt(cipher_text) + + +# id = 9999999999 +# cipher_text = speck.encrypt(id) +# cipher_bytes = bytearray(cipher_text.to_bytes(length=16, byteorder="big")) +# checksummed = rsc.encode(cipher_bytes) +# encoded = b62enc.encode(checksummed) + class lazy_property(t.Generic[T]): """Lazily-evaluated property decorator. @@ -71,7 +108,11 @@ def sqlachanges(sa_object): Returns the changes made to this object so far this session, in {'propertyname': [listofvalues] } format. """ attrs = sa.inspect(sa_object).attrs - return {a.key: list(reversed(a.history.sum())) for a in attrs if len(a.history.sum()) > 1} + return { + a.key: list(reversed(a.history.sum())) + for a in attrs + if len(a.history.sum()) > 1 + } def camel_to_snake_case(name: str) -> str: diff --git a/tox.ini b/tox.ini index 5efed8a..afc8567 100644 --- a/tox.ini +++ b/tox.ini @@ -1,9 +1,7 @@ [tox] isolated_build = true envlist = - py3{11,10,9,8,7} - pypy3{9,8,7} - py310-lowest + py3{11,10,9} style typing docs @@ -12,20 +10,11 @@ isolated_build = true [testenv] groups = tests +allowlist_externals = uv deps = lowest: flask==2.2 lowest: sqlalchemy==1.4.18 -commands = pytest -v --tb=short --basetemp={envtmpdir} {posargs} - -[testenv:style] -groups = pre-commit -skip_install = true -commands = poetry run pre-commit run --all-files --show-diff-on-failure - -[testenv:typing] -groups = mypy -commands = mypy - -[testenv:docs] -groups = docs -commands = sphinx-build -W -b html -d {envtmpdir}/doctrees docs {envtmpdir}/html +commands = + uv sync --python {envpython} + uv pip install -e .[tests] + uv run python -m pytest -v --tb=short --doctest-modules tests --basetemp={envtmpdir} {posargs} diff --git a/workspace.code-workspace b/workspace.code-workspace index aa9e12e..e8fe344 100644 --- a/workspace.code-workspace +++ b/workspace.code-workspace @@ -5,15 +5,19 @@ } ], "settings": { - "python.terminal.activateEnvInCurrentTerminal": true, - "python.formatting.provider": "black", - "[python]": { - "editor.formatOnSave": true, - "editor.codeActionsOnSave": { - "source.organizeImports": true - } - }, - "esbonio.sphinx.confDir": "", - "python.linting.pylintEnabled": false + "[python]": { + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + }, + "editor.defaultFormatter": "charliermarsh.ruff" + }, + "python.testing.pytestEnabled": true, + "notebook.formatOnSave.enabled": true, + "notebook.codeActionsOnSave": { + "notebook.source.fixAll": "explicit", + "notebook.source.organizeImports": "explicit" + } } }