Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: trailing-whitespace
- id: requirements-txt-fixer
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.9.7
rev: v0.11.0
hooks:
- id: ruff
args:
Expand Down
2 changes: 1 addition & 1 deletion sql_db_utils/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.0"
__version__ = "1.0.1"
155 changes: 57 additions & 98 deletions sql_db_utils/asyncio/declarative_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ class DeclarativeUtils:
"""

async def __new__(
cls, raw_database: str, project_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False
cls, raw_database: str, tenant_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False
) -> None:
obj = super().__new__(cls)
obj.__init__(raw_database, project_id, session_manager, schema, raw_db)
obj.__init__(raw_database, tenant_id, session_manager, schema, raw_db)
await obj._get_declarative_module()
return obj

def __init__(
self, raw_database: str, project_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False
self, raw_database: str, tenant_id: str, session_manager: SQLSessionManager, schema: str, raw_db: bool = False
) -> None:
self.raw_database: str = raw_database
self.project_id: str = project_id
self.tenant_id: str = tenant_id
self.session_manager: SQLSessionManager = session_manager
self.raw_db = raw_db
self.schema = schema
Expand All @@ -43,29 +43,27 @@ async def _pre_check(self):
await self._get_declarative_module()

async def _prepare_declarative_file(self, refresh: bool = False):
declarative_project_directory = PathConfig.DECLARATIVES_PATH / self.project_id
if not declarative_project_directory.exists():
declarative_project_directory.mkdir(parents=True)
project_init_file = declarative_project_directory / "__init__.py"
if not project_init_file.exists():
with open(project_init_file, "w") as f:
declarative_tenant_directory = PathConfig.DECLARATIVES_PATH / self.tenant_id
if not declarative_tenant_directory.exists():
declarative_tenant_directory.mkdir(parents=True)
tenant_init_file = declarative_tenant_directory / "__init__.py"
if not tenant_init_file.exists():
with open(tenant_init_file, "w") as f:
f.write("")
declarative_file = declarative_project_directory / f"async_{self.raw_database}_{self.schema}.py"
declarative_file = declarative_tenant_directory / f"async_{self.raw_database}_{self.schema}.py"
if declarative_file.exists() and ModuleConfig.DEFER_GEN_REFRESH and not refresh:
return f"{self.project_id}.async_{self.raw_database}_{self.schema}"
return f"{self.tenant_id}.async_{self.raw_database}_{self.schema}"
try:
logging.debug(f"Attempting to create declarative file: {declarative_file}")
from sql_db_utils.asyncio.codegen import UTDeclarativeGenerator

session = await self.session_manager.get_session(
self.raw_database, None if self.raw_db else self.project_id
)
session = await self.session_manager.get_session(self.raw_database, None if self.raw_db else self.tenant_id)
meta = MetaData()
async with session.bind.begin() as conn:
await conn.run_sync(meta.reflect, schema=self.schema)
with open(declarative_file, "w", encoding="utf-8") as f:
generator = UTDeclarativeGenerator(
raw_database=self.raw_database if self.raw_db else f"{self.project_id}__{self.raw_database}",
raw_database=self.raw_database if self.raw_db else f"{self.tenant_id}__{self.raw_database}",
metadata=meta,
bind=session.bind,
options=set(),
Expand All @@ -79,7 +77,7 @@ async def _prepare_declarative_file(self, refresh: bool = False):
except Exception as e:
logging.error(f"Error creating declarative file: {e}")
return False
return f"{self.project_id}.async_{self.raw_database}_{self.schema}"
return f"{self.tenant_id}.async_{self.raw_database}_{self.schema}"

async def _get_declarative_module(self): # NOSONAR
if declarative_module_path := await self._prepare_declarative_file():
Expand All @@ -97,8 +95,20 @@ async def _get_declarative_module(self): # NOSONAR
try:
import asyncio

loop = asyncio.get_event_loop()
loop.stop()
logging.warning("Emergency shutdown required - gracefully canceling tasks")
loop = asyncio.get_running_loop()
tasks = [t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task()]
logging.debug(f"Canceling {len(tasks)} pending tasks")

for task in tasks:
task.cancel()

# Wait for all tasks to complete with cancellation
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)

logging.info("Tasks gracefully canceled, exiting")
sys.exit(1)
except ImportError:
logging.error("Not asyncio module, stopping using sys.exit")
sys.exit(1)
Comment thread
faizanazim11 marked this conversation as resolved.
Expand Down Expand Up @@ -169,105 +179,54 @@ def get_declarative_utils_factory(
self,
raw_database: str,
session_manager: SQLSessionManager,
security_enabled: bool = True,
):
if security_enabled:
try:
from ut_security_util import MetaInfoSchema

async def get_declarative_utils(
meta: MetaInfoSchema,
schema: Annotated[str, Query] = PostgresConfig.PG_DEFAULT_SCHEMA,
) -> DeclarativeUtils:
global declarative_utils
if declarative_util := declarative_utils.get(f"{raw_database}_{meta.project_id}_{schema}"):
await declarative_util._pre_check()
return declarative_util
else:
declarative_util = await DeclarativeUtils(
raw_database, meta.project_id, session_manager, schema
)
declarative_utils[f"{raw_database}_{meta.project_id}_{schema}"] = declarative_util
return declarative_util

return get_declarative_utils
except ImportError:
logging.error("ut_security_util not installed, please install it to use security features")
raise
else:

async def get_declarative_utils(
project_id: Annotated[str, Cookie], schema: Annotated[str, Query] = PostgresConfig.PG_DEFAULT_SCHEMA
) -> DeclarativeUtils:
global declarative_utils
if declarative_util := declarative_utils.get(f"{raw_database}_{project_id}_{schema}"):
await declarative_util._pre_check()
return declarative_util
else:
declarative_util = await DeclarativeUtils(raw_database, project_id, session_manager, schema)
declarative_utils[f"{raw_database}_{project_id}_{schema}"] = declarative_util
return declarative_util

return get_declarative_utils
async def get_declarative_utils(
tenant_id: Annotated[str, Cookie], schema: Annotated[str, Query] = PostgresConfig.PG_DEFAULT_SCHEMA
) -> DeclarativeUtils:
global declarative_utils
if declarative_util := declarative_utils.get(f"{raw_database}_{tenant_id}_{schema}"):
await declarative_util._pre_check()
return declarative_util
else:
declarative_util = await DeclarativeUtils(raw_database, tenant_id, session_manager, schema)
declarative_utils[f"{raw_database}_{tenant_id}_{schema}"] = declarative_util
return declarative_util

return get_declarative_utils

def get_schema_mandated_declarative_utils_factory(
self,
raw_database: str,
session_manager: SQLSessionManager,
schema: str,
security_enabled: bool = True,
):
if security_enabled:
try:
from ut_security_util import MetaInfoSchema

async def get_declarative_utils(
meta: MetaInfoSchema,
) -> DeclarativeUtils:
global declarative_utils
if declarative_util := declarative_utils.get(f"{raw_database}_{meta.project_id}_{schema}"):
await declarative_util._pre_check()
return declarative_util
else:
declarative_util = await DeclarativeUtils(
raw_database, meta.project_id, session_manager, schema
)
declarative_utils[f"{raw_database}_{meta.project_id}_{schema}"] = declarative_util
return declarative_util

return get_declarative_utils
except ImportError:
logging.error("ut_security_util not installed, please install it to use security features")
raise
else:

async def get_declarative_utils(project_id: Annotated[str, Cookie]) -> DeclarativeUtils:
global declarative_utils
if declarative_util := declarative_utils.get(f"{raw_database}_{project_id}_{schema}"):
await declarative_util._pre_check()
return declarative_util
else:
declarative_util = await DeclarativeUtils(raw_database, project_id, session_manager, schema)
declarative_utils[f"{raw_database}_{project_id}_{schema}"] = declarative_util
return declarative_util

return get_declarative_utils
async def get_declarative_utils(tenant_id: Annotated[str, Cookie]) -> DeclarativeUtils:
global declarative_utils
if declarative_util := declarative_utils.get(f"{raw_database}_{tenant_id}_{schema}"):
await declarative_util._pre_check()
return declarative_util
else:
declarative_util = await DeclarativeUtils(raw_database, tenant_id, session_manager, schema)
declarative_utils[f"{raw_database}_{tenant_id}_{schema}"] = declarative_util
return declarative_util

return get_declarative_utils

async def get_declarative_utils(
self,
raw_database: str,
project_id: str,
tenant_id: str,
session_manager: SQLSessionManager,
schema: str = PostgresConfig.PG_DEFAULT_SCHEMA,
raw_db: bool = False,
) -> DeclarativeUtils:
global declarative_utils
if declarative_util := declarative_utils.get(f"{raw_database}_{project_id}_{schema}"):
if declarative_util := declarative_utils.get(f"{raw_database}_{tenant_id}_{schema}"):
await declarative_util._pre_check()
return declarative_util
else:
declarative_util = await DeclarativeUtils(raw_database, project_id, session_manager, schema, raw_db)
declarative_utils[f"{raw_database}_{project_id}_{schema}"] = declarative_util
declarative_util = await DeclarativeUtils(raw_database, tenant_id, session_manager, schema, raw_db)
declarative_utils[f"{raw_database}_{tenant_id}_{schema}"] = declarative_util
return declarative_util


Expand Down
54 changes: 18 additions & 36 deletions sql_db_utils/asyncio/session_management.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from typing import Any, AsyncGenerator, Union
from typing import Annotated, Any, AsyncGenerator, Union

from redis import Redis
from sqlalchemy import Engine, MetaData, NullPool, text
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
Expand All @@ -14,19 +13,18 @@


class SQLSessionManager:
def __init__(self, redis_project_db: Union[Redis, None] = None, database_uri: Union[str, None] = None) -> None:
__slots__ = ("_db_engines", "database_uri")

def __init__(self, database_uri: Union[str, None] = None) -> None:
self._db_engines = {}
if not redis_project_db:
from sql_db_utils.redis_connections import project_db as redis_project_db
self.redis_project_source_db = redis_project_db
self.database_uri = database_uri or PostgresConfig.POSTGRES_URI

def __del__(self) -> None:
for engine in self._db_engines.values():
engine.dispose()

def _get_fully_qualified_db(self, database: str, project_id: Union[str, None] = None) -> str:
return f"{project_id}__{database}" if project_id else database
def _get_fully_qualified_db(self, database: str, tenant_id: Union[str, None] = None) -> str:
return f"{tenant_id}__{database}" if tenant_id else database

async def _ensure_engine_connection(self, _engine_obj: Engine):
for _ in range(PostgresConfig.PG_MAX_RETRY):
Expand All @@ -43,9 +41,9 @@ async def _ensure_engine_connection(self, _engine_obj: Engine):
logging.error("Server connection failed")

async def _get_engine(
self, database: str, project_id: Union[str, None] = None, metadata: Union[MetaData, None] = None
self, database: str, tenant_id: Union[str, None] = None, metadata: Union[MetaData, None] = None
) -> AsyncSession:
qualified_db_name = self._get_fully_qualified_db(database=database, project_id=project_id)
qualified_db_name = self._get_fully_qualified_db(database=database, tenant_id=tenant_id)
if not (engine := self._db_engines.get(qualified_db_name)):
logging.debug(f"Creating engine for database: {qualified_db_name}")
if PostgresConfig.PG_ENABLE_POOLING:
Expand Down Expand Up @@ -89,47 +87,31 @@ async def _get_engine(
async def get_session(
self,
database: str,
project_id: Union[str, None] = None,
tenant_id: Union[str, None] = None,
metadata: Union[MetaData, None] = None,
retrying: bool = False,
) -> AsyncSession:
if PostgresConfig.PG_RETRY_QUERY or retrying:
return AsyncSession(
bind=self._get_engine(database=database, project_id=project_id, metadata=metadata),
bind=self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata),
future=True,
query_cls=RetryingQuery,
)
return AsyncSession(
bind=await self._get_engine(database=database, project_id=project_id, metadata=metadata),
bind=await self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata),
expire_on_commit=False,
future=True,
)

async def get_engine_obj(
self, database: str, project_id: Union[str, None] = None, metadata: Union[MetaData, None] = None
self, database: str, tenant_id: Union[str, None] = None, metadata: Union[MetaData, None] = None
) -> AsyncEngine:
return await self._get_engine(database=database, project_id=project_id, metadata=metadata)

def get_db_factory(
self, database: str, security_enabled: bool = True, retrying: bool = False
) -> AsyncGenerator[AsyncSession, Any]:
if security_enabled:
try:
from ut_security_util import MetaInfoSchema

async def get_db(meta: MetaInfoSchema):
yield await self.get_session(database=database, project_id=meta.project_id, retrying=retrying)
return await self._get_engine(database=database, tenant_id=tenant_id, metadata=metadata)

return get_db
except ImportError:
logging.error("ut_security_util not installed, please install it to use security features")
raise
else:
from fastapi import Request
def get_db_factory(self, database: str, retrying: bool = False) -> AsyncGenerator[AsyncSession, Any]:
from fastapi import Cookie

async def get_db(request: Request):
cookies = request.cookies
project_id = cookies.get("project_id")
yield await self.get_session(database=database, project_id=project_id, retrying=retrying)
async def get_db(tenant_id: Annotated[str, Cookie]):
yield await self.get_session(database=database, tenant_id=tenant_id, retrying=retrying)

return get_db
return get_db
8 changes: 1 addition & 7 deletions sql_db_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ def validate_my_field(cls, value):
return value


class _RedisConfig(BaseSettings):
REDIS_URI: str
REDIS_PROJECT_TAGS_DB: int = 18


class _BasePathConf(BaseSettings):
BASE_PATH: str = "/code/data"

Expand All @@ -64,7 +59,6 @@ def validate_paths(self) -> Self:

ModuleConfig = _ModuleConfig()
PostgresConfig = _PostgresConfig()
RedisConfig = _RedisConfig()
PathConfig = _PathConf()

__all__ = ["ModuleConfig", "PostgresConfig", "RedisConfig", "PathConfig"]
__all__ = ["ModuleConfig", "PostgresConfig", "PathConfig"]
6 changes: 2 additions & 4 deletions sql_db_utils/constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from enum import StrEnum

ENUMPARENT = (StrEnum,)


class QueryType(*ENUMPARENT):
class QueryType(StrEnum):
"""
An enumeration representing the different types of queries that can be executed.

Expand All @@ -18,7 +16,7 @@ class QueryType(*ENUMPARENT):
POLAR = "polars"


class AGGridDateTrim(*ENUMPARENT):
class AGGridDateTrim(StrEnum):
"""
An enumeration representing the different date trimming options

Expand Down
Loading