Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ migrate: clean-api ## Run database migrations; or specify a revision: `make
migrate-new: clean-api ## Autogenerate a new database migration: `make migrate-new ARGS='Description here'`
$(DOCKER_COMPOSE) run --rm -u root -w /code --entrypoint alembic api revision --autogenerate -m "$(ARGS)"

poetry-add: clean-api ## Add a poetry dependency: `make poetry-add ARGS='pytest --group dev'`
$(DOCKER_COMPOSE) run --rm -e STANDALONE=true --no-deps -u root -w /code --entrypoint poetry api add $(ARGS)
poetry-%: clean-api ## Run arbitrary poetry actions with support for optional ARGS; e.g. `make poetry-lock`
$(DOCKER_COMPOSE) run --rm -e STANDALONE=true --no-deps -u root -w /code --entrypoint poetry api $* $(ARGS)

# This ensures that even if they pass in an empty value, we default to parsing the "api" folder
ifndef FILEPATH
Expand Down
13 changes: 5 additions & 8 deletions api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ class SomeModel(db.AlchemyBase):

from sqlalchemy import (
BigInteger,
Binary,
Boolean,
Column,
Date,
Expand Down Expand Up @@ -73,8 +72,6 @@ class SomeModel(db.AlchemyBase):
within_group,
)
from sqlalchemy.dialects.postgresql import JSONB, TIMESTAMP, UUID
from sqlalchemy.engine import RowProxy
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import (
Query,
Expand All @@ -83,6 +80,7 @@ class SomeModel(db.AlchemyBase):
backref,
contains_eager,
joinedload,
registry,
relationship,
sessionmaker,
)
Expand All @@ -106,7 +104,6 @@ class SomeModel(db.AlchemyBase):
BigInteger,
Integer,
SmallInteger,
Binary,
LargeBinary,
Boolean,
Date,
Expand Down Expand Up @@ -171,7 +168,6 @@ class SomeModel(db.AlchemyBase):
Table,
UniqueConstraint,
Query,
RowProxy,
hybrid_property,
# ORM
flag_modified,
Expand All @@ -181,8 +177,8 @@ class SomeModel(db.AlchemyBase):
)

# Setup base engine and session class
engine = create_engine(settings.postgres_url, echo=settings.debug)
SessionLocal = sessionmaker(bind=engine)
engine = create_engine(settings.postgres_url, echo=settings.debug, future=True)
SessionLocal = sessionmaker(engine)

# Setup our declarative base
meta = MetaData(
Expand All @@ -194,6 +190,7 @@ class SomeModel(db.AlchemyBase):
"pk": "pk_%(table_name)s",
}
)
AlchemyBase = declarative_base(metadata=meta)
mapper_registry = registry(metadata=meta)
AlchemyBase = mapper_registry.generate_base()

UTCTimestamp = TIMESTAMP(timezone=True)
12 changes: 6 additions & 6 deletions api/depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy import select

from .db import Session, SessionLocal
from .environment import settings
Expand Down Expand Up @@ -58,12 +59,11 @@ def get_current_user(
if user_badge is None or jwt_hex is None:
raise CredentialsException()
jwt_id = uuid.UUID(hex=jwt_hex)
current_user = session.query(User).filter(User.badge == user_badge).first()
revoked_session = (
session.query(UserRevokedToken)
.filter(UserRevokedToken.revoked_uuid == jwt_id)
.first()
)
stmt = select(User).where(User.badge == user_badge)
current_user = session.execute(stmt).scalar_one_or_none()

stmt = select(UserRevokedToken).where(UserRevokedToken.revoked_uuid == jwt_id)
revoked_session = session.execute(stmt).scalar_one_or_none()
if revoked_session or current_user is None:
raise CredentialsException()
if current_user.is_banned:
Expand Down
2 changes: 1 addition & 1 deletion api/models/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def type_weight(self):
@type_weight.expression
def type_weight(cls):
return db.case(
[
*[
(cls.card_type == value, index)
for index, value in enumerate(CARD_TYPE_ORDER)
],
Expand Down
9 changes: 5 additions & 4 deletions api/services/card.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re

from sqlalchemy import select

from api import db
from api.models.card import Card, CardConjuration
from api.models.release import Release
Expand Down Expand Up @@ -190,11 +192,10 @@ def create_card(
text,
):
conjuration_stubs.add(stubify(match.group(1)))
existing_conjurations = (
session.query(Card.id, Card.stub, Card.name)
.filter(Card.stub.in_(conjuration_stubs), Card.is_legacy.is_(False))
.all()
stmt = select(Card.id, Card.stub, Card.name).where(
Card.stub.in_(conjuration_stubs), Card.is_legacy.is_(False)
)
existing_conjurations = session.execute(stmt).all()
existing_stubs = set(x.stub for x in existing_conjurations)
missing_conjurations = conjuration_stubs.symmetric_difference(existing_stubs)
if missing_conjurations:
Expand Down
105 changes: 55 additions & 50 deletions api/services/deck.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import defaultdict
from operator import itemgetter

from sqlalchemy import select
from sqlalchemy.sql import Select
from starlette.requests import Request

from api import db
Expand Down Expand Up @@ -74,30 +76,28 @@ def create_or_update_deck(
# Tracks if dice or cards changed, as this necessitates resetting the export flag
needs_new_export = False
if deck_id:
deck = (
session.query(Deck)
stmt = (
select(Deck)
.options(
db.joinedload("cards"),
db.joinedload("dice"),
db.joinedload("selected_cards"),
db.joinedload(Deck.cards),
db.joinedload(Deck.dice),
db.joinedload(Deck.selected_cards),
)
.get(deck_id)
.where(Deck.id == deck_id)
)
deck = session.execute(stmt).unique().scalar_one()
deck.title = title
deck.description = description
deck.phoenixborn_id = phoenixborn.id
deck.modified = now
if deck.is_red_rains != is_red_rains:
if (
session.query(Deck)
.filter(
Deck.source_id == deck_id,
Deck.is_snapshot.is_(True),
Deck.is_public.is_(True),
Deck.is_deleted.is_(False),
)
.count()
):
stmt = select(db.func.count(Deck.id)).where(
Deck.source_id == deck_id,
Deck.is_snapshot.is_(True),
Deck.is_public.is_(True),
Deck.is_deleted.is_(False),
)
if session.execute(stmt).scalar():
raise RedRainsConversionFailed()
deck.is_red_rains = is_red_rains
else:
Expand Down Expand Up @@ -141,16 +141,16 @@ def create_or_update_deck(
if tutor_map:
card_stubs.update(tutor_map.keys())
card_stubs.update(tutor_map.values())
minimal_cards = (
session.query(Card.id, Card.stub, Card.name, Card.card_type, Card.phoenixborn)
stmt = (
select(Card.id, Card.stub, Card.name, Card.card_type, Card.phoenixborn)
.join(Card.release)
.filter(
.where(
Card.stub.in_(card_stubs),
Card.is_legacy.is_(False),
Release.is_public == True,
)
.all()
)
minimal_cards = session.execute(stmt).all()
for card in minimal_cards:
# Minimal cards could include bogus cards thanks to first_five list and similar, so fall
# back to zero to ensure this is something with a count
Expand Down Expand Up @@ -307,8 +307,7 @@ def create_snapshot_for_deck(
return snapshot


def get_decks_query(
session: db.Session,
def get_decks_stmt(
show_legacy=False,
show_red_rains=False,
is_public=False,
Expand All @@ -319,17 +318,17 @@ def get_decks_query(
cards: list[str] | None = None,
players: list[str] | None = None,
show_preconstructed=False,
) -> db.Query:
query = session.query(Deck).filter(
):
stmt = select(Deck).where(
Deck.is_legacy.is_(show_legacy),
Deck.is_deleted.is_(False),
Deck.is_red_rains.is_(show_red_rains),
)
if show_preconstructed:
query = query.filter(Deck.is_preconstructed.is_(True))
stmt = stmt.where(Deck.is_preconstructed.is_(True))
if is_public:
deck_comp = db.aliased(Deck)
query = query.outerjoin(
stmt = stmt.outerjoin(
deck_comp,
db.and_(
Deck.source_id == deck_comp.source_id,
Expand All @@ -341,40 +340,40 @@ def get_decks_query(
db.and_(Deck.created == deck_comp.created, Deck.id < deck_comp.id),
),
),
).filter(
).where(
deck_comp.id.is_(None), Deck.is_snapshot.is_(True), Deck.is_public.is_(True)
)
else:
query = query.filter(Deck.is_snapshot.is_(False))
stmt = stmt.where(Deck.is_snapshot.is_(False))
if q and q.strip():
query = query.filter(
stmt = stmt.where(
db.func.to_tsvector("english", db.cast(Deck.title, db.Text)).match(
to_prefixed_tsquery(q)
)
)
# Filter by Phoenixborn stubs (this is always an OR comparison between Phoenixborn)
if phoenixborn:
query = query.join(Card, Card.id == Deck.phoenixborn_id).filter(
stmt = stmt.join(Card, Card.id == Deck.phoenixborn_id).where(
Card.stub.in_(phoenixborn)
)
# Filter by cards (this is always an OR comparison between cards)
if cards:
card_table = db.aliased(Card)
query = (
query.join(DeckCard, DeckCard.deck_id == Deck.id)
stmt = (
stmt.join(DeckCard, DeckCard.deck_id == Deck.id)
.join(card_table, card_table.id == DeckCard.card_id)
.filter(card_table.stub.in_(cards))
.where(card_table.stub.in_(cards))
)
# Filter by player badge, and always ensure that we eagerly load the user object
if players:
query = (
query.join(User, User.id == Deck.user_id)
.filter(User.badge.in_(players))
stmt = (
stmt.join(User, User.id == Deck.user_id)
.where(User.badge.in_(players))
.options(db.contains_eager(Deck.user))
)
else:
query = query.options(db.joinedload(Deck.user))
return query.order_by(getattr(Deck.created, order)())
stmt = stmt.options(db.joinedload(Deck.user))
return stmt.order_by(getattr(Deck.created, order)())


def add_conjurations(card_id_to_conjuration_mapping, root_card_id, conjuration_set):
Expand All @@ -399,12 +398,12 @@ def add_conjurations(card_id_to_conjuration_mapping, root_card_id, conjuration_s

def get_conjuration_mapping(session: db.Session, card_ids: set[int]) -> dict:
"""Gathers top-level conjurations into a mapping keyed off the root card ID"""
conjuration_results = (
session.query(Card, CardConjuration.card_id.label("root_card"))
stmt = (
select(Card, CardConjuration.card_id.label("root_card"))
.join(CardConjuration, Card.id == CardConjuration.conjuration_id)
.filter(CardConjuration.card_id.in_(card_ids))
.all()
.where(CardConjuration.card_id.in_(card_ids))
)
conjuration_results = session.execute(stmt).all()
card_id_to_conjurations = defaultdict(list)
for result in conjuration_results:
card_id_to_conjurations[result.root_card].append(result.Card)
Expand Down Expand Up @@ -507,7 +506,7 @@ def generate_deck_dict(


def paginate_deck_listing(
query: db.Query,
stmt: Select,
session: db.Session,
request: Request,
paging: PaginationOptions,
Expand All @@ -516,7 +515,7 @@ def paginate_deck_listing(
"""Generates a paginated deck listing using as few queries as possible."""
# Gather our paginated results
output = paginated_results_for_query(
query=query, paging=paging, url=str(request.url)
session=session, stmt=stmt, paging=paging, url=str(request.url)
)
# Parse through the decks so that we can load their cards en masse with a single query
deck_ids = set()
Expand All @@ -526,12 +525,14 @@ def paginate_deck_listing(
# Ensure we lookup our Phoenixborn cards
needed_cards.add(deck_row.phoenixborn_id)
# Fetch and collate our dice information for all decks
deck_dice = session.query(DeckDie).filter(DeckDie.deck_id.in_(deck_ids)).all()
deckdie_stmt = select(DeckDie).where(DeckDie.deck_id.in_(deck_ids))
deck_dice = session.execute(deckdie_stmt).scalars().all()
deck_id_to_dice = defaultdict(list)
for deck_die in deck_dice:
deck_id_to_dice[deck_die.deck_id].append(deck_die)
# Now that we have all our basic deck information, look up the cards and quantities they include
deck_cards = session.query(DeckCard).filter(DeckCard.deck_id.in_(deck_ids)).all()
deckcard_stmt = select(DeckCard).where(DeckCard.deck_id.in_(deck_ids))
deck_cards = session.execute(deckcard_stmt).scalars().all()
deck_id_to_deck_cards = defaultdict(list)
for deck_card in deck_cards:
needed_cards.add(deck_card.card_id)
Expand All @@ -541,7 +542,8 @@ def paginate_deck_listing(
session=session, card_ids=needed_cards
)
# Now that we have root-level conjurations, we can gather all our cards and setup our decks
cards = session.query(Card).filter(Card.id.in_(needed_cards)).all()
card_stmt = select(Card).where(Card.id.in_(needed_cards))
cards = session.execute(card_stmt).scalars().all()
card_id_to_card = {x.id: x for x in cards}
deck_output = []
for deck in output["results"]:
Expand All @@ -568,16 +570,19 @@ def deck_to_dict(
"""Converts a Deck object into an output dict using as few queries as possible."""
needed_cards = set()
needed_cards.add(deck.phoenixborn_id)
deck_cards = session.query(DeckCard).filter(DeckCard.deck_id == deck.id).all()
stmt = select(DeckCard).where(DeckCard.deck_id == deck.id)
deck_cards = session.execute(stmt).scalars().all()
for deck_card in deck_cards:
needed_cards.add(deck_card.card_id)
deck_dice = session.query(DeckDie).filter(DeckDie.deck_id == deck.id).all()
stmt = select(DeckDie).where(DeckDie.deck_id == deck.id)
deck_dice = session.execute(stmt).scalars().all()
# And finally we need to fetch all top-level conjurations
card_id_to_conjurations = get_conjuration_mapping(
session=session, card_ids=needed_cards
)
# Now that we have root-level conjurations, we can gather all our cards and generate deck output
cards = session.query(Card).filter(Card.id.in_(needed_cards)).all()
stmt = select(Card).where(Card.id.in_(needed_cards))
cards = session.execute(stmt).scalars().all()
card_id_to_card = {x.id: x for x in cards}
deck_dict = generate_deck_dict(
deck=deck,
Expand Down
Loading