Skip to content
Draft

authn #440

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions elroy/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class User(SQLModel, table=True):
__table_args__ = {"extend_existing": True}
id: Optional[int] = Field(default=None, primary_key=True)
token: str = Field(..., description="The unique token for the user")
email: Optional[str] = Field(None, description="User email address", unique=True)
password_hash: Optional[str] = Field(None, description="Hashed password for authentication")
created_at: datetime = Field(default_factory=utc_now, nullable=False)
updated_at: datetime = Field(default_factory=utc_now, nullable=False) # noqa F841

Expand Down
35 changes: 35 additions & 0 deletions elroy/db/postgres/alembic/versions/add_user_auth_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""add user authentication fields

Revision ID: add_user_auth_fields
Revises: b360a1f1b06e
Create Date: 2025-08-03 12:00:00.000000

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op
from sqlmodel.sql.sqltypes import AutoString

# revision identifiers, used by Alembic.
revision: str = "add_user_auth_fields"
down_revision: Union[str, None] = "b360a1f1b06e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Add email and password_hash columns
op.add_column("user", sa.Column("email", AutoString(), nullable=True))
op.add_column("user", sa.Column("password_hash", AutoString(), nullable=True))

# Create unique index on email
op.create_index("ix_user_email", "user", ["email"], unique=True)


def downgrade() -> None:
# Drop the columns
op.drop_index("ix_user_email", "user")
op.drop_column("user", "password_hash")
op.drop_column("user", "email")
35 changes: 35 additions & 0 deletions elroy/db/sqlite/alembic/versions/add_user_auth_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""add user authentication fields

Revision ID: add_user_auth_fields
Revises: f880962b9187
Create Date: 2025-08-03 12:00:00.000000

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op
from sqlmodel.sql.sqltypes import AutoString

# revision identifiers, used by Alembic.
revision: str = "add_user_auth_fields"
down_revision: Union[str, None] = "f880962b9187"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Add email and password_hash columns
op.add_column("user", sa.Column("email", AutoString(), nullable=True))
op.add_column("user", sa.Column("password_hash", AutoString(), nullable=True))

# Create unique index on email
op.create_index("ix_user_email", "user", ["email"], unique=True)


def downgrade() -> None:
# Drop the columns
op.drop_index("ix_user_email", "user")
op.drop_column("user", "password_hash")
op.drop_column("user", "email")
8 changes: 8 additions & 0 deletions elroy/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import secrets
import string
import threading
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
Expand Down Expand Up @@ -27,6 +29,12 @@ def run_async(thread_pool: ThreadPoolExecutor, coro):
return thread_pool.submit(asyncio.run, coro).result()


def generate_random_string(length: int = 32) -> str:
"""Generate a cryptographically secure random string."""
alphabet = string.ascii_letters + string.digits
return "".join(secrets.choice(alphabet) for _ in range(length))


def is_blank(input: Optional[str]) -> bool:
assert isinstance(input, (str, type(None)))
return not input or not input.strip()
Expand Down
59 changes: 59 additions & 0 deletions elroy/web_api/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
from datetime import datetime, timedelta, timezone
from typing import Optional

import jwt
from passlib.context import CryptContext

from ..db.db_models import User

# Password hashing configuration
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

# JWT configuration
SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30


def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return pwd_context.verify(plain_password, hashed_password)


def get_password_hash(password: str) -> str:
"""Hash a password."""
return pwd_context.hash(password)


def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a JWT access token."""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt


def verify_token(token: str) -> Optional[dict]:
"""Verify and decode a JWT token."""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except jwt.ExpiredSignatureError:
return None
except jwt.JWTError:
return None


def authenticate_user(email: str, password: str, db_session) -> Optional[User]:
"""Authenticate a user by email and password."""
user = db_session.query(User).filter(User.email == email).first()
if not user or not user.password_hash:
return None
if not verify_password(password, user.password_hash):
return None
return user
17 changes: 12 additions & 5 deletions elroy/web_api/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from typing import List

from fastapi import FastAPI
from fastapi import Depends, FastAPI
from pydantic import BaseModel

from elroy.api import Elroy
from elroy.repository.memories.models import MemoryResponse

from ..db.db_models import User
from .routes.auth import get_current_user
from .routes.auth import router as auth_router

app = FastAPI(title="Elroy API", version="1.0.0", log_level="info")

# Include authentication routes
app.include_router(auth_router)

# Style note: do not catch and reraise errors, outside of specific error handling, let regular errors propagate.


Expand Down Expand Up @@ -39,7 +46,7 @@ class ApiResponse(BaseModel):


@app.get("/get_current_messages", response_model=List[MessageResponse])
async def get_current_messages():
async def get_current_messages(current_user: User = Depends(get_current_user)):
"""Return a list of current messages in the conversation context."""
elroy = Elroy()
elroy.ctx
Expand All @@ -52,14 +59,14 @@ async def get_current_messages():


@app.post("/create_augmented_memory", response_model=ApiResponse)
async def create_augmented_memory(request: MemoryRequest):
async def create_augmented_memory(request: MemoryRequest, current_user: User = Depends(get_current_user)):
elroy = Elroy()
result = elroy.create_augmented_memory(request.text)
return ApiResponse(result=result)


@app.get("/get_current_memories", response_model=List[MemoryResponse])
async def get_current_memories():
async def get_current_memories(current_user: User = Depends(get_current_user)):
"""Return a list of memories for the current user."""
elroy = Elroy()
elroy.ctx
Expand All @@ -72,7 +79,7 @@ async def get_current_memories():


@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
async def chat(request: ChatRequest, current_user: User = Depends(get_current_user)):
"""Process a user message and return the updated conversation."""
elroy = Elroy()
elroy.message(request.message)
Expand Down
132 changes: 132 additions & 0 deletions elroy/web_api/routes/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from datetime import timedelta

from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel, EmailStr

from ...core.ctx import ElroyContext
from ...db.db_models import User
from ...utils.utils import generate_random_string
from ..auth import (
ACCESS_TOKEN_EXPIRE_MINUTES,
authenticate_user,
create_access_token,
get_password_hash,
verify_token,
)

router = APIRouter(prefix="/auth", tags=["authentication"])
security = HTTPBearer()


class UserRegistration(BaseModel):
email: EmailStr
password: str


class UserLogin(BaseModel):
email: EmailStr
password: str


class Token(BaseModel):
access_token: str
token_type: str


class UserResponse(BaseModel):
id: int
email: str
token: str


def get_db_session():
"""Get database session dependency."""
ctx = ElroyContext()
return ctx.db.get_session()


def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> User:
"""Get current authenticated user from JWT token."""
db_session = get_db_session()

payload = verify_token(credentials.credentials)
if payload is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)

email: str = payload.get("sub")
if email is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)

user = db_session.query(User).filter(User.email == email).first()
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)

return user


@router.post("/register", response_model=UserResponse)
async def register(user_data: UserRegistration):
"""Register a new user."""
db_session = get_db_session()

# Check if user already exists
existing_user = db_session.query(User).filter(User.email == user_data.email).first()
if existing_user:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered")

# Create new user
user = User(
email=user_data.email,
password_hash=get_password_hash(user_data.password),
token=generate_random_string(32), # Generate unique token
)

db_session.add(user)
db_session.commit()
db_session.refresh(user)

return UserResponse(id=user.id, email=user.email, token=user.token)


@router.post("/login", response_model=Token)
async def login(user_data: UserLogin):
"""Login user and return JWT token."""
db_session = get_db_session()

user = authenticate_user(user_data.email, user_data.password, db_session)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)

access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(data={"sub": user.email}, expires_delta=access_token_expires)

return {"access_token": access_token, "token_type": "bearer"}


@router.post("/logout")
async def logout(current_user: User = Depends(get_current_user)):
"""Logout user (token invalidation handled client-side)."""
return {"message": "Successfully logged out"}


@router.get("/me", response_model=UserResponse)
async def get_current_user_info(current_user: User = Depends(get_current_user)):
"""Get current user information."""
return UserResponse(id=current_user.id, email=current_user.email, token=current_user.token)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ dependencies = [
"apscheduler>=3.11.0",
"fastapi>=0.104.0",
"uvicorn>=0.24.0",
"PyJWT>=2.8.0",
"passlib[bcrypt]>=1.7.4",
"python-multipart>=0.0.6",
]

[project.optional-dependencies]
Expand Down
Loading