Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
run: alembic upgrade head && alembic check

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v6
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6
with:
files: server/coverage.xml
flags: backend
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ jobs:
EOF

- name: Create release
uses: softprops/action-gh-release@v2
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # v2
with:
tag_name: ${{ github.ref_name }}
name: "WrzDJ ${{ github.ref_name }}"
Expand Down
27 changes: 27 additions & 0 deletions server/alembic/versions/033_add_user_token_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Add token_version column to users table.

SECURITY (CRIT-2): enables JWT revocation. Every JWT carries a `tv` claim
that must match the user's token_version. Bumping the version (via logout
or admin action) invalidates all outstanding tokens for that user.

Revision ID: 033
Revises: 032
"""

import sqlalchemy as sa

from alembic import op

revision = "033"
down_revision = "032"


def upgrade() -> None:
op.add_column(
"users",
sa.Column("token_version", sa.Integer(), nullable=False, server_default="0"),
)


def downgrade() -> None:
op.drop_column("users", "token_version")
19 changes: 18 additions & 1 deletion server/app/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,27 @@ def login(
if settings.is_lockout_enabled:
lockout_manager.record_success(client_ip, username)

access_token = create_access_token(data={"sub": user.username})
access_token = create_access_token(data={"sub": user.username, "tv": user.token_version})
return Token(access_token=access_token)


@router.post("/logout", response_model=StatusMessageResponse)
@limiter.limit("30/minute")
def logout(
request: Request,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StatusMessageResponse:
"""Invalidate all outstanding JWTs for the current user.

SECURITY (CRIT-2): bumps token_version so every previously-issued JWT
for this user fails the version check in get_current_user.
"""
current_user.token_version += 1
db.commit()
return StatusMessageResponse(status="ok", message="Logged out")


@router.get("/me", response_model=UserOut)
@limiter.limit("60/minute")
def get_me(request: Request, current_user: User = Depends(get_current_user)) -> User:
Expand Down
3 changes: 3 additions & 0 deletions server/app/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def get_current_user(db: Session = Depends(get_db), token: str = Depends(oauth2_
raise credentials_exception
if not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
# CRIT-2: reject tokens whose version doesn't match the user's current version
if token_data.token_version != user.token_version:
raise credentials_exception
return user


Expand Down
19 changes: 19 additions & 0 deletions server/app/api/kiosk.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ def _resolve_event_name(db: Session, event_code: str | None) -> str | None:
return event.name if event else None


def _assert_caller_owns_event(event: Event, user: User) -> None:
"""Enforce that the caller owns the target event (or is an admin).

SECURITY (CRIT-3, CRIT-4): before this check, any DJ could pair or
reassign a kiosk to an event owned by another DJ by supplying the
victim's event code. See docs/security/audit-2026-04-08.md.
"""
if user.role == "admin":
return
if event.created_by_user_id != user.id:
raise HTTPException(status_code=403, detail="You do not own this event")


# ── Public endpoints ───────────────────────────────────────────────────


Expand Down Expand Up @@ -134,6 +147,9 @@ def complete_kiosk_pairing(
if not event:
raise HTTPException(status_code=404, detail="Event not found")

# CRIT-3: caller must own the target event (or be admin)
_assert_caller_owns_event(event, current_user)

try:
complete_pairing(db, kiosk, event.code, current_user.id)
except ValueError as e:
Expand Down Expand Up @@ -204,6 +220,9 @@ def assign_kiosk(
if not event:
raise HTTPException(status_code=404, detail="Event not found")

# CRIT-4: caller must own the target event (or be admin)
_assert_caller_owns_event(event, current_user)

assign_kiosk_event(db, kiosk, event.code)
return KioskOut(
id=kiosk.id,
Expand Down
28 changes: 25 additions & 3 deletions server/app/api/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
import logging
from typing import Any

from fastapi import APIRouter, Request
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.orm import Session
from sse_starlette.sse import EventSourceResponse

from app.api.deps import get_db
from app.core.rate_limit import limiter
from app.services.event import EventLookupResult, get_event_by_code_with_status
from app.services.event_bus import get_event_bus

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -46,18 +50,36 @@ async def _event_generator(


@router.get("/events/{code}/stream")
async def event_stream(code: str, request: Request) -> EventSourceResponse:
@limiter.limit("10/minute")
async def event_stream(
code: str,
request: Request,
db: Session = Depends(get_db),
) -> EventSourceResponse:
"""Public SSE endpoint for real-time event updates.

SECURITY (CRIT-5): rate-limited and existence-checked. Before this fix,
the endpoint had no rate limit and no existence check, allowing
unauthenticated DoS (unlimited long-lived connections exhausting FDs)
and passive eavesdropping via 6-char event-code brute force.

Event types:
- request_created: New request submitted
- request_status_changed: Request status update
- now_playing_changed: Now-playing track update
- requests_bulk_update: Batch accept/reject
- bridge_status_changed: Bridge connect/disconnect
"""
event, result = get_event_by_code_with_status(db, code)
if result == EventLookupResult.NOT_FOUND:
raise HTTPException(status_code=404, detail="Event not found")
if result == EventLookupResult.ARCHIVED:
raise HTTPException(status_code=410, detail="Event has been archived")
if result == EventLookupResult.EXPIRED:
raise HTTPException(status_code=410, detail="Event has expired")

return EventSourceResponse(
_event_generator(request, code),
_event_generator(request, event.code),
media_type="text/event-stream",
headers={"X-Accel-Buffering": "no"},
)
8 changes: 7 additions & 1 deletion server/app/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime
from enum import Enum

from sqlalchemy import Boolean, DateTime, String, Text
from sqlalchemy import Boolean, DateTime, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship

from app.core.encryption import EncryptedText
Expand All @@ -27,6 +27,12 @@ class User(Base):
email: Mapped[str | None] = mapped_column(String(255), unique=True, nullable=True, index=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=utcnow)

# SECURITY (CRIT-2): JWT revocation — bumped on logout or admin force-revoke.
# Every JWT carries a `tv` claim that must match this value on decode.
token_version: Mapped[int] = mapped_column(
Integer, nullable=False, default=0, server_default="0"
)

# Tidal OAuth tokens (encrypted at rest via Fernet)
tidal_access_token: Mapped[str | None] = mapped_column(EncryptedText, nullable=True)
tidal_refresh_token: Mapped[str | None] = mapped_column(EncryptedText, nullable=True)
Expand Down
1 change: 1 addition & 0 deletions server/app/schemas/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ class Token(BaseModel):

class TokenData(BaseModel):
username: str | None = None
token_version: int = 0
14 changes: 11 additions & 3 deletions server/app/services/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

settings = get_settings()

# SECURITY (CRIT-1): JWT algorithm is a security invariant and must NEVER be
# sourced from config. Hardcoding prevents an `alg=none` confusion attack via
# env-var manipulation. See docs/security/audit-2026-04-08.md CRIT-1.
_JWT_ALGORITHM = "HS256"


def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
Expand All @@ -26,17 +31,20 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s
else:
expire = datetime.now(UTC) + timedelta(minutes=settings.jwt_expire_minutes)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_algorithm)
encoded_jwt = jwt.encode(to_encode, settings.jwt_secret, algorithm=_JWT_ALGORITHM)
return encoded_jwt


def decode_token(token: str) -> TokenData | None:
try:
payload = jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm])
payload = jwt.decode(token, settings.jwt_secret, algorithms=[_JWT_ALGORITHM])
username: str = payload.get("sub")
if username is None:
return None
return TokenData(username=username)
# CRIT-2: reject tokens without the tv claim (legacy pre-fix tokens)
if "tv" not in payload:
return None
return TokenData(username=username, token_version=payload["tv"])
except jwt.PyJWTError:
return None

Expand Down
51 changes: 51 additions & 0 deletions server/tests/test_auth_jwt_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""TDD guard for CRIT-1 — JWT algorithm must be hardcoded, not config-sourced.

An operator (or attacker with env write access) setting JWT_ALGORITHM=none
must NOT silently disable signature verification. The accepted-algorithm list
is a security invariant and must never come from config.
"""

import jwt

from app.core.config import get_settings
from app.services.auth import create_access_token, decode_token

settings = get_settings()


class TestJwtAlgorithmHardcoded:
"""CRIT-1 guard: decode only accepts HS256, regardless of settings."""

def test_decode_rejects_none_algorithm(self):
"""A token signed with alg=none must be rejected."""
unsigned = jwt.encode({"sub": "attacker"}, "", algorithm="none")
assert decode_token(unsigned) is None

def test_decode_rejects_hs512_token(self):
"""Only HS256 is accepted. An HS512 token (even forged with the
real secret) must be rejected — the algorithm whitelist is the
security boundary, not the secret."""
token = jwt.encode({"sub": "attacker"}, settings.jwt_secret, algorithm="HS512")
assert decode_token(token) is None

def test_encode_uses_hs256_regardless_of_setting(self, monkeypatch):
"""Even if settings.jwt_algorithm is mutated at runtime to an
insecure value, encode must still emit HS256."""
monkeypatch.setattr(settings, "jwt_algorithm", "none", raising=False)
token = create_access_token({"sub": "alice", "tv": 0})
header = jwt.get_unverified_header(token)
assert header["alg"] == "HS256"

def test_decode_accepts_valid_hs256(self):
"""Sanity: a legitimate HS256 token still decodes."""
token = create_access_token({"sub": "alice", "tv": 0})
td = decode_token(token)
assert td is not None
assert td.username == "alice"

def test_decode_rejects_none_alg_even_if_setting_mutated(self, monkeypatch):
"""Even if an attacker could flip settings.jwt_algorithm to 'none'
at runtime, the decode path must still reject unsigned tokens."""
monkeypatch.setattr(settings, "jwt_algorithm", "none", raising=False)
unsigned = jwt.encode({"sub": "attacker"}, "", algorithm="none")
assert decode_token(unsigned) is None
Loading
Loading