diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9fad6a8..09b4e67 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: trailing-whitespace - id: requirements-txt-fixer - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.9.7 + rev: v0.11.0 hooks: - id: ruff args: diff --git a/sql_db_utils/__version__.py b/sql_db_utils/__version__.py index 5becc17..5c4105c 100644 --- a/sql_db_utils/__version__.py +++ b/sql_db_utils/__version__.py @@ -1 +1 @@ -__version__ = "1.0.0" +__version__ = "1.0.1" diff --git a/sql_db_utils/asyncio/declarative_utils.py b/sql_db_utils/asyncio/declarative_utils.py index 7477f46..a63cae2 100644 --- a/sql_db_utils/asyncio/declarative_utils.py +++ b/sql_db_utils/asyncio/declarative_utils.py @@ -21,18 +21,18 @@ class DeclarativeUtils: """ async def __new__( - cls, raw_database: str, project_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False + cls, raw_database: str, tenant_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False ) -> None: obj = super().__new__(cls) - obj.__init__(raw_database, project_id, session_manager, schema, raw_db) + obj.__init__(raw_database, tenant_id, session_manager, schema, raw_db) await obj._get_declarative_module() return obj def __init__( - self, raw_database: str, project_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False + self, raw_database: str, tenant_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False ) -> None: self.raw_database: str = raw_database - self.project_id: str = project_id + self.tenant_id: str = tenant_id self.session_manager: SQLSessionManager = session_manager self.raw_db = raw_db self.schema = schema @@ -43,29 +43,27 @@ async def _pre_check(self): await self._get_declarative_module() async def _prepare_declarative_file(self, refresh: bool = False): - declarative_project_directory = PathConfig.DECLARATIVES_PATH / self.project_id - if not declarative_project_directory.exists(): - declarative_project_directory.mkdir(parents=True) - project_init_file = declarative_project_directory / "__init__.py" - if not project_init_file.exists(): - with open(project_init_file, "w") as f: + declarative_tenant_directory = PathConfig.DECLARATIVES_PATH / self.tenant_id + if not declarative_tenant_directory.exists(): + declarative_tenant_directory.mkdir(parents=True) + tenant_init_file = declarative_tenant_directory / "__init__.py" + if not tenant_init_file.exists(): + with open(tenant_init_file, "w") as f: f.write("") - declarative_file = declarative_project_directory / f"async_{self.raw_database}_{self.schema}.py" + declarative_file = declarative_tenant_directory / f"async_{self.raw_database}_{self.schema}.py" if declarative_file.exists() and ModuleConfig.DEFER_GEN_REFRESH and not refresh: - return f"{self.project_id}.async_{self.raw_database}_{self.schema}" + return f"{self.tenant_id}.async_{self.raw_database}_{self.schema}" try: logging.debug(f"Attempting to create declarative file: {declarative_file}") from sql_db_utils.asyncio.codegen import UTDeclarativeGenerator - session = await self.session_manager.get_session( - self.raw_database, None if self.raw_db else self.project_id - ) + session = await self.session_manager.get_session(self.raw_database, None if self.raw_db else self.tenant_id) meta = MetaData() async with session.bind.begin() as conn: await conn.run_sync(meta.reflect, schema=self.schema) with open(declarative_file, "w", encoding="utf-8") as f: generator = UTDeclarativeGenerator( - raw_database=self.raw_database if self.raw_db else f"{self.project_id}__{self.raw_database}", + raw_database=self.raw_database if self.raw_db else f"{self.tenant_id}__{self.raw_database}", metadata=meta, bind=session.bind, options=set(), @@ -79,7 +77,7 @@ async def _prepare_declarative_file(self, refresh: bool = False): except Exception as e: logging.error(f"Error creating declarative file: {e}") return False - return f"{self.project_id}.async_{self.raw_database}_{self.schema}" + return f"{self.tenant_id}.async_{self.raw_database}_{self.schema}" async def _get_declarative_module(self): # NOSONAR if declarative_module_path := await self._prepare_declarative_file(): @@ -97,8 +95,20 @@ async def _get_declarative_module(self): # NOSONAR try: import asyncio - loop = asyncio.get_event_loop() - loop.stop() + logging.warning("Emergency shutdown required - gracefully canceling tasks") + loop = asyncio.get_running_loop() + tasks = [t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task()] + logging.debug(f"Canceling {len(tasks)} pending tasks") + + for task in tasks: + task.cancel() + + # Wait for all tasks to complete with cancellation + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + logging.info("Tasks gracefully canceled, exiting") + sys.exit(1) except ImportError: logging.error("Not asyncio module, stopping using sys.exit") sys.exit(1) @@ -169,105 +179,54 @@ def get_declarative_utils_factory( self, raw_database: str, session_manager: SQLSessionManager, - security_enabled: bool = True, ): - if security_enabled: - try: - from ut_security_util import MetaInfoSchema - - async def get_declarative_utils( - meta: MetaInfoSchema, - schema: Annotated[str, Query] = PostgresConfig.PG_DEFAULT_SCHEMA, - ) -> DeclarativeUtils: - global declarative_utils - if declarative_util := declarative_utils.get(f"{raw_database}_{meta.project_id}_{schema}"): - await declarative_util._pre_check() - return declarative_util - else: - declarative_util = await DeclarativeUtils( - raw_database, meta.project_id, session_manager, schema - ) - declarative_utils[f"{raw_database}_{meta.project_id}_{schema}"] = declarative_util - return declarative_util - - return get_declarative_utils - except ImportError: - logging.error("ut_security_util not installed, please install it to use security features") - raise - else: - - async def get_declarative_utils( - project_id: Annotated[str, Cookie], schema: Annotated[str, Query] = PostgresConfig.PG_DEFAULT_SCHEMA - ) -> DeclarativeUtils: - global declarative_utils - if declarative_util := declarative_utils.get(f"{raw_database}_{project_id}_{schema}"): - await declarative_util._pre_check() - return declarative_util - else: - declarative_util = await DeclarativeUtils(raw_database, project_id, session_manager, schema) - declarative_utils[f"{raw_database}_{project_id}_{schema}"] = declarative_util - return declarative_util - - return get_declarative_utils + async def get_declarative_utils( + tenant_id: Annotated[str, Cookie], schema: Annotated[str, Query] = PostgresConfig.PG_DEFAULT_SCHEMA + ) -> DeclarativeUtils: + global declarative_utils + if declarative_util := declarative_utils.get(f"{raw_database}_{tenant_id}_{schema}"): + await declarative_util._pre_check() + return declarative_util + else: + declarative_util = await DeclarativeUtils(raw_database, tenant_id, session_manager, schema) + declarative_utils[f"{raw_database}_{tenant_id}_{schema}"] = declarative_util + return declarative_util + + return get_declarative_utils def get_schema_mandated_declarative_utils_factory( self, raw_database: str, session_manager: SQLSessionManager, schema: str, - security_enabled: bool = True, ): - if security_enabled: - try: - from ut_security_util import MetaInfoSchema - - async def get_declarative_utils( - meta: MetaInfoSchema, - ) -> DeclarativeUtils: - global declarative_utils - if declarative_util := declarative_utils.get(f"{raw_database}_{meta.project_id}_{schema}"): - await declarative_util._pre_check() - return declarative_util - else: - declarative_util = await DeclarativeUtils( - raw_database, meta.project_id, session_manager, schema - ) - declarative_utils[f"{raw_database}_{meta.project_id}_{schema}"] = declarative_util - return declarative_util - - return get_declarative_utils - except ImportError: - logging.error("ut_security_util not installed, please install it to use security features") - raise - else: - - async def get_declarative_utils(project_id: Annotated[str, Cookie]) -> DeclarativeUtils: - global declarative_utils - if declarative_util := declarative_utils.get(f"{raw_database}_{project_id}_{schema}"): - await declarative_util._pre_check() - return declarative_util - else: - declarative_util = await DeclarativeUtils(raw_database, project_id, session_manager, schema) - declarative_utils[f"{raw_database}_{project_id}_{schema}"] = declarative_util - return declarative_util - - return get_declarative_utils + async def get_declarative_utils(tenant_id: Annotated[str, Cookie]) -> DeclarativeUtils: + global declarative_utils + if declarative_util := declarative_utils.get(f"{raw_database}_{tenant_id}_{schema}"): + await declarative_util._pre_check() + return declarative_util + else: + declarative_util = await DeclarativeUtils(raw_database, tenant_id, session_manager, schema) + declarative_utils[f"{raw_database}_{tenant_id}_{schema}"] = declarative_util + return declarative_util + + return get_declarative_utils async def get_declarative_utils( self, raw_database: str, - project_id: str, + tenant_id: str, session_manager: SQLSessionManager, schema: str = PostgresConfig.PG_DEFAULT_SCHEMA, raw_db: bool = False, ) -> DeclarativeUtils: global declarative_utils - if declarative_util := declarative_utils.get(f"{raw_database}_{project_id}_{schema}"): + if declarative_util := declarative_utils.get(f"{raw_database}_{tenant_id}_{schema}"): await declarative_util._pre_check() return declarative_util else: - declarative_util = await DeclarativeUtils(raw_database, project_id, session_manager, schema, raw_db) - declarative_utils[f"{raw_database}_{project_id}_{schema}"] = declarative_util + declarative_util = await DeclarativeUtils(raw_database, tenant_id, session_manager, schema, raw_db) + declarative_utils[f"{raw_database}_{tenant_id}_{schema}"] = declarative_util return declarative_util diff --git a/sql_db_utils/asyncio/session_management.py b/sql_db_utils/asyncio/session_management.py index 3132cba..6154a8f 100644 --- a/sql_db_utils/asyncio/session_management.py +++ b/sql_db_utils/asyncio/session_management.py @@ -1,7 +1,6 @@ import logging -from typing import Any, AsyncGenerator, Union +from typing import Annotated, Any, AsyncGenerator, Union -from redis import Redis from sqlalchemy import Engine, MetaData, NullPool, text from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine @@ -14,19 +13,18 @@ class SQLSessionManager: - def __init__(self, redis_project_db: Union[Redis, None] = None, database_uri: Union[str, None] = None) -> None: + __slots__ = ("_db_engines", "database_uri") + + def __init__(self, database_uri: Union[str, None] = None) -> None: self._db_engines = {} - if not redis_project_db: - from sql_db_utils.redis_connections import project_db as redis_project_db - self.redis_project_source_db = redis_project_db self.database_uri = database_uri or PostgresConfig.POSTGRES_URI def __del__(self) -> None: for engine in self._db_engines.values(): engine.dispose() - def _get_fully_qualified_db(self, database: str, project_id: Union[str, None] = None) -> str: - return f"{project_id}__{database}" if project_id else database + def _get_fully_qualified_db(self, database: str, tenant_id: Union[str, None] = None) -> str: + return f"{tenant_id}__{database}" if tenant_id else database async def _ensure_engine_connection(self, _engine_obj: Engine): for _ in range(PostgresConfig.PG_MAX_RETRY): @@ -43,9 +41,9 @@ async def _ensure_engine_connection(self, _engine_obj: Engine): logging.error("Server connection failed") async def _get_engine( - self, database: str, project_id: Union[str, None] = None, metadata: Union[MetaData, None] = None + self, database: str, tenant_id: Union[str, None] = None, metadata: Union[MetaData, None] = None ) -> AsyncSession: - qualified_db_name = self._get_fully_qualified_db(database=database, project_id=project_id) + qualified_db_name = self._get_fully_qualified_db(database=database, tenant_id=tenant_id) if not (engine := self._db_engines.get(qualified_db_name)): logging.debug(f"Creating engine for database: {qualified_db_name}") if PostgresConfig.PG_ENABLE_POOLING: @@ -89,47 +87,31 @@ async def _get_engine( async def get_session( self, database: str, - project_id: Union[str, None] = None, + tenant_id: Union[str, None] = None, metadata: Union[MetaData, None] = None, retrying: bool = False, ) -> AsyncSession: if PostgresConfig.PG_RETRY_QUERY or retrying: return AsyncSession( - bind=self._get_engine(database=database, project_id=project_id, metadata=metadata), + bind=self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata), future=True, query_cls=RetryingQuery, ) return AsyncSession( - bind=await self._get_engine(database=database, project_id=project_id, metadata=metadata), + bind=await self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata), expire_on_commit=False, future=True, ) async def get_engine_obj( - self, database: str, project_id: Union[str, None] = None, metadata: Union[MetaData, None] = None + self, database: str, tenant_id: Union[str, None] = None, metadata: Union[MetaData, None] = None ) -> AsyncEngine: - return await self._get_engine(database=database, project_id=project_id, metadata=metadata) - - def get_db_factory( - self, database: str, security_enabled: bool = True, retrying: bool = False - ) -> AsyncGenerator[AsyncSession, Any]: - if security_enabled: - try: - from ut_security_util import MetaInfoSchema - - async def get_db(meta: MetaInfoSchema): - yield await self.get_session(database=database, project_id=meta.project_id, retrying=retrying) + return await self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata) - return get_db - except ImportError: - logging.error("ut_security_util not installed, please install it to use security features") - raise - else: - from fastapi import Request + def get_db_factory(self, database: str, retrying: bool = False) -> AsyncGenerator[AsyncSession, Any]: + from fastapi import Cookie - async def get_db(request: Request): - cookies = request.cookies - project_id = cookies.get("project_id") - yield await self.get_session(database=database, project_id=project_id, retrying=retrying) + async def get_db(tenant_id: Annotated[str, Cookie]): + yield await self.get_session(database=database, tenant_id=tenant_id, retrying=retrying) - return get_db + return get_db diff --git a/sql_db_utils/config.py b/sql_db_utils/config.py index 8a8347b..aa02e81 100644 --- a/sql_db_utils/config.py +++ b/sql_db_utils/config.py @@ -42,11 +42,6 @@ def validate_my_field(cls, value): return value -class _RedisConfig(BaseSettings): - REDIS_URI: str - REDIS_PROJECT_TAGS_DB: int = 18 - - class _BasePathConf(BaseSettings): BASE_PATH: str = "/code/data" @@ -64,7 +59,6 @@ def validate_paths(self) -> Self: ModuleConfig = _ModuleConfig() PostgresConfig = _PostgresConfig() -RedisConfig = _RedisConfig() PathConfig = _PathConf() -__all__ = ["ModuleConfig", "PostgresConfig", "RedisConfig", "PathConfig"] +__all__ = ["ModuleConfig", "PostgresConfig", "PathConfig"] diff --git a/sql_db_utils/constants.py b/sql_db_utils/constants.py index 86fd307..0114dcf 100644 --- a/sql_db_utils/constants.py +++ b/sql_db_utils/constants.py @@ -1,9 +1,7 @@ from enum import StrEnum -ENUMPARENT = (StrEnum,) - -class QueryType(*ENUMPARENT): +class QueryType(StrEnum): """ An enumeration representing the different types of queries that can be executed. @@ -18,7 +16,7 @@ class QueryType(*ENUMPARENT): POLAR = "polars" -class AGGridDateTrim(*ENUMPARENT): +class AGGridDateTrim(StrEnum): """ An enumeration representing the different date trimming options diff --git a/sql_db_utils/declarative_utils.py b/sql_db_utils/declarative_utils.py index ea04fe1..d806c64 100644 --- a/sql_db_utils/declarative_utils.py +++ b/sql_db_utils/declarative_utils.py @@ -21,10 +21,10 @@ class DeclarativeUtils: """ def __init__( - self, raw_database: str, project_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False + self, raw_database: str, tenant_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False ) -> None: self.raw_database: str = raw_database - self.project_id: str = project_id + self.tenant_id: str = tenant_id self.session_manager: SQLSessionManager = session_manager self.declarative_module = None self.raw_db = raw_db @@ -36,26 +36,26 @@ def _pre_check(self): self._get_declarative_module() def _prepare_declarative_file(self, refresh: bool = False): - declarative_project_directory = PathConfig.DECLARATIVES_PATH / self.project_id - if not declarative_project_directory.exists(): - declarative_project_directory.mkdir(parents=True) - project_init_file = declarative_project_directory / "__init__.py" - if not project_init_file.exists(): - with open(project_init_file, "w") as f: + declarative_tenant_directory = PathConfig.DECLARATIVES_PATH / self.tenant_id + if not declarative_tenant_directory.exists(): + declarative_tenant_directory.mkdir(parents=True) + tenant_init_file = declarative_tenant_directory / "__init__.py" + if not tenant_init_file.exists(): + with open(tenant_init_file, "w") as f: f.write("") - declarative_file = declarative_project_directory / f"{self.raw_database}_{self.schema}.py" + declarative_file = declarative_tenant_directory / f"{self.raw_database}_{self.schema}.py" if declarative_file.exists() and ModuleConfig.DEFER_GEN_REFRESH and not refresh: - return f"{self.project_id}.{self.raw_database}_{self.schema}" + return f"{self.tenant_id}.{self.raw_database}_{self.schema}" try: logging.debug(f"Attempting to create declarative file: {declarative_file}") from sql_db_utils.codegen import UTDeclarativeGenerator - session = self.session_manager.get_session(self.raw_database, None if self.raw_db else self.project_id) + session = self.session_manager.get_session(self.raw_database, None if self.raw_db else self.tenant_id) meta = MetaData() meta.reflect(bind=session.bind, schema=self.schema) with open(declarative_file, "w", encoding="utf-8") as f: generator = UTDeclarativeGenerator( - raw_database=self.raw_database if self.raw_db else f"{self.project_id}__{self.raw_database}", + raw_database=self.raw_database if self.raw_db else f"{self.tenant_id}__{self.raw_database}", metadata=meta, bind=session.bind, options=set(), @@ -69,7 +69,7 @@ def _prepare_declarative_file(self, refresh: bool = False): except Exception as e: logging.error(f"Error creating declarative file: {e}") return False - return f"{self.project_id}.{self.raw_database}_{self.schema}" + return f"{self.tenant_id}.{self.raw_database}_{self.schema}" def _get_declarative_module(self): # NOSONAR if declarative_module_path := self._prepare_declarative_file(): @@ -83,9 +83,29 @@ def _get_declarative_module(self): # NOSONAR logging.debug("Module import failed due to module creation, exiting module for force restart") try: import asyncio + import concurrent.futures + import time - loop = asyncio.get_event_loop() - loop.stop() + loop = asyncio.get_running_loop() + tasks = [t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task()] + + for task in tasks: + task.cancel() + + # Use run_coroutine_threadsafe since we're in a sync function + future = asyncio.run_coroutine_threadsafe( + asyncio.gather(*tasks, return_exceptions=True), loop + ) + + # Wait for tasks to be cancelled (with timeout) + try: + future.result(timeout=5) + except concurrent.futures.TimeoutError: + logging.warning("Timeout while waiting for tasks to cancel") + + # Give the loop a moment to process the cancellations + time.sleep(0.1) + loop.call_soon_threadsafe(loop.stop) except ImportError: logging.error("Not asyncio module, stopping using sys.exit") sys.exit(1) @@ -151,104 +171,56 @@ def __init__(self) -> None: declarative_utils = {} sys.path.append(str(PathConfig.DECLARATIVES_PATH)) - def get_declarative_utils_factory( - self, raw_database: str, session_manager: SQLSessionManager, security_enabled: bool = True - ): - if security_enabled: - try: - from ut_security_util import MetaInfoSchema - - def get_declarative_utils( - meta: MetaInfoSchema, - schema: Annotated[str, Query] = PostgresConfig.PG_DEFAULT_SCHEMA, - ) -> DeclarativeUtils: - global declarative_utils - if declarative_util := declarative_utils.get(f"{raw_database}_{meta.project_id}_{schema}"): - declarative_util._pre_check() - return declarative_util - else: - declarative_util = DeclarativeUtils(raw_database, meta.project_id, session_manager, schema) - declarative_utils[f"{raw_database}_{meta.project_id}_{schema}"] = declarative_util - return declarative_util - - return get_declarative_utils - except ImportError: - logging.error("ut_security_util not installed, please install it to use security features") - raise - else: - - async def get_declarative_utils( - project_id: Annotated[str, Cookie], schema: Annotated[str, Query] = PostgresConfig.PG_DEFAULT_SCHEMA - ): - global declarative_utils - if declarative_util := declarative_utils.get(f"{raw_database}_{project_id}_{schema}"): - declarative_util._pre_check() - return declarative_util - else: - declarative_util = DeclarativeUtils(raw_database, project_id, session_manager, schema) - declarative_utils[f"{raw_database}_{project_id}_{schema}"] = declarative_util - return declarative_util - - return get_declarative_utils + def get_declarative_utils_factory(self, raw_database: str, session_manager: SQLSessionManager): + async def get_declarative_utils( + tenant_id: Annotated[str, Cookie], schema: Annotated[str, Query] = PostgresConfig.PG_DEFAULT_SCHEMA + ): + global declarative_utils + if declarative_util := declarative_utils.get(f"{raw_database}_{tenant_id}_{schema}"): + declarative_util._pre_check() + return declarative_util + else: + declarative_util = DeclarativeUtils(raw_database, tenant_id, session_manager, schema) + declarative_utils[f"{raw_database}_{tenant_id}_{schema}"] = declarative_util + return declarative_util + + return get_declarative_utils def get_schema_mandated_declarative_utils_factory( self, raw_database: str, session_manager: SQLSessionManager, schema: str, - security_enabled: bool = True, ): - if security_enabled: - try: - from ut_security_util import MetaInfoSchema - - def get_declarative_utils( - meta: MetaInfoSchema, - ) -> DeclarativeUtils: - global declarative_utils - if declarative_util := declarative_utils.get(f"{raw_database}_{meta.project_id}_{schema}"): - declarative_util._pre_check() - return declarative_util - else: - declarative_util = DeclarativeUtils(raw_database, meta.project_id, session_manager, schema) - declarative_utils[f"{raw_database}_{meta.project_id}_{schema}"] = declarative_util - return declarative_util - - return get_declarative_utils - except ImportError: - logging.error("ut_security_util not installed, please install it to use security features") - raise - else: - - async def get_declarative_utils( - project_id: Annotated[str, Cookie], - ): - global declarative_utils - if declarative_util := declarative_utils.get(f"{raw_database}_{project_id}_{schema}"): - declarative_util._pre_check() - return declarative_util - else: - declarative_util = DeclarativeUtils(raw_database, project_id, session_manager, schema) - declarative_utils[f"{raw_database}_{project_id}_{schema}"] = declarative_util - return declarative_util - - return get_declarative_utils + async def get_declarative_utils( + tenant_id: Annotated[str, Cookie], + ): + global declarative_utils + if declarative_util := declarative_utils.get(f"{raw_database}_{tenant_id}_{schema}"): + declarative_util._pre_check() + return declarative_util + else: + declarative_util = DeclarativeUtils(raw_database, tenant_id, session_manager, schema) + declarative_utils[f"{raw_database}_{tenant_id}_{schema}"] = declarative_util + return declarative_util + + return get_declarative_utils def get_declarative_utils( self, raw_database: str, session_manager: SQLSessionManager, - project_id: str, + tenant_id: str, schema: str = PostgresConfig.PG_DEFAULT_SCHEMA, raw_db: bool = False, ) -> DeclarativeUtils: global declarative_utils - if declarative_util := declarative_utils.get(f"{raw_database}_{project_id}_{schema}"): + if declarative_util := declarative_utils.get(f"{raw_database}_{tenant_id}_{schema}"): declarative_util._pre_check() return declarative_util else: - declarative_util = DeclarativeUtils(raw_database, project_id, session_manager, schema, raw_db=raw_db) - declarative_utils[f"{raw_database}_{project_id}_{schema}"] = declarative_util + declarative_util = DeclarativeUtils(raw_database, tenant_id, session_manager, schema, raw_db=raw_db) + declarative_utils[f"{raw_database}_{tenant_id}_{schema}"] = declarative_util return declarative_util diff --git a/sql_db_utils/session_management.py b/sql_db_utils/session_management.py index bf8be88..1b75988 100644 --- a/sql_db_utils/session_management.py +++ b/sql_db_utils/session_management.py @@ -1,7 +1,6 @@ import logging -from typing import Callable, Union +from typing import Annotated, Callable, Union -from redis import Redis from sqlalchemy import Engine, MetaData, NullPool, create_engine, text from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Session @@ -14,15 +13,14 @@ class SQLSessionManager: - def __init__(self, redis_project_db: Union[Redis, None] = None, database_uri: Union[str, None] = None) -> None: + __slots__ = ("_db_engines", "database_uri") + + def __init__(self, database_uri: Union[str, None] = None) -> None: self._db_engines = {} - if not redis_project_db: - from sql_db_utils.redis_connections import project_db as redis_project_db - self.redis_project_source_db = redis_project_db self.database_uri = database_uri or PostgresConfig.POSTGRES_URI - def _get_fully_qualified_db(self, database: str, project_id: Union[str, None] = None) -> str: - return f"{project_id}__{database}" if project_id else database + def _get_fully_qualified_db(self, database: str, tenant_id: Union[str, None] = None) -> str: + return f"{tenant_id}__{database}" if tenant_id else database def _ensure_engine_connection(self, _engine_obj: Engine): for _ in range(PostgresConfig.PG_MAX_RETRY): @@ -39,9 +37,9 @@ def _ensure_engine_connection(self, _engine_obj: Engine): logging.error("Server connection failed") def _get_engine( - self, database: str, project_id: Union[str, None] = None, metadata: Union[MetaData, None] = None + self, database: str, tenant_id: Union[str, None] = None, metadata: Union[MetaData, None] = None ) -> Engine: - qualified_db_name = self._get_fully_qualified_db(database=database, project_id=project_id) + qualified_db_name = self._get_fully_qualified_db(database=database, tenant_id=tenant_id) if not (engine := self._db_engines.get(qualified_db_name)): logging.debug(f"Creating engine for database: {qualified_db_name}") if PostgresConfig.PG_ENABLE_POOLING: @@ -85,44 +83,30 @@ def _get_engine( def get_session( self, database: str, - project_id: Union[str, None] = None, + tenant_id: Union[str, None] = None, metadata: Union[MetaData, None] = None, retrying: bool = False, ) -> Session: if PostgresConfig.PG_RETRY_QUERY or retrying: return Session( - bind=self._get_engine(database=database, project_id=project_id, metadata=metadata), + bind=self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata), future=True, query_cls=RetryingQuery, ) return Session( - bind=self._get_engine(database=database, project_id=project_id, metadata=metadata), + bind=self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata), future=True, ) def get_engine_obj( - self, database: str, project_id: Union[str, None] = None, metadata: Union[MetaData, None] = None + self, database: str, tenant_id: Union[str, None] = None, metadata: Union[MetaData, None] = None ) -> Engine: - return self._get_engine(database=database, project_id=project_id, metadata=metadata) - - def get_db_factory(self, database: str, security_enabled: bool = True, retrying: bool = False) -> Callable: - if security_enabled: - try: - from ut_security_util import MetaInfoSchema - - def get_db(meta: MetaInfoSchema): - yield self.get_session(database=database, project_id=meta.project_id, retrying=retrying) + return self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata) - return get_db - except ImportError: - logging.error("ut_security_util not installed, please install it to use security features") - raise - else: - from fastapi import Request + def get_db_factory(self, database: str, retrying: bool = False) -> Callable: + from fastapi import Cookie - async def get_db(request: Request): - cookies = request.cookies - project_id = cookies.get("project_id") - yield self.get_session(database=database, project_id=project_id, retrying=retrying) + async def get_db(tenant_id: Annotated[str, Cookie]): + yield self.get_session(database=database, tenant_id=tenant_id, retrying=retrying) - return get_db + return get_db