Skip to content

Commit 2ff5062

Browse files
committed
refactor code to add typing and fix linting issues
Signed-off-by: Grant Ramsay <seapagan@gmail.com>
1 parent d39565b commit 2ff5062

File tree

3 files changed

+50
-24
lines changed

3 files changed

+50
-24
lines changed

db.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,57 @@
11
"""Set up the database connection and session.""" ""
2-
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
3-
from sqlalchemy.orm import declarative_base
2+
from collections.abc import AsyncGenerator
3+
from typing import Any
4+
5+
from sqlalchemy import MetaData
6+
from sqlalchemy.ext.asyncio import (
7+
AsyncSession,
8+
async_sessionmaker,
9+
create_async_engine,
10+
)
11+
from sqlalchemy.orm import DeclarativeBase
412

513
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
6-
# DATABASE_URL = "sqlite+aiosqlite:///./test.db"
14+
# DATABASE_URL = "sqlite+aiosqlite:///./test.db" # noqa: ERA001
715
# Note that (as far as I can tell from the docs and searching) there is no need
816
# to add 'check_same_thread=False' to the sqlite connection string, as
917
# SQLAlchemy version 1.4+ will automatically add it for you when using SQLite.
1018

11-
engine = create_async_engine(DATABASE_URL, echo=False)
12-
Base = declarative_base()
13-
async_session = async_sessionmaker(engine, expire_on_commit=False)
19+
20+
class Base(DeclarativeBase):
21+
"""Base class for SQLAlchemy models.
22+
23+
All other models should inherit from this class.
24+
"""
25+
26+
metadata = MetaData(
27+
naming_convention={
28+
"ix": "ix_%(column_0_label)s",
29+
"uq": "uq_%(table_name)s_%(column_0_name)s",
30+
"ck": "ck_%(table_name)s_%(constraint_name)s",
31+
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
32+
"pk": "pk_%(table_name)s",
33+
}
34+
)
35+
36+
37+
async_engine = create_async_engine(DATABASE_URL, echo=False)
38+
async_session = async_sessionmaker(async_engine, expire_on_commit=False)
1439

1540

16-
async def get_db():
41+
async def get_db() -> AsyncGenerator[AsyncSession, Any]:
1742
"""Get a database session.
1843
1944
To be used for dependency injection.
2045
"""
21-
async with async_session() as session:
22-
async with session.begin():
23-
yield session
46+
async with async_session() as session, session.begin():
47+
yield session
2448

2549

26-
async def init_models():
50+
async def init_models() -> None:
2751
"""Create tables if they don't already exist.
2852
2953
In a real-life example we would use Alembic to manage migrations.
3054
"""
31-
async with engine.begin() as conn:
32-
# await conn.run_sync(Base.metadata.drop_all)
55+
async with async_engine.begin() as conn:
56+
# await conn.run_sync(Base.metadata.drop_all) # noqa: ERA001
3357
await conn.run_sync(Base.metadata.create_all)

main.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
"""An example of using FastAPI with Async SQLAlchemy 2."""
2+
from collections.abc import AsyncGenerator, Sequence
23
from contextlib import asynccontextmanager
4+
from typing import Any
35

46
import uvicorn
5-
from fastapi import Depends, FastAPI
6-
from sqlalchemy import select
7-
87
from db import get_db, init_models
8+
from fastapi import Depends, FastAPI
99
from models import User
10+
from sqlalchemy import select
11+
from sqlalchemy.ext.asyncio import AsyncSession
1012

1113

1214
@asynccontextmanager
13-
async def lifespan(app: FastAPI):
15+
async def lifespan(app: FastAPI) -> AsyncGenerator[Any, None]: # noqa: ARG001
1416
"""Run tasks before and after the server starts."""
1517
await init_models()
1618
yield
@@ -20,25 +22,26 @@ async def lifespan(app: FastAPI):
2022

2123

2224
@app.get("/")
23-
async def root():
25+
async def root() -> dict[str, str]:
2426
"""Root endpoint."""
2527
return {"message": "Test API for FastAPI and Async SQLAlchemy ."}
2628

2729

2830
@app.post("/users/")
29-
async def create_user(name: str, email: str, session=Depends(get_db)):
31+
async def create_user(
32+
name: str, email: str, session: AsyncSession = Depends(get_db)
33+
) -> User:
3034
"""Add a user."""
3135
user = User(name=name, email=email)
3236
session.add(user)
3337
return user
3438

3539

3640
@app.get("/users/")
37-
async def get_users(session=Depends(get_db)):
41+
async def get_users(session: AsyncSession = Depends(get_db)) -> Sequence[User]:
3842
"""Get all users."""
3943
result = await session.execute(select(User))
40-
users = result.scalars().all()
41-
return users
44+
return result.scalars().all()
4245

4346

4447
if __name__ == "__main__":

models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Define Models used in this example."""
2-
from sqlalchemy import Column, Integer, String
3-
42
from db import Base
3+
from sqlalchemy import Column, Integer, String
54

65

76
class User(Base):

0 commit comments

Comments
 (0)