|
1 | 1 | """Set up the database connection and session.""" "" |
2 | | -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine |
3 | | -from sqlalchemy.orm import declarative_base |
| 2 | +from collections.abc import AsyncGenerator |
| 3 | +from typing import Any |
| 4 | + |
| 5 | +from sqlalchemy import MetaData |
| 6 | +from sqlalchemy.ext.asyncio import ( |
| 7 | + AsyncSession, |
| 8 | + async_sessionmaker, |
| 9 | + create_async_engine, |
| 10 | +) |
| 11 | +from sqlalchemy.orm import DeclarativeBase |
4 | 12 |
|
5 | 13 | DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/postgres" |
6 | | -# DATABASE_URL = "sqlite+aiosqlite:///./test.db" |
| 14 | +# DATABASE_URL = "sqlite+aiosqlite:///./test.db" # noqa: ERA001 |
7 | 15 | # Note that (as far as I can tell from the docs and searching) there is no need |
8 | 16 | # to add 'check_same_thread=False' to the sqlite connection string, as |
9 | 17 | # SQLAlchemy version 1.4+ will automatically add it for you when using SQLite. |
10 | 18 |
|
11 | | -engine = create_async_engine(DATABASE_URL, echo=False) |
12 | | -Base = declarative_base() |
13 | | -async_session = async_sessionmaker(engine, expire_on_commit=False) |
| 19 | + |
| 20 | +class Base(DeclarativeBase): |
| 21 | + """Base class for SQLAlchemy models. |
| 22 | +
|
| 23 | + All other models should inherit from this class. |
| 24 | + """ |
| 25 | + |
| 26 | + metadata = MetaData( |
| 27 | + naming_convention={ |
| 28 | + "ix": "ix_%(column_0_label)s", |
| 29 | + "uq": "uq_%(table_name)s_%(column_0_name)s", |
| 30 | + "ck": "ck_%(table_name)s_%(constraint_name)s", |
| 31 | + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", |
| 32 | + "pk": "pk_%(table_name)s", |
| 33 | + } |
| 34 | + ) |
| 35 | + |
| 36 | + |
| 37 | +async_engine = create_async_engine(DATABASE_URL, echo=False) |
| 38 | +async_session = async_sessionmaker(async_engine, expire_on_commit=False) |
14 | 39 |
|
15 | 40 |
|
16 | | -async def get_db(): |
| 41 | +async def get_db() -> AsyncGenerator[AsyncSession, Any]: |
17 | 42 | """Get a database session. |
18 | 43 |
|
19 | 44 | To be used for dependency injection. |
20 | 45 | """ |
21 | | - async with async_session() as session: |
22 | | - async with session.begin(): |
23 | | - yield session |
| 46 | + async with async_session() as session, session.begin(): |
| 47 | + yield session |
24 | 48 |
|
25 | 49 |
|
26 | | -async def init_models(): |
| 50 | +async def init_models() -> None: |
27 | 51 | """Create tables if they don't already exist. |
28 | 52 |
|
29 | 53 | In a real-life example we would use Alembic to manage migrations. |
30 | 54 | """ |
31 | | - async with engine.begin() as conn: |
32 | | - # await conn.run_sync(Base.metadata.drop_all) |
| 55 | + async with async_engine.begin() as conn: |
| 56 | + # await conn.run_sync(Base.metadata.drop_all) # noqa: ERA001 |
33 | 57 | await conn.run_sync(Base.metadata.create_all) |
0 commit comments