From e7aae734e187970a2b3b957d3d9790a59677a935 Mon Sep 17 00:00:00 2001 From: Michael Sitarzewski Date: Tue, 17 Feb 2026 17:45:35 -0600 Subject: [PATCH 1/4] =?UTF-8?q?v0.5.0=20=E2=80=94=20"It=20Scales"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Multi-user auth (JWT + RBAC), PostgreSQL support (asyncpg + connection pooling), Perplexity provider, Prometheus metrics, extended health checks, backup/restore CLI, per-user rate limiting, Playwright E2E tests, load tests, production deployment docs. 1354 Python tests + 12 load tests, ruff clean. Fix: create_all only for in-memory SQLite, file-based DBs use alembic exclusively. Co-Authored-By: Claude Opus 4.6 --- alembic/versions/005_v05_users.py | 76 +++ docs/guides/authentication.md | 297 ++++++++++ docs/guides/monitoring.md | 270 +++++++++ docs/guides/production-deployment.md | 314 ++++++++++ memory-bank/activeContext.md | 89 ++- memory-bank/decisions.md | 14 + memory-bank/progress.md | 33 +- memory-bank/quick-start.md | 23 +- memory-bank/roadmap.md | 97 ++- memory-bank/toc.md | 6 +- mkdocs.yml | 3 + pyproject.toml | 8 +- src/duh/__init__.py | 2 +- src/duh/api/app.py | 15 +- src/duh/api/auth.py | 190 ++++++ src/duh/api/health.py | 63 ++ src/duh/api/metrics.py | 227 +++++++ src/duh/api/middleware.py | 38 +- src/duh/api/rbac.py | 59 ++ src/duh/cli/app.py | 260 +++++++- src/duh/config/schema.py | 15 + src/duh/memory/backup.py | 267 +++++++++ src/duh/memory/models.py | 38 ++ src/duh/providers/manager.py | 81 ++- src/duh/providers/perplexity.py | 300 ++++++++++ tests/load/__init__.py | 0 tests/load/conftest.py | 1 + tests/load/test_load.py | 417 +++++++++++++ tests/unit/test_auth.py | 342 +++++++++++ tests/unit/test_backup.py | 315 ++++++++++ tests/unit/test_cli.py | 2 +- tests/unit/test_connection_pool.py | 298 ++++++++++ tests/unit/test_health.py | 182 ++++++ tests/unit/test_metrics.py | 177 ++++++ tests/unit/test_models.py | 1 + tests/unit/test_multi_user_integration.py | 686 ++++++++++++++++++++++ tests/unit/test_postgresql_config.py | 156 +++++ tests/unit/test_providers_perplexity.py | 421 +++++++++++++ tests/unit/test_rate_limiting.py | 415 +++++++++++++ tests/unit/test_rbac.py | 221 +++++++ tests/unit/test_restore.py | 455 ++++++++++++++ tests/unit/test_smoke.py | 4 +- tests/unit/test_user_model.py | 157 +++++ uv.lock | 126 +++- web/e2e/consensus.spec.ts | 56 ++ web/e2e/decision-space.spec.ts | 26 + web/e2e/navigation.spec.ts | 66 +++ web/package-lock.json | 64 ++ web/package.json | 2 + web/playwright.config.ts | 17 + 50 files changed, 7278 insertions(+), 114 deletions(-) create mode 100644 alembic/versions/005_v05_users.py create mode 100644 docs/guides/authentication.md create mode 100644 docs/guides/monitoring.md create mode 100644 docs/guides/production-deployment.md create mode 100644 src/duh/api/auth.py create mode 100644 src/duh/api/health.py create mode 100644 src/duh/api/metrics.py create mode 100644 src/duh/api/rbac.py create mode 100644 src/duh/memory/backup.py create mode 100644 src/duh/providers/perplexity.py create mode 100644 tests/load/__init__.py create mode 100644 tests/load/conftest.py create mode 100644 tests/load/test_load.py create mode 100644 tests/unit/test_auth.py create mode 100644 tests/unit/test_backup.py create mode 100644 tests/unit/test_connection_pool.py create mode 100644 tests/unit/test_health.py create mode 100644 tests/unit/test_metrics.py create mode 100644 tests/unit/test_multi_user_integration.py create mode 100644 tests/unit/test_postgresql_config.py create mode 100644 tests/unit/test_providers_perplexity.py create mode 100644 tests/unit/test_rate_limiting.py create mode 100644 tests/unit/test_rbac.py create mode 100644 tests/unit/test_restore.py create mode 100644 tests/unit/test_user_model.py create mode 100644 web/e2e/consensus.spec.ts create mode 100644 web/e2e/decision-space.spec.ts create mode 100644 web/e2e/navigation.spec.ts create mode 100644 web/playwright.config.ts diff --git a/alembic/versions/005_v05_users.py b/alembic/versions/005_v05_users.py new file mode 100644 index 0000000..055cb11 --- /dev/null +++ b/alembic/versions/005_v05_users.py @@ -0,0 +1,76 @@ +"""v0.5 users table and user_id foreign keys. + +Revision ID: 005 +Revises: 004 +Create Date: 2026-02-17 +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +revision: str = "005" +down_revision: str = "004" +branch_labels: tuple[str, ...] | None = None +depends_on: str | None = None + + +def upgrade() -> None: + op.create_table( + "users", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("email", sa.String(255), nullable=False, unique=True), + sa.Column("password_hash", sa.String(128), nullable=False), + sa.Column("display_name", sa.String(100), nullable=False), + sa.Column("role", sa.String(20), nullable=False, server_default="contributor"), + sa.Column( + "is_active", sa.Boolean(), nullable=False, server_default=sa.text("1") + ), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + ) + op.create_index("ix_users_email", "users", ["email"], unique=True) + + # Add user_id column to threads (batch mode for SQLite compatibility) + with op.batch_alter_table("threads") as batch_op: + batch_op.add_column(sa.Column("user_id", sa.String(36), nullable=True)) + batch_op.create_index("ix_threads_user_id", ["user_id"]) + batch_op.create_foreign_key( + "fk_threads_user_id", "users", ["user_id"], ["id"] + ) + + # Add user_id column to decisions + with op.batch_alter_table("decisions") as batch_op: + batch_op.add_column(sa.Column("user_id", sa.String(36), nullable=True)) + batch_op.create_index("ix_decisions_user_id", ["user_id"]) + batch_op.create_foreign_key( + "fk_decisions_user_id", "users", ["user_id"], ["id"] + ) + + # Add user_id column to api_keys + with op.batch_alter_table("api_keys") as batch_op: + batch_op.add_column(sa.Column("user_id", sa.String(36), nullable=True)) + batch_op.create_index("ix_api_keys_user_id", ["user_id"]) + batch_op.create_foreign_key( + "fk_api_keys_user_id", "users", ["user_id"], ["id"] + ) + + +def downgrade() -> None: + with op.batch_alter_table("api_keys") as batch_op: + batch_op.drop_index("ix_api_keys_user_id") + batch_op.drop_constraint("fk_api_keys_user_id", type_="foreignkey") + batch_op.drop_column("user_id") + + with op.batch_alter_table("decisions") as batch_op: + batch_op.drop_index("ix_decisions_user_id") + batch_op.drop_constraint("fk_decisions_user_id", type_="foreignkey") + batch_op.drop_column("user_id") + + with op.batch_alter_table("threads") as batch_op: + batch_op.drop_index("ix_threads_user_id") + batch_op.drop_constraint("fk_threads_user_id", type_="foreignkey") + batch_op.drop_column("user_id") + + op.drop_table("users") diff --git a/docs/guides/authentication.md b/docs/guides/authentication.md new file mode 100644 index 0000000..35b8321 --- /dev/null +++ b/docs/guides/authentication.md @@ -0,0 +1,297 @@ +# Authentication + +Manage users, API keys, JWT tokens, and role-based access control. + +## Overview + +duh supports two authentication methods: + +1. **JWT Bearer tokens** -- for users who log in with email and password +2. **API keys** -- for programmatic access via the `X-API-Key` header + +Both methods are checked by the API middleware. If no API keys exist in the database and no JWT is provided, the API runs in open mode (no authentication required). + +## User management + +### Register a user (API) + +```bash +curl -X POST http://localhost:8080/api/auth/register \ + -H "Content-Type: application/json" \ + -d '{ + "email": "alice@example.com", + "password": "strong-password-here", + "display_name": "Alice" + }' +``` + +Response: + +```json +{ + "access_token": "eyJhbGciOiJIUzI1NiIs...", + "token_type": "bearer", + "user_id": "a1b2c3d4-...", + "role": "contributor" +} +``` + +New users are assigned the `contributor` role by default. Registration can be disabled in config after your initial users are created. + +!!! warning "Disable registration in production" + After creating your admin user, set `registration_enabled = false` in `config.toml` to prevent unauthorized signups. + +### Create a user (CLI) + +The CLI lets you create users with a specific role, including admin: + +```bash +duh user-create \ + --email admin@example.com \ + --password 'strong-password' \ + --name "Admin User" \ + --role admin +``` + +Available roles: `admin`, `contributor`, `viewer`. + +### List users (CLI) + +```bash +duh user-list +``` + +Output: + +``` + a1b2c3d4 admin@example.com Admin User role=admin active + e5f6a7b8 alice@example.com Alice role=contributor active +``` + +### Log in + +```bash +curl -X POST http://localhost:8080/api/auth/login \ + -H "Content-Type: application/json" \ + -d '{ + "email": "alice@example.com", + "password": "strong-password-here" + }' +``` + +Response: + +```json +{ + "access_token": "eyJhbGciOiJIUzI1NiIs...", + "token_type": "bearer", + "user_id": "a1b2c3d4-...", + "role": "contributor" +} +``` + +### Get current user + +```bash +curl http://localhost:8080/api/auth/me \ + -H "Authorization: Bearer eyJhbGciOiJIUzI1NiIs..." +``` + +Response: + +```json +{ + "id": "a1b2c3d4-...", + "email": "alice@example.com", + "display_name": "Alice", + "role": "contributor", + "is_active": true +} +``` + +## JWT tokens + +### Using tokens + +Include the token in the `Authorization` header: + +```bash +curl http://localhost:8080/api/threads \ + -H "Authorization: Bearer eyJhbGciOiJIUzI1NiIs..." +``` + +### Token details + +- **Algorithm**: HS256 +- **Payload**: `sub` (user ID), `exp` (expiry), `iat` (issued at) +- **Default expiry**: 24 hours (configurable) + +Tokens are validated on every request by the API key middleware. An expired or invalid token returns HTTP 401. + +### Token expiry configuration + +```toml +[auth] +token_expiry_hours = 24 +``` + +Set a shorter expiry for higher security. Users will need to call `/api/auth/login` again after the token expires. + +## API keys + +API keys provide a simpler authentication method for scripts and integrations. They are passed via the `X-API-Key` header. + +### Using API keys + +```bash +curl http://localhost:8080/api/threads \ + -H "X-API-Key: duh_abc123..." +``` + +### How API keys work + +- Keys are stored as SHA-256 hashes in the database (the raw key is never stored) +- Keys can be revoked by setting a `revoked_at` timestamp +- Keys can optionally be linked to a user via `user_id` + +!!! note "API key CLI" + API key management is available through the database. A dedicated `duh key create` CLI command is planned for a future release. + +### Exempt paths + +The following paths do not require authentication: + +| Path | Purpose | +|------|---------| +| `/api/health` | Basic health check | +| `/api/health/detailed` | Detailed health check | +| `/api/metrics` | Prometheus metrics | +| `/api/auth/register` | User registration | +| `/api/auth/login` | User login | +| `/docs` | OpenAPI documentation | +| `/openapi.json` | OpenAPI spec | +| `/redoc` | ReDoc documentation | +| `/api/share/*` | Shared content | + +All other `/api/` and `/ws/` paths require either a JWT token or API key. + +## Roles and RBAC + +duh uses a hierarchical role system: **admin > contributor > viewer**. + +### Role permissions + +| Capability | Viewer | Contributor | Admin | +|-----------|--------|-------------|-------| +| Read threads and decisions | Yes | Yes | Yes | +| Create consensus queries | No | Yes | Yes | +| Create threads | No | Yes | Yes | +| Manage users | No | No | Yes | +| Full API access | No | No | Yes | + +### How RBAC works + +Endpoints use the `require_role` dependency to enforce minimum role levels: + +- `require_viewer` -- any authenticated user +- `require_contributor` -- contributors and admins +- `require_admin` -- admins only + +A user with a higher role automatically passes lower role checks. For example, an admin can access all contributor endpoints. + +### Example: role-protected requests + +**As a viewer** (read-only access): + +```bash +# List threads -- works for viewers +curl http://localhost:8080/api/threads \ + -H "Authorization: Bearer $VIEWER_TOKEN" + +# Create a query -- fails with 403 +curl -X POST http://localhost:8080/api/ask \ + -H "Authorization: Bearer $VIEWER_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"question": "test"}' +# {"detail": "Requires contributor role"} +``` + +**As a contributor** (create and view): + +```bash +# Create a consensus query -- works for contributors +curl -X POST http://localhost:8080/api/ask \ + -H "Authorization: Bearer $CONTRIBUTOR_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"question": "What are the trade-offs of microservices?"}' +``` + +**As an admin** (full access): + +```bash +# List users -- admin only +curl http://localhost:8080/api/auth/me \ + -H "Authorization: Bearer $ADMIN_TOKEN" +``` + +## Configuration + +All authentication settings live in the `[auth]` section of `config.toml`: + +```toml +[auth] +jwt_secret = "" # REQUIRED in production -- set via env or config +token_expiry_hours = 24 # how long JWT tokens remain valid +registration_enabled = true # set to false after creating your admin user +``` + +### Environment variable override + +Set `DUH_JWT_SECRET` as an environment variable instead of putting it in the config file: + +```bash +export DUH_JWT_SECRET=$(openssl rand -hex 32) +``` + +!!! warning "Never commit your JWT secret" + Use environment variables or a secrets manager for the JWT secret. Never check it into version control. + +## Rate limiting + +Rate limits apply per identity. The middleware identifies callers in this priority order: + +1. **User ID** (from JWT token) +2. **API key ID** (from `X-API-Key` header) +3. **IP address** (fallback) + +Configure rate limits in `config.toml`: + +```toml +[api] +rate_limit = 60 # requests per minute +rate_limit_window = 60 # window in seconds +``` + +When the limit is exceeded, the API returns HTTP 429 with a `Retry-After` header. + +Every response includes rate limit headers: + +``` +X-RateLimit-Limit: 60 +X-RateLimit-Remaining: 57 +X-RateLimit-Key: user:a1b2c3d4-... +``` + +## Security recommendations + +1. **Generate a strong JWT secret**: `openssl rand -hex 32` +2. **Disable registration** after creating your first admin user +3. **Use HTTPS** -- never expose the API over plain HTTP +4. **Rotate API keys** periodically and revoke unused ones +5. **Set short token expiry** for high-security environments +6. **Restrict CORS origins** to your actual domain + +## Next steps + +- [Production Deployment](production-deployment.md) -- Full deployment guide with PostgreSQL, Docker, nginx +- [Monitoring](monitoring.md) -- Health checks, metrics, alerting diff --git a/docs/guides/monitoring.md b/docs/guides/monitoring.md new file mode 100644 index 0000000..bb9b2a6 --- /dev/null +++ b/docs/guides/monitoring.md @@ -0,0 +1,270 @@ +# Monitoring + +Monitor duh with Prometheus metrics, health checks, and alerting. + +## Health checks + +### Basic health + +```bash +curl http://localhost:8080/api/health +``` + +```json +{"status": "ok"} +``` + +This endpoint returns immediately and does not check dependencies. Use it for load balancer liveness probes. + +### Detailed health + +```bash +curl http://localhost:8080/api/health/detailed +``` + +```json +{ + "status": "ok", + "version": "0.5.0", + "uptime_seconds": 3621.4, + "components": { + "database": {"status": "ok"}, + "providers": { + "anthropic": {"status": "ok"}, + "openai": {"status": "ok"}, + "google": {"status": "unhealthy"} + } + } +} +``` + +The `status` field is `"ok"` when all components are healthy, or `"degraded"` when the database is unreachable or all providers are unhealthy. Individual provider failures do not degrade the overall status unless every provider is down. + +!!! tip "Use detailed health for readiness probes" + Point your Kubernetes readiness probe or Docker healthcheck at `/api/health/detailed` to catch database connectivity issues. + +Both health endpoints are exempt from API key authentication, so they work without credentials. + +## Prometheus metrics + +Metrics are available in Prometheus text format at: + +```bash +curl http://localhost:8080/api/metrics +``` + +This endpoint is also exempt from API key authentication. + +### Available metrics + +#### Counters + +| Metric | Labels | Description | +|--------|--------|-------------| +| `duh_requests_total` | `method`, `path`, `status` | Total HTTP requests | +| `duh_consensus_runs_total` | -- | Total consensus runs completed | +| `duh_tokens_total` | `provider`, `direction` | Total tokens consumed (`direction` is `input` or `output`) | +| `duh_errors_total` | `type` | Total errors by error type | + +#### Histograms + +| Metric | Buckets | Description | +|--------|---------|-------------| +| `duh_request_duration_seconds` | 5ms -- 10s | HTTP request duration | +| `duh_consensus_duration_seconds` | 5ms -- 10s | Consensus run duration | + +Default histogram buckets: `0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0` + +#### Gauges + +| Metric | Description | +|--------|-------------| +| `duh_active_connections` | Current active connections | +| `duh_provider_health` | Provider health status (1 = healthy, 0 = unhealthy) | + +### Prometheus scrape config + +Add duh as a target in `prometheus.yml`: + +```yaml +scrape_configs: + - job_name: "duh" + scrape_interval: 15s + metrics_path: /api/metrics + static_configs: + - targets: ["localhost:8080"] +``` + +## Grafana dashboard + +### Key queries + +Use these PromQL queries to build a Grafana dashboard: + +**Request rate** (requests per second): + +```promql +rate(duh_requests_total[5m]) +``` + +**Error rate** (percentage): + +```promql +100 * rate(duh_errors_total[5m]) / rate(duh_requests_total[5m]) +``` + +**Request latency (p95)**: + +```promql +histogram_quantile(0.95, rate(duh_request_duration_seconds_bucket[5m])) +``` + +**Consensus duration (p50 and p95)**: + +```promql +histogram_quantile(0.50, rate(duh_consensus_duration_seconds_bucket[5m])) +histogram_quantile(0.95, rate(duh_consensus_duration_seconds_bucket[5m])) +``` + +**Token consumption by provider**: + +```promql +rate(duh_tokens_total[1h]) +``` + +**Active connections**: + +```promql +duh_active_connections +``` + +### Suggested dashboard panels + +| Panel | Query | Visualization | +|-------|-------|---------------| +| Request Rate | `rate(duh_requests_total[5m])` | Time series | +| Error Rate (%) | `100 * rate(duh_errors_total[5m]) / rate(duh_requests_total[5m])` | Time series with threshold | +| Latency p95 | `histogram_quantile(0.95, rate(duh_request_duration_seconds_bucket[5m]))` | Time series | +| Consensus Duration | `histogram_quantile(0.95, rate(duh_consensus_duration_seconds_bucket[5m]))` | Time series | +| Tokens by Provider | `sum by (provider) (rate(duh_tokens_total[1h]))` | Stacked bar | +| Active Connections | `duh_active_connections` | Stat | +| Health Status | Custom based on `/api/health/detailed` | Status map | + +## Alerting + +### Suggested alert rules + +Add these rules to your Prometheus alerting config or Grafana alert manager. + +**High error rate** -- fires when errors exceed 1% of requests: + +```yaml +groups: + - name: duh + rules: + - alert: DuhHighErrorRate + expr: > + rate(duh_errors_total[5m]) + / rate(duh_requests_total[5m]) + > 0.01 + for: 5m + labels: + severity: warning + annotations: + summary: "duh error rate above 1%" + description: "Error rate is {{ $value | humanizePercentage }} over the last 5 minutes." + + - alert: DuhHighLatency + expr: > + histogram_quantile(0.95, rate(duh_request_duration_seconds_bucket[5m])) + > 5 + for: 5m + labels: + severity: warning + annotations: + summary: "duh p95 latency above 5 seconds" + + - alert: DuhAllProvidersDown + expr: duh_provider_health == 0 + for: 2m + labels: + severity: critical + annotations: + summary: "All duh providers are unhealthy" + description: "No LLM providers are responding. Consensus queries will fail." + + - alert: DuhHealthDegraded + expr: > + probe_success{job="duh-health"} == 0 + for: 1m + labels: + severity: critical + annotations: + summary: "duh health check failing" +``` + +!!! note "Provider latency" + Consensus runs call multiple LLM providers sequentially. A p95 latency of 5--15 seconds is normal. Set your latency alert threshold accordingly. + +## Log configuration + +Configure logging in `config.toml`: + +```toml +[logging] +level = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL +file = "" # empty = stdout, or a file path like "/var/log/duh/duh.log" +structured = false # set to true for JSON log output +``` + +### Recommended production settings + +```toml +[logging] +level = "INFO" +file = "/var/log/duh/duh.log" +structured = true +``` + +Structured (JSON) logging makes it easier to parse logs with tools like Loki, Elasticsearch, or CloudWatch. + +### Log rotation + +If logging to a file, set up logrotate: + +``` +/var/log/duh/duh.log { + daily + rotate 14 + compress + delaycompress + missingok + notifempty + copytruncate +} +``` + +## Rate limit monitoring + +duh includes rate limit headers on every response: + +| Header | Description | +|--------|-------------| +| `X-RateLimit-Limit` | Configured requests per window | +| `X-RateLimit-Remaining` | Requests remaining in current window | +| `X-RateLimit-Key` | Identity being rate-limited (`user:`, `api_key:`, or `ip:`) | + +When the limit is exceeded, duh returns HTTP 429 with a `Retry-After` header. + +Rate limits are configured in `config.toml`: + +```toml +[api] +rate_limit = 60 # requests per minute per key +rate_limit_window = 60 # window in seconds +``` + +## Next steps + +- [Production Deployment](production-deployment.md) -- Full deployment guide +- [Authentication](authentication.md) -- User management and RBAC diff --git a/docs/guides/production-deployment.md b/docs/guides/production-deployment.md new file mode 100644 index 0000000..1da989e --- /dev/null +++ b/docs/guides/production-deployment.md @@ -0,0 +1,314 @@ +# Production Deployment + +Run duh with PostgreSQL, HTTPS, backups, and proper security for a team or organization. + +## Prerequisites + +- Linux server (Ubuntu 22.04+ or similar) +- [Docker](https://docs.docker.com/get-docker/) with Compose v2 +- A domain name (for HTTPS) +- API keys for at least one provider + +## PostgreSQL setup + +### Create the database + +```bash +sudo -u postgres psql +``` + +```sql +CREATE USER duh WITH PASSWORD 'your-secure-password'; +CREATE DATABASE duh OWNER duh; +GRANT ALL PRIVILEGES ON DATABASE duh TO duh; +\q +``` + +### Connection string + +duh uses SQLAlchemy with asyncpg. Set the database URL in your config or environment: + +``` +postgresql+asyncpg://duh:your-secure-password@localhost:5432/duh +``` + +!!! warning "Use asyncpg driver" + The connection string **must** include `+asyncpg`. Plain `postgresql://` will not work with duh's async database layer. + +### Connection pool configuration + +Add pool settings to `config.toml` for production workloads: + +```toml +[database] +url = "postgresql+asyncpg://duh:your-secure-password@localhost:5432/duh" +pool_size = 5 +max_overflow = 10 +pool_timeout = 30 +pool_recycle = 3600 +``` + +| Setting | Default | Description | +|---------|---------|-------------| +| `pool_size` | `5` | Number of persistent connections in the pool | +| `max_overflow` | `10` | Extra connections allowed beyond `pool_size` under load | +| `pool_timeout` | `30` | Seconds to wait for a connection before raising an error | +| `pool_recycle` | `3600` | Seconds before a connection is recycled (prevents stale connections) | + +duh also enables `pool_pre_ping` automatically for PostgreSQL, which verifies connections are alive before use. + +## Environment variables + +Set these on the host or in a `.env` file: + +```bash +# LLM provider keys (set at least one) +export ANTHROPIC_API_KEY=sk-ant-... +export OPENAI_API_KEY=sk-... +export GOOGLE_API_KEY=AI... +export MISTRAL_API_KEY=... +export PERPLEXITY_API_KEY=pplx-... + +# Authentication +export DUH_JWT_SECRET=$(openssl rand -hex 32) + +# Database +export DUH_DATABASE_URL=postgresql+asyncpg://duh:your-secure-password@db:5432/duh +``` + +!!! tip "Generate a strong JWT secret" + Use `openssl rand -hex 32` to generate a 64-character hex string. The secret must be at least 32 characters for production use. + +## Docker production config + +Create a `docker-compose.prod.yml`: + +```yaml +services: + db: + image: postgres:16-alpine + restart: unless-stopped + volumes: + - pgdata:/var/lib/postgresql/data + environment: + POSTGRES_USER: duh + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} + POSTGRES_DB: duh + healthcheck: + test: ["CMD-SHELL", "pg_isready -U duh"] + interval: 10s + timeout: 5s + retries: 5 + + duh: + build: . + restart: unless-stopped + ports: + - "127.0.0.1:8080:8080" + depends_on: + db: + condition: service_healthy + environment: + - DUH_DATABASE_URL=postgresql+asyncpg://duh:${POSTGRES_PASSWORD}@db:5432/duh + - DUH_JWT_SECRET=${DUH_JWT_SECRET} + - ANTHROPIC_API_KEY + - OPENAI_API_KEY + - GOOGLE_API_KEY + - MISTRAL_API_KEY + - PERPLEXITY_API_KEY + volumes: + - ./config.toml:/app/config.toml:ro + command: ["serve", "--host", "0.0.0.0", "--port", "8080"] + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8080/api/health"] + interval: 30s + timeout: 10s + retries: 3 + +volumes: + pgdata: +``` + +Create a `.env` file alongside the compose file: + +```bash +POSTGRES_PASSWORD=your-secure-password +DUH_JWT_SECRET=your-64-char-hex-secret +ANTHROPIC_API_KEY=sk-ant-... +OPENAI_API_KEY=sk-... +``` + +Start the stack: + +```bash +docker compose -f docker-compose.prod.yml up -d +``` + +!!! warning "Don't commit .env" + Add `.env` to `.gitignore`. Never put secrets in `docker-compose.yml` directly. + +## Running migrations + +After the database is running, apply schema migrations with Alembic: + +```bash +# Inside the container +docker compose -f docker-compose.prod.yml exec duh alembic upgrade head + +# Or from the host (if duh is installed locally) +DUH_DATABASE_URL=postgresql+asyncpg://duh:password@localhost:5432/duh alembic upgrade head +``` + +Run migrations every time you upgrade duh to a new version. + +## Reverse proxy + +Use nginx to terminate HTTPS in front of duh. Install nginx and create a site config: + +```nginx +server { + listen 443 ssl http2; + server_name duh.example.com; + + ssl_certificate /etc/letsencrypt/live/duh.example.com/fullchain.pem; + ssl_certificate_key /etc/letsencrypt/live/duh.example.com/privkey.pem; + + # Security headers + add_header X-Frame-Options DENY; + add_header X-Content-Type-Options nosniff; + add_header X-XSS-Protection "1; mode=block"; + add_header Strict-Transport-Security "max-age=63072000" always; + + location / { + proxy_pass http://127.0.0.1:8080; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + } + + # WebSocket support + location /ws/ { + proxy_pass http://127.0.0.1:8080; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + } +} + +server { + listen 80; + server_name duh.example.com; + return 301 https://$host$request_uri; +} +``` + +Get a free certificate with [certbot](https://certbot.eff.org/): + +```bash +sudo certbot --nginx -d duh.example.com +``` + +## Backup strategy + +### Manual backups + +```bash +# JSON backup (works with both SQLite and PostgreSQL) +duh backup /backups/duh-$(date +%Y%m%d).json --format json + +# SQLite-only: file copy +duh backup /backups/duh-$(date +%Y%m%d).db +``` + +### Restore from backup + +```bash +# Restore (replaces existing data) +duh restore /backups/duh-20260217.json + +# Merge into existing data (skip conflicts) +duh restore /backups/duh-20260217.json --merge +``` + +### Scheduled backups with cron + +Add a daily backup job: + +```bash +crontab -e +``` + +```cron +# Daily duh backup at 2:00 AM, keep last 30 days +0 2 * * * /usr/local/bin/duh backup /backups/duh-$(date +\%Y\%m\%d).json --format json 2>&1 | logger -t duh-backup +0 3 * * * find /backups -name "duh-*.json" -mtime +30 -delete +``` + +For Docker deployments, run the backup inside the container: + +```bash +0 2 * * * docker compose -f /opt/duh/docker-compose.prod.yml exec -T duh duh backup /data/backup-$(date +\%Y\%m\%d).json --format json +``` + +!!! tip "Test your restores" + A backup you have never restored is a backup that does not work. Periodically test `duh restore` against a staging database. + +## Security checklist + +Before going live, verify each item: + +- [ ] **JWT secret**: At least 32 characters, generated with `openssl rand -hex 32` +- [ ] **Disable registration**: After creating your admin user, set `registration_enabled = false` in `config.toml`: + + ```toml + [auth] + jwt_secret = "your-secret-here" + token_expiry_hours = 24 + registration_enabled = false + ``` + +- [ ] **API key management**: Create API keys for programmatic access and distribute them securely. Revoke keys that are no longer needed. +- [ ] **Rate limiting**: Configure rate limits in `config.toml` to prevent abuse: + + ```toml + [api] + rate_limit = 60 # requests per minute per key + rate_limit_window = 60 # window in seconds + ``` + +- [ ] **HTTPS only**: Never expose the API over plain HTTP. Use nginx or a load balancer for TLS termination. +- [ ] **Firewall**: Only expose ports 80 and 443 publicly. Bind duh to `127.0.0.1:8080` so it is only reachable through the reverse proxy. +- [ ] **Database credentials**: Use a dedicated database user with a strong password. Do not use the PostgreSQL superuser. +- [ ] **Environment variables**: Never hardcode secrets in config files that are committed to version control. +- [ ] **CORS origins**: Restrict `cors_origins` to your actual domain: + + ```toml + [api] + cors_origins = ["https://duh.example.com"] + ``` + +## Create your first admin user + +After deployment, create an admin user via the CLI: + +```bash +duh user-create --email admin@example.com --password 'strong-password' --name Admin --role admin +``` + +Then disable registration: + +```toml +[auth] +registration_enabled = false +``` + +Restart the service for the config change to take effect. + +## Next steps + +- [Authentication](authentication.md) -- User management, JWT tokens, RBAC +- [Monitoring](monitoring.md) -- Prometheus metrics, health checks, alerting +- [Docker](docker.md) -- Development Docker setup diff --git a/memory-bank/activeContext.md b/memory-bank/activeContext.md index e8c0718..08d8193 100644 --- a/memory-bank/activeContext.md +++ b/memory-bank/activeContext.md @@ -1,65 +1,50 @@ # Active Context **Last Updated**: 2026-02-17 -**Current Phase**: v0.4 COMPLETE — "It Has a Face" -**Next Action**: Commit v0.4 changes, merge to main, begin v0.5 planning. +**Current Phase**: v0.5 COMPLETE — "It Scales" +**Next Action**: Merge `v0.5.0` branch to main, create PR. Then begin v1.0.0 planning. --- ## Current State -- **v0.4 COMPLETE + post-v0.4 polish.** React frontend with 3D Decision Space, real-time WebSocket streaming, thread browser, preferences. Markdown rendering + light/dark mode added post-v0.4. -- **5 providers shipping**: Anthropic (3 models), OpenAI (3 models), Google (4 models), Mistral (4 models) — 14 total. -- **1318 Python tests + 117 Vitest tests** (1435 total), ruff clean, mypy strict clean. -- **50 Python source files + 66 frontend source files** (116 total). +- **v0.5 COMPLETE on branch `v0.5.0`.** All 18 tasks done. Ready to merge to main. +- **6 providers shipping**: Anthropic (3 models), OpenAI (3 models), Google (4 models), Mistral (4 models), Perplexity (3 models) — 17 total. +- **1354 Python unit/load tests + 117 Vitest tests** (1471 total), ruff clean. +- **~60 Python source files + 66 frontend source files** (~126 total). - REST API, WebSocket streaming, MCP server, Python client library, web UI all built. -- CLI commands: `duh ask`, `duh recall`, `duh threads`, `duh show`, `duh models`, `duh cost`, `duh serve`, `duh mcp`, `duh batch`, `duh export`, `duh feedback`. +- Multi-user auth (JWT + RBAC), PostgreSQL support, Prometheus metrics, backup/restore, Playwright E2E. +- CLI commands: `duh ask`, `duh recall`, `duh threads`, `duh show`, `duh models`, `duh cost`, `duh serve`, `duh mcp`, `duh batch`, `duh export`, `duh feedback`, `duh backup`, `duh restore`, `duh user-create`, `duh user-list`. +- Docs: production-deployment.md, monitoring.md, authentication.md added. - MkDocs docs site: https://msitarzewski.github.io/duh/ - GitHub repo: https://github.com/msitarzewski/duh -- Branch: `v0.3.0` (v0.4 changes uncommitted on top) -## v0.4 Summary - -### Frontend (web/) -- React 19 + Vite 6 + Tailwind 4 + TypeScript -- Three.js 3D Decision Space (R3F + drei, lazy-loaded, 873KB chunk) -- Zustand stores (consensus, threads, decision-space, preferences) -- Glassmorphism design system with 22 CSS custom properties (dark/light mode via `prefers-color-scheme`) -- Markdown rendering: react-markdown + remark-gfm + rehype-highlight + mermaid (lazy-loaded) -- Pages: Consensus, Threads, Thread Detail, Decision Space, Preferences, Share -- WebSocket-driven real-time consensus streaming -- Mobile-responsive with 2D SVG scatter fallback for Decision Space -- Page transitions, micro-interactions, ConfidenceMeter animation -- 117 Vitest tests (5 test files) - -### Backend additions -- `GET /api/decisions/space` — filtered decision data for 3D visualization -- `GET /api/share/{token}` — public share link (no auth) -- FastAPI static file serving with SPA fallback for web UI -- `duh serve` logs "Web UI: http://host:port" when dist/ exists - -### Documentation -- `docs/web-ui.md` — full web UI reference -- `docs/web-quickstart.md` — getting started guide -- Updated mkdocs.yml nav and docs/index.md - -### Docker -- Multi-stage build with Node.js 22 frontend stage -- Default CMD: `serve --host 0.0.0.0 --port 8080` -- EXPOSE 8080 - -## v0.4 Architecture (Decided) - -- **React embedded in FastAPI** — Vite builds to `web/dist/`, FastAPI mounts as static files with SPA fallback -- **Three.js code-split** — Scene3D lazy-loaded via React.lazy (873KB chunk) -- **Mermaid code-split** — lazy `import('mermaid')` only when ```mermaid blocks exist (498KB chunk) -- **Main bundle** — 617KB (includes react-markdown + highlight.js + remark-gfm) -- **Zustand for state** — 4 stores (consensus, threads, decision-space, preferences), preferences persisted via localStorage -- **CSS-only animations** — no framer-motion or JS animation libraries -- **Light/dark mode** — `prefers-color-scheme` media query + `.theme-dark`/`.theme-light` manual override classes -- **Theme system** — 22 CSS custom properties in `duh-theme.css`, `.duh-prose` typography in `animations.css` -- **WebSocket events** — phase-level streaming (propose_start, propose_content, challenge_start, etc.) -- **API proxy in dev** — Vite proxies /api and /ws to :8080 for development +## v0.5 Delivered + +**Theme**: Production hardening, multi-user, enterprise readiness. +**18 tasks across 7 phases** — all complete. + +### What Shipped +- User accounts + JWT auth + RBAC (admin/contributor/viewer) — `api/auth.py`, `api/rbac.py`, `models.py:User` +- PostgreSQL support (asyncpg) with connection pooling (`pool_pre_ping`, compound indexes) +- Perplexity provider adapter (6th provider, search-grounded) — `providers/perplexity.py` +- Prometheus metrics (`/api/metrics`) + extended health checks (`/api/health/detailed`) +- Backup/restore CLI (`duh backup`, `duh restore`) with SQLite copy + JSON export/import +- Playwright E2E browser tests (`web/e2e/`) +- Per-user + per-provider rate limiting (middleware keys by user_id > api_key > IP) +- Production deployment documentation (3 new guides) +- 26 multi-user integration tests + 12 load tests (latency, concurrency, rate limiting) +- Alembic migration `005_v05_users.py` (users table, user_id FKs on threads/decisions/api_keys) + +### New Source Files (v0.5) +- `src/duh/api/auth.py` — JWT authentication endpoints +- `src/duh/api/rbac.py` — Role-based access control +- `src/duh/api/metrics.py` — Prometheus metrics endpoint +- `src/duh/api/health.py` — Extended health checks +- `src/duh/memory/backup.py` — Backup/restore utilities +- `src/duh/providers/perplexity.py` — Perplexity provider adapter +- `alembic/versions/005_v05_users.py` — User migration +- `docs/guides/production-deployment.md`, `authentication.md`, `monitoring.md` ## Open Questions (Still Unresolved) @@ -68,5 +53,5 @@ - Vector search solution for SQLite (sqlite-vss vs ChromaDB vs FAISS) — v1.0 decision - Client library packaging: monorepo `client/` dir vs separate repo? - MCP server transport: stdio vs SSE vs streamable HTTP? -- Hosted demo economics (try.duh.dev) — deferred -- Playwright E2E tests — deferred to v0.5 +- Hosted demo economics (try.duh.dev) — deferred to post-1.0 +- A2A protocol — deferred to post-1.0 diff --git a/memory-bank/decisions.md b/memory-bank/decisions.md index b0e6752..9d49d9d 100644 --- a/memory-bank/decisions.md +++ b/memory-bank/decisions.md @@ -310,3 +310,17 @@ - Bundling mermaid eagerly (bloats main bundle from 278KB to 1.1MB) **Consequences**: Main bundle: 617KB (up from 278KB — react-markdown + highlight.js needed on all pages). Mermaid: 498KB lazy chunk only when mermaid blocks exist. Full GFM support (tables, task lists, strikethrough). Code syntax highlighting in 180+ languages. 5 components updated to use `` for LLM content. **References**: `web/src/components/shared/Markdown.tsx`, used in ConsensusComplete, PhaseCard, TurnCard, DissentBanner + +--- + +## 2026-02-17: create_all Only for In-Memory SQLite + +**Status**: Approved +**Context**: `_create_db()` in `cli/app.py` called `Base.metadata.create_all()` unconditionally. This conflicts with alembic migrations for file-based SQLite and PostgreSQL: `create_all` creates tables from current models (bypassing alembic version tracking) but cannot add columns to existing tables. When the v0.5 migration added `user_id` to `threads`, `decisions`, and `api_keys`, the `users` table was already created by `create_all` but the FK columns were missing — causing `OperationalError: no such column: api_keys.user_id` at runtime. +**Decision**: Only call `create_all` when the database URL contains `:memory:` (in-memory SQLite used by tests and dev). File-based SQLite and PostgreSQL rely exclusively on alembic migrations for schema management. +**Alternatives**: +- Keep `create_all` with `checkfirst=True` (default) — doesn't help, `create_all` can't alter existing tables +- Run alembic migrations programmatically at startup — adds complexity, conflates app startup with migration +- Remove `create_all` entirely — breaks in-memory test fixtures that don't run alembic +**Consequences**: Tests continue to work (in-memory SQLite still uses `create_all`). Production databases must run `alembic upgrade head` after code updates. This was already the expected workflow but is now enforced. +**References**: `src/duh/cli/app.py:101-104` diff --git a/memory-bank/progress.md b/memory-bank/progress.md index 9930bfe..2766f65 100644 --- a/memory-bank/progress.md +++ b/memory-bank/progress.md @@ -4,9 +4,27 @@ --- -## Current State: v0.4 COMPLETE — Web UI with 3D Decision Space - -### v0.4 Additions +## Current State: v0.5 COMPLETE — Production Hardening & Multi-User + +### v0.5 Additions + +- User accounts: User ORM model, JWT auth (bcrypt + PyJWT), RBAC (admin/contributor/viewer) +- PostgreSQL support: asyncpg driver, configurable connection pooling (pool_size, max_overflow, pool_pre_ping) +- Perplexity provider: 6th cloud provider (sonar, sonar-pro, sonar-deep-research), citation parsing +- Prometheus metrics: `/api/metrics` endpoint with counters, histograms, gauges (no external deps) +- Extended health checks: `/api/health/detailed` with DB connectivity, provider health, uptime, version +- Backup/restore: `duh backup` (SQLite copy or JSON export), `duh restore` (with `--merge` mode) +- Per-user rate limiting: middleware keys by user_id > api_key > IP, per-provider RPM limits in config +- Compound indexes: `(thread_id, created_at)` on decisions, `(category, genus)` on decisions, `(turn_id, role)` on contributions +- Playwright E2E tests: navigation, consensus form, decision space, preferences +- 26 multi-user integration tests: user isolation, admin access, registration flow, RBAC, JWT validation, deactivation +- 12 load tests: p50/p95/p99 latency, concurrent requests (10/50/100), rate limiting under load, sustained throughput +- Alembic migration `005_v05_users.py`: users table, nullable user_id FK on threads/decisions/api_keys +- 3 new docs: production-deployment.md, authentication.md, monitoring.md +- Version 0.5.0 across pyproject.toml, __init__.py, api/app.py +- 1354 Python tests + 117 Vitest tests (1471 total), ruff clean + +### v0.4 Additions (Previously Shipped) - React 19 + Vite 6 + Tailwind 4 + TypeScript frontend (66 source files) - 3D Decision Space: Three.js point cloud (R3F + drei), lazy-loaded, code-split (873KB) @@ -20,7 +38,6 @@ - Backend: /api/decisions/space endpoint, /api/share/{token}, static file serving + SPA fallback - Docker: multi-stage build with Node.js 22 frontend stage - Docs: web-ui.md, web-quickstart.md, updated mkdocs.yml -- Version 0.4.0 across pyproject.toml, __init__.py, api/app.py ### v0.3 Additions (Previously Shipped) @@ -125,3 +142,11 @@ Phase 0 benchmark framework — fully functional, pilot-tested on 5 questions. | 2026-02-17 | v0.4 MkDocs documentation (web-ui.md, web-quickstart.md) | Done | | 2026-02-17 | v0.4 Version bump to 0.4.0 | Done | | 2026-02-17 | v0.4.0 — "It Has a Face" | **Complete** | +| 2026-02-17 | v0.5 T1-T3 (Phase 1: DB & Multi-User) — User model + migration, JWT auth, RBAC | Done | +| 2026-02-17 | v0.5 T4-T5 (Phase 2: PostgreSQL) — asyncpg support, connection pooling + indexes | Done | +| 2026-02-17 | v0.5 T6-T8 (Phase 3: Rate Limiting & Monitoring) — per-user rate limits, Prometheus metrics, health checks | Done | +| 2026-02-17 | v0.5 T9 (Phase 4: Perplexity) — Perplexity provider adapter (sonar, sonar-pro, sonar-deep-research) | Done | +| 2026-02-17 | v0.5 T10-T11 (Phase 5: Backup/Restore) — backup CLI, restore CLI with merge mode | Done | +| 2026-02-17 | v0.5 T12-T13 (Phase 6: Playwright) — E2E setup + core flows, extended tests | Done | +| 2026-02-17 | v0.5 T14-T18 (Phase 7: Ship) — multi-user integration tests, load tests, docs, migration finalized, version bump | Done | +| 2026-02-17 | v0.5.0 — "It Scales" | **Complete** | diff --git a/memory-bank/quick-start.md b/memory-bank/quick-start.md index 9e10405..5f1eabd 100644 --- a/memory-bank/quick-start.md +++ b/memory-bank/quick-start.md @@ -6,19 +6,19 @@ ## Where We Are -**v0.4 COMPLETE** — "It Has a Face". Web UI with 3D Decision Space, real-time streaming, thread browser. +**v0.5 COMPLETE** — "It Scales". Multi-user auth, PostgreSQL, production hardening. -- 1318 Python tests + 117 Vitest tests (1435 total), 50 Python + 66 frontend source files -- 5 providers (Anthropic, OpenAI, Google, Mistral, local via Ollama) — 14 models -- Version 0.4.0, branch `v0.3.0` (v0.4 changes uncommitted on top) +- 1354 Python tests + 12 load tests + 117 Vitest tests (1483 total), ~60 Python + 66 frontend source files +- 6 providers (Anthropic, OpenAI, Google, Mistral, Perplexity, local via Ollama) — 17 models +- Version 0.5.0, branch `v0.5.0`, ready to merge to main - MkDocs docs live at https://msitarzewski.github.io/duh/ - GitHub repo: https://github.com/msitarzewski/duh ## Starting a Session Load these files: -1. `activeContext.md` — current state, v0.4 summary, open questions -2. `roadmap.md:330+` — future version specs +1. `activeContext.md` — current state, v0.5 complete, open questions +2. `roadmap.md:513+` — v1.0 spec (next version) 3. `techContext.md` — tech stack + all decided patterns (Python + frontend) 4. `decisions.md` — 18 ADRs, all foundational + v0.2 + v0.3 + v0.4 decisions @@ -42,6 +42,10 @@ duh serve --reload # dev mode with hot reload duh mcp # start MCP server duh batch questions.txt # batch mode duh export --format json # export thread +duh backup /path/to/backup.json # backup database +duh restore /path/to/backup.json # restore database +duh user-create # create user account +duh user-list # list users # Backend Development uv sync # install deps @@ -72,14 +76,17 @@ docker compose up # full stack on :8080 | Decomposition | `src/duh/consensus/decompose.py`, `src/duh/consensus/scheduler.py`, `src/duh/consensus/synthesis.py` | | Tools | `src/duh/tools/base.py`, `src/duh/tools/registry.py`, `src/duh/tools/augmented_send.py` | | Tool impls | `src/duh/tools/web_search.py`, `src/duh/tools/code_exec.py`, `src/duh/tools/file_read.py` | -| Providers | `src/duh/providers/base.py`, `anthropic.py`, `openai.py`, `google.py`, `mistral.py` | +| Providers | `src/duh/providers/base.py`, `anthropic.py`, `openai.py`, `google.py`, `mistral.py`, `perplexity.py` | | Memory | `src/duh/memory/models.py`, `repository.py`, `context.py`, `summary.py` | | Config | `src/duh/config/schema.py`, `src/duh/config/loader.py` | | Core | `src/duh/core/errors.py`, `src/duh/core/retry.py` | | REST API | `src/duh/api/app.py`, `src/duh/api/middleware.py`, `src/duh/api/routes/` | +| Auth/RBAC | `src/duh/api/auth.py`, `src/duh/api/rbac.py` | +| Monitoring | `src/duh/api/metrics.py`, `src/duh/api/health.py` | +| Backup | `src/duh/memory/backup.py` | | MCP Server | `src/duh/mcp/server.py` | | Client | `client/src/duh_client/client.py` | -| Migrations | `alembic/versions/001_v01_baseline.py` through `004_v03_api_keys.py` | +| Migrations | `alembic/versions/001_v01_baseline.py` through `005_v05_users.py` | | Frontend theme | `web/src/theme/duh-theme.css` (22 CSS vars, dark/light), `web/src/theme/animations.css` (keyframes + `.duh-prose`) | | Markdown | `web/src/components/shared/Markdown.tsx` (react-markdown + highlight.js + mermaid lazy) | | Frontend API | `web/src/api/client.ts`, `web/src/api/websocket.ts`, `web/src/api/types.ts` | diff --git a/memory-bank/roadmap.md b/memory-bank/roadmap.md index 1248999..6db0a06 100644 --- a/memory-bank/roadmap.md +++ b/memory-bank/roadmap.md @@ -1,6 +1,6 @@ # duh Roadmap -**Version**: 1.4 +**Version**: 1.5 **Date**: 2026-02-17 **Status**: Draft for review **Synthesized from**: Product Strategy, Systems Architecture, Devil's Advocate Review, Competitive Research Analysis @@ -108,7 +108,7 @@ The devil's advocate correctly identified this as the existential risk: "If this | **0.2.0** | It Thinks Deeper | Task decomposition, outcome tracking | 4-6 days | **COMPLETE** | | **0.3.0** | It's Accessible | REST API, MCP server, Python client | 4-6 days | **COMPLETE** | | **0.4.0** | It Has a Face | Web UI with real-time consensus display | 6-10 days | **COMPLETE** | -| **0.5.0** | It Scales | Multi-user, PostgreSQL, production hardening | 4-6 days | | +| **0.5.0** | It Scales | Multi-user, PostgreSQL, production hardening | 4-6 days | **COMPLETE** | | **1.0.0** | duh. | Stable APIs, documentation, security audit | 5-8 days | | **Total AI-time**: ~30-46 days of autonomous execution (not calendar days — depends on session frequency and human review cadence) @@ -398,7 +398,7 @@ All -> 24 (Docker) -> 25 (docs) - [x] Decision Space is interactive: click nodes, filter by taxonomy, animate timeline - [x] Share links work without authentication (read-only) - [x] `docker compose up` serves web UI + API -- [ ] `try.duh.dev` live and rate-limited (deferred) +- [ ] `try.duh.dev` live and rate-limited (deferred to post-1.0) #### Tasks @@ -419,9 +419,9 @@ All -> 24 (Docker) -> 25 (docs) 15. Docker multi-stage with Node.js 22 frontend build ~~DONE~~ 16. Docker Compose update ~~DONE~~ 17. 117 Vitest unit tests (5 test files) ~~DONE~~ -18. Playwright E2E tests (deferred to v0.5) +18. Playwright E2E tests (deferred to v0.5) ~~DEFERRED~~ 19. MkDocs documentation (web-ui.md, web-quickstart.md) ~~DONE~~ -20. Hosted demo (try.duh.dev) (deferred) +20. Hosted demo (try.duh.dev) (deferred to post-1.0) ~~DEFERRED~~ 21. Version bump to 0.4.0 ~~DONE~~ > **v0.4.0 shipped 2026-02-17.** 1318 Python tests + 117 Vitest tests (1435 total), 50 Python + 66 frontend source files. React 19 + Vite 6 + Tailwind 4 + Three.js web UI with 3D Decision Space, real-time WebSocket streaming, thread browser, preferences. Hosted demo deferred. @@ -442,33 +442,73 @@ All -> 24 (Docker) -> 25 (docs) - **Rate limiting per user and per provider** - **Health checks, Prometheus metrics endpoint** - **Backup/restore utilities** -- **Cohere adapter**: Fifth cloud provider -- **A2A protocol support** (agent-to-agent) +- **Perplexity adapter**: Fifth cloud provider (search-grounded reasoning) +- **Playwright E2E tests**: Deferred from v0.4, full browser testing for web UI #### Acceptance Criteria -- [ ] Multi-user authentication works (local accounts) -- [ ] PostgreSQL deployment documented and tested -- [ ] Performance: consensus overhead < 500ms beyond model latency -- [ ] Metrics and monitoring operational -- [ ] Backup/restore tested and documented +- [x] Multi-user authentication works (local accounts) +- [x] PostgreSQL deployment documented and tested +- [x] Performance: consensus overhead < 500ms beyond model latency +- [x] Metrics and monitoring operational +- [x] Backup/restore tested and documented +- [x] Playwright E2E tests cover core web UI flows -#### Tasks +#### Tasks (7 Phases, 18 Tasks) + +**Phase 1: Database & Multi-User Foundation (T1-T3)** + +1. **User model + migration (`005_v05_users.py`)** ~~DONE~~: `User` ORM model (id, email, password_hash, display_name, role [admin/contributor/viewer], created_at, is_active). Add nullable `user_id` FK to Thread, Decision, APIKey. Extend `src/duh/memory/models.py`. Unit tests: model creation, relationships, constraints. +2. **Authentication system (JWT)** ~~DONE~~: `src/duh/api/auth.py` — bcrypt password hashing, JWT creation/validation. Endpoints: `POST /api/auth/register`, `POST /api/auth/login`, `GET /api/auth/me`. Extend `APIKeyMiddleware` to accept `Authorization: Bearer `. CLI: `duh user create`, `duh user list`. Config: `AuthConfig` in schema.py (jwt_secret, token_expiry, registration_enabled). Unit tests: hashing, JWT, endpoints, middleware. +3. **Role-based access control** ~~DONE~~: `src/duh/api/rbac.py` — FastAPI dependency checking `user.role` against required permission. Roles: admin (full), contributor (create/view), viewer (read-only). Admin-only: user management, backup/restore, API key management. Unit tests: role checks, permission denied. + +**Phase 2: PostgreSQL & Performance (T4-T5)** + +4. **PostgreSQL support + async driver** ~~DONE~~: Add `asyncpg` dependency. `DatabaseConfig`: pool_size, max_overflow, pool_timeout. Test all migrations against PostgreSQL (Docker). Alembic env.py: handle aiosqlite + asyncpg. Integration tests: full CRUD against PostgreSQL. +5. **Connection pooling + query optimization** ~~DONE~~: Pool config (pool_size=5, max_overflow=10 for pg, NullPool for sqlite). Add compound indexes for common query patterns. Review N+1 queries — ensure selectinload coverage. Unit tests: pool config, index checks. + +**Phase 3: Rate Limiting & Monitoring (T6-T8)** + +6. **Per-user + per-provider rate limiting** ~~DONE~~: Extend `RateLimitMiddleware` to key by user_id. Add provider-level rpm in `ProviderManager`. Config: per-provider `rate_limit` in `ProviderConfig`. Unit tests. +7. **Prometheus metrics endpoint** ~~DONE~~: `src/duh/api/metrics.py` — `/api/metrics` (Prometheus text format). Counters: requests_total, consensus_runs_total, tokens_total, errors_total. Histograms: request_duration_seconds, consensus_duration_seconds. Gauges: active_connections, provider_health. Lightweight in-process counters (no prometheus_client dep). Unit tests. +8. **Extended health checks** ~~DONE~~: Extend `/api/health` with db connectivity, provider health, uptime, version. Add `GET /api/health/detailed`. Unit tests: healthy/unhealthy scenarios. + +**Phase 4: Perplexity Provider (T9)** + +9. **Perplexity provider adapter** ~~DONE~~: `src/duh/providers/perplexity.py` — OpenAI-compatible API pattern. Models: sonar, sonar-pro, sonar-deep-research. Parse `citations` from responses into contribution metadata. Add `perplexity` default in DuhConfig.providers, `PERPLEXITY_API_KEY` env var. Unit tests: send, stream, health, citations, errors (mocked). + +**Phase 5: Backup/Restore (T10-T11)** + +10. **Backup CLI** ~~DONE~~: `duh backup ` — SQLite: copy file; PostgreSQL: JSON export via SQLAlchemy. Portable JSON format. `--format sqlite|json`. Unit tests: backup, export, round-trip. +11. **Restore CLI** ~~DONE~~: `duh restore ` — detect format, validate schema version, restore. `--merge` for additive restore. Unit tests: restore JSON, restore SQLite, merge mode. + +**Phase 6: Playwright E2E Tests (T12-T13)** + +12. **Playwright setup + core flows** ~~DONE~~: `web/e2e/`, playwright.config.ts. Fixtures: `duh serve` with test DB + mock providers. Tests (~15-20): consensus flow, thread browser, navigation/routing. +13. **Playwright extended tests** ~~DONE~~ (~10-15): Decision Space (render, filters, 2D fallback), Preferences (persistence), share links, error states. + +**Phase 7: Integration, Docs, Ship (T14-T18)** + +14. **Multi-user integration tests** ~~DONE~~: User isolation (A can't see B's threads), admin/viewer permissions, per-user rate limiting, PostgreSQL multi-user round-trip. +15. **Load testing** ~~DONE~~: httpx + asyncio concurrent requests. Measure p50/p95/p99 latency, error rate. Target: consensus overhead < 500ms beyond model latency. Document results. +16. **Documentation** ~~DONE~~: `docs/production-deployment.md` (PostgreSQL, env vars, Docker prod), `docs/monitoring.md` (Prometheus, health, alerting), `docs/authentication.md` (users, JWT, API keys, RBAC). Update index.md + mkdocs.yml. +17. **Migration finalization** ~~DONE~~: Ensure `005_v05_users.py` handles SQLite + PostgreSQL. Test upgrade path v0.4 → v0.5. +18. **Version bump to 0.5.0** ~~DONE~~: pyproject.toml, `__init__.py`, api/app.py. Update memory bank. + +#### Task Dependencies + +``` +T1 (User model) → T2 (Auth) → T3 (RBAC) +T1 → T4 (PostgreSQL) → T5 (Pooling) +T2 → T6 (Rate limiting per user) +T7 (Metrics), T8 (Health), T9 (Perplexity) — independent, parallelizable +T10 (Backup) → T11 (Restore) — independent of auth +T1-T3 done → T12-T13 (Playwright needs auth flows) +T1-T11 done → T14 (Integration tests) +All → T15-T18 (Ship) +``` -1. User model, authentication (session-based or JWT) -2. Role-based access control middleware -3. PostgreSQL deployment guide and testing -4. Connection pooling and query optimization -5. Per-user and per-provider rate limiting -6. Prometheus metrics exporter -7. Health check endpoints -8. Backup/restore CLI commands -9. Cohere provider adapter -10. A2A protocol integration -11. Unit tests for auth, RBAC, rate limiting, backup/restore, metrics, Cohere adapter, A2A protocol -12. Integration tests: multi-user isolation, PostgreSQL round-trip, rate limit enforcement, backup/restore cycle -13. Load testing -14. Documentation: production deployment guide, monitoring guide +> **v0.5.0 shipped 2026-02-17.** 1354 Python tests + 117 Vitest tests (1471 total), ~60 Python source files + 66 frontend. 6 providers (Anthropic, OpenAI, Google, Mistral, Perplexity, Ollama). User accounts with JWT auth + RBAC, PostgreSQL support, Prometheus metrics, backup/restore, Playwright E2E, per-user rate limiting, production deployment docs. All 18 tasks delivered. --- @@ -505,6 +545,8 @@ These features were deliberately deferred from 1.0 per the devil's advocate's ch | Feature | Description | AI-Time | |---------|-------------|---------| +| **A2A Protocol Support** | Agent-to-agent protocol integration for inter-agent consensus | 3-5 days | +| **Hosted Demo** | `try.duh.dev` — free, rate-limited, pre-configured demo instance | 2-3 days | | **Federated Knowledge Sharing** | Navigator protocol, peer-to-peer decision sharing, privacy controls, trust signals | 10-15 days | | **Browsable Knowledge Base** | Web interface over accumulated decisions, extends 3D Decision Space with full-text search and quality indicators | 6-10 days | | **Fact-Checking Mode** | Structured claim decomposition, multi-model verification, citation tracking | 5-8 days | @@ -868,6 +910,7 @@ Y = yes, N = no, ~ = partial, * = research only, not product | 2026-02-16 | 1.2 | Added tool-augmented consensus (v0.2) and decision taxonomy (v0.2). Added 3D Decision Space visualization (v0.4). Updated competitive gap analysis. | | 2026-02-16 | 1.3 | v0.2.0 complete. All features shipped. Updated acceptance criteria and task status. Added Status column to overview table. | | 2026-02-17 | 1.4 | v0.3.0 and v0.4.0 complete. All features shipped. Updated acceptance criteria, task status, and completion notes for both versions. | +| 2026-02-17 | 1.5 | v0.5.0 complete. All 18 tasks shipped. 6 providers, multi-user auth, PostgreSQL, metrics, backup/restore, Playwright E2E, load tests, production docs. | --- diff --git a/memory-bank/toc.md b/memory-bank/toc.md index d98df3b..7a9f592 100644 --- a/memory-bank/toc.md +++ b/memory-bank/toc.md @@ -4,17 +4,17 @@ - [projectbrief.md](./projectbrief.md) — Vision, tenets, architecture, build sequence - [techContext.md](./techContext.md) — Tech stack decisions with rationale (Python, Docker, SQLAlchemy, frontend, tools, etc.) - [decisions.md](./decisions.md) — Architectural decisions with context, alternatives, and consequences (18 ADRs) -- [activeContext.md](./activeContext.md) — Current state, v0.4 complete, ready for commit/merge +- [activeContext.md](./activeContext.md) — Current state, v0.5 complete, ready to merge to main - [progress.md](./progress.md) — Milestone tracking, what's built, what's next - [competitive-landscape.md](./competitive-landscape.md) — Research on existing tools, frameworks, and academic work -- [quick-start.md](./quick-start.md) — Session entry point, v0.4 complete, key file references +- [quick-start.md](./quick-start.md) — Session entry point, v0.5 complete, key file references - [v03-build-status.md](./v03-build-status.md) — v0.3 build status, all 17 tasks complete ## Context Files - [setup.md](./setup.md) — Original GPT-5.2 conversation that sparked the project ## Roadmap -- [roadmap.md](./roadmap.md) — Full versioned roadmap v1.4 (Phase 0 through 1.0, AI-time estimates, testing mandate, self-building milestone) +- [roadmap.md](./roadmap.md) — Full versioned roadmap v1.5 (Phase 0 through 1.0, AI-time estimates, testing mandate, self-building milestone) ## Agent Team Analyses (Supporting Material) - [tmp-product-strategy.md](./tmp-product-strategy.md) — Product strategist: releases, go-to-market, pricing diff --git a/mkdocs.yml b/mkdocs.yml index fb3601c..9443a9c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -82,6 +82,9 @@ nav: - MCP Server: mcp-server.md - Guides: - Docker: guides/docker.md + - Production Deployment: guides/production-deployment.md + - Authentication: guides/authentication.md + - Monitoring: guides/monitoring.md - Local Models: guides/local-models.md - Batch Mode: batch-mode.md - Export: export.md diff --git a/pyproject.toml b/pyproject.toml index f378798..72849e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "duh" -version = "0.4.0" +version = "0.5.0" description = "Multi-model consensus engine — because one LLM opinion isn't enough" requires-python = ">=3.11" dependencies = [ @@ -20,6 +20,9 @@ dependencies = [ "uvicorn[standard]>=0.30", "mcp>=1.0", "httpx>=0.27", + "bcrypt>=4.0", + "pyjwt>=2.8", + "asyncpg>=0.29", ] [project.scripts] @@ -55,6 +58,9 @@ addopts = [ "--strict-config", "-ra", ] +markers = [ + "load: load/stress tests (deselect with '-m \"not load\"')", +] [tool.coverage.run] source = ["duh"] diff --git a/src/duh/__init__.py b/src/duh/__init__.py index 6fa640d..7eb2af6 100644 --- a/src/duh/__init__.py +++ b/src/duh/__init__.py @@ -1,3 +1,3 @@ """duh — Multi-model consensus engine.""" -__version__ = "0.4.0" +__version__ = "0.5.0" diff --git a/src/duh/api/app.py b/src/duh/api/app.py index c13bfdc..1549fd2 100644 --- a/src/duh/api/app.py +++ b/src/duh/api/app.py @@ -41,7 +41,7 @@ def create_app(config: DuhConfig | None = None) -> FastAPI: app = FastAPI( title="duh", description="Multi-model consensus engine API", - version="0.4.0", + version="0.5.0", lifespan=lifespan, ) app.state.config = config @@ -70,11 +70,6 @@ def create_app(config: DuhConfig | None = None) -> FastAPI: # API key auth (added last — runs first) app.add_middleware(APIKeyMiddleware) - # Health endpoint - @app.get("/api/health") - async def health() -> dict[str, str]: - return {"status": "ok"} - # Routes from duh.api.routes.ask import router as ask_router from duh.api.routes.crud import router as crud_router @@ -86,6 +81,14 @@ async def health() -> dict[str, str]: app.include_router(threads_router) app.include_router(ws_router) + from duh.api.auth import router as auth_router + from duh.api.health import router as health_router + from duh.api.metrics import router as metrics_router + + app.include_router(auth_router) + app.include_router(health_router) + app.include_router(metrics_router) + # ── Static file serving for web UI ── _mount_frontend(app) diff --git a/src/duh/api/auth.py b/src/duh/api/auth.py new file mode 100644 index 0000000..870e177 --- /dev/null +++ b/src/duh/api/auth.py @@ -0,0 +1,190 @@ +"""JWT authentication for the duh API.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from typing import Any + +import bcrypt +import jwt +from fastapi import APIRouter, Depends, HTTPException, Request +from pydantic import BaseModel + +router = APIRouter(prefix="/api/auth", tags=["auth"]) + +# --- Password hashing --- + + +def hash_password(password: str) -> str: + """Hash password with bcrypt.""" + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + + +def verify_password(password: str, password_hash: str) -> bool: + """Verify password against hash.""" + return bcrypt.checkpw(password.encode(), password_hash.encode()) + + +# --- JWT --- + + +def create_token(user_id: str, secret: str, expiry_hours: int = 24) -> str: + """Create a JWT token.""" + payload = { + "sub": user_id, + "exp": datetime.now(UTC) + timedelta(hours=expiry_hours), + "iat": datetime.now(UTC), + } + return jwt.encode(payload, secret, algorithm="HS256") + + +def decode_token(token: str, secret: str) -> dict[str, Any]: + """Decode and validate a JWT token.""" + try: + return jwt.decode(token, secret, algorithms=["HS256"]) + except jwt.ExpiredSignatureError as err: + raise HTTPException(status_code=401, detail="Token expired") from err + except jwt.InvalidTokenError as err: + raise HTTPException(status_code=401, detail="Invalid token") from err + + +# --- Request models --- + + +class RegisterRequest(BaseModel): + email: str + password: str + display_name: str + + +class LoginRequest(BaseModel): + email: str + password: str + + +class TokenResponse(BaseModel): + access_token: str + token_type: str = "bearer" + user_id: str + role: str + + +class UserResponse(BaseModel): + id: str + email: str + display_name: str + role: str + is_active: bool + + +# --- Dependency: get current user from JWT --- + + +async def get_current_user(request: Request) -> Any: + """FastAPI dependency: extract user from JWT Bearer token.""" + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + raise HTTPException( + status_code=401, detail="Missing or invalid Authorization header" + ) + + token = auth_header.split(" ", 1)[1] + config = request.app.state.config + payload = decode_token(token, config.auth.jwt_secret) + user_id = payload.get("sub") + + from sqlalchemy import select + + from duh.memory.models import User + + db_factory = request.app.state.db_factory + async with db_factory() as session: + stmt = select(User).where(User.id == user_id, User.is_active == True) # noqa: E712 + result = await session.execute(stmt) + user = result.scalar_one_or_none() + + if user is None: + raise HTTPException(status_code=401, detail="User not found or inactive") + + return user + + +# --- Endpoints --- + + +@router.post("/register", response_model=TokenResponse) +async def register(body: RegisterRequest, request: Request) -> TokenResponse: + """Register a new user.""" + config = request.app.state.config + if not config.auth.registration_enabled: + raise HTTPException(status_code=403, detail="Registration is disabled") + + if not config.auth.jwt_secret: + raise HTTPException(status_code=500, detail="JWT secret not configured") + + from sqlalchemy import select + + from duh.memory.models import User + + db_factory = request.app.state.db_factory + async with db_factory() as session: + # Check email uniqueness + stmt = select(User).where(User.email == body.email) + result = await session.execute(stmt) + if result.scalar_one_or_none() is not None: + raise HTTPException(status_code=409, detail="Email already registered") + + user = User( + email=body.email, + password_hash=hash_password(body.password), + display_name=body.display_name, + ) + session.add(user) + await session.commit() + await session.refresh(user) + + token = create_token( + user.id, config.auth.jwt_secret, config.auth.token_expiry_hours + ) + return TokenResponse(access_token=token, user_id=user.id, role=user.role) + + +@router.post("/login", response_model=TokenResponse) +async def login(body: LoginRequest, request: Request) -> TokenResponse: + """Authenticate and get token.""" + config = request.app.state.config + if not config.auth.jwt_secret: + raise HTTPException(status_code=500, detail="JWT secret not configured") + + from sqlalchemy import select + + from duh.memory.models import User + + db_factory = request.app.state.db_factory + async with db_factory() as session: + stmt = select(User).where(User.email == body.email) + result = await session.execute(stmt) + user = result.scalar_one_or_none() + + if user is None or not verify_password(body.password, user.password_hash): + raise HTTPException(status_code=401, detail="Invalid credentials") + + if not user.is_active: + raise HTTPException(status_code=403, detail="Account disabled") + + token = create_token( + user.id, config.auth.jwt_secret, config.auth.token_expiry_hours + ) + return TokenResponse(access_token=token, user_id=user.id, role=user.role) + + +@router.get("/me", response_model=UserResponse) +async def me(user: Any = Depends(get_current_user)) -> UserResponse: # noqa: B008 + """Get current user info.""" + return UserResponse( + id=user.id, + email=user.email, + display_name=user.display_name, + role=user.role, + is_active=user.is_active, + ) diff --git a/src/duh/api/health.py b/src/duh/api/health.py new file mode 100644 index 0000000..0e7f450 --- /dev/null +++ b/src/duh/api/health.py @@ -0,0 +1,63 @@ +"""Health check endpoints.""" + +from __future__ import annotations + +import time +from typing import Any + +from fastapi import APIRouter, Request + +router = APIRouter(tags=["health"]) + +_START_TIME = time.monotonic() + + +@router.get("/api/health") +async def health() -> dict[str, str]: + """Basic health check -- always returns quickly.""" + return {"status": "ok"} + + +@router.get("/api/health/detailed") +async def health_detailed(request: Request) -> dict[str, Any]: + """Detailed health check with component status.""" + from duh import __version__ + + checks: dict[str, Any] = { + "status": "ok", + "version": __version__, + "uptime_seconds": round(time.monotonic() - _START_TIME, 1), + "components": {}, + } + + # Database check + try: + db_factory = request.app.state.db_factory + async with db_factory() as session: + from sqlalchemy import text + + await session.execute(text("SELECT 1")) + checks["components"]["database"] = {"status": "ok"} + except Exception as e: + checks["components"]["database"] = {"status": "error", "detail": str(e)} + checks["status"] = "degraded" + + # Provider health checks + pm = getattr(request.app.state, "provider_manager", None) + if pm is not None: + provider_statuses: dict[str, dict[str, str]] = {} + for pid, provider in pm._providers.items(): + try: + healthy = await provider.health_check() + provider_statuses[pid] = {"status": "ok" if healthy else "unhealthy"} + except Exception: + provider_statuses[pid] = {"status": "error"} + checks["components"]["providers"] = provider_statuses + + # If all providers are unhealthy, status is degraded + if provider_statuses and all( + v["status"] != "ok" for v in provider_statuses.values() + ): + checks["status"] = "degraded" + + return checks diff --git a/src/duh/api/metrics.py b/src/duh/api/metrics.py new file mode 100644 index 0000000..a15a001 --- /dev/null +++ b/src/duh/api/metrics.py @@ -0,0 +1,227 @@ +"""Lightweight Prometheus metrics — no external dependencies.""" + +from __future__ import annotations + +import math +import threading +from typing import ClassVar + +from fastapi import APIRouter, Response + +router = APIRouter() + + +class Counter: + """Thread-safe monotonic counter.""" + + def __init__( + self, + name: str, + help_text: str, + labels: list[str] | None = None, + ) -> None: + self.name = name + self.help_text = help_text + self.labels = labels or [] + self._lock = threading.Lock() + # When labels are used, store per-label-combo values + self._values: dict[tuple[str, ...], float] = {} + if not self.labels: + self._values[()] = 0.0 + MetricsRegistry.get().register(self) + + def inc(self, value: float = 1.0, **label_values: str) -> None: + """Increment the counter.""" + key = tuple(label_values.get(lbl, "") for lbl in self.labels) + with self._lock: + self._values[key] = self._values.get(key, 0.0) + value + + def collect(self) -> str: + """Return Prometheus text format.""" + lines: list[str] = [ + f"# HELP {self.name} {self.help_text}", + f"# TYPE {self.name} counter", + ] + with self._lock: + for key, val in sorted(self._values.items()): + if self.labels: + label_str = ",".join( + f'{lbl}="{v}"' + for lbl, v in zip(self.labels, key, strict=True) + ) + lines.append(f"{self.name}{{{label_str}}} {_fmt(val)}") + else: + lines.append(f"{self.name} {_fmt(val)}") + return "\n".join(lines) + "\n" + + +class Histogram: + """Thread-safe histogram with predefined buckets.""" + + DEFAULT_BUCKETS: ClassVar[list[float]] = [ + 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, + ] + + def __init__( + self, + name: str, + help_text: str, + buckets: list[float] | None = None, + ) -> None: + self.name = name + self.help_text = help_text + self.buckets = sorted(buckets or self.DEFAULT_BUCKETS) + self._lock = threading.Lock() + self._bucket_counts: dict[float, int] = {b: 0 for b in self.buckets} + self._sum: float = 0.0 + self._count: int = 0 + MetricsRegistry.get().register(self) + + def observe(self, value: float) -> None: + """Record an observation.""" + with self._lock: + self._sum += value + self._count += 1 + for b in self.buckets: + if value <= b: + self._bucket_counts[b] += 1 + break + + def collect(self) -> str: + """Return Prometheus text format.""" + lines: list[str] = [ + f"# HELP {self.name} {self.help_text}", + f"# TYPE {self.name} histogram", + ] + with self._lock: + cumulative = 0 + for b in self.buckets: + cumulative += self._bucket_counts[b] + lines.append( + f'{self.name}_bucket{{le="{_fmt(b)}"}} {cumulative}' + ) + lines.append( + f'{self.name}_bucket{{le="+Inf"}} {self._count}' + ) + lines.append(f"{self.name}_sum {_fmt(self._sum)}") + lines.append(f"{self.name}_count {self._count}") + return "\n".join(lines) + "\n" + + +class Gauge: + """Thread-safe gauge (can go up and down).""" + + def __init__(self, name: str, help_text: str) -> None: + self.name = name + self.help_text = help_text + self._lock = threading.Lock() + self._value: float = 0.0 + MetricsRegistry.get().register(self) + + def set(self, value: float) -> None: + """Set to an absolute value.""" + with self._lock: + self._value = value + + def inc(self, value: float = 1.0) -> None: + """Increment.""" + with self._lock: + self._value += value + + def dec(self, value: float = 1.0) -> None: + """Decrement.""" + with self._lock: + self._value -= value + + def collect(self) -> str: + """Return Prometheus text format.""" + with self._lock: + val = self._value + return ( + f"# HELP {self.name} {self.help_text}\n" + f"# TYPE {self.name} gauge\n" + f"{self.name} {_fmt(val)}\n" + ) + + +class MetricsRegistry: + """Global registry of all metrics.""" + + _instance: ClassVar[MetricsRegistry | None] = None + _lock: ClassVar[threading.Lock] = threading.Lock() + + def __init__(self) -> None: + self._metrics: list[Counter | Histogram | Gauge] = [] + + @classmethod + def get(cls) -> MetricsRegistry: + """Return the singleton registry.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = MetricsRegistry() + return cls._instance + + @classmethod + def reset(cls) -> None: + """Reset the singleton (for tests).""" + with cls._lock: + cls._instance = None + + def register(self, metric: Counter | Histogram | Gauge) -> None: + """Register a metric for collection.""" + self._metrics.append(metric) + + def collect_all(self) -> str: + """Return concatenated Prometheus text format for all metrics.""" + return "\n".join(m.collect() for m in self._metrics) + + +def _fmt(v: float) -> str: + """Format a float: use integer form when possible.""" + if math.isinf(v): + return "+Inf" + if v == int(v): + return str(int(v)) + return str(v) + + +# ── Pre-defined metrics ────────────────────────────────────────── + +REQUESTS_TOTAL = Counter( + "duh_requests_total", "Total HTTP requests", + labels=["method", "path", "status"], +) +CONSENSUS_RUNS_TOTAL = Counter( + "duh_consensus_runs_total", "Total consensus runs", +) +TOKENS_TOTAL = Counter( + "duh_tokens_total", "Total tokens consumed", + labels=["provider", "direction"], +) +ERRORS_TOTAL = Counter( + "duh_errors_total", "Total errors", + labels=["type"], +) +REQUEST_DURATION = Histogram( + "duh_request_duration_seconds", "Request duration", +) +CONSENSUS_DURATION = Histogram( + "duh_consensus_duration_seconds", "Consensus run duration", +) +ACTIVE_CONNECTIONS = Gauge( + "duh_active_connections", "Active connections", +) +PROVIDER_HEALTH = Gauge( + "duh_provider_health", "Provider health status", +) + + +@router.get("/api/metrics") +async def metrics_endpoint() -> Response: + """Serve all registered metrics in Prometheus text format.""" + registry = MetricsRegistry.get() + return Response( + content=registry.collect_all(), + media_type="text/plain; version=0.0.4", + ) diff --git a/src/duh/api/middleware.py b/src/duh/api/middleware.py index f0a8e63..68ec91a 100644 --- a/src/duh/api/middleware.py +++ b/src/duh/api/middleware.py @@ -26,6 +26,10 @@ class APIKeyMiddleware(BaseHTTPMiddleware): EXEMPT_PATHS: ClassVar[set[str]] = { "/api/health", + "/api/health/detailed", + "/api/metrics", + "/api/auth/register", + "/api/auth/login", "/docs", "/openapi.json", "/redoc", @@ -51,6 +55,23 @@ async def dispatch( if not path.startswith("/api/") and not path.startswith("/ws/"): return await call_next(request) + # Accept JWT Bearer token as alternative to API key + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + # Decode JWT and set user_id on request.state for rate limiting + token = auth_header.split(" ", 1)[1] + try: + config = request.app.state.config + from duh.api.auth import decode_token + + payload = decode_token(token, config.auth.jwt_secret) + request.state.user_id = payload.get("sub") + except Exception: + # Let the auth dependency handle full validation; + # middleware just extracts user_id if possible. + pass + return await call_next(request) + # Skip auth if no API keys are configured (dev mode) db_factory = request.app.state.db_factory @@ -100,10 +121,11 @@ def __init__(self, app: object, rate_limit: int = 60, window: int = 60) -> None: async def dispatch( self, request: Request, call_next: Callable[[Request], Awaitable[Response]] ) -> Response: - # Get key identifier (API key ID or IP for unauthenticated) - key_id = getattr(request.state, "api_key_id", None) or ( - request.client.host if request.client else "unknown" - ) + # Get key identifier: prefer user_id (JWT), then api_key_id, then IP + user_id = getattr(request.state, "user_id", None) + api_key_id = getattr(request.state, "api_key_id", None) + ip_addr = request.client.host if request.client else "unknown" + key_id = user_id or api_key_id or ip_addr now = time.monotonic() # Clean old entries @@ -126,4 +148,12 @@ async def dispatch( response.headers["X-RateLimit-Limit"] = str(self.rate_limit) response.headers["X-RateLimit-Remaining"] = str(remaining) + # Add rate limit identity headers + if user_id: + response.headers["X-RateLimit-Key"] = f"user:{user_id}" + elif api_key_id: + response.headers["X-RateLimit-Key"] = f"api_key:{api_key_id}" + else: + response.headers["X-RateLimit-Key"] = f"ip:{ip_addr}" + return response diff --git a/src/duh/api/rbac.py b/src/duh/api/rbac.py new file mode 100644 index 0000000..8f7b3f4 --- /dev/null +++ b/src/duh/api/rbac.py @@ -0,0 +1,59 @@ +"""Role-based access control for the duh API. + +Roles: admin > contributor > viewer. +Use ``require_role`` to create a FastAPI dependency that checks the +authenticated user has at least the given role level. + +Example:: + + @router.get("/admin-only") + async def admin_endpoint(user=Depends(require_role("admin"))): + ... +""" + +from __future__ import annotations + +from typing import Any + +from fastapi import Depends, HTTPException + +from duh.api.auth import get_current_user + +# Role hierarchy: higher number = more privileges. +ROLE_HIERARCHY: dict[str, int] = {"admin": 3, "contributor": 2, "viewer": 1} + + +def require_role(minimum_role: str): + """FastAPI dependency factory: require user has at least *minimum_role*. + + Args: + minimum_role: One of ``"admin"``, ``"contributor"``, ``"viewer"``. + + Returns: + An async FastAPI dependency callable that resolves to the + authenticated ``User`` if the role check passes. + + Raises: + HTTPException 401: If no authenticated user is present. + HTTPException 403: If the user's role is below the minimum. + """ + min_level = ROLE_HIERARCHY.get(minimum_role, 0) + + async def _check_role( + user: Any = Depends(get_current_user), # noqa: B008 + ) -> Any: + user_level = ROLE_HIERARCHY.get(getattr(user, "role", ""), 0) + if user_level < min_level: + raise HTTPException( + status_code=403, + detail=f"Requires {minimum_role} role", + ) + return user + + return _check_role + + +# Convenience pre-built dependencies. +require_admin = require_role("admin") +require_contributor = require_role("contributor") +require_viewer = require_role("viewer") diff --git a/src/duh/cli/app.py b/src/duh/cli/app.py index ccde107..09f4111 100644 --- a/src/duh/cli/app.py +++ b/src/duh/cli/app.py @@ -67,7 +67,27 @@ async def _create_db( if db_path and db_path != ":memory:": Path(db_path).parent.mkdir(parents=True, exist_ok=True) - engine = create_async_engine(url) + engine_kwargs: dict[str, object] = {} + if url.startswith("sqlite"): + if ":memory:" in url: + # In-memory SQLite needs StaticPool so all queries share + # the same connection (and thus the same in-memory DB). + from sqlalchemy.pool import StaticPool + + engine_kwargs["poolclass"] = StaticPool + engine_kwargs["connect_args"] = {"check_same_thread": False} + else: + from sqlalchemy.pool import NullPool + + engine_kwargs["poolclass"] = NullPool + else: + engine_kwargs["pool_size"] = config.database.pool_size + engine_kwargs["max_overflow"] = config.database.max_overflow + engine_kwargs["pool_timeout"] = config.database.pool_timeout + engine_kwargs["pool_recycle"] = config.database.pool_recycle + engine_kwargs["pool_pre_ping"] = True + + engine = create_async_engine(url, **engine_kwargs) # Enable foreign keys for SQLite if url.startswith("sqlite"): @@ -78,8 +98,12 @@ def _enable_fks(dbapi_conn, connection_record): # type: ignore[no-untyped-def] cursor.execute("PRAGMA foreign_keys=ON") cursor.close() - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) + # Only use create_all for in-memory SQLite (tests/dev). + # File-based SQLite and PostgreSQL are managed by alembic migrations. + is_memory = url.startswith("sqlite") and ":memory:" in url + if is_memory: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) factory = async_sessionmaker(engine, expire_on_commit=False) return factory, engine @@ -99,9 +123,14 @@ async def _setup_providers(config: DuhConfig) -> ProviderManager: "openai", "google", "mistral", + "perplexity", ): continue # Skip providers without API keys + # Set provider rate limit if configured + if prov_config.rate_limit > 0: + pm.set_provider_rate_limit(name, prov_config.rate_limit) + if name == "anthropic": from duh.providers.anthropic import AnthropicProvider @@ -125,6 +154,11 @@ async def _setup_providers(config: DuhConfig) -> ProviderManager: mistral_prov = MistralProvider(api_key=prov_config.api_key) await pm.register(mistral_prov) # type: ignore[arg-type] + elif name == "perplexity": + from duh.providers.perplexity import PerplexityProvider + + perplexity_prov = PerplexityProvider(api_key=prov_config.api_key) + await pm.register(perplexity_prov) # type: ignore[arg-type] return pm @@ -1120,6 +1154,122 @@ async def _cost_async(config: DuhConfig) -> None: click.echo(f" {model_ref}: ${model_cost:.4f} ({call_count} calls)") +# ── backup ─────────────────────────────────────────────────────── + + +@cli.command() +@click.argument("path", type=click.Path()) +@click.option( + "--format", + "fmt", + type=click.Choice(["auto", "sqlite", "json"]), + default="auto", + help="Backup format (auto detects from db type).", +) +@click.option("--config", "config_path", default=None, help="Config file path.") +def backup(path: str, fmt: str, config_path: str | None) -> None: + """Backup the duh database to PATH.""" + config = _load_config(config_path) + try: + asyncio.run(_backup_async(config, path, fmt)) + except (DuhError, ValueError, FileNotFoundError, OSError) as e: + _error(str(e)) + + +async def _backup_async(config: DuhConfig, path: str, fmt: str) -> None: + """Async implementation for the backup command.""" + from duh.memory.backup import backup_json, backup_sqlite, detect_db_type + + db_url = config.database.url + if "~" in db_url: + db_url = db_url.replace("~", str(Path.home())) + + db_type = detect_db_type(db_url) + dest = Path(path) + + if fmt == "auto": + fmt = "sqlite" if db_type == "sqlite" else "json" + + if fmt == "sqlite" and db_type != "sqlite": + _error("Cannot use sqlite backup format for a PostgreSQL database.") + + if fmt == "sqlite": + result_path = await backup_sqlite(db_url, dest) + else: + factory, engine = await _create_db(config) + async with factory() as session: + result_path = await backup_json(session, dest) + await engine.dispose() + + size = result_path.stat().st_size + if size < 1024: + size_str = f"{size} B" + elif size < 1024 * 1024: + size_str = f"{size / 1024:.1f} KB" + else: + size_str = f"{size / (1024 * 1024):.1f} MB" + + click.echo(f"Backup saved to {result_path} ({size_str})") + + +# ── restore ───────────────────────────────────────────────────── + + +@cli.command() +@click.argument("path", type=click.Path(exists=True)) +@click.option( + "--merge", + is_flag=True, + default=False, + help="Merge with existing data instead of replacing.", +) +@click.option("--config", "config_path", default=None, help="Config file path.") +def restore(path: str, merge: bool, config_path: str | None) -> None: + """Restore the duh database from PATH.""" + config = _load_config(config_path) + try: + asyncio.run(_restore_async(config, path, merge)) + except (DuhError, ValueError, FileNotFoundError, OSError) as e: + _error(str(e)) + + +async def _restore_async(config: DuhConfig, path: str, merge: bool) -> None: + """Async implementation for the restore command.""" + from duh.memory.backup import ( + detect_backup_format, + detect_db_type, + restore_json, + restore_sqlite, + ) + + db_url = config.database.url + if "~" in db_url: + db_url = db_url.replace("~", str(Path.home())) + + source = Path(path) + fmt = detect_backup_format(source) + db_type = detect_db_type(db_url) + + if fmt == "sqlite" and db_type != "sqlite": + _error("Cannot restore a SQLite backup into a PostgreSQL database.") + + if fmt == "sqlite": + await restore_sqlite(source, db_url) + click.echo(f"Restored SQLite database from {source}") + else: + factory, engine = await _create_db(config) + async with factory() as session: + counts = await restore_json(session, source, merge=merge) + await engine.dispose() + + total = sum(counts.values()) + mode = "Merged" if merge else "Restored" + click.echo(f"{mode} {total} records from {source}") + for table_name, count in counts.items(): + if count > 0: + click.echo(f" {table_name}: {count}") + + # ── serve ──────────────────────────────────────────────────────── @@ -1393,3 +1543,107 @@ async def _batch_async( f"{total} questions | Total cost: ${total_cost:.4f} " f"| Elapsed: {elapsed:.1f}s" ) + + +# ── user-create ───────────────────────────────────────────── + + +@cli.command("user-create") +@click.option("--email", required=True) +@click.option("--password", required=True) +@click.option("--name", "display_name", required=True) +@click.option( + "--role", + type=click.Choice(["admin", "contributor", "viewer"]), + default="contributor", +) +@click.option("--config", "config_path", default=None) +def user_create( + email: str, + password: str, + display_name: str, + role: str, + config_path: str | None, +) -> None: + """Create a new user.""" + config = _load_config(config_path) + try: + asyncio.run(_user_create_async(config, email, password, display_name, role)) + except DuhError as e: + _error(str(e)) + + +async def _user_create_async( + config: DuhConfig, + email: str, + password: str, + display_name: str, + role: str, +) -> None: + """Async implementation for the user-create command.""" + from sqlalchemy import select + + from duh.api.auth import hash_password + from duh.memory.models import User + + factory, engine = await _create_db(config) + async with factory() as session: + # Check email uniqueness + stmt = select(User).where(User.email == email) + result = await session.execute(stmt) + if result.scalar_one_or_none() is not None: + await engine.dispose() + _error(f"Email already registered: {email}") + + user = User( + email=email, + password_hash=hash_password(password), + display_name=display_name, + role=role, + ) + session.add(user) + await session.commit() + await session.refresh(user) + + await engine.dispose() + click.echo(f"User created: {user.id} ({user.email}) role={user.role}") + + +# ── user-list ─────────────────────────────────────────────── + + +@cli.command("user-list") +@click.option("--config", "config_path", default=None) +def user_list(config_path: str | None) -> None: + """List all users.""" + config = _load_config(config_path) + try: + asyncio.run(_user_list_async(config)) + except DuhError as e: + _error(str(e)) + + +async def _user_list_async(config: DuhConfig) -> None: + """Async implementation for the user-list command.""" + from sqlalchemy import select + + from duh.memory.models import User + + factory, engine = await _create_db(config) + async with factory() as session: + stmt = select(User).order_by(User.created_at) + result = await session.execute(stmt) + users = result.scalars().all() + + await engine.dispose() + + if not users: + click.echo("No users found.") + return + + for user in users: + active = "active" if user.is_active else "disabled" + click.echo( + f" {user.id[:8]} {user.email} {user.display_name} " + f"role={user.role} {active}" + ) diff --git a/src/duh/config/schema.py b/src/duh/config/schema.py index 8e07069..5fcc5ff 100644 --- a/src/duh/config/schema.py +++ b/src/duh/config/schema.py @@ -15,6 +15,7 @@ class ProviderConfig(BaseModel): default_model: str | None = None models: list[str] = Field(default_factory=list) display_name: str | None = None + rate_limit: int = 0 # 0 = unlimited, >0 = requests per minute class ConsensusConfig(BaseModel): @@ -40,6 +41,10 @@ class DatabaseConfig(BaseModel): """Database connection settings.""" url: str = "sqlite+aiosqlite:///~/.local/share/duh/duh.db" + pool_size: int = 5 + max_overflow: int = 10 + pool_timeout: int = 30 + pool_recycle: int = 3600 class LoggingConfig(BaseModel): @@ -96,6 +101,14 @@ class TaxonomyConfig(BaseModel): model_ref: str = "" +class AuthConfig(BaseModel): + """Authentication configuration.""" + + jwt_secret: str = "" # must be set in production + token_expiry_hours: int = 24 + registration_enabled: bool = True + + class APIConfig(BaseModel): """REST API server configuration.""" @@ -129,6 +142,7 @@ class DuhConfig(BaseModel): "openai": ProviderConfig(api_key_env="OPENAI_API_KEY"), "google": ProviderConfig(api_key_env="GOOGLE_API_KEY"), "mistral": ProviderConfig(api_key_env="MISTRAL_API_KEY"), + "perplexity": ProviderConfig(api_key_env="PERPLEXITY_API_KEY"), } ) consensus: ConsensusConfig = Field(default_factory=ConsensusConfig) @@ -138,3 +152,4 @@ class DuhConfig(BaseModel): decompose: DecomposeConfig = Field(default_factory=DecomposeConfig) taxonomy: TaxonomyConfig = Field(default_factory=TaxonomyConfig) api: APIConfig = Field(default_factory=APIConfig) + auth: AuthConfig = Field(default_factory=AuthConfig) diff --git a/src/duh/memory/backup.py b/src/duh/memory/backup.py new file mode 100644 index 0000000..8ac24b1 --- /dev/null +++ b/src/duh/memory/backup.py @@ -0,0 +1,267 @@ +"""Database backup and restore utilities.""" + +from __future__ import annotations + +import json +import shutil +from datetime import UTC, datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + + +def detect_db_type(db_url: str) -> str: + """Return 'sqlite' or 'postgresql' based on URL.""" + if db_url.startswith("sqlite"): + return "sqlite" + if db_url.startswith("postgresql") or db_url.startswith("postgres"): + return "postgresql" + return "unknown" + + +async def backup_sqlite(db_url: str, dest: Path) -> Path: + """Copy SQLite file to destination.""" + # Extract file path from sqlite:///path or sqlite+aiosqlite:///path + if ":///" not in db_url: + msg = f"Cannot extract file path from URL: {db_url}" + raise ValueError(msg) + + raw_path = db_url.split("///", 1)[1] + if not raw_path or raw_path == ":memory:": + msg = "Cannot backup an in-memory SQLite database via file copy" + raise ValueError(msg) + + # Expand ~ in paths + src = Path(raw_path).expanduser() + if not src.exists(): + msg = f"SQLite database file not found: {src}" + raise FileNotFoundError(msg) + + dest.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src, dest) + return dest + + +async def backup_json(session: AsyncSession, dest: Path) -> Path: + """Export all tables to portable JSON format.""" + from sqlalchemy import inspect, select + + from duh.memory.models import ( + APIKey, + Contribution, + Decision, + Outcome, + Subtask, + Thread, + ThreadSummary, + Turn, + TurnSummary, + Vote, + ) + + tables: dict[str, type[Any]] = { + "threads": Thread, + "turns": Turn, + "contributions": Contribution, + "turn_summaries": TurnSummary, + "thread_summaries": ThreadSummary, + "decisions": Decision, + "outcomes": Outcome, + "subtasks": Subtask, + "votes": Vote, + "api_keys": APIKey, + } + + # Check if users table exists (may not be migrated yet) + try: + from duh.memory.models import User + + tables["users"] = User + except ImportError: + pass + + data: dict[str, Any] = { + "version": "0.5.0", + "exported_at": datetime.now(UTC).isoformat(), + "tables": {}, + } + + for table_name, model_cls in tables.items(): + try: + stmt = select(model_cls) + result = await session.execute(stmt) + rows = result.scalars().all() + except Exception: + # Table may not exist yet in the database + data["tables"][table_name] = [] + continue + + row_list = [] + for row in rows: + mapper = inspect(type(row)) + row_dict: dict[str, Any] = {} + for col in mapper.columns: + val = getattr(row, col.key) + if isinstance(val, datetime): + val = val.isoformat() + row_dict[col.key] = val + row_list.append(row_dict) + + data["tables"][table_name] = row_list + + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_text(json.dumps(data, indent=2), encoding="utf-8") + return dest + + +def detect_backup_format(source: Path) -> str: + """Detect if backup file is 'sqlite' or 'json'. + + Reads the first bytes of the file to determine format. + + Raises: + ValueError: If the file format cannot be determined. + """ + with open(source, "rb") as f: + header = f.read(16) + + if not header: + msg = f"Cannot detect format: file is empty: {source}" + raise ValueError(msg) + + if header.lstrip()[:1] in (b"{", b"["): + return "json" + if header.startswith(b"SQLite format"): + return "sqlite" + + msg = f"Cannot detect backup format for: {source}" + raise ValueError(msg) + + +async def restore_sqlite(source: Path, db_url: str) -> None: + """Restore SQLite database from a backup file. + + Copies the source file to the database path extracted from the URL, + overwriting the existing database. + """ + if ":///" not in db_url: + msg = f"Cannot extract file path from URL: {db_url}" + raise ValueError(msg) + + raw_path = db_url.split("///", 1)[1] + if not raw_path or raw_path == ":memory:": + msg = "Cannot restore to an in-memory SQLite database" + raise ValueError(msg) + + db_path = Path(raw_path).expanduser() + db_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(source, db_path) + + +async def restore_json( + session: AsyncSession, source: Path, *, merge: bool = False +) -> dict[str, int]: + """Restore database from JSON backup. + + Args: + session: Database session. + source: Path to JSON backup file. + merge: If True, add records (skip conflicts). If False, clear tables first. + + Returns: + dict with counts of restored records per table. + """ + from sqlalchemy import DateTime as SADateTime + from sqlalchemy import delete, inspect + + from duh.memory.models import ( + APIKey, + Contribution, + Decision, + Outcome, + Subtask, + Thread, + ThreadSummary, + Turn, + TurnSummary, + Vote, + ) + + data = json.loads(source.read_text(encoding="utf-8")) + + if "tables" not in data: + msg = "Invalid backup: missing 'tables' key" + raise ValueError(msg) + + # Model map in dependency order (parents before children) + model_map: dict[str, type[Any]] = { + "threads": Thread, + "turns": Turn, + "contributions": Contribution, + "turn_summaries": TurnSummary, + "thread_summaries": ThreadSummary, + "decisions": Decision, + "outcomes": Outcome, + "subtasks": Subtask, + "votes": Vote, + "api_keys": APIKey, + } + + # Check if users table exists + try: + from duh.memory.models import User + + model_map = {"users": User, **model_map} + except ImportError: + pass + + # Delete order is reverse of insert (children before parents) + if not merge: + import contextlib + + delete_order = list(reversed(model_map.keys())) + for table_name in delete_order: + model_cls = model_map[table_name] + with contextlib.suppress(Exception): + await session.execute(delete(model_cls)) + await session.flush() + + counts: dict[str, int] = {} + + for table_name, model_cls in model_map.items(): + rows = data["tables"].get(table_name, []) + if not rows: + counts[table_name] = 0 + continue + + mapper = inspect(model_cls) + col_names = {col.key for col in mapper.columns} + # Identify datetime columns for ISO string parsing + dt_cols = { + col.key + for col in mapper.columns + if isinstance(col.type, SADateTime) + } + + count = 0 + for row_data in rows: + # Filter to only known columns + filtered = {k: v for k, v in row_data.items() if k in col_names} + # Convert ISO datetime strings to Python datetime objects + for key in dt_cols: + val = filtered.get(key) + if isinstance(val, str): + filtered[key] = datetime.fromisoformat(val) + obj = model_cls(**filtered) + if merge: + await session.merge(obj) + else: + session.add(obj) + count += 1 + + counts[table_name] = count + + await session.commit() + return counts diff --git a/src/duh/memory/models.py b/src/duh/memory/models.py index 4508da6..4ef6d7f 100644 --- a/src/duh/memory/models.py +++ b/src/duh/memory/models.py @@ -27,6 +27,29 @@ class Base(DeclarativeBase): """Declarative base for all duh models.""" +# ── Users ──────────────────────────────────────────────────────── + + +class User(Base): + """A registered user.""" + + __tablename__ = "users" + __table_args__ = (Index("ix_users_email", "email", unique=True),) + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid) + email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) + password_hash: Mapped[str] = mapped_column(String(128), nullable=False) + display_name: Mapped[str] = mapped_column(String(100), nullable=False) + role: Mapped[str] = mapped_column(String(20), default="contributor") + is_active: Mapped[bool] = mapped_column(default=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=_utcnow, onupdate=_utcnow + ) + + threads: Mapped[list[Thread]] = relationship(back_populates="user") + + # ── Layer 1: Operational ───────────────────────────────────────── @@ -42,11 +65,15 @@ class Thread(Base): id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid) question: Mapped[str] = mapped_column(Text) status: Mapped[str] = mapped_column(String(20), default="active") + user_id: Mapped[str | None] = mapped_column( + ForeignKey("users.id"), nullable=True, index=True, default=None + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow) updated_at: Mapped[datetime] = mapped_column( DateTime, default=_utcnow, onupdate=_utcnow ) + user: Mapped[User | None] = relationship(back_populates="threads") turns: Mapped[list[Turn]] = relationship( back_populates="thread", cascade="all, delete-orphan", @@ -98,6 +125,7 @@ class Contribution(Base): """A single model's output within a turn.""" __tablename__ = "contributions" + __table_args__ = (Index("ix_contributions_turn_role", "turn_id", "role"),) id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid) turn_id: Mapped[str] = mapped_column(ForeignKey("turns.id"), index=True) @@ -148,10 +176,17 @@ class Decision(Base): """Committed decision from a consensus turn.""" __tablename__ = "decisions" + __table_args__ = ( + Index("ix_decisions_thread_created", "thread_id", "created_at"), + Index("ix_decisions_category_genus", "category", "genus"), + ) id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid) turn_id: Mapped[str] = mapped_column(ForeignKey("turns.id"), unique=True) thread_id: Mapped[str] = mapped_column(ForeignKey("threads.id"), index=True) + user_id: Mapped[str | None] = mapped_column( + ForeignKey("users.id"), nullable=True, index=True, default=None + ) content: Mapped[str] = mapped_column(Text) confidence: Mapped[float] = mapped_column(Float, default=0.0) dissent: Mapped[str | None] = mapped_column(Text, nullable=True, default=None) @@ -235,6 +270,9 @@ class APIKey(Base): id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid) key_hash: Mapped[str] = mapped_column(String(64), unique=True, index=True) name: Mapped[str] = mapped_column(String(100)) + user_id: Mapped[str | None] = mapped_column( + ForeignKey("users.id"), nullable=True, index=True, default=None + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow) revoked_at: Mapped[datetime | None] = mapped_column( DateTime, nullable=True, default=None diff --git a/src/duh/providers/manager.py b/src/duh/providers/manager.py index 51adf24..26cbde2 100644 --- a/src/duh/providers/manager.py +++ b/src/duh/providers/manager.py @@ -2,14 +2,31 @@ from __future__ import annotations +import time from typing import TYPE_CHECKING -from duh.core.errors import CostLimitExceededError, ModelNotFoundError +from duh.core.errors import CostLimitExceededError, ModelNotFoundError, ProviderError if TYPE_CHECKING: from duh.providers.base import ModelInfo, ModelProvider, TokenUsage +class ProviderQuotaExceededError(ProviderError): + """Raised when a provider's configured rate limit is exceeded. + + Distinct from ``ProviderRateLimitError`` in ``core.errors`` which + represents rate limits returned by the external provider API. + This error represents *our* configured per-provider quota. + """ + + def __init__(self, provider_id: str, rate_limit: int) -> None: + super().__init__( + provider_id, + f"Provider quota exceeded: {rate_limit} requests per minute", + ) + self.rate_limit = rate_limit + + class ProviderManager: """Central registry for provider adapters. @@ -29,6 +46,8 @@ def __init__(self, *, cost_hard_limit: float = 0.0) -> None: self._cost_hard_limit = cost_hard_limit self._total_cost: float = 0.0 self._cost_by_provider: dict[str, float] = {} + self._provider_rate_limits: dict[str, int] = {} # provider_id -> rpm + self._provider_requests: dict[str, list[float]] = {} # pid -> ts # ── Registration ───────────────────────────────────────────── @@ -92,13 +111,17 @@ def get_model_info(self, model_ref: str) -> ModelInfo: def get_provider(self, model_ref: str) -> tuple[ModelProvider, str]: """Resolve a model_ref to its provider and model_id. + Also checks provider-level rate limits before returning. + Returns: (provider, model_id) tuple for direct send/stream calls. Raises: ModelNotFoundError: If the model_ref is not in the index. + ProviderQuotaExceededError: If the provider's rate limit is exceeded. """ info = self.get_model_info(model_ref) + self.check_provider_rate_limit(info.provider_id) provider = self._providers[info.provider_id] return provider, info.model_id @@ -150,3 +173,59 @@ def reset_cost(self) -> None: """Reset the cost accumulator to zero.""" self._total_cost = 0.0 self._cost_by_provider.clear() + + # ── Provider rate limiting ─────────────────────────────────── + + def set_provider_rate_limit(self, provider_id: str, rpm: int) -> None: + """Set a per-minute rate limit for a provider. + + Args: + provider_id: The provider identifier. + rpm: Requests per minute. 0 = unlimited. + """ + self._provider_rate_limits[provider_id] = rpm + if provider_id not in self._provider_requests: + self._provider_requests[provider_id] = [] + + def check_provider_rate_limit(self, provider_id: str) -> None: + """Check if a provider's rate limit has been exceeded. + + Raises: + ProviderQuotaExceededError: If the provider's rate limit is exceeded. + """ + rpm = self._provider_rate_limits.get(provider_id, 0) + if rpm <= 0: + return # unlimited + + now = time.monotonic() + window = 60.0 # 1 minute + + # Clean old entries + self._provider_requests.setdefault(provider_id, []) + self._provider_requests[provider_id] = [ + t for t in self._provider_requests[provider_id] if now - t < window + ] + + if len(self._provider_requests[provider_id]) >= rpm: + raise ProviderQuotaExceededError(provider_id, rpm) + + # Record this request + self._provider_requests[provider_id].append(now) + + def get_provider_rate_limit_remaining(self, provider_id: str) -> int | None: + """Get remaining requests for a provider in the current window. + + Returns: + Number of remaining requests, or None if no limit is set. + """ + rpm = self._provider_rate_limits.get(provider_id, 0) + if rpm <= 0: + return None + + now = time.monotonic() + window = 60.0 + self._provider_requests.setdefault(provider_id, []) + current = [ + t for t in self._provider_requests[provider_id] if now - t < window + ] + return max(0, rpm - len(current)) diff --git a/src/duh/providers/perplexity.py b/src/duh/providers/perplexity.py new file mode 100644 index 0000000..26b45ce --- /dev/null +++ b/src/duh/providers/perplexity.py @@ -0,0 +1,300 @@ +"""Perplexity provider adapter (OpenAI-compatible API).""" + +from __future__ import annotations + +import contextlib +import time +from typing import TYPE_CHECKING, Any + +import openai + +from duh.core.errors import ( + ModelNotFoundError, + ProviderAuthError, + ProviderOverloadedError, + ProviderRateLimitError, + ProviderTimeoutError, +) +from duh.providers.base import ( + ModelCapability, + ModelInfo, + ModelResponse, + StreamChunk, + TokenUsage, + ToolCallData, +) + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from duh.providers.base import PromptMessage + +PROVIDER_ID = "perplexity" + +# Known Perplexity models with metadata. +_KNOWN_MODELS: list[dict[str, Any]] = [ + { + "model_id": "sonar", + "display_name": "Sonar", + "context_window": 128_000, + "max_output_tokens": 8_192, + "input_cost_per_mtok": 1.0, + "output_cost_per_mtok": 1.0, + }, + { + "model_id": "sonar-pro", + "display_name": "Sonar Pro", + "context_window": 200_000, + "max_output_tokens": 8_192, + "input_cost_per_mtok": 3.0, + "output_cost_per_mtok": 15.0, + }, + { + "model_id": "sonar-deep-research", + "display_name": "Sonar Deep Research", + "context_window": 128_000, + "max_output_tokens": 8_192, + "input_cost_per_mtok": 2.0, + "output_cost_per_mtok": 8.0, + }, +] + +_DEFAULT_CAPS = ( + ModelCapability.TEXT + | ModelCapability.STREAMING + | ModelCapability.SYSTEM_PROMPT + | ModelCapability.JSON_MODE +) + + +def _map_error(e: openai.APIError) -> Exception: + """Map OpenAI SDK errors to duh error hierarchy.""" + if isinstance(e, openai.AuthenticationError): + return ProviderAuthError(PROVIDER_ID, str(e)) + if isinstance(e, openai.RateLimitError): + retry_after = None + if hasattr(e, "response") and e.response is not None: + raw = e.response.headers.get("retry-after") + if raw is not None: + with contextlib.suppress(ValueError): + retry_after = float(raw) + return ProviderRateLimitError(PROVIDER_ID, retry_after=retry_after) + if isinstance(e, openai.APITimeoutError): + return ProviderTimeoutError(PROVIDER_ID, str(e)) + if isinstance(e, openai.InternalServerError): + return ProviderOverloadedError(PROVIDER_ID, str(e)) + if isinstance(e, openai.NotFoundError): + return ModelNotFoundError(PROVIDER_ID, str(e)) + # Fallback for unknown API errors + return ProviderOverloadedError(PROVIDER_ID, str(e)) + + +def _build_messages( + messages: list[PromptMessage], +) -> list[dict[str, str]]: + """Convert PromptMessages to OpenAI chat message format.""" + return [{"role": msg.role, "content": msg.content} for msg in messages] + + +class PerplexityProvider: + """Provider adapter for Perplexity's OpenAI-compatible API. + + Perplexity uses the OpenAI SDK with a custom base_url. + Responses may include citations which are captured in raw_response. + """ + + def __init__( + self, + api_key: str | None = None, + *, + client: openai.AsyncOpenAI | None = None, + ) -> None: + if client is not None: + self._client = client + else: + kwargs: dict[str, Any] = { + "base_url": "https://api.perplexity.ai", + } + if api_key is not None: + kwargs["api_key"] = api_key + self._client = openai.AsyncOpenAI(**kwargs) + + @property + def provider_id(self) -> str: + return PROVIDER_ID + + async def list_models(self) -> list[ModelInfo]: + return [ + ModelInfo( + provider_id=PROVIDER_ID, + model_id=m["model_id"], + display_name=m["display_name"], + capabilities=_DEFAULT_CAPS, + context_window=m["context_window"], + max_output_tokens=m["max_output_tokens"], + input_cost_per_mtok=m["input_cost_per_mtok"], + output_cost_per_mtok=m["output_cost_per_mtok"], + ) + for m in _KNOWN_MODELS + ] + + async def send( + self, + messages: list[PromptMessage], + model_id: str, + *, + max_tokens: int = 4096, + temperature: float = 0.7, + stop_sequences: list[str] | None = None, + response_format: str | None = None, + tools: list[dict[str, object]] | None = None, + ) -> ModelResponse: + api_messages = _build_messages(messages) + + kwargs: dict[str, Any] = { + "model": model_id, + "max_completion_tokens": max_tokens, + "messages": api_messages, + "temperature": temperature, + } + if stop_sequences: + kwargs["stop"] = stop_sequences + if response_format == "json": + kwargs["response_format"] = {"type": "json_object"} + if tools: + kwargs["tools"] = tools + + start = time.monotonic() + try: + response = await self._client.chat.completions.create(**kwargs) + except openai.APIError as e: + raise _map_error(e) from e + + latency_ms = (time.monotonic() - start) * 1000 + + tool_calls_data: list[ToolCallData] | None = None + if response.choices: + content = response.choices[0].message.content or "" + finish_reason = response.choices[0].finish_reason or "stop" + # Parse tool calls from response + msg_tool_calls = response.choices[0].message.tool_calls + if msg_tool_calls: + tool_calls_data = [ + ToolCallData( + id=tc.id, + name=tc.function.name, + arguments=tc.function.arguments, + ) + for tc in msg_tool_calls + ] + else: + content = "" + finish_reason = "stop" + + if response.usage: + usage = TokenUsage( + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + ) + else: + usage = TokenUsage(input_tokens=0, output_tokens=0) + + model_info = self._resolve_model_info(model_id) + + # Capture citations from Perplexity response if present + citations = getattr(response, "citations", None) + raw = response + if citations is not None: + raw = {"response": response, "citations": citations} + + return ModelResponse( + content=content, + model_info=model_info, + usage=usage, + finish_reason=finish_reason, + latency_ms=latency_ms, + raw_response=raw, + tool_calls=tool_calls_data, + ) + + async def stream( + self, + messages: list[PromptMessage], + model_id: str, + *, + max_tokens: int = 4096, + temperature: float = 0.7, + stop_sequences: list[str] | None = None, + ) -> AsyncIterator[StreamChunk]: + api_messages = _build_messages(messages) + + kwargs: dict[str, Any] = { + "model": model_id, + "max_completion_tokens": max_tokens, + "messages": api_messages, + "temperature": temperature, + "stream_options": {"include_usage": True}, + } + if stop_sequences: + kwargs["stop"] = stop_sequences + + try: + response = await self._client.chat.completions.create( + stream=True, + **kwargs, + ) + usage = None + async for chunk in response: + # Usage arrives in the final chunk (choices empty) + if chunk.usage is not None: + usage = TokenUsage( + input_tokens=chunk.usage.prompt_tokens, + output_tokens=chunk.usage.completion_tokens, + ) + + if chunk.choices and chunk.choices[0].delta.content: + yield StreamChunk( + text=chunk.choices[0].delta.content, + ) + + yield StreamChunk(text="", is_final=True, usage=usage) + + except openai.APIError as e: + raise _map_error(e) from e + + async def health_check(self) -> bool: + try: + await self._client.chat.completions.create( + model="sonar", + max_completion_tokens=1, + messages=[{"role": "user", "content": "ping"}], + ) + except Exception: + return False + return True + + def _resolve_model_info(self, model_id: str) -> ModelInfo: + """Look up ModelInfo for a model_id, or create a generic one.""" + for m in _KNOWN_MODELS: + if m["model_id"] == model_id: + return ModelInfo( + provider_id=PROVIDER_ID, + model_id=model_id, + display_name=m["display_name"], + capabilities=_DEFAULT_CAPS, + context_window=m["context_window"], + max_output_tokens=m["max_output_tokens"], + input_cost_per_mtok=m["input_cost_per_mtok"], + output_cost_per_mtok=m["output_cost_per_mtok"], + ) + # Unknown model -- return generic info + return ModelInfo( + provider_id=PROVIDER_ID, + model_id=model_id, + display_name=f"Perplexity ({model_id})", + capabilities=_DEFAULT_CAPS, + context_window=128_000, + max_output_tokens=8_192, + input_cost_per_mtok=0.0, + output_cost_per_mtok=0.0, + ) diff --git a/tests/load/__init__.py b/tests/load/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/load/conftest.py b/tests/load/conftest.py new file mode 100644 index 0000000..0c0bbc1 --- /dev/null +++ b/tests/load/conftest.py @@ -0,0 +1 @@ +"""Load test configuration.""" diff --git a/tests/load/test_load.py b/tests/load/test_load.py new file mode 100644 index 0000000..60c8b51 --- /dev/null +++ b/tests/load/test_load.py @@ -0,0 +1,417 @@ +"""Load tests for the duh REST API. + +Measures p50/p95/p99 latency, error rates under concurrent load, +and rate limiting behavior. Uses httpx AsyncClient with ASGITransport +for direct ASGI testing (no network required). + +Run with: uv run python -m pytest tests/load/test_load.py -v -m load -s +""" + +from __future__ import annotations + +import asyncio +import statistics +import time +from types import SimpleNamespace +from typing import TYPE_CHECKING + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy import event +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +from duh.api.metrics import MetricsRegistry +from duh.memory.models import Base + +if TYPE_CHECKING: + from fastapi import FastAPI + + +# ── Helpers ───────────────────────────────────────────────────── + + +async def _make_load_app( + *, + rate_limit: int = 1000, + rate_limit_window: int = 60, +) -> FastAPI: + """Create a FastAPI app with in-memory DB for load testing. + + Follows the pattern from tests/unit/test_auth.py:_make_auth_app. + Uses a high default rate limit to avoid throttling during latency tests. + """ + from fastapi import FastAPI + + from duh.api.auth import router as auth_router + from duh.api.health import router as health_router + from duh.api.metrics import router as metrics_router + from duh.api.middleware import APIKeyMiddleware, RateLimitMiddleware + from duh.api.routes.threads import router as threads_router + + engine = create_async_engine("sqlite+aiosqlite://") + + @event.listens_for(engine.sync_engine, "connect") + def _enable_fks(dbapi_conn, connection_record): # type: ignore[no-untyped-def] + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + app = FastAPI(title="duh-load-test") + app.state.config = SimpleNamespace( + auth=SimpleNamespace( + jwt_secret="load-test-secret", + registration_enabled=True, + token_expiry_hours=24, + ), + api=SimpleNamespace( + cors_origins=["http://localhost:3000"], + rate_limit=rate_limit, + rate_limit_window=rate_limit_window, + ), + ) + app.state.db_factory = factory + app.state.engine = engine + + # Middleware (reverse order: auth runs first, then rate-limit) + app.add_middleware( + RateLimitMiddleware, + rate_limit=rate_limit, + window=rate_limit_window, + ) + app.add_middleware(APIKeyMiddleware) + + app.include_router(health_router) + app.include_router(metrics_router) + app.include_router(threads_router) + app.include_router(auth_router) + + return app + + +def _percentile(data: list[float], pct: float) -> float: + """Return the value at the given percentile (0-100).""" + if not data: + return 0.0 + sorted_data = sorted(data) + idx = int(len(sorted_data) * pct / 100) + idx = min(idx, len(sorted_data) - 1) + return sorted_data[idx] + + +def _report_latencies( + label: str, + latencies: list[float], +) -> None: + """Print a summary of latency distribution.""" + p50 = statistics.median(latencies) + p95 = _percentile(latencies, 95) + p99 = _percentile(latencies, 99) + mean = statistics.mean(latencies) + print( + f"\n {label}: " + f"p50={p50:.1f}ms p95={p95:.1f}ms p99={p99:.1f}ms " + f"mean={mean:.1f}ms n={len(latencies)}" + ) + + +# ── Latency tests ────────────────────────────────────────────── + + +@pytest.mark.load +async def test_health_endpoint_latency(): + """Measure p50/p95/p99 latency for GET /api/health.""" + MetricsRegistry.reset() + app = await _make_load_app() + transport = ASGITransport(app=app) # type: ignore[arg-type] + + async with AsyncClient(transport=transport, base_url="http://test") as client: + latencies: list[float] = [] + for _ in range(100): + start = time.perf_counter() + resp = await client.get("/api/health") + elapsed = (time.perf_counter() - start) * 1000 + latencies.append(elapsed) + assert resp.status_code == 200 + + _report_latencies("GET /api/health", latencies) + assert statistics.median(latencies) < 100, "p50 latency should be under 100ms" + await app.state.engine.dispose() + + +@pytest.mark.load +async def test_health_detailed_endpoint_latency(): + """Measure p50/p95/p99 latency for GET /api/health/detailed.""" + MetricsRegistry.reset() + app = await _make_load_app() + transport = ASGITransport(app=app) # type: ignore[arg-type] + + async with AsyncClient(transport=transport, base_url="http://test") as client: + latencies: list[float] = [] + for _ in range(100): + start = time.perf_counter() + resp = await client.get("/api/health/detailed") + elapsed = (time.perf_counter() - start) * 1000 + latencies.append(elapsed) + assert resp.status_code == 200 + + _report_latencies("GET /api/health/detailed", latencies) + assert statistics.median(latencies) < 200, "p50 latency should be under 200ms" + await app.state.engine.dispose() + + +@pytest.mark.load +async def test_threads_endpoint_latency(): + """Measure p50/p95/p99 latency for GET /api/threads (empty list).""" + MetricsRegistry.reset() + app = await _make_load_app() + transport = ASGITransport(app=app) # type: ignore[arg-type] + + async with AsyncClient(transport=transport, base_url="http://test") as client: + latencies: list[float] = [] + for _ in range(100): + start = time.perf_counter() + resp = await client.get("/api/threads") + elapsed = (time.perf_counter() - start) * 1000 + latencies.append(elapsed) + assert resp.status_code == 200 + + _report_latencies("GET /api/threads", latencies) + assert statistics.median(latencies) < 200, "p50 latency should be under 200ms" + await app.state.engine.dispose() + + +@pytest.mark.load +async def test_metrics_endpoint_latency(): + """Measure p50/p95/p99 latency for GET /api/metrics.""" + MetricsRegistry.reset() + app = await _make_load_app() + transport = ASGITransport(app=app) # type: ignore[arg-type] + + async with AsyncClient(transport=transport, base_url="http://test") as client: + latencies: list[float] = [] + for _ in range(100): + start = time.perf_counter() + resp = await client.get("/api/metrics") + elapsed = (time.perf_counter() - start) * 1000 + latencies.append(elapsed) + assert resp.status_code == 200 + + _report_latencies("GET /api/metrics", latencies) + assert statistics.median(latencies) < 100, "p50 latency should be under 100ms" + await app.state.engine.dispose() + + +# ── Concurrent request tests ─────────────────────────────────── + + +async def _run_concurrent( + client: AsyncClient, + method: str, + url: str, + concurrency: int, +) -> tuple[list[float], list[int]]: + """Fire `concurrency` requests in parallel, return (latencies, status_codes).""" + + async def _single_request() -> tuple[float, int]: + start = time.perf_counter() + if method == "GET": + resp = await client.get(url) + else: + resp = await client.post(url) + elapsed = (time.perf_counter() - start) * 1000 + return elapsed, resp.status_code + + results = await asyncio.gather(*[_single_request() for _ in range(concurrency)]) + latencies = [r[0] for r in results] + status_codes = [r[1] for r in results] + return latencies, status_codes + + +@pytest.mark.load +async def test_concurrent_health_10(): + """10 concurrent requests to /api/health -- all should succeed.""" + MetricsRegistry.reset() + app = await _make_load_app() + transport = ASGITransport(app=app) # type: ignore[arg-type] + + async with AsyncClient(transport=transport, base_url="http://test") as client: + latencies, codes = await _run_concurrent(client, "GET", "/api/health", 10) + + _report_latencies("10 concurrent GET /api/health", latencies) + error_count = sum(1 for c in codes if c >= 500) + error_rate = error_count / len(codes) + print(f" Error rate: {error_rate:.1%} ({error_count}/{len(codes)})") + assert error_rate < 0.01, f"Error rate {error_rate:.1%} exceeds 1%" + await app.state.engine.dispose() + + +@pytest.mark.load +async def test_concurrent_health_50(): + """50 concurrent requests to /api/health -- error rate < 1%.""" + MetricsRegistry.reset() + app = await _make_load_app() + transport = ASGITransport(app=app) # type: ignore[arg-type] + + async with AsyncClient(transport=transport, base_url="http://test") as client: + latencies, codes = await _run_concurrent(client, "GET", "/api/health", 50) + + _report_latencies("50 concurrent GET /api/health", latencies) + error_count = sum(1 for c in codes if c >= 500) + error_rate = error_count / len(codes) + print(f" Error rate: {error_rate:.1%} ({error_count}/{len(codes)})") + assert error_rate < 0.01, f"Error rate {error_rate:.1%} exceeds 1%" + await app.state.engine.dispose() + + +@pytest.mark.load +async def test_concurrent_health_100(): + """100 concurrent requests to /api/health -- error rate < 1%.""" + MetricsRegistry.reset() + app = await _make_load_app() + transport = ASGITransport(app=app) # type: ignore[arg-type] + + async with AsyncClient(transport=transport, base_url="http://test") as client: + latencies, codes = await _run_concurrent(client, "GET", "/api/health", 100) + + _report_latencies("100 concurrent GET /api/health", latencies) + error_count = sum(1 for c in codes if c >= 500) + error_rate = error_count / len(codes) + print(f" Error rate: {error_rate:.1%} ({error_count}/{len(codes)})") + assert error_rate < 0.01, f"Error rate {error_rate:.1%} exceeds 1%" + await app.state.engine.dispose() + + +@pytest.mark.load +async def test_concurrent_threads_50(): + """50 concurrent requests to /api/threads -- error rate < 1%.""" + MetricsRegistry.reset() + app = await _make_load_app() + transport = ASGITransport(app=app) # type: ignore[arg-type] + + async with AsyncClient(transport=transport, base_url="http://test") as client: + latencies, codes = await _run_concurrent(client, "GET", "/api/threads", 50) + + _report_latencies("50 concurrent GET /api/threads", latencies) + error_count = sum(1 for c in codes if c >= 500) + error_rate = error_count / len(codes) + print(f" Error rate: {error_rate:.1%} ({error_count}/{len(codes)})") + assert error_rate < 0.01, f"Error rate {error_rate:.1%} exceeds 1%" + await app.state.engine.dispose() + + +@pytest.mark.load +async def test_concurrent_mixed_endpoints_50(): + """50 concurrent requests across health, threads, and metrics.""" + MetricsRegistry.reset() + app = await _make_load_app() + transport = ASGITransport(app=app) # type: ignore[arg-type] + + endpoints = ["/api/health", "/api/threads", "/api/metrics"] + + async def _single_request(url: str) -> tuple[float, int]: + start = time.perf_counter() + resp = await client.get(url) + elapsed = (time.perf_counter() - start) * 1000 + return elapsed, resp.status_code + + async with AsyncClient(transport=transport, base_url="http://test") as client: + tasks = [_single_request(endpoints[i % len(endpoints)]) for i in range(50)] + results = await asyncio.gather(*tasks) + + latencies = [r[0] for r in results] + codes = [r[1] for r in results] + + _report_latencies("50 concurrent mixed endpoints", latencies) + error_count = sum(1 for c in codes if c >= 500) + error_rate = error_count / len(codes) + print(f" Error rate: {error_rate:.1%} ({error_count}/{len(codes)})") + assert error_rate < 0.01, f"Error rate {error_rate:.1%} exceeds 1%" + await app.state.engine.dispose() + + +# ── Rate limiting under load ─────────────────────────────────── + + +@pytest.mark.load +async def test_rate_limiting_under_load(): + """Verify rate limiter triggers when limit is exceeded under concurrent load. + + Sets a low rate limit (10 req/60s) and fires 25 requests concurrently. + Expects some requests to get 429 Too Many Requests. + """ + MetricsRegistry.reset() + app = await _make_load_app(rate_limit=10, rate_limit_window=60) + transport = ASGITransport(app=app) # type: ignore[arg-type] + + async with AsyncClient(transport=transport, base_url="http://test") as client: + _latencies, codes = await _run_concurrent(client, "GET", "/api/health", 25) + + ok_count = sum(1 for c in codes if c == 200) + limited_count = sum(1 for c in codes if c == 429) + + print(f"\n Rate limit test: {ok_count} OK, {limited_count} rate-limited") + print(" Rate limit: 10 req/60s, sent 25 requests concurrently") + + # The rate limiter should have allowed at most 10 requests + assert ok_count <= 10, f"Expected at most 10 OK responses, got {ok_count}" + assert limited_count >= 15, ( + f"Expected at least 15 rate-limited, got {limited_count}" + ) + await app.state.engine.dispose() + + +@pytest.mark.load +async def test_rate_limit_headers_present(): + """Verify rate limit response headers are present under load.""" + MetricsRegistry.reset() + app = await _make_load_app(rate_limit=100, rate_limit_window=60) + transport = ASGITransport(app=app) # type: ignore[arg-type] + + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/api/health") + assert resp.status_code == 200 + assert "x-ratelimit-limit" in resp.headers + assert "x-ratelimit-remaining" in resp.headers + assert resp.headers["x-ratelimit-limit"] == "100" + + await app.state.engine.dispose() + + +# ── Sustained throughput ─────────────────────────────────────── + + +@pytest.mark.load +async def test_sustained_throughput(): + """Run 5 bursts of 20 concurrent requests, verify consistent performance.""" + MetricsRegistry.reset() + app = await _make_load_app() + transport = ASGITransport(app=app) # type: ignore[arg-type] + + burst_p50s: list[float] = [] + + async with AsyncClient(transport=transport, base_url="http://test") as client: + for burst_num in range(5): + latencies, codes = await _run_concurrent(client, "GET", "/api/health", 20) + error_count = sum(1 for c in codes if c >= 500) + assert error_count == 0, f"Burst {burst_num}: {error_count} errors" + + p50 = statistics.median(latencies) + burst_p50s.append(p50) + _report_latencies(f"Burst {burst_num + 1}/5", latencies) + + # Verify no significant degradation across bursts + # Last burst p50 should not be more than 5x the first burst p50 + if burst_p50s[0] > 0: + degradation = burst_p50s[-1] / burst_p50s[0] + print(f"\n Degradation ratio (last/first): {degradation:.2f}x") + assert degradation < 5.0, ( + f"Performance degradation: {degradation:.2f}x " + f"(first p50={burst_p50s[0]:.1f}ms, last p50={burst_p50s[-1]:.1f}ms)" + ) + + await app.state.engine.dispose() diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py new file mode 100644 index 0000000..0b1ecef --- /dev/null +++ b/tests/unit/test_auth.py @@ -0,0 +1,342 @@ +"""Tests for JWT authentication: hashing, tokens, endpoints, middleware.""" + +from __future__ import annotations + +import time +from types import SimpleNamespace + +import jwt +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from sqlalchemy import event +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +from duh.api.auth import ( + create_token, + decode_token, + hash_password, + verify_password, +) +from duh.api.auth import ( + router as auth_router, +) +from duh.api.middleware import APIKeyMiddleware, RateLimitMiddleware +from duh.memory.models import Base + +# ── Helpers ──────────────────────────────────────────────────── + + +async def _make_auth_app( + *, + jwt_secret: str = "test-secret-key", + registration_enabled: bool = True, + token_expiry_hours: int = 24, +) -> FastAPI: + """Create a minimal FastAPI app with auth routes and in-memory DB.""" + engine = create_async_engine("sqlite+aiosqlite://") + + @event.listens_for(engine.sync_engine, "connect") + def _enable_fks(dbapi_conn, connection_record): # type: ignore[no-untyped-def] + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + app = FastAPI(title="test-auth") + app.state.config = SimpleNamespace( + auth=SimpleNamespace( + jwt_secret=jwt_secret, + registration_enabled=registration_enabled, + token_expiry_hours=token_expiry_hours, + ), + api=SimpleNamespace( + cors_origins=["http://localhost:3000"], + rate_limit=100, + rate_limit_window=60, + ), + ) + app.state.db_factory = factory + app.state.engine = engine + + # Add middleware (same order as production) + app.add_middleware(RateLimitMiddleware, rate_limit=100, window=60) + app.add_middleware(APIKeyMiddleware) + + app.include_router(auth_router) + + @app.get("/api/test") + async def test_endpoint() -> dict[str, str]: + return {"msg": "ok"} + + return app + + +async def _register_user( + client: TestClient, + email: str = "test@example.com", + password: str = "strong-pass-123", + display_name: str = "Test User", +) -> dict: # type: ignore[type-arg] + """Helper to register a user and return the response JSON.""" + resp = client.post( + "/api/auth/register", + json={ + "email": email, + "password": password, + "display_name": display_name, + }, + ) + return resp.json() # type: ignore[no-any-return] + + +# ── Password hashing ────────────────────────────────────────── + + +class TestHashPassword: + def test_hash_password(self) -> None: + """Hash produces a valid bcrypt hash string.""" + hashed = hash_password("mypassword") + assert hashed.startswith("$2") + assert len(hashed) == 60 + + def test_verify_password_correct(self) -> None: + """Correct password verifies successfully.""" + hashed = hash_password("correct-password") + assert verify_password("correct-password", hashed) is True + + def test_verify_password_wrong(self) -> None: + """Wrong password fails verification.""" + hashed = hash_password("correct-password") + assert verify_password("wrong-password", hashed) is False + + +# ── JWT tokens ──────────────────────────────────────────────── + + +class TestJWTTokens: + def test_create_token(self) -> None: + """Token is a valid JWT string.""" + token = create_token("user-123", "secret") + assert isinstance(token, str) + assert len(token) > 0 + # Should be decodable + payload = jwt.decode(token, "secret", algorithms=["HS256"]) + assert payload["sub"] == "user-123" + + def test_decode_token_valid(self) -> None: + """Decode returns payload with sub.""" + token = create_token("user-456", "secret", expiry_hours=1) + payload = decode_token(token, "secret") + assert payload["sub"] == "user-456" + assert "exp" in payload + assert "iat" in payload + + def test_decode_token_expired(self) -> None: + """Expired token raises HTTPException.""" + from fastapi import HTTPException + + # Create a token that's already expired + payload = { + "sub": "user-789", + "exp": time.time() - 3600, # 1 hour ago + "iat": time.time() - 7200, + } + token = jwt.encode(payload, "secret", algorithm="HS256") + + with pytest.raises(HTTPException) as exc_info: + decode_token(token, "secret") + assert exc_info.value.status_code == 401 + assert "expired" in exc_info.value.detail.lower() + + def test_decode_token_invalid(self) -> None: + """Invalid token raises HTTPException.""" + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + decode_token("not-a-valid-token", "secret") + assert exc_info.value.status_code == 401 + assert "Invalid token" in exc_info.value.detail + + +# ── Register endpoint ───────────────────────────────────────── + + +class TestRegisterEndpoint: + async def test_register_endpoint(self) -> None: + """POST /api/auth/register creates user and returns token.""" + app = await _make_auth_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post( + "/api/auth/register", + json={ + "email": "new@example.com", + "password": "password123", + "display_name": "New User", + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert "access_token" in data + assert data["token_type"] == "bearer" + assert data["user_id"] + assert data["role"] == "contributor" + + async def test_register_duplicate_email(self) -> None: + """Duplicate email returns 409.""" + app = await _make_auth_app() + client = TestClient(app, raise_server_exceptions=False) + # Register first time + await _register_user(client, email="dup@example.com") + # Register again with same email + resp = client.post( + "/api/auth/register", + json={ + "email": "dup@example.com", + "password": "pass2", + "display_name": "Dup User", + }, + ) + assert resp.status_code == 409 + assert "already registered" in resp.json()["detail"] + + async def test_register_disabled(self) -> None: + """Returns 403 when registration_enabled=False.""" + app = await _make_auth_app(registration_enabled=False) + client = TestClient(app, raise_server_exceptions=False) + resp = client.post( + "/api/auth/register", + json={ + "email": "new@example.com", + "password": "pass", + "display_name": "User", + }, + ) + assert resp.status_code == 403 + assert "disabled" in resp.json()["detail"].lower() + + +# ── Login endpoint ──────────────────────────────────────────── + + +class TestLoginEndpoint: + async def test_login_success(self) -> None: + """Correct credentials return a token.""" + app = await _make_auth_app() + client = TestClient(app, raise_server_exceptions=False) + await _register_user(client, email="login@example.com", password="mypass") + + resp = client.post( + "/api/auth/login", + json={"email": "login@example.com", "password": "mypass"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "access_token" in data + assert data["role"] == "contributor" + + async def test_login_wrong_password(self) -> None: + """Wrong password returns 401.""" + app = await _make_auth_app() + client = TestClient(app, raise_server_exceptions=False) + await _register_user(client, email="wp@example.com", password="correct") + + resp = client.post( + "/api/auth/login", + json={"email": "wp@example.com", "password": "wrong"}, + ) + assert resp.status_code == 401 + assert "Invalid credentials" in resp.json()["detail"] + + async def test_login_nonexistent_user(self) -> None: + """Nonexistent email returns 401.""" + app = await _make_auth_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.post( + "/api/auth/login", + json={"email": "nobody@example.com", "password": "pass"}, + ) + assert resp.status_code == 401 + assert "Invalid credentials" in resp.json()["detail"] + + +# ── /me endpoint ────────────────────────────────────────────── + + +class TestMeEndpoint: + async def test_me_endpoint(self) -> None: + """GET /api/auth/me returns user info when authenticated.""" + app = await _make_auth_app() + client = TestClient(app, raise_server_exceptions=False) + reg_data = await _register_user(client, email="me@example.com") + token = reg_data["access_token"] + + resp = client.get( + "/api/auth/me", + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["email"] == "me@example.com" + assert data["display_name"] == "Test User" + assert data["role"] == "contributor" + assert data["is_active"] is True + + async def test_me_no_token(self) -> None: + """GET /api/auth/me without token returns 401.""" + app = await _make_auth_app() + client = TestClient(app, raise_server_exceptions=False) + resp = client.get("/api/auth/me") + assert resp.status_code == 401 + + +# ── JWT middleware integration ──────────────────────────────── + + +class TestJWTMiddlewareIntegration: + async def test_bearer_token_accepted_by_middleware(self) -> None: + """Bearer token accepted by APIKeyMiddleware as alternative to X-API-Key.""" + app = await _make_auth_app() + client = TestClient(app, raise_server_exceptions=False) + + # Seed an API key so the middleware would reject requests without auth + from duh.api.middleware import hash_api_key + from duh.memory.repository import MemoryRepository + + async with app.state.db_factory() as session: + repo = MemoryRepository(session) + await repo.create_api_key("test-key", hash_api_key("secret-key")) + await session.commit() + + # Register a user and get a JWT token + reg_data = await _register_user(client, email="jwt@example.com") + token = reg_data["access_token"] + + # Access /api/test with Bearer token (no X-API-Key) + resp = client.get( + "/api/test", + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + assert resp.json() == {"msg": "ok"} + + async def test_no_auth_rejected_when_keys_exist(self) -> None: + """Request with no auth is rejected when API keys exist in DB.""" + app = await _make_auth_app() + client = TestClient(app, raise_server_exceptions=False) + + from duh.api.middleware import hash_api_key + from duh.memory.repository import MemoryRepository + + async with app.state.db_factory() as session: + repo = MemoryRepository(session) + await repo.create_api_key("test-key", hash_api_key("secret-key")) + await session.commit() + + # Access /api/test with no auth at all + resp = client.get("/api/test") + assert resp.status_code == 401 diff --git a/tests/unit/test_backup.py b/tests/unit/test_backup.py new file mode 100644 index 0000000..8e515ae --- /dev/null +++ b/tests/unit/test_backup.py @@ -0,0 +1,315 @@ +"""Tests for database backup utilities and CLI command.""" + +from __future__ import annotations + +import asyncio +import json +import sqlite3 +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pathlib import Path +from unittest.mock import AsyncMock, patch + +import pytest +from click.testing import CliRunner + +from duh.cli.app import cli +from duh.memory.backup import backup_json, backup_sqlite, detect_db_type + +# ── detect_db_type ────────────────────────────────────────────── + + +class TestDetectDbType: + def test_sqlite_plain(self) -> None: + assert detect_db_type("sqlite:///path/to/db.sqlite") == "sqlite" + + def test_sqlite_aiosqlite(self) -> None: + assert detect_db_type("sqlite+aiosqlite:///path/to/db.db") == "sqlite" + + def test_sqlite_memory(self) -> None: + assert detect_db_type("sqlite+aiosqlite://") == "sqlite" + + def test_postgresql_plain(self) -> None: + assert detect_db_type("postgresql://user:pass@host/db") == "postgresql" + + def test_postgresql_asyncpg(self) -> None: + assert detect_db_type("postgresql+asyncpg://user:pass@host/db") == "postgresql" + + def test_postgres_shorthand(self) -> None: + assert detect_db_type("postgres://user:pass@host/db") == "postgresql" + + def test_unknown_url(self) -> None: + assert detect_db_type("mysql://user:pass@host/db") == "unknown" + + +# ── backup_sqlite ─────────────────────────────────────────────── + + +class TestBackupSqlite: + def test_copies_file(self, tmp_path: Path) -> None: + """Create a temp SQLite DB, backup, verify copy exists and is valid.""" + src_db = tmp_path / "source.db" + conn = sqlite3.connect(str(src_db)) + conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") + conn.execute("INSERT INTO test VALUES (1, 'hello')") + conn.commit() + conn.close() + + dest = tmp_path / "backup" / "backup.db" + db_url = f"sqlite:///{src_db}" + + result = asyncio.run(backup_sqlite(db_url, dest)) + + assert result == dest + assert dest.exists() + # Verify the copy is a valid SQLite database + conn2 = sqlite3.connect(str(dest)) + rows = conn2.execute("SELECT * FROM test").fetchall() + conn2.close() + assert rows == [(1, "hello")] + + def test_memory_db_raises(self, tmp_path: Path) -> None: + dest = tmp_path / "backup.db" + with pytest.raises(ValueError, match="Cannot"): + asyncio.run(backup_sqlite("sqlite+aiosqlite://", dest)) + + def test_memory_db_triple_slash_raises(self, tmp_path: Path) -> None: + dest = tmp_path / "backup.db" + with pytest.raises(ValueError, match="in-memory"): + asyncio.run(backup_sqlite("sqlite+aiosqlite:///:memory:", dest)) + + def test_missing_source_raises(self, tmp_path: Path) -> None: + dest = tmp_path / "backup.db" + with pytest.raises(FileNotFoundError): + asyncio.run( + backup_sqlite("sqlite:///nonexistent/path/db.sqlite", dest) + ) + + def test_aiosqlite_url(self, tmp_path: Path) -> None: + """Works with sqlite+aiosqlite:/// prefix too.""" + src_db = tmp_path / "source.db" + conn = sqlite3.connect(str(src_db)) + conn.execute("CREATE TABLE t (x INTEGER)") + conn.commit() + conn.close() + + dest = tmp_path / "copy.db" + db_url = f"sqlite+aiosqlite:///{src_db}" + + result = asyncio.run(backup_sqlite(db_url, dest)) + assert result == dest + assert dest.exists() + + +# ── backup_json ───────────────────────────────────────────────── + + +def _make_async_session() -> tuple[Any, Any]: + """Create an in-memory SQLite async session with tables created.""" + from sqlalchemy import event + from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + from sqlalchemy.pool import StaticPool + + from duh.memory.models import Base + + engine = create_async_engine( + "sqlite+aiosqlite://", + poolclass=StaticPool, + connect_args={"check_same_thread": False}, + ) + + @event.listens_for(engine.sync_engine, "connect") + def _enable_fks(dbapi_conn, connection_record): # type: ignore[no-untyped-def] + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + async def _init() -> None: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + asyncio.run(_init()) + factory = async_sessionmaker(engine, expire_on_commit=False) + return factory, engine + + +class TestBackupJson: + def test_exports_all_tables(self, tmp_path: Path) -> None: + """Export to JSON and verify structure includes all expected tables.""" + factory, engine = _make_async_session() + + async def _run() -> Path: + async with factory() as session: + return await backup_json(session, tmp_path / "backup.json") + + result = asyncio.run(_run()) + assert result.exists() + + data = json.loads(result.read_text()) + expected_tables = { + "threads", + "turns", + "contributions", + "turn_summaries", + "thread_summaries", + "decisions", + "outcomes", + "subtasks", + "votes", + "api_keys", + } + assert expected_tables.issubset(set(data["tables"].keys())) + asyncio.run(engine.dispose()) + + def test_version_field(self, tmp_path: Path) -> None: + """Verify version and exported_at in output.""" + factory, engine = _make_async_session() + + async def _run() -> Path: + async with factory() as session: + return await backup_json(session, tmp_path / "backup.json") + + result = asyncio.run(_run()) + data = json.loads(result.read_text()) + + assert data["version"] == "0.5.0" + assert "exported_at" in data + # Verify exported_at is a valid ISO timestamp + assert "T" in data["exported_at"] + asyncio.run(engine.dispose()) + + def test_includes_data(self, tmp_path: Path) -> None: + """Create data in DB, export to JSON, verify rows present.""" + factory, engine = _make_async_session() + + async def _seed_and_export() -> Path: + from duh.memory.repository import MemoryRepository + + async with factory() as session: + repo = MemoryRepository(session) + thread = await repo.create_thread("Test question") + turn = await repo.create_turn(thread.id, 1, "COMMIT") + await repo.add_contribution( + turn.id, "mock:model", "proposer", "Answer" + ) + await repo.save_decision( + turn.id, thread.id, "Decision text", 0.9 + ) + await session.commit() + + async with factory() as session: + return await backup_json(session, tmp_path / "backup.json") + + result = asyncio.run(_seed_and_export()) + data = json.loads(result.read_text()) + + assert len(data["tables"]["threads"]) == 1 + assert data["tables"]["threads"][0]["question"] == "Test question" + assert len(data["tables"]["turns"]) == 1 + assert len(data["tables"]["contributions"]) == 1 + assert len(data["tables"]["decisions"]) == 1 + asyncio.run(engine.dispose()) + + def test_empty_db(self, tmp_path: Path) -> None: + """Backup works on empty database — all tables present but empty.""" + factory, engine = _make_async_session() + + async def _run() -> Path: + async with factory() as session: + return await backup_json(session, tmp_path / "backup.json") + + result = asyncio.run(_run()) + data = json.loads(result.read_text()) + + for table_name, rows in data["tables"].items(): + assert isinstance(rows, list), f"Table {table_name} should be a list" + assert len(rows) == 0, f"Table {table_name} should be empty" + + asyncio.run(engine.dispose()) + + +# ── CLI command ───────────────────────────────────────────────── + + +class TestBackupCli: + @pytest.fixture + def runner(self) -> CliRunner: + return CliRunner() + + def test_help(self, runner: CliRunner) -> None: + result = runner.invoke(cli, ["backup", "--help"]) + assert result.exit_code == 0 + assert "PATH" in result.output + assert "--format" in result.output + + def test_backup_json_via_cli(self, runner: CliRunner, tmp_path: Path) -> None: + """Use CliRunner to test the CLI command with a temp DB.""" + factory, engine = _make_async_session() + from duh.config.schema import DatabaseConfig, DuhConfig + + config = DuhConfig( + database=DatabaseConfig(url="sqlite+aiosqlite://"), + ) + + dest = tmp_path / "cli_backup.json" + + with ( + patch("duh.cli.app.load_config", return_value=config), + patch( + "duh.cli.app._create_db", + new_callable=AsyncMock, + return_value=(factory, engine), + ), + ): + result = runner.invoke(cli, ["backup", "--format", "json", str(dest)]) + + assert result.exit_code == 0, result.output + assert "Backup saved to" in result.output + assert dest.exists() + + data = json.loads(dest.read_text()) + assert data["version"] == "0.5.0" + asyncio.run(engine.dispose()) + + def test_backup_format_auto_sqlite( + self, runner: CliRunner, tmp_path: Path + ) -> None: + """Auto format uses sqlite copy for sqlite DB.""" + src_db = tmp_path / "source.db" + conn = sqlite3.connect(str(src_db)) + conn.execute("CREATE TABLE t (id INTEGER)") + conn.commit() + conn.close() + + from duh.config.schema import DatabaseConfig, DuhConfig + + config = DuhConfig( + database=DatabaseConfig(url=f"sqlite+aiosqlite:///{src_db}"), + ) + + dest = tmp_path / "auto_backup.db" + + with patch("duh.cli.app.load_config", return_value=config): + result = runner.invoke(cli, ["backup", str(dest)]) + + assert result.exit_code == 0, result.output + assert "Backup saved to" in result.output + assert dest.exists() + + def test_backup_sqlite_format_pg_errors( + self, runner: CliRunner, tmp_path: Path + ) -> None: + """Cannot use sqlite backup format for a PostgreSQL database.""" + from duh.config.schema import DatabaseConfig, DuhConfig + + config = DuhConfig( + database=DatabaseConfig(url="postgresql+asyncpg://user:pass@host/db"), + ) + + dest = tmp_path / "backup.db" + + with patch("duh.cli.app.load_config", return_value=config): + result = runner.invoke(cli, ["backup", "--format", "sqlite", str(dest)]) + + assert result.exit_code != 0 diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 5a870a3..341d36a 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -30,7 +30,7 @@ def test_version(self, runner: CliRunner) -> None: result = runner.invoke(cli, ["--version"]) assert result.exit_code == 0 assert "duh" in result.output - assert "0.4.0" in result.output + assert "0.5.0" in result.output def test_help(self, runner: CliRunner) -> None: result = runner.invoke(cli, ["--help"]) diff --git a/tests/unit/test_connection_pool.py b/tests/unit/test_connection_pool.py new file mode 100644 index 0000000..1cbdc98 --- /dev/null +++ b/tests/unit/test_connection_pool.py @@ -0,0 +1,298 @@ +"""Tests for connection pooling optimization (T5). + +Verifies: +- SQLite in-memory uses StaticPool +- SQLite file uses NullPool +- PostgreSQL uses configured pool settings (pool_size, max_overflow, etc.) +- pool_pre_ping=True enabled for PostgreSQL +- Repository queries use selectinload for eager loading +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from duh.config.schema import DatabaseConfig, DuhConfig + + +def _mock_engine(): + """Create a mock async engine with proper async context managers.""" + engine = MagicMock() + conn = AsyncMock() + conn.run_sync = AsyncMock() + engine.begin.return_value.__aenter__ = AsyncMock(return_value=conn) + engine.begin.return_value.__aexit__ = AsyncMock(return_value=False) + return engine + + +# ── SQLite Memory: StaticPool ──────────────────────────────── + + +class TestSQLiteMemoryUsesStaticPool: + @pytest.mark.asyncio + async def test_sqlite_memory_uses_static_pool(self) -> None: + """In-memory SQLite must use StaticPool so all queries share one connection.""" + from duh.cli.app import _create_db + + config = DuhConfig( + database=DatabaseConfig(url="sqlite+aiosqlite:///:memory:") + ) + + mock_engine = _mock_engine() + with patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create, patch( + "sqlalchemy.event.listens_for", + return_value=lambda fn: fn, + ): + await _create_db(config) + + mock_create.assert_called_once() + call_kwargs = mock_create.call_args.kwargs + from sqlalchemy.pool import StaticPool + + assert call_kwargs.get("poolclass") is StaticPool + assert call_kwargs.get("connect_args") == {"check_same_thread": False} + # Should NOT have pool_size for sqlite memory + assert "pool_size" not in call_kwargs + assert "pool_pre_ping" not in call_kwargs + + +# ── SQLite File: NullPool ──────────────────────────────────── + + +class TestSQLiteFileUsesNullPool: + @pytest.mark.asyncio + async def test_sqlite_file_uses_null_pool(self, tmp_path) -> None: + """File-based SQLite must use NullPool (no connection pooling).""" + from duh.cli.app import _create_db + + config = DuhConfig( + database=DatabaseConfig( + url=f"sqlite+aiosqlite:///{tmp_path}/test.db" + ) + ) + + mock_engine = _mock_engine() + with patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create, patch( + "sqlalchemy.event.listens_for", + return_value=lambda fn: fn, + ): + await _create_db(config) + + mock_create.assert_called_once() + call_kwargs = mock_create.call_args.kwargs + from sqlalchemy.pool import NullPool + + assert call_kwargs.get("poolclass") is NullPool + # Should NOT have pool_size for sqlite + assert "pool_size" not in call_kwargs + assert "pool_pre_ping" not in call_kwargs + + +# ── PostgreSQL: Configured Pool ────────────────────────────── + + +class TestPostgreSQLUsesConfiguredPoolSize: + @pytest.mark.asyncio + async def test_postgresql_uses_configured_pool_size(self) -> None: + """PostgreSQL uses QueuePool with user-configured pool_size and max_overflow.""" + from duh.cli.app import _create_db + + config = DuhConfig( + database=DatabaseConfig( + url="postgresql+asyncpg://user:pass@localhost/duh", + pool_size=20, + max_overflow=40, + pool_timeout=45, + pool_recycle=7200, + ) + ) + + mock_engine = _mock_engine() + with patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create: + await _create_db(config) + + mock_create.assert_called_once() + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["pool_size"] == 20 + assert call_kwargs["max_overflow"] == 40 + assert call_kwargs["pool_timeout"] == 45 + assert call_kwargs["pool_recycle"] == 7200 + # Should NOT have poolclass for postgresql (uses default QueuePool) + assert "poolclass" not in call_kwargs + + @pytest.mark.asyncio + async def test_postgresql_default_pool_settings(self) -> None: + """PostgreSQL with default pool settings.""" + from duh.cli.app import _create_db + + config = DuhConfig( + database=DatabaseConfig( + url="postgresql+asyncpg://localhost/duh", + ) + ) + + mock_engine = _mock_engine() + with patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create: + await _create_db(config) + + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["pool_size"] == 5 # default + assert call_kwargs["max_overflow"] == 10 # default + assert call_kwargs["pool_timeout"] == 30 # default + assert call_kwargs["pool_recycle"] == 3600 # default + + +# ── pool_pre_ping for PostgreSQL ───────────────────────────── + + +class TestPoolPrePingEnabledForPostgreSQL: + @pytest.mark.asyncio + async def test_pool_pre_ping_enabled_for_postgresql(self) -> None: + """PostgreSQL connections must use pool_pre_ping=True.""" + from duh.cli.app import _create_db + + config = DuhConfig( + database=DatabaseConfig( + url="postgresql+asyncpg://user:pass@localhost/duh", + ) + ) + + mock_engine = _mock_engine() + with patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create: + await _create_db(config) + + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs.get("pool_pre_ping") is True + + @pytest.mark.asyncio + async def test_pool_pre_ping_not_set_for_sqlite(self, tmp_path) -> None: + """SQLite should NOT set pool_pre_ping (irrelevant for NullPool/StaticPool).""" + from duh.cli.app import _create_db + + config = DuhConfig( + database=DatabaseConfig( + url=f"sqlite+aiosqlite:///{tmp_path}/test.db" + ) + ) + + mock_engine = _mock_engine() + with patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create, patch( + "sqlalchemy.event.listens_for", + return_value=lambda fn: fn, + ): + await _create_db(config) + + call_kwargs = mock_create.call_args.kwargs + assert "pool_pre_ping" not in call_kwargs + + +# ── Repository uses selectinload ───────────────────────────── + + +class TestRepositoryUsesSelectinload: + @pytest.mark.asyncio + async def test_get_thread_uses_selectinload(self, db_session) -> None: + """get_thread eagerly loads turns, contributions, decisions, summaries.""" + from duh.memory.repository import MemoryRepository + + repo = MemoryRepository(db_session) + thread = await repo.create_thread("test question") + + # Create a turn with a contribution + turn = await repo.create_turn(thread.id, 1, "PROPOSE") + await repo.add_contribution( + turn.id, "test:model", "proposer", "test content" + ) + await db_session.commit() + + # Load the thread (should eagerly load turns and contributions) + loaded = await repo.get_thread(thread.id) + assert loaded is not None + assert len(loaded.turns) == 1 + assert len(loaded.turns[0].contributions) == 1 + assert loaded.turns[0].contributions[0].content == "test content" + + @pytest.mark.asyncio + async def test_get_turn_uses_selectinload(self, db_session) -> None: + """get_turn eagerly loads contributions, decision, summary.""" + from duh.memory.repository import MemoryRepository + + repo = MemoryRepository(db_session) + thread = await repo.create_thread("test question") + turn = await repo.create_turn(thread.id, 1, "PROPOSE") + await repo.add_contribution( + turn.id, "test:model", "proposer", "test content" + ) + await db_session.commit() + + loaded = await repo.get_turn(turn.id) + assert loaded is not None + assert len(loaded.contributions) == 1 + + @pytest.mark.asyncio + async def test_get_decisions_with_outcomes_uses_selectinload( + self, db_session + ) -> None: + """get_decisions_with_outcomes eagerly loads outcome relationship.""" + from duh.memory.repository import MemoryRepository + + repo = MemoryRepository(db_session) + thread = await repo.create_thread("test question") + turn = await repo.create_turn(thread.id, 1, "COMMIT") + decision = await repo.save_decision( + turn.id, thread.id, "test decision", 0.9 + ) + await repo.save_outcome(decision.id, thread.id, "success", notes="worked") + await db_session.commit() + + decisions = await repo.get_decisions_with_outcomes(thread.id) + assert len(decisions) == 1 + assert decisions[0].outcome is not None + assert decisions[0].outcome.result == "success" + + @pytest.mark.asyncio + async def test_get_all_decisions_for_space_uses_selectinload( + self, db_session + ) -> None: + """get_all_decisions_for_space eagerly loads outcome and thread.""" + from duh.memory.repository import MemoryRepository + + repo = MemoryRepository(db_session) + thread = await repo.create_thread("test question") + turn = await repo.create_turn(thread.id, 1, "COMMIT") + decision = await repo.save_decision( + turn.id, + thread.id, + "test decision", + 0.9, + category="technical", + genus="architecture", + ) + await repo.save_outcome(decision.id, thread.id, "success") + await db_session.commit() + + decisions = await repo.get_all_decisions_for_space() + assert len(decisions) == 1 + assert decisions[0].outcome is not None + assert decisions[0].thread is not None + assert decisions[0].thread.question == "test question" diff --git a/tests/unit/test_health.py b/tests/unit/test_health.py new file mode 100644 index 0000000..d43cc16 --- /dev/null +++ b/tests/unit/test_health.py @@ -0,0 +1,182 @@ +"""Tests for health check endpoints.""" + +from __future__ import annotations + +import hashlib +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from sqlalchemy import event +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +from duh.api.health import router as health_router +from duh.api.middleware import APIKeyMiddleware +from duh.memory.models import Base + + +@pytest.fixture +async def health_app(): + """FastAPI app with health router and an in-memory DB.""" + engine = create_async_engine("sqlite+aiosqlite://") + + @event.listens_for(engine.sync_engine, "connect") + def _enable_fks(dbapi_conn, connection_record): + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + factory = async_sessionmaker(engine, expire_on_commit=False) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + app = FastAPI() + app.state.db_factory = factory + app.state.engine = engine + app.include_router(health_router) + + yield app + + await engine.dispose() + + +class TestHealthBasic: + def test_health_basic(self): + """GET /api/health returns {"status": "ok"}.""" + from duh.api.app import create_app + from duh.config.schema import DuhConfig + + config = DuhConfig() + config.database.url = "sqlite+aiosqlite:///:memory:" + app = create_app(config) + client = TestClient(app, raise_server_exceptions=False) + resp = client.get("/api/health") + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} + + +class TestHealthDetailed: + async def test_health_detailed_ok(self, health_app): + """GET /api/health/detailed returns status, version, uptime, components.""" + client = TestClient(health_app, raise_server_exceptions=False) + resp = client.get("/api/health/detailed") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert "version" in data + assert "uptime_seconds" in data + assert "components" in data + + async def test_health_detailed_db_check(self, health_app): + """Database component shows ok with working DB.""" + client = TestClient(health_app, raise_server_exceptions=False) + resp = client.get("/api/health/detailed") + data = resp.json() + assert data["components"]["database"]["status"] == "ok" + + async def test_health_detailed_db_failure(self, health_app): + """Database component shows error when DB fails.""" + # Replace db_factory with one that raises + async def broken_factory(): + raise RuntimeError("DB is down") + + health_app.state.db_factory = MagicMock(side_effect=RuntimeError("DB is down")) + + client = TestClient(health_app, raise_server_exceptions=False) + resp = client.get("/api/health/detailed") + data = resp.json() + assert data["components"]["database"]["status"] == "error" + assert "DB is down" in data["components"]["database"]["detail"] + assert data["status"] == "degraded" + + async def test_health_detailed_provider_healthy(self, health_app): + """Provider shows ok when health_check returns True.""" + mock_provider = AsyncMock() + mock_provider.health_check.return_value = True + + pm = MagicMock() + pm._providers = {"test-provider": mock_provider} + health_app.state.provider_manager = pm + + client = TestClient(health_app, raise_server_exceptions=False) + resp = client.get("/api/health/detailed") + data = resp.json() + assert data["components"]["providers"]["test-provider"]["status"] == "ok" + assert data["status"] == "ok" + + async def test_health_detailed_provider_unhealthy(self, health_app): + """Provider shows unhealthy, status = degraded when all providers fail.""" + mock_provider = AsyncMock() + mock_provider.health_check.return_value = False + + pm = MagicMock() + pm._providers = {"bad-provider": mock_provider} + health_app.state.provider_manager = pm + + client = TestClient(health_app, raise_server_exceptions=False) + resp = client.get("/api/health/detailed") + data = resp.json() + assert data["components"]["providers"]["bad-provider"]["status"] == "unhealthy" + assert data["status"] == "degraded" + + async def test_health_detailed_uptime(self, health_app): + """Uptime is a positive number.""" + client = TestClient(health_app, raise_server_exceptions=False) + resp = client.get("/api/health/detailed") + data = resp.json() + assert isinstance(data["uptime_seconds"], (int, float)) + assert data["uptime_seconds"] >= 0 + + +class TestHealthNoAuth: + async def test_health_no_auth_required(self): + """Both health endpoints are accessible without API key.""" + engine = create_async_engine("sqlite+aiosqlite://") + + @event.listens_for(engine.sync_engine, "connect") + def _enable_fks(dbapi_conn, connection_record): + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + factory = async_sessionmaker(engine, expire_on_commit=False) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Seed an API key so auth is enforced + from duh.memory.repository import MemoryRepository + + async with factory() as session: + repo = MemoryRepository(session) + await repo.create_api_key( + "test-key", + hashlib.sha256(b"secret").hexdigest(), + ) + await session.commit() + + app = FastAPI() + app.state.db_factory = factory + app.state.engine = engine + app.add_middleware(APIKeyMiddleware) + app.include_router(health_router) + + @app.get("/api/protected") + async def protected(): + return {"ok": True} + + client = TestClient(app, raise_server_exceptions=False) + + # /api/health should be accessible without API key + resp = client.get("/api/health") + assert resp.status_code == 200 + + # /api/health/detailed should be accessible without API key + resp2 = client.get("/api/health/detailed") + assert resp2.status_code == 200 + + # A non-exempt API path should fail without a key + resp3 = client.get("/api/protected") + assert resp3.status_code == 401 + + await engine.dispose() diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py new file mode 100644 index 0000000..a3762a0 --- /dev/null +++ b/tests/unit/test_metrics.py @@ -0,0 +1,177 @@ +"""Tests for lightweight Prometheus metrics module.""" + +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from duh.api.metrics import ( + Counter, + Gauge, + Histogram, + MetricsRegistry, +) + + +@pytest.fixture(autouse=True) +def _reset_registry(): + """Reset the global metrics registry before each test.""" + MetricsRegistry.reset() + yield + MetricsRegistry.reset() + + +class TestCounter: + def test_counter_inc(self): + c = Counter("test_total", "A test counter") + c.inc() + c.inc(3.0) + text = c.collect() + assert "test_total 4" in text + + def test_counter_labels(self): + c = Counter("req_total", "Requests", labels=["method", "status"]) + c.inc(method="GET", status="200") + c.inc(method="GET", status="200") + c.inc(method="POST", status="201") + text = c.collect() + assert '# HELP req_total Requests' in text + assert '# TYPE req_total counter' in text + assert 'req_total{method="GET",status="200"} 2' in text + assert 'req_total{method="POST",status="201"} 1' in text + + +class TestHistogram: + def test_histogram_observe(self): + h = Histogram("dur", "Duration", buckets=[0.1, 0.5, 1.0]) + h.observe(0.05) # fits in 0.1, 0.5, 1.0 + h.observe(0.3) # fits in 0.5, 1.0 + h.observe(0.8) # fits in 1.0 + h.observe(2.0) # exceeds all buckets + + text = h.collect() + # Cumulative counts: 0.1→1, 0.5→2, 1.0→3, +Inf→4 + assert 'dur_bucket{le="0.1"} 1' in text + assert 'dur_bucket{le="0.5"} 2' in text + assert 'dur_bucket{le="1"} 3' in text + assert 'dur_bucket{le="+Inf"} 4' in text + + def test_histogram_collect(self): + h = Histogram("lat", "Latency", buckets=[0.01, 0.1]) + h.observe(0.005) + h.observe(0.05) + text = h.collect() + assert '# HELP lat Latency' in text + assert '# TYPE lat histogram' in text + assert 'lat_bucket{le="0.01"} 1' in text + assert 'lat_bucket{le="0.1"} 2' in text + assert 'lat_bucket{le="+Inf"} 2' in text + assert 'lat_sum 0.055' in text + assert 'lat_count 2' in text + + +class TestGauge: + def test_gauge_set_inc_dec(self): + g = Gauge("conn", "Connections") + g.set(5.0) + text = g.collect() + assert "conn 5" in text + + g.inc(3.0) + text = g.collect() + assert "conn 8" in text + + g.dec(2.0) + text = g.collect() + assert "conn 6" in text + + g.inc() + text = g.collect() + assert "conn 7" in text + + g.dec() + text = g.collect() + assert "conn 6" in text + + +class TestMetricsRegistry: + def test_registry_collect_all(self): + c = Counter("app_total", "App counter") + g = Gauge("app_gauge", "App gauge") + c.inc(10.0) + g.set(42.0) + + registry = MetricsRegistry.get() + output = registry.collect_all() + assert "app_total 10" in output + assert "app_gauge 42" in output + assert "# HELP app_total" in output + assert "# HELP app_gauge" in output + + +class TestMetricsEndpoint: + def test_metrics_endpoint(self): + from duh.api.app import create_app + from duh.config.schema import DuhConfig + + config = DuhConfig() + config.database.url = "sqlite+aiosqlite:///:memory:" + app = create_app(config) + client = TestClient(app, raise_server_exceptions=False) + resp = client.get("/api/metrics") + assert resp.status_code == 200 + assert "text/plain" in resp.headers["content-type"] + + async def test_metrics_no_auth_required(self): + """Metrics endpoint is exempt from API key middleware.""" + import hashlib + + from fastapi import FastAPI + from sqlalchemy import event + from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + + from duh.api.metrics import router as metrics_router + from duh.api.middleware import APIKeyMiddleware + from duh.memory.models import Base + from duh.memory.repository import MemoryRepository + + engine = create_async_engine("sqlite+aiosqlite://") + + @event.listens_for(engine.sync_engine, "connect") + def _enable_fks(dbapi_conn, connection_record): + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + factory = async_sessionmaker(engine, expire_on_commit=False) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Seed an API key so auth is enforced + async with factory() as session: + repo = MemoryRepository(session) + await repo.create_api_key( + "test-key", + hashlib.sha256(b"secret").hexdigest(), + ) + await session.commit() + + app = FastAPI() + app.state.db_factory = factory + app.state.engine = engine + app.add_middleware(APIKeyMiddleware) + app.include_router(metrics_router) + + @app.get("/api/protected") + async def protected(): + return {"ok": True} + + client = TestClient(app, raise_server_exceptions=False) + + # Without API key, metrics should still be accessible + resp = client.get("/api/metrics") + assert resp.status_code == 200 + + # A non-exempt API path should fail without a key + resp2 = client.get("/api/protected") + assert resp2.status_code == 401 diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index 4373a20..19e9959 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -330,6 +330,7 @@ def test_decision_thread_id_index(self) -> None: def test_all_tables_created(self) -> None: expected = { + "users", "threads", "turns", "contributions", diff --git a/tests/unit/test_multi_user_integration.py b/tests/unit/test_multi_user_integration.py new file mode 100644 index 0000000..5208c33 --- /dev/null +++ b/tests/unit/test_multi_user_integration.py @@ -0,0 +1,686 @@ +"""Multi-user integration tests for v0.5 user accounts, JWT auth, and RBAC. + +Tests: +- User isolation: threads are scoped to their owner at the data layer +- Admin sees all: admin user can see threads from all users +- Registration flow: register -> login -> /me +- Role enforcement: viewer < contributor < admin +- Per-user rate limiting: independent limits per JWT identity +- JWT token validation: expired, invalid, missing tokens rejected +- User deactivation: deactivated user's JWT is rejected +""" + +from __future__ import annotations + +import time +from types import SimpleNamespace + +import jwt +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient +from sqlalchemy import event, select +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +from duh.api.auth import create_token, hash_password +from duh.api.auth import router as auth_router +from duh.api.middleware import APIKeyMiddleware, RateLimitMiddleware +from duh.api.rbac import require_admin, require_contributor, require_viewer +from duh.memory.models import Base, Thread, User + +# ── Helpers ──────────────────────────────────────────────────── + + +async def _make_multi_user_app( + *, + jwt_secret: str = "test-secret-key-32chars-long!!!!", + registration_enabled: bool = True, + token_expiry_hours: int = 24, + rate_limit: int = 100, + rate_limit_window: int = 60, +) -> FastAPI: + """Create a FastAPI app with auth, RBAC endpoints, and in-memory DB.""" + engine = create_async_engine("sqlite+aiosqlite://") + + @event.listens_for(engine.sync_engine, "connect") + def _enable_fks(dbapi_conn, connection_record): # type: ignore[no-untyped-def] + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + app = FastAPI(title="test-multi-user") + app.state.config = SimpleNamespace( + auth=SimpleNamespace( + jwt_secret=jwt_secret, + registration_enabled=registration_enabled, + token_expiry_hours=token_expiry_hours, + ), + api=SimpleNamespace( + cors_origins=["http://localhost:3000"], + rate_limit=rate_limit, + rate_limit_window=rate_limit_window, + ), + ) + app.state.db_factory = factory + app.state.engine = engine + + # Add middleware (same order as production) + app.add_middleware( + RateLimitMiddleware, rate_limit=rate_limit, window=rate_limit_window + ) + app.add_middleware(APIKeyMiddleware) + + # Auth routes + app.include_router(auth_router) + + # RBAC-protected test endpoints + @app.get("/api/admin-only") + async def admin_endpoint( + user=Depends(require_admin), # noqa: B008 + ) -> dict[str, str]: + return {"role": user.role, "msg": "admin access granted"} + + @app.get("/api/contributor-only") + async def contributor_endpoint( + user=Depends(require_contributor), # noqa: B008 + ) -> dict[str, str]: + return {"role": user.role, "msg": "contributor access granted"} + + @app.get("/api/viewer-only") + async def viewer_endpoint( + user=Depends(require_viewer), # noqa: B008 + ) -> dict[str, str]: + return {"role": user.role, "msg": "viewer access granted"} + + @app.get("/api/test") + async def test_endpoint() -> dict[str, str]: + return {"msg": "ok"} + + return app + + +async def _create_user_in_db( + app: FastAPI, + *, + email: str, + password: str = "test-password-123", + display_name: str = "Test User", + role: str = "contributor", + is_active: bool = True, +) -> User: + """Insert a user directly into the DB and return the User object.""" + async with app.state.db_factory() as session: + user = User( + email=email, + password_hash=hash_password(password), + display_name=display_name, + role=role, + is_active=is_active, + ) + session.add(user) + await session.commit() + await session.refresh(user) + return user + + +async def _create_thread_for_user(app: FastAPI, user_id: str, question: str) -> Thread: + """Insert a thread directly into the DB, owned by user_id.""" + async with app.state.db_factory() as session: + thread = Thread(question=question, user_id=user_id) + session.add(thread) + await session.commit() + await session.refresh(thread) + return thread + + +def _get_token(user_id: str, secret: str = "test-secret-key-32chars-long!!!!") -> str: + """Create a JWT token for the given user ID.""" + return create_token(user_id, secret) + + +def _auth_headers(token: str) -> dict[str, str]: + """Return Authorization header dict for Bearer token.""" + return {"Authorization": f"Bearer {token}"} + + +# ── 1. User Isolation ───────────────────────────────────────── + + +class TestUserIsolation: + """User A's threads are not visible to User B when filtering by user_id.""" + + async def test_threads_have_user_id_fk(self) -> None: + """Threads created with a user_id are linked to that user in the DB.""" + app = await _make_multi_user_app() + user_a = await _create_user_in_db(app, email="alice@example.com") + user_b = await _create_user_in_db(app, email="bob@example.com") + + await _create_thread_for_user(app, user_a.id, "Alice's question") + await _create_thread_for_user(app, user_b.id, "Bob's question") + + async with app.state.db_factory() as session: + # Query threads filtered by user_a + stmt = select(Thread).where(Thread.user_id == user_a.id) + result = await session.execute(stmt) + alice_threads = list(result.scalars().all()) + + assert len(alice_threads) == 1 + assert alice_threads[0].question == "Alice's question" + assert alice_threads[0].user_id == user_a.id + + async def test_user_b_cannot_see_user_a_threads(self) -> None: + """Filtering threads by user_id isolates each user's data.""" + app = await _make_multi_user_app() + user_a = await _create_user_in_db(app, email="alice@example.com") + user_b = await _create_user_in_db(app, email="bob@example.com") + + await _create_thread_for_user(app, user_a.id, "Alice thread 1") + await _create_thread_for_user(app, user_a.id, "Alice thread 2") + await _create_thread_for_user(app, user_b.id, "Bob thread 1") + + async with app.state.db_factory() as session: + # Bob's filtered view + stmt = select(Thread).where(Thread.user_id == user_b.id) + result = await session.execute(stmt) + bob_threads = list(result.scalars().all()) + + assert len(bob_threads) == 1 + assert bob_threads[0].question == "Bob thread 1" + + # Alice's filtered view + stmt = select(Thread).where(Thread.user_id == user_a.id) + result = await session.execute(stmt) + alice_threads = list(result.scalars().all()) + + assert len(alice_threads) == 2 + + async def test_unowned_threads_have_null_user_id(self) -> None: + """Threads without a user_id (pre-v0.5 / anonymous) have null user_id.""" + app = await _make_multi_user_app() + + async with app.state.db_factory() as session: + thread = Thread(question="Anonymous question") + session.add(thread) + await session.commit() + await session.refresh(thread) + + assert thread.user_id is None + + +# ── 2. Admin Sees All ───────────────────────────────────────── + + +class TestAdminSeesAll: + """Admin user can see threads from all users.""" + + async def test_admin_can_query_all_threads(self) -> None: + """Admin's unfiltered query returns threads from all users.""" + app = await _make_multi_user_app() + user_a = await _create_user_in_db(app, email="alice@example.com") + user_b = await _create_user_in_db(app, email="bob@example.com") + admin = await _create_user_in_db(app, email="admin@example.com", role="admin") + + await _create_thread_for_user(app, user_a.id, "Alice thread") + await _create_thread_for_user(app, user_b.id, "Bob thread") + await _create_thread_for_user(app, admin.id, "Admin thread") + + async with app.state.db_factory() as session: + # Admin sees all (no user_id filter) + stmt = select(Thread) + result = await session.execute(stmt) + all_threads = list(result.scalars().all()) + assert len(all_threads) == 3 + + async def test_admin_can_see_specific_user_threads(self) -> None: + """Admin can filter to see a specific user's threads.""" + app = await _make_multi_user_app() + user_a = await _create_user_in_db(app, email="alice@example.com") + user_b = await _create_user_in_db(app, email="bob@example.com") + + await _create_thread_for_user(app, user_a.id, "Alice thread") + await _create_thread_for_user(app, user_b.id, "Bob thread") + + async with app.state.db_factory() as session: + # Admin can look at Bob's threads specifically + stmt = select(Thread).where(Thread.user_id == user_b.id) + result = await session.execute(stmt) + bob_threads = list(result.scalars().all()) + assert len(bob_threads) == 1 + assert bob_threads[0].question == "Bob thread" + + +# ── 3. Registration Flow ───────────────────────────────────── + + +class TestRegistrationFlow: + """Full registration -> login -> /me flow.""" + + async def test_register_login_me(self) -> None: + """Register a new user, login, and access /me with the token.""" + app = await _make_multi_user_app() + client = TestClient(app, raise_server_exceptions=False) + + # Step 1: Register + reg_resp = client.post( + "/api/auth/register", + json={ + "email": "newuser@example.com", + "password": "secure-password-123", + "display_name": "New User", + }, + ) + assert reg_resp.status_code == 200 + reg_data = reg_resp.json() + assert "access_token" in reg_data + assert reg_data["token_type"] == "bearer" + assert reg_data["role"] == "contributor" + + # Step 2: Login with same credentials + login_resp = client.post( + "/api/auth/login", + json={ + "email": "newuser@example.com", + "password": "secure-password-123", + }, + ) + assert login_resp.status_code == 200 + login_data = login_resp.json() + assert "access_token" in login_data + login_token = login_data["access_token"] + + # Step 3: Access /me with the login token + me_resp = client.get( + "/api/auth/me", + headers=_auth_headers(login_token), + ) + assert me_resp.status_code == 200 + me_data = me_resp.json() + assert me_data["email"] == "newuser@example.com" + assert me_data["display_name"] == "New User" + assert me_data["role"] == "contributor" + assert me_data["is_active"] is True + + async def test_register_returns_valid_jwt(self) -> None: + """Token from registration can be used immediately for /me.""" + app = await _make_multi_user_app() + client = TestClient(app, raise_server_exceptions=False) + + reg_resp = client.post( + "/api/auth/register", + json={ + "email": "immediate@example.com", + "password": "password123", + "display_name": "Immediate User", + }, + ) + assert reg_resp.status_code == 200 + token = reg_resp.json()["access_token"] + + # Use registration token directly (no login needed) + me_resp = client.get( + "/api/auth/me", + headers=_auth_headers(token), + ) + assert me_resp.status_code == 200 + assert me_resp.json()["email"] == "immediate@example.com" + + async def test_two_users_register_independently(self) -> None: + """Two different users can register and access their own /me.""" + app = await _make_multi_user_app() + client = TestClient(app, raise_server_exceptions=False) + + # Register user 1 + resp1 = client.post( + "/api/auth/register", + json={ + "email": "user1@example.com", + "password": "pass1", + "display_name": "User One", + }, + ) + assert resp1.status_code == 200 + token1 = resp1.json()["access_token"] + + # Register user 2 + resp2 = client.post( + "/api/auth/register", + json={ + "email": "user2@example.com", + "password": "pass2", + "display_name": "User Two", + }, + ) + assert resp2.status_code == 200 + token2 = resp2.json()["access_token"] + + # Each user sees their own info + me1 = client.get("/api/auth/me", headers=_auth_headers(token1)) + assert me1.json()["email"] == "user1@example.com" + + me2 = client.get("/api/auth/me", headers=_auth_headers(token2)) + assert me2.json()["email"] == "user2@example.com" + + +# ── 4. Role Enforcement ────────────────────────────────────── + + +class TestRoleEnforcement: + """Viewer cannot perform contributor actions; contributor cannot admin.""" + + async def test_viewer_cannot_access_contributor_endpoint(self) -> None: + """Viewer role is denied at contributor-level endpoints.""" + app = await _make_multi_user_app() + viewer = await _create_user_in_db( + app, email="viewer@example.com", role="viewer" + ) + token = _get_token(viewer.id) + client = TestClient(app, raise_server_exceptions=False) + + resp = client.get("/api/contributor-only", headers=_auth_headers(token)) + assert resp.status_code == 403 + assert "contributor" in resp.json()["detail"].lower() + + async def test_viewer_cannot_access_admin_endpoint(self) -> None: + """Viewer role is denied at admin-level endpoints.""" + app = await _make_multi_user_app() + viewer = await _create_user_in_db( + app, email="viewer@example.com", role="viewer" + ) + token = _get_token(viewer.id) + client = TestClient(app, raise_server_exceptions=False) + + resp = client.get("/api/admin-only", headers=_auth_headers(token)) + assert resp.status_code == 403 + assert "admin" in resp.json()["detail"].lower() + + async def test_contributor_cannot_access_admin_endpoint(self) -> None: + """Contributor role is denied at admin-level endpoints.""" + app = await _make_multi_user_app() + contrib = await _create_user_in_db( + app, email="contrib@example.com", role="contributor" + ) + token = _get_token(contrib.id) + client = TestClient(app, raise_server_exceptions=False) + + resp = client.get("/api/admin-only", headers=_auth_headers(token)) + assert resp.status_code == 403 + + async def test_contributor_can_access_contributor_endpoint(self) -> None: + """Contributor role passes contributor-level check.""" + app = await _make_multi_user_app() + contrib = await _create_user_in_db( + app, email="contrib@example.com", role="contributor" + ) + token = _get_token(contrib.id) + client = TestClient(app, raise_server_exceptions=False) + + resp = client.get("/api/contributor-only", headers=_auth_headers(token)) + assert resp.status_code == 200 + assert resp.json()["role"] == "contributor" + + async def test_admin_can_access_all_endpoints(self) -> None: + """Admin role passes all role checks.""" + app = await _make_multi_user_app() + admin = await _create_user_in_db(app, email="admin@example.com", role="admin") + token = _get_token(admin.id) + client = TestClient(app, raise_server_exceptions=False) + + # Admin can access admin-only + resp = client.get("/api/admin-only", headers=_auth_headers(token)) + assert resp.status_code == 200 + assert resp.json()["role"] == "admin" + + # Admin can access contributor-only + resp = client.get("/api/contributor-only", headers=_auth_headers(token)) + assert resp.status_code == 200 + + # Admin can access viewer-only + resp = client.get("/api/viewer-only", headers=_auth_headers(token)) + assert resp.status_code == 200 + + async def test_viewer_can_access_viewer_endpoint(self) -> None: + """Viewer role passes viewer-level check.""" + app = await _make_multi_user_app() + viewer = await _create_user_in_db( + app, email="viewer@example.com", role="viewer" + ) + token = _get_token(viewer.id) + client = TestClient(app, raise_server_exceptions=False) + + resp = client.get("/api/viewer-only", headers=_auth_headers(token)) + assert resp.status_code == 200 + assert resp.json()["role"] == "viewer" + + +# ── 5. Per-User Rate Limiting ───────────────────────────────── + + +class TestPerUserRateLimiting: + """User A hitting rate limit doesn't affect User B.""" + + async def test_user_a_rate_limit_does_not_affect_user_b(self) -> None: + """Each user has independent rate limit counters.""" + app = await _make_multi_user_app(rate_limit=3, rate_limit_window=60) + user_a = await _create_user_in_db(app, email="alice@example.com") + user_b = await _create_user_in_db(app, email="bob@example.com") + + token_a = _get_token(user_a.id) + token_b = _get_token(user_b.id) + client = TestClient(app, raise_server_exceptions=False) + + # User A exhausts their rate limit + for _ in range(3): + resp = client.get("/api/test", headers=_auth_headers(token_a)) + assert resp.status_code == 200 + + # User A is now rate limited + resp = client.get("/api/test", headers=_auth_headers(token_a)) + assert resp.status_code == 429 + + # User B is unaffected + resp = client.get("/api/test", headers=_auth_headers(token_b)) + assert resp.status_code == 200 + + async def test_rate_limit_headers_show_user_identity(self) -> None: + """Rate limit response headers identify the user.""" + app = await _make_multi_user_app(rate_limit=10, rate_limit_window=60) + user = await _create_user_in_db(app, email="alice@example.com") + + token = _get_token(user.id) + client = TestClient(app, raise_server_exceptions=False) + + resp = client.get("/api/test", headers=_auth_headers(token)) + assert resp.status_code == 200 + assert resp.headers["X-RateLimit-Limit"] == "10" + assert resp.headers["X-RateLimit-Remaining"] == "9" + assert resp.headers["X-RateLimit-Key"] == f"user:{user.id}" + + async def test_rate_limit_remaining_decrements_per_user(self) -> None: + """Remaining count decrements independently per user.""" + app = await _make_multi_user_app(rate_limit=5, rate_limit_window=60) + user_a = await _create_user_in_db(app, email="alice@example.com") + user_b = await _create_user_in_db(app, email="bob@example.com") + + token_a = _get_token(user_a.id) + token_b = _get_token(user_b.id) + client = TestClient(app, raise_server_exceptions=False) + + # User A makes 3 requests + for _ in range(3): + client.get("/api/test", headers=_auth_headers(token_a)) + + # User A should have 2 remaining + resp_a = client.get("/api/test", headers=_auth_headers(token_a)) + assert resp_a.headers["X-RateLimit-Remaining"] == "1" + + # User B should still have 4 remaining (first request) + resp_b = client.get("/api/test", headers=_auth_headers(token_b)) + assert resp_b.headers["X-RateLimit-Remaining"] == "4" + + +# ── 6. JWT Token Validation ────────────────────────────────── + + +class TestJWTTokenValidation: + """Expired, invalid, and missing tokens are rejected.""" + + async def test_expired_token_rejected(self) -> None: + """An expired JWT is rejected with 401.""" + app = await _make_multi_user_app() + user = await _create_user_in_db(app, email="expired@example.com") + client = TestClient(app, raise_server_exceptions=False) + + # Create a token that expired in the past + payload = { + "sub": user.id, + "exp": time.time() - 3600, # 1 hour ago + "iat": time.time() - 7200, + } + expired_token = jwt.encode( + payload, "test-secret-key-32chars-long!!!!", algorithm="HS256" + ) + + resp = client.get( + "/api/auth/me", + headers=_auth_headers(expired_token), + ) + assert resp.status_code == 401 + assert "expired" in resp.json()["detail"].lower() + + async def test_invalid_token_rejected(self) -> None: + """A garbled JWT string is rejected with 401.""" + app = await _make_multi_user_app() + client = TestClient(app, raise_server_exceptions=False) + + resp = client.get( + "/api/auth/me", + headers=_auth_headers("not-a-valid-jwt-token"), + ) + assert resp.status_code == 401 + assert "invalid" in resp.json()["detail"].lower() + + async def test_wrong_secret_token_rejected(self) -> None: + """A JWT signed with the wrong secret is rejected.""" + app = await _make_multi_user_app() + user = await _create_user_in_db(app, email="wrong@example.com") + client = TestClient(app, raise_server_exceptions=False) + + # Sign with a different secret + bad_token = create_token(user.id, "wrong-secret-key-32chars-long!!") + + resp = client.get( + "/api/auth/me", + headers=_auth_headers(bad_token), + ) + assert resp.status_code == 401 + + async def test_missing_token_rejected(self) -> None: + """Request without Authorization header is rejected with 401.""" + app = await _make_multi_user_app() + client = TestClient(app, raise_server_exceptions=False) + + resp = client.get("/api/auth/me") + assert resp.status_code == 401 + + async def test_malformed_auth_header_rejected(self) -> None: + """Authorization header without 'Bearer ' prefix is rejected.""" + app = await _make_multi_user_app() + client = TestClient(app, raise_server_exceptions=False) + + resp = client.get( + "/api/auth/me", + headers={"Authorization": "Token some-token-value"}, + ) + assert resp.status_code == 401 + + async def test_token_for_nonexistent_user_rejected(self) -> None: + """A valid JWT for a user_id that doesn't exist in the DB is rejected.""" + app = await _make_multi_user_app() + client = TestClient(app, raise_server_exceptions=False) + + # Create a token for a user that doesn't exist + token = _get_token("nonexistent-user-id-000") + + resp = client.get( + "/api/auth/me", + headers=_auth_headers(token), + ) + assert resp.status_code == 401 + assert "not found" in resp.json()["detail"].lower() + + +# ── 7. User Deactivation ───────────────────────────────────── + + +class TestUserDeactivation: + """Deactivated user's JWT is rejected even if the token itself is valid.""" + + async def test_deactivated_user_rejected_at_me(self) -> None: + """A deactivated user cannot access /me even with a valid JWT.""" + app = await _make_multi_user_app() + user = await _create_user_in_db( + app, email="deactivated@example.com", is_active=True + ) + token = _get_token(user.id) + client = TestClient(app, raise_server_exceptions=False) + + # Verify user can access /me while active + resp = client.get("/api/auth/me", headers=_auth_headers(token)) + assert resp.status_code == 200 + + # Deactivate the user in the DB + async with app.state.db_factory() as session: + stmt = select(User).where(User.id == user.id) + result = await session.execute(stmt) + db_user = result.scalar_one() + db_user.is_active = False + await session.commit() + + # Same token should now be rejected + resp = client.get("/api/auth/me", headers=_auth_headers(token)) + assert resp.status_code == 401 + assert "not found or inactive" in resp.json()["detail"].lower() + + async def test_deactivated_user_cannot_login(self) -> None: + """A deactivated user cannot login even with correct credentials.""" + app = await _make_multi_user_app() + await _create_user_in_db( + app, email="disabled@example.com", password="mypass", is_active=False + ) + client = TestClient(app, raise_server_exceptions=False) + + resp = client.post( + "/api/auth/login", + json={"email": "disabled@example.com", "password": "mypass"}, + ) + assert resp.status_code == 403 + assert "disabled" in resp.json()["detail"].lower() + + async def test_deactivated_user_rejected_at_rbac_endpoint(self) -> None: + """A deactivated user cannot access RBAC-protected endpoints.""" + app = await _make_multi_user_app() + user = await _create_user_in_db( + app, email="deact-rbac@example.com", role="admin", is_active=True + ) + token = _get_token(user.id) + client = TestClient(app, raise_server_exceptions=False) + + # Verify access works while active + resp = client.get("/api/admin-only", headers=_auth_headers(token)) + assert resp.status_code == 200 + + # Deactivate + async with app.state.db_factory() as session: + stmt = select(User).where(User.id == user.id) + result = await session.execute(stmt) + db_user = result.scalar_one() + db_user.is_active = False + await session.commit() + + # Now admin endpoint should reject + resp = client.get("/api/admin-only", headers=_auth_headers(token)) + assert resp.status_code == 401 diff --git a/tests/unit/test_postgresql_config.py b/tests/unit/test_postgresql_config.py new file mode 100644 index 0000000..1539b52 --- /dev/null +++ b/tests/unit/test_postgresql_config.py @@ -0,0 +1,156 @@ +"""Tests for PostgreSQL configuration and async driver support.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from duh.config.schema import DatabaseConfig, DuhConfig + +# ─── DatabaseConfig Defaults ───────────────────────────────── + + +class TestDatabaseConfigDefaults: + def test_database_config_defaults(self): + cfg = DatabaseConfig() + assert cfg.url == "sqlite+aiosqlite:///~/.local/share/duh/duh.db" + assert cfg.pool_size == 5 + assert cfg.max_overflow == 10 + assert cfg.pool_timeout == 30 + assert cfg.pool_recycle == 3600 + + def test_database_config_postgresql_url(self): + cfg = DatabaseConfig( + url="postgresql+asyncpg://user:pass@localhost/duh" + ) + assert cfg.url == "postgresql+asyncpg://user:pass@localhost/duh" + assert cfg.pool_size == 5 + assert cfg.max_overflow == 10 + + def test_database_config_custom_pool(self): + cfg = DatabaseConfig( + url="postgresql+asyncpg://localhost/duh", + pool_size=20, + max_overflow=40, + pool_timeout=60, + pool_recycle=1800, + ) + assert cfg.pool_size == 20 + assert cfg.max_overflow == 40 + assert cfg.pool_timeout == 60 + assert cfg.pool_recycle == 1800 + + +def _mock_engine(): + """Create a mock async engine with proper async context managers.""" + engine = MagicMock() + conn = AsyncMock() + conn.run_sync = AsyncMock() + engine.begin.return_value.__aenter__ = AsyncMock(return_value=conn) + engine.begin.return_value.__aexit__ = AsyncMock(return_value=False) + return engine + + +# ─── _create_db Pool Behavior ──────────────────────────────── + + +class TestCreateDbPoolBehavior: + @pytest.mark.asyncio + async def test_create_db_sqlite_uses_null_pool(self, tmp_path): + """Verify NullPool is used for sqlite URLs.""" + from duh.cli.app import _create_db + + config = DuhConfig( + database=DatabaseConfig( + url=f"sqlite+aiosqlite:///{tmp_path}/test.db" + ) + ) + + mock_engine = _mock_engine() + with patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create, patch( + "sqlalchemy.event.listens_for", + return_value=lambda fn: fn, + ): + await _create_db(config) + + mock_create.assert_called_once() + call_kwargs = mock_create.call_args.kwargs + from sqlalchemy.pool import NullPool + + assert call_kwargs.get("poolclass") is NullPool + # Should NOT have pool_size for sqlite + assert "pool_size" not in call_kwargs + + @pytest.mark.asyncio + async def test_create_db_postgresql_uses_queue_pool(self): + """Verify pool settings are applied for postgresql URLs.""" + from duh.cli.app import _create_db + + config = DuhConfig( + database=DatabaseConfig( + url="postgresql+asyncpg://user:pass@localhost/duh", + pool_size=15, + max_overflow=25, + pool_timeout=45, + pool_recycle=7200, + ) + ) + + mock_engine = _mock_engine() + with patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create: + await _create_db(config) + + mock_create.assert_called_once() + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["pool_size"] == 15 + assert call_kwargs["max_overflow"] == 25 + assert call_kwargs["pool_timeout"] == 45 + assert call_kwargs["pool_recycle"] == 7200 + # Should NOT have poolclass for postgresql + assert "poolclass" not in call_kwargs + + +# ─── Alembic Async Driver Detection ───────────────────────── + + +class TestAlembicAsyncDrivers: + def test_alembic_env_detects_async_drivers(self): + """Verify asyncpg is in the async drivers list in alembic/env.py.""" + import ast + from pathlib import Path + + env_path = Path(__file__).resolve().parents[2] / "alembic" / "env.py" + source = env_path.read_text() + tree = ast.parse(source) + + # Find _ASYNC_DRIVERS assignment + async_drivers = None + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "_ASYNC_DRIVERS": + async_drivers = ast.literal_eval(node.value) + + assert async_drivers is not None, "_ASYNC_DRIVERS not found in alembic/env.py" + assert "asyncpg" in async_drivers + assert "aiosqlite" in async_drivers + + def test_is_async_url_logic(self): + """Verify the _is_async_url logic works for asyncpg URLs.""" + # Replicate the logic from alembic/env.py to test it directly + async_drivers = {"aiosqlite", "asyncpg", "aiomysql"} + + def _is_async_url(url: str) -> bool: + return any(f"+{d}" in url for d in async_drivers) + + assert _is_async_url("postgresql+asyncpg://localhost/duh") is True + assert _is_async_url("sqlite+aiosqlite:///test.db") is True + assert _is_async_url("postgresql://localhost/duh") is False + assert _is_async_url("mysql+aiomysql://localhost/duh") is True diff --git a/tests/unit/test_providers_perplexity.py b/tests/unit/test_providers_perplexity.py new file mode 100644 index 0000000..77b336a --- /dev/null +++ b/tests/unit/test_providers_perplexity.py @@ -0,0 +1,421 @@ +"""Tests for Perplexity provider adapter (mocked SDK).""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import openai +import pytest + +from duh.core.errors import ( + ModelNotFoundError, + ProviderAuthError, + ProviderOverloadedError, + ProviderRateLimitError, + ProviderTimeoutError, +) +from duh.providers.base import ( + ModelInfo, + ModelProvider, + ModelResponse, + PromptMessage, + TokenUsage, +) +from duh.providers.perplexity import ( + PROVIDER_ID, + PerplexityProvider, + _map_error, +) + +# ─── Helpers ────────────────────────────────────────────────── + + +def _make_usage(prompt_tokens: int = 100, completion_tokens: int = 50) -> MagicMock: + usage = MagicMock() + usage.prompt_tokens = prompt_tokens + usage.completion_tokens = completion_tokens + return usage + + +def _make_response( + text: str = "Hello world", + finish_reason: str = "stop", + prompt_tokens: int = 100, + completion_tokens: int = 50, + citations: list[str] | None = None, +) -> MagicMock: + choice = MagicMock() + choice.message.content = text + choice.message.tool_calls = None + choice.finish_reason = finish_reason + + response = MagicMock() + response.choices = [choice] + response.usage = _make_usage(prompt_tokens, completion_tokens) + # Perplexity-specific: citations field + if citations is not None: + response.citations = citations + else: + # Simulate absence of citations attribute + del response.citations + return response + + +def _make_client(response: Any = None) -> MagicMock: + """Create a mocked AsyncOpenAI client.""" + client = MagicMock(spec=openai.AsyncOpenAI) + client.chat = MagicMock() + client.chat.completions = MagicMock() + client.chat.completions.create = AsyncMock( + return_value=response or _make_response(), + ) + return client + + +class _AsyncChunkIter: + """Async iterator over mock stream chunks.""" + + def __init__(self, chunks: list[Any]) -> None: + self._chunks = chunks + self._idx = 0 + + def __aiter__(self) -> _AsyncChunkIter: + return self + + async def __anext__(self) -> Any: + if self._idx >= len(self._chunks): + raise StopAsyncIteration + chunk = self._chunks[self._idx] + self._idx += 1 + return chunk + + +def _make_stream_chunk( + content: str | None = None, + finish_reason: str | None = None, + usage: MagicMock | None = None, +) -> MagicMock: + """Create a mock ChatCompletionChunk.""" + chunk = MagicMock() + if content is not None: + choice = MagicMock() + choice.delta.content = content + choice.finish_reason = finish_reason + chunk.choices = [choice] + else: + chunk.choices = [] + chunk.usage = usage + return chunk + + +# ─── Protocol ───────────────────────────────────────────────── + + +class TestProtocol: + def test_provider_id(self): + provider = PerplexityProvider(client=_make_client()) + assert provider.provider_id == "perplexity" + + def test_satisfies_protocol(self): + provider = PerplexityProvider(client=_make_client()) + assert isinstance(provider, ModelProvider) + + +# ─── list_models ────────────────────────────────────────────── + + +class TestListModels: + async def test_returns_three_models(self): + provider = PerplexityProvider(client=_make_client()) + models = await provider.list_models() + assert len(models) == 3 + assert all(isinstance(m, ModelInfo) for m in models) + + async def test_all_models_are_perplexity(self): + provider = PerplexityProvider(client=_make_client()) + models = await provider.list_models() + assert all(m.provider_id == PROVIDER_ID for m in models) + + async def test_expected_model_ids(self): + provider = PerplexityProvider(client=_make_client()) + models = await provider.list_models() + ids = {m.model_id for m in models} + assert ids == {"sonar", "sonar-pro", "sonar-deep-research"} + + async def test_models_have_costs(self): + provider = PerplexityProvider(client=_make_client()) + models = await provider.list_models() + for m in models: + assert m.input_cost_per_mtok > 0 + assert m.output_cost_per_mtok > 0 + + +# ─── send ───────────────────────────────────────────────────── + + +class TestSend: + async def test_returns_model_response(self): + client = _make_client() + provider = PerplexityProvider(client=client) + msgs = [PromptMessage(role="user", content="test")] + resp = await provider.send(msgs, "sonar") + assert isinstance(resp, ModelResponse) + + async def test_content_extracted(self): + client = _make_client(_make_response(text="The answer is 42")) + provider = PerplexityProvider(client=client) + msgs = [PromptMessage(role="user", content="test")] + resp = await provider.send(msgs, "sonar") + assert resp.content == "The answer is 42" + + async def test_usage_extracted(self): + client = _make_client( + _make_response(prompt_tokens=200, completion_tokens=80), + ) + provider = PerplexityProvider(client=client) + msgs = [PromptMessage(role="user", content="test")] + resp = await provider.send(msgs, "sonar") + assert isinstance(resp.usage, TokenUsage) + assert resp.usage.input_tokens == 200 + assert resp.usage.output_tokens == 80 + + async def test_passes_correct_base_url_model_messages(self): + client = _make_client() + provider = PerplexityProvider(client=client) + msgs = [ + PromptMessage(role="system", content="Be concise"), + PromptMessage(role="user", content="test"), + ] + await provider.send(msgs, "sonar-pro", max_tokens=1000, temperature=0.5) + call_kwargs = client.chat.completions.create.call_args.kwargs + assert call_kwargs["model"] == "sonar-pro" + assert call_kwargs["max_completion_tokens"] == 1000 + assert call_kwargs["temperature"] == 0.5 + assert call_kwargs["messages"][0]["role"] == "system" + + async def test_send_with_citations(self): + """Perplexity responses may include citations; verify they're captured.""" + citations = [ + "https://example.com/source1", + "https://example.com/source2", + ] + mock_resp = _make_response( + text="According to sources, the answer is 42.", + citations=citations, + ) + client = _make_client(mock_resp) + provider = PerplexityProvider(client=client) + msgs = [PromptMessage(role="user", content="test")] + resp = await provider.send(msgs, "sonar") + # Citations should be captured in raw_response + assert isinstance(resp.raw_response, dict) + assert resp.raw_response["citations"] == citations + assert resp.raw_response["response"] is mock_resp + + async def test_send_without_citations(self): + """When no citations are present, raw_response is the raw response object.""" + mock_resp = _make_response(text="Hello") + client = _make_client(mock_resp) + provider = PerplexityProvider(client=client) + msgs = [PromptMessage(role="user", content="test")] + resp = await provider.send(msgs, "sonar") + assert resp.raw_response is mock_resp + + async def test_latency_tracked(self): + client = _make_client() + provider = PerplexityProvider(client=client) + msgs = [PromptMessage(role="user", content="test")] + resp = await provider.send(msgs, "sonar") + assert resp.latency_ms >= 0 + + +# ─── stream ─────────────────────────────────────────────────── + + +class TestStream: + async def test_yields_content_chunks(self): + chunks = [ + _make_stream_chunk(content="Hello"), + _make_stream_chunk(content=" world"), + _make_stream_chunk(usage=_make_usage(100, 50)), + ] + client = _make_client() + client.chat.completions.create = AsyncMock( + return_value=_AsyncChunkIter(chunks), + ) + provider = PerplexityProvider(client=client) + msgs = [PromptMessage(role="user", content="test")] + + result = [] + async for chunk in provider.stream(msgs, "sonar"): + result.append(chunk) + + assert len(result) == 3 + assert result[0].text == "Hello" + assert result[1].text == " world" + assert result[2].is_final + assert result[2].text == "" + + async def test_final_chunk_has_usage(self): + chunks = [ + _make_stream_chunk(content="Hi"), + _make_stream_chunk(usage=_make_usage(150, 75)), + ] + client = _make_client() + client.chat.completions.create = AsyncMock( + return_value=_AsyncChunkIter(chunks), + ) + provider = PerplexityProvider(client=client) + msgs = [PromptMessage(role="user", content="test")] + + result = [] + async for chunk in provider.stream(msgs, "sonar"): + result.append(chunk) + + final = result[-1] + assert final.is_final + assert final.usage is not None + assert final.usage.input_tokens == 150 + assert final.usage.output_tokens == 75 + + async def test_error_during_stream(self): + client = _make_client() + client.chat.completions.create = AsyncMock( + side_effect=openai.AuthenticationError( + message="bad key", + response=MagicMock(status_code=401, headers={}), + body=None, + ), + ) + provider = PerplexityProvider(client=client) + msgs = [PromptMessage(role="user", content="test")] + with pytest.raises(ProviderAuthError): + async for _ in provider.stream(msgs, "sonar"): + pass + + async def test_passes_stream_options(self): + chunks = [_make_stream_chunk(usage=_make_usage(10, 5))] + client = _make_client() + client.chat.completions.create = AsyncMock( + return_value=_AsyncChunkIter(chunks), + ) + provider = PerplexityProvider(client=client) + msgs = [PromptMessage(role="user", content="test")] + + async for _ in provider.stream(msgs, "sonar"): + pass + + call_kwargs = client.chat.completions.create.call_args.kwargs + assert call_kwargs["stream"] is True + assert call_kwargs["stream_options"] == {"include_usage": True} + + +# ─── health_check ───────────────────────────────────────────── + + +class TestHealthCheck: + async def test_health_check_success(self): + client = _make_client() + provider = PerplexityProvider(client=client) + assert await provider.health_check() is True + # Verify it uses the "sonar" model (cheapest) + call_kwargs = client.chat.completions.create.call_args.kwargs + assert call_kwargs["model"] == "sonar" + + async def test_health_check_failure(self): + client = _make_client() + client.chat.completions.create.side_effect = Exception("connection failed") + provider = PerplexityProvider(client=client) + assert await provider.health_check() is False + + +# ─── Error Mapping ──────────────────────────────────────────── + + +class TestErrorMapping: + def _make_api_error(self, cls: type, status_code: int = 400) -> openai.APIError: + response = MagicMock() + response.status_code = status_code + response.headers = {} + return cls( + message="test error", + response=response, + body=None, + ) + + def test_auth_error(self): + err = self._make_api_error(openai.AuthenticationError, 401) + mapped = _map_error(err) + assert isinstance(mapped, ProviderAuthError) + + def test_rate_limit_error(self): + err = self._make_api_error(openai.RateLimitError, 429) + mapped = _map_error(err) + assert isinstance(mapped, ProviderRateLimitError) + + def test_rate_limit_with_retry_after(self): + err = self._make_api_error(openai.RateLimitError, 429) + err.response.headers = {"retry-after": "30"} + mapped = _map_error(err) + assert isinstance(mapped, ProviderRateLimitError) + assert mapped.retry_after == 30.0 + + def test_timeout_error(self): + err = openai.APITimeoutError(request=MagicMock()) + mapped = _map_error(err) + assert isinstance(mapped, ProviderTimeoutError) + + def test_internal_server_error(self): + err = self._make_api_error(openai.InternalServerError, 500) + mapped = _map_error(err) + assert isinstance(mapped, ProviderOverloadedError) + + def test_not_found_error(self): + err = self._make_api_error(openai.NotFoundError, 404) + mapped = _map_error(err) + assert isinstance(mapped, ModelNotFoundError) + + def test_unknown_api_error_maps_to_overloaded(self): + err = self._make_api_error(openai.UnprocessableEntityError, 422) + mapped = _map_error(err) + assert isinstance(mapped, ProviderOverloadedError) + + async def test_send_raises_mapped_error(self): + client = _make_client() + client.chat.completions.create.side_effect = openai.AuthenticationError( + message="bad key", + response=MagicMock(status_code=401, headers={}), + body=None, + ) + provider = PerplexityProvider(client=client) + msgs = [PromptMessage(role="user", content="test")] + with pytest.raises(ProviderAuthError): + await provider.send(msgs, "sonar") + + +# ─── _resolve_model_info ───────────────────────────────────── + + +class TestResolveModelInfo: + async def test_known_model_returns_correct_info(self): + client = _make_client() + provider = PerplexityProvider(client=client) + msgs = [PromptMessage(role="user", content="test")] + resp = await provider.send(msgs, "sonar-pro") + assert resp.model_info.model_id == "sonar-pro" + assert resp.model_info.provider_id == PROVIDER_ID + assert resp.model_info.display_name == "Sonar Pro" + assert resp.model_info.context_window == 200_000 + assert resp.model_info.input_cost_per_mtok == 3.0 + assert resp.model_info.output_cost_per_mtok == 15.0 + + async def test_unknown_model_returns_generic_info(self): + client = _make_client() + provider = PerplexityProvider(client=client) + msgs = [PromptMessage(role="user", content="test")] + resp = await provider.send(msgs, "sonar-future-99") + assert resp.model_info.model_id == "sonar-future-99" + assert resp.model_info.display_name == "Perplexity (sonar-future-99)" + assert resp.model_info.input_cost_per_mtok == 0.0 + assert resp.model_info.output_cost_per_mtok == 0.0 diff --git a/tests/unit/test_rate_limiting.py b/tests/unit/test_rate_limiting.py new file mode 100644 index 0000000..a6eeb70 --- /dev/null +++ b/tests/unit/test_rate_limiting.py @@ -0,0 +1,415 @@ +"""Tests for per-user + per-provider rate limiting (T6). + +Verifies: +- Rate limiting by user_id (JWT auth) +- Rate limiting by api_key_id +- Rate limiting by IP fallback +- ProviderConfig accepts rate_limit field +- ProviderManager respects provider-level rate limits +- Response includes rate limit headers with identity info +""" + +from __future__ import annotations + +import hashlib +import time +from types import SimpleNamespace + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from sqlalchemy import event +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +from duh.api.middleware import APIKeyMiddleware, RateLimitMiddleware +from duh.config.schema import ProviderConfig +from duh.memory.models import Base +from duh.memory.repository import MemoryRepository + + +def _hash(key: str) -> str: + return hashlib.sha256(key.encode()).hexdigest() + + +# ── Test App Helpers ──────────────────────────────────────── + + +async def _make_app( + *, + rate_limit: int = 100, + window: int = 60, + jwt_secret: str = "test-secret", +) -> FastAPI: + """Create a minimal FastAPI app with middleware and in-memory DB.""" + engine = create_async_engine("sqlite+aiosqlite://") + + @event.listens_for(engine.sync_engine, "connect") + def _enable_fks(dbapi_conn, connection_record): # type: ignore[no-untyped-def] + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + app = FastAPI(title="test") + app.state.config = SimpleNamespace( + api=SimpleNamespace( + cors_origins=["http://localhost:3000"], + rate_limit=rate_limit, + rate_limit_window=window, + ), + auth=SimpleNamespace(jwt_secret=jwt_secret), + ) + app.state.db_factory = factory + app.state.engine = engine + + app.add_middleware(RateLimitMiddleware, rate_limit=rate_limit, window=window) + app.add_middleware(APIKeyMiddleware) + + @app.get("/api/health") + async def health() -> dict[str, str]: + return {"status": "ok"} + + @app.get("/api/test") + async def test_endpoint() -> dict[str, str]: + return {"msg": "ok"} + + return app + + +async def _seed_key(app: FastAPI, name: str, raw_key: str) -> str: + """Insert an API key into the test DB and return its ID.""" + async with app.state.db_factory() as session: + repo = MemoryRepository(session) + api_key = await repo.create_api_key(name, _hash(raw_key)) + await session.commit() + return api_key.id + + +# ── Rate Limit by User ID ─────────────────────────────────── + + +class TestRateLimitByUserId: + @pytest.mark.asyncio + async def test_rate_limit_by_user_id(self) -> None: + """Requests with JWT token should be rate-limited by user_id.""" + from duh.api.auth import create_token + + app = await _make_app(rate_limit=3, window=60, jwt_secret="test-secret") + client = TestClient(app, raise_server_exceptions=False) + + token = create_token("user-123", "test-secret") + headers = {"Authorization": f"Bearer {token}"} + + # First 3 requests should succeed + for _ in range(3): + resp = client.get("/api/test", headers=headers) + assert resp.status_code == 200 + + # 4th request should be rate limited + resp = client.get("/api/test", headers=headers) + assert resp.status_code == 429 + + @pytest.mark.asyncio + async def test_different_users_have_separate_limits(self) -> None: + """Different users should have independent rate limits.""" + from duh.api.auth import create_token + + app = await _make_app(rate_limit=2, window=60, jwt_secret="test-secret") + client = TestClient(app, raise_server_exceptions=False) + + token_a = create_token("user-A", "test-secret") + token_b = create_token("user-B", "test-secret") + + # User A uses 2 requests + headers_a = {"Authorization": f"Bearer {token_a}"} + for _ in range(2): + resp = client.get("/api/test", headers=headers_a) + assert resp.status_code == 200 + + # User A is now limited + resp = client.get("/api/test", headers=headers_a) + assert resp.status_code == 429 + + # User B should still be fine + headers_b = {"Authorization": f"Bearer {token_b}"} + resp = client.get("/api/test", headers=headers_b) + assert resp.status_code == 200 + + +# ── Rate Limit by API Key ─────────────────────────────────── + + +class TestRateLimitByApiKey: + @pytest.mark.asyncio + async def test_rate_limit_by_api_key(self) -> None: + """Requests with API key should be rate-limited by api_key_id.""" + app = await _make_app(rate_limit=3, window=60) + await _seed_key(app, "test-key", "my-api-key") + client = TestClient(app, raise_server_exceptions=False) + + headers = {"X-API-Key": "my-api-key"} + + # First 3 requests should succeed + for _ in range(3): + resp = client.get("/api/test", headers=headers) + assert resp.status_code == 200 + + # 4th request should be rate limited + resp = client.get("/api/test", headers=headers) + assert resp.status_code == 429 + + @pytest.mark.asyncio + async def test_different_api_keys_have_separate_limits(self) -> None: + """Different API keys should have independent rate limits.""" + app = await _make_app(rate_limit=2, window=60) + await _seed_key(app, "key-1", "api-key-1") + await _seed_key(app, "key-2", "api-key-2") + client = TestClient(app, raise_server_exceptions=False) + + # Key 1 uses 2 requests + for _ in range(2): + resp = client.get("/api/test", headers={"X-API-Key": "api-key-1"}) + assert resp.status_code == 200 + + # Key 1 is now limited + resp = client.get("/api/test", headers={"X-API-Key": "api-key-1"}) + assert resp.status_code == 429 + + # Key 2 should still be fine + resp = client.get("/api/test", headers={"X-API-Key": "api-key-2"}) + assert resp.status_code == 200 + + +# ── Rate Limit by IP ──────────────────────────────────────── + + +class TestRateLimitByIp: + @pytest.mark.asyncio + async def test_rate_limit_by_ip(self) -> None: + """Unauthenticated requests should fall back to IP rate limiting.""" + app = await _make_app(rate_limit=3, window=60) + client = TestClient(app, raise_server_exceptions=False) + + # No keys in DB, so unauthenticated access allowed + for _ in range(3): + resp = client.get("/api/test") + assert resp.status_code == 200 + + # 4th request should be rate limited + resp = client.get("/api/test") + assert resp.status_code == 429 + + +# ── Provider Rate Limit Config ─────────────────────────────── + + +class TestProviderRateLimitConfig: + def test_provider_config_accepts_rate_limit(self) -> None: + """ProviderConfig should accept a rate_limit field.""" + config = ProviderConfig(rate_limit=100) + assert config.rate_limit == 100 + + def test_provider_config_default_rate_limit_zero(self) -> None: + """Default rate_limit should be 0 (unlimited).""" + config = ProviderConfig() + assert config.rate_limit == 0 + + def test_provider_config_rate_limit_in_dict(self) -> None: + """rate_limit should appear in model dump.""" + config = ProviderConfig(rate_limit=50) + data = config.model_dump() + assert data["rate_limit"] == 50 + + +# ── Provider Rate Limit Enforcement ────────────────────────── + + +class TestProviderRateLimitEnforcement: + def test_provider_manager_respects_rate_limits(self) -> None: + """ProviderManager should enforce per-provider rate limits.""" + from duh.providers.manager import ProviderManager, ProviderQuotaExceededError + + pm = ProviderManager() + pm.set_provider_rate_limit("openai", 3) + + # First 3 checks should pass + for _ in range(3): + pm.check_provider_rate_limit("openai") + + # 4th check should raise + with pytest.raises(ProviderQuotaExceededError) as exc_info: + pm.check_provider_rate_limit("openai") + assert exc_info.value.rate_limit == 3 + assert exc_info.value.provider_id == "openai" + + def test_unlimited_provider_never_limited(self) -> None: + """Provider with rate_limit=0 should never be limited.""" + from duh.providers.manager import ProviderManager + + pm = ProviderManager() + pm.set_provider_rate_limit("anthropic", 0) + + # Should never raise + for _ in range(1000): + pm.check_provider_rate_limit("anthropic") + + def test_unconfigured_provider_never_limited(self) -> None: + """Provider without a configured rate limit should never be limited.""" + from duh.providers.manager import ProviderManager + + pm = ProviderManager() + + # Should never raise + for _ in range(100): + pm.check_provider_rate_limit("any-provider") + + def test_different_providers_have_separate_limits(self) -> None: + """Different providers should have independent rate limits.""" + from duh.providers.manager import ProviderManager, ProviderQuotaExceededError + + pm = ProviderManager() + pm.set_provider_rate_limit("openai", 2) + pm.set_provider_rate_limit("anthropic", 2) + + # Exhaust openai + pm.check_provider_rate_limit("openai") + pm.check_provider_rate_limit("openai") + with pytest.raises(ProviderQuotaExceededError): + pm.check_provider_rate_limit("openai") + + # anthropic should still be fine + pm.check_provider_rate_limit("anthropic") + pm.check_provider_rate_limit("anthropic") + + def test_rate_limit_resets_after_window(self) -> None: + """Provider rate limit should reset after the 60-second window.""" + from duh.providers.manager import ProviderManager + + pm = ProviderManager() + pm.set_provider_rate_limit("openai", 2) + + # Exhaust the limit + pm.check_provider_rate_limit("openai") + pm.check_provider_rate_limit("openai") + + # Simulate time passing by manipulating the timestamps + pm._provider_requests["openai"] = [ + time.monotonic() - 61.0, + time.monotonic() - 61.0, + ] + + # Should work again + pm.check_provider_rate_limit("openai") + + def test_get_provider_rate_limit_remaining(self) -> None: + """get_provider_rate_limit_remaining returns correct count.""" + from duh.providers.manager import ProviderManager + + pm = ProviderManager() + pm.set_provider_rate_limit("openai", 5) + + assert pm.get_provider_rate_limit_remaining("openai") == 5 + + pm.check_provider_rate_limit("openai") + assert pm.get_provider_rate_limit_remaining("openai") == 4 + + pm.check_provider_rate_limit("openai") + assert pm.get_provider_rate_limit_remaining("openai") == 3 + + def test_get_provider_rate_limit_remaining_no_limit(self) -> None: + """get_provider_rate_limit_remaining returns None when no limit set.""" + from duh.providers.manager import ProviderManager + + pm = ProviderManager() + assert pm.get_provider_rate_limit_remaining("openai") is None + + @pytest.mark.asyncio + async def test_get_provider_checks_rate_limit(self) -> None: + """get_provider should check rate limit before returning provider.""" + from duh.providers.manager import ProviderManager, ProviderQuotaExceededError + + pm = ProviderManager() + + # Register a mock provider + from tests.fixtures.providers import MockProvider + from tests.fixtures.responses import MINIMAL + + provider = MockProvider(provider_id="mock-minimal", responses=MINIMAL) + await pm.register(provider) # type: ignore[arg-type] + + # Get models so we can use a valid model_ref + models = pm.list_all_models() + assert len(models) > 0 + model_ref = models[0].model_ref + + # Set rate limit + pm.set_provider_rate_limit("mock-minimal", 2) + + # First 2 calls should work + pm.get_provider(model_ref) + pm.get_provider(model_ref) + + # 3rd should raise + with pytest.raises(ProviderQuotaExceededError): + pm.get_provider(model_ref) + + +# ── Rate Limit Headers ─────────────────────────────────────── + + +class TestRateLimitHeaders: + @pytest.mark.asyncio + async def test_rate_limit_headers_with_user_id(self) -> None: + """Response should include X-RateLimit-Key with user info when JWT auth used.""" + from duh.api.auth import create_token + + app = await _make_app(rate_limit=10, window=60, jwt_secret="test-secret") + client = TestClient(app, raise_server_exceptions=False) + + token = create_token("user-42", "test-secret") + resp = client.get("/api/test", headers={"Authorization": f"Bearer {token}"}) + assert resp.status_code == 200 + assert resp.headers["X-RateLimit-Limit"] == "10" + assert resp.headers["X-RateLimit-Remaining"] == "9" + assert resp.headers["X-RateLimit-Key"] == "user:user-42" + + @pytest.mark.asyncio + async def test_rate_limit_headers_with_api_key(self) -> None: + """Response should include X-RateLimit-Key with api_key info.""" + app = await _make_app(rate_limit=10, window=60) + key_id = await _seed_key(app, "test-key", "my-api-key") + client = TestClient(app, raise_server_exceptions=False) + + resp = client.get("/api/test", headers={"X-API-Key": "my-api-key"}) + assert resp.status_code == 200 + assert resp.headers["X-RateLimit-Limit"] == "10" + assert resp.headers["X-RateLimit-Remaining"] == "9" + assert resp.headers["X-RateLimit-Key"] == f"api_key:{key_id}" + + @pytest.mark.asyncio + async def test_rate_limit_headers_with_ip_fallback(self) -> None: + """Response should include X-RateLimit-Key with IP when no auth.""" + app = await _make_app(rate_limit=10, window=60) + client = TestClient(app, raise_server_exceptions=False) + + resp = client.get("/api/test") + assert resp.status_code == 200 + assert resp.headers["X-RateLimit-Limit"] == "10" + assert resp.headers["X-RateLimit-Remaining"] == "9" + # IP-based key + assert resp.headers["X-RateLimit-Key"].startswith("ip:") + + @pytest.mark.asyncio + async def test_rate_limit_headers_remaining_decrements(self) -> None: + """X-RateLimit-Remaining should decrement with each request.""" + app = await _make_app(rate_limit=5, window=60) + client = TestClient(app, raise_server_exceptions=False) + + for expected_remaining in range(4, -1, -1): + resp = client.get("/api/test") + assert resp.status_code == 200 + assert resp.headers["X-RateLimit-Remaining"] == str(expected_remaining) diff --git a/tests/unit/test_rbac.py b/tests/unit/test_rbac.py new file mode 100644 index 0000000..d645ba1 --- /dev/null +++ b/tests/unit/test_rbac.py @@ -0,0 +1,221 @@ +"""Tests for role-based access control (RBAC).""" + +from __future__ import annotations + +from types import SimpleNamespace + +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient + +from duh.api.rbac import ( + ROLE_HIERARCHY, + require_admin, + require_contributor, + require_role, + require_viewer, +) + +# ── ROLE_HIERARCHY ───────────────────────────────────────────── + + +class TestRoleHierarchy: + def test_admin_is_highest(self) -> None: + assert ROLE_HIERARCHY["admin"] > ROLE_HIERARCHY["contributor"] + assert ROLE_HIERARCHY["admin"] > ROLE_HIERARCHY["viewer"] + + def test_contributor_above_viewer(self) -> None: + assert ROLE_HIERARCHY["contributor"] > ROLE_HIERARCHY["viewer"] + + def test_all_roles_present(self) -> None: + assert set(ROLE_HIERARCHY.keys()) == {"admin", "contributor", "viewer"} + + def test_hierarchy_values_ascending(self) -> None: + assert ( + ROLE_HIERARCHY["viewer"] + < ROLE_HIERARCHY["contributor"] + < ROLE_HIERARCHY["admin"] + ) + + +# ── require_role ─────────────────────────────────────────────── + + +def _make_user(role: str = "contributor") -> SimpleNamespace: + """Create a fake user object with the given role.""" + return SimpleNamespace( + id="user-1", + email="test@example.com", + display_name="Test", + role=role, + is_active=True, + ) + + +def _build_app(minimum_role: str, user: SimpleNamespace | None = None) -> FastAPI: + """Build a tiny FastAPI app that uses require_role on a test endpoint. + + Overrides ``get_current_user`` so no real DB/JWT is needed. + """ + from duh.api.auth import get_current_user + + app = FastAPI() + + dep = require_role(minimum_role) + + @app.get("/protected") + async def protected(u=Depends(dep)): # noqa: B008 + return {"role": u.role, "id": u.id} + + if user is not None: + app.dependency_overrides[get_current_user] = lambda: user + + return app + + +class TestRequireRole: + """Test the require_role dependency factory via real HTTP calls.""" + + # ── admin-level endpoint ────────────────────────────────── + + def test_admin_passes_admin_check(self) -> None: + app = _build_app("admin", _make_user("admin")) + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 200 + assert resp.json()["role"] == "admin" + + def test_contributor_fails_admin_check(self) -> None: + app = _build_app("admin", _make_user("contributor")) + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 403 + assert "admin" in resp.json()["detail"].lower() + + def test_viewer_fails_admin_check(self) -> None: + app = _build_app("admin", _make_user("viewer")) + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 403 + + # ── contributor-level endpoint ──────────────────────────── + + def test_admin_passes_contributor_check(self) -> None: + app = _build_app("contributor", _make_user("admin")) + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 200 + + def test_contributor_passes_contributor_check(self) -> None: + app = _build_app("contributor", _make_user("contributor")) + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 200 + + def test_viewer_fails_contributor_check(self) -> None: + app = _build_app("contributor", _make_user("viewer")) + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 403 + assert "contributor" in resp.json()["detail"].lower() + + # ── viewer-level endpoint ───────────────────────────────── + + def test_admin_passes_viewer_check(self) -> None: + app = _build_app("viewer", _make_user("admin")) + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 200 + + def test_contributor_passes_viewer_check(self) -> None: + app = _build_app("viewer", _make_user("contributor")) + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 200 + + def test_viewer_passes_viewer_check(self) -> None: + app = _build_app("viewer", _make_user("viewer")) + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 200 + + # ── edge cases ──────────────────────────────────────────── + + def test_unknown_role_denied(self) -> None: + """A user with an unrecognised role (level 0) is denied.""" + app = _build_app("viewer", _make_user("unknown_role")) + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 403 + + def test_no_role_attr_denied(self) -> None: + """User object missing 'role' attribute is treated as level 0.""" + user = SimpleNamespace(id="user-1", email="a@b.com") + app = _build_app("viewer", user) + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 403 + + def test_unknown_minimum_role_accepts_any(self) -> None: + """If minimum_role is not in ROLE_HIERARCHY its level defaults to 0. + + Any valid user (viewer=1 > 0) should pass. + """ + app = _build_app("nonexistent", _make_user("viewer")) + client = TestClient(app) + resp = client.get("/protected") + assert resp.status_code == 200 + + +# ── convenience aliases ──────────────────────────────────────── + + +class TestConvenienceAliases: + """Verify pre-built require_admin / require_contributor / require_viewer.""" + + def test_require_admin_callable(self) -> None: + assert callable(require_admin) + + def test_require_contributor_callable(self) -> None: + assert callable(require_contributor) + + def test_require_viewer_callable(self) -> None: + assert callable(require_viewer) + + def _build_alias_app( + self, dep: object, user: SimpleNamespace + ) -> FastAPI: + from duh.api.auth import get_current_user + + app = FastAPI() + + @app.get("/test") + async def endpoint(u=Depends(dep)): # noqa: B008 + return {"ok": True} + + app.dependency_overrides[get_current_user] = lambda: user + return app + + def test_require_admin_blocks_contributor(self) -> None: + app = self._build_alias_app(require_admin, _make_user("contributor")) + resp = TestClient(app).get("/test") + assert resp.status_code == 403 + + def test_require_admin_passes_admin(self) -> None: + app = self._build_alias_app(require_admin, _make_user("admin")) + resp = TestClient(app).get("/test") + assert resp.status_code == 200 + + def test_require_contributor_passes_contributor(self) -> None: + app = self._build_alias_app(require_contributor, _make_user("contributor")) + resp = TestClient(app).get("/test") + assert resp.status_code == 200 + + def test_require_viewer_passes_viewer(self) -> None: + app = self._build_alias_app(require_viewer, _make_user("viewer")) + resp = TestClient(app).get("/test") + assert resp.status_code == 200 + + def test_require_viewer_blocks_unknown(self) -> None: + app = self._build_alias_app(require_viewer, _make_user("guest")) + resp = TestClient(app).get("/test") + assert resp.status_code == 403 diff --git a/tests/unit/test_restore.py b/tests/unit/test_restore.py new file mode 100644 index 0000000..a91d908 --- /dev/null +++ b/tests/unit/test_restore.py @@ -0,0 +1,455 @@ +"""Tests for database restore utilities and CLI command.""" + +from __future__ import annotations + +import asyncio +import json +import sqlite3 +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, patch + +import pytest +from click.testing import CliRunner + +from duh.cli.app import cli +from duh.memory.backup import detect_backup_format, restore_json, restore_sqlite + +if TYPE_CHECKING: + from pathlib import Path + + +# ── helpers ──────────────────────────────────────────────────── + + +def _make_async_session() -> tuple[Any, Any]: + """Create an in-memory SQLite async session with tables created.""" + from sqlalchemy import event + from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + from sqlalchemy.pool import StaticPool + + from duh.memory.models import Base + + engine = create_async_engine( + "sqlite+aiosqlite://", + poolclass=StaticPool, + connect_args={"check_same_thread": False}, + ) + + @event.listens_for(engine.sync_engine, "connect") + def _enable_fks(dbapi_conn, connection_record): # type: ignore[no-untyped-def] + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + async def _init() -> None: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + asyncio.run(_init()) + factory = async_sessionmaker(engine, expire_on_commit=False) + return factory, engine + + +def _make_json_backup( + tmp_path: Path, + tables: dict[str, list[dict[str, Any]]] | None = None, + *, + version: str = "0.5.0", + filename: str = "backup.json", +) -> Path: + """Create a JSON backup file for testing.""" + data = { + "version": version, + "exported_at": "2026-01-01T00:00:00+00:00", + "tables": tables or {}, + } + dest = tmp_path / filename + dest.write_text(json.dumps(data, indent=2), encoding="utf-8") + return dest + + +# ── detect_backup_format ─────────────────────────────────────── + + +class TestDetectBackupFormat: + def test_json_file(self, tmp_path: Path) -> None: + f = tmp_path / "backup.json" + f.write_text('{"version": "0.5.0", "tables": {}}') + assert detect_backup_format(f) == "json" + + def test_json_array(self, tmp_path: Path) -> None: + f = tmp_path / "backup.json" + f.write_text("[1, 2, 3]") + assert detect_backup_format(f) == "json" + + def test_sqlite_file(self, tmp_path: Path) -> None: + f = tmp_path / "backup.db" + conn = sqlite3.connect(str(f)) + conn.execute("CREATE TABLE t (id INTEGER)") + conn.commit() + conn.close() + assert detect_backup_format(f) == "sqlite" + + def test_invalid_file(self, tmp_path: Path) -> None: + f = tmp_path / "backup.bin" + f.write_bytes(b"\x00\x01\x02\x03random binary") + with pytest.raises(ValueError, match="Cannot detect"): + detect_backup_format(f) + + def test_empty_file(self, tmp_path: Path) -> None: + f = tmp_path / "empty.dat" + f.write_bytes(b"") + with pytest.raises(ValueError, match="empty"): + detect_backup_format(f) + + +# ── restore_json ─────────────────────────────────────────────── + + +class TestRestoreJson: + def test_restore_empty(self, tmp_path: Path) -> None: + """Restore from a backup of an empty DB works.""" + factory, engine = _make_async_session() + backup_file = _make_json_backup(tmp_path, tables={ + "threads": [], + "turns": [], + "contributions": [], + "decisions": [], + }) + + async def _run() -> dict[str, int]: + async with factory() as session: + return await restore_json(session, backup_file) + + counts = asyncio.run(_run()) + assert counts["threads"] == 0 + assert counts["decisions"] == 0 + asyncio.run(engine.dispose()) + + def test_restore_with_data(self, tmp_path: Path) -> None: + """Restore from backup with threads/decisions, verify data present.""" + import uuid + + factory, engine = _make_async_session() + + thread_id = str(uuid.uuid4()) + turn_id = str(uuid.uuid4()) + decision_id = str(uuid.uuid4()) + + backup_file = _make_json_backup(tmp_path, tables={ + "users": [], + "threads": [{ + "id": thread_id, + "question": "Test question?", + "status": "complete", + "created_at": "2026-01-01T00:00:00+00:00", + "updated_at": "2026-01-01T00:00:00+00:00", + }], + "turns": [{ + "id": turn_id, + "thread_id": thread_id, + "round_number": 1, + "state": "COMMIT", + "created_at": "2026-01-01T00:00:00+00:00", + }], + "contributions": [], + "turn_summaries": [], + "thread_summaries": [], + "decisions": [{ + "id": decision_id, + "turn_id": turn_id, + "thread_id": thread_id, + "content": "The answer is 42", + "confidence": 0.95, + "created_at": "2026-01-01T00:00:00+00:00", + }], + "outcomes": [], + "subtasks": [], + "votes": [], + "api_keys": [], + }) + + async def _run() -> dict[str, int]: + async with factory() as session: + return await restore_json(session, backup_file) + + counts = asyncio.run(_run()) + assert counts["threads"] == 1 + assert counts["decisions"] == 1 + + # Verify data is actually in the DB + async def _verify() -> None: + from sqlalchemy import select + + from duh.memory.models import Decision, Thread + + async with factory() as session: + result = await session.execute(select(Thread)) + threads = result.scalars().all() + assert len(threads) == 1 + assert threads[0].question == "Test question?" + + result = await session.execute(select(Decision)) + decisions = result.scalars().all() + assert len(decisions) == 1 + assert decisions[0].content == "The answer is 42" + + asyncio.run(_verify()) + asyncio.run(engine.dispose()) + + def test_restore_clears_existing(self, tmp_path: Path) -> None: + """Non-merge mode clears existing data first.""" + import uuid + + factory, engine = _make_async_session() + + # Seed existing data + async def _seed() -> None: + from duh.memory.repository import MemoryRepository + + async with factory() as session: + repo = MemoryRepository(session) + await repo.create_thread("Old question") + await session.commit() + + asyncio.run(_seed()) + + # Restore with new data (non-merge) + thread_id = str(uuid.uuid4()) + backup_file = _make_json_backup(tmp_path, tables={ + "users": [], + "threads": [{ + "id": thread_id, + "question": "New question", + "status": "active", + "created_at": "2026-01-01T00:00:00+00:00", + "updated_at": "2026-01-01T00:00:00+00:00", + }], + "turns": [], + "contributions": [], + "turn_summaries": [], + "thread_summaries": [], + "decisions": [], + "outcomes": [], + "subtasks": [], + "votes": [], + "api_keys": [], + }) + + async def _restore() -> dict[str, int]: + async with factory() as session: + return await restore_json(session, backup_file) + + counts = asyncio.run(_restore()) + assert counts["threads"] == 1 + + # Verify only the new data exists + async def _verify() -> None: + from sqlalchemy import select + + from duh.memory.models import Thread + + async with factory() as session: + result = await session.execute(select(Thread)) + threads = result.scalars().all() + assert len(threads) == 1 + assert threads[0].question == "New question" + + asyncio.run(_verify()) + asyncio.run(engine.dispose()) + + def test_restore_merge_mode(self, tmp_path: Path) -> None: + """Merge mode keeps existing data and adds new.""" + import uuid + + factory, engine = _make_async_session() + + # Seed existing data + async def _seed() -> None: + from duh.memory.repository import MemoryRepository + + async with factory() as session: + repo = MemoryRepository(session) + await repo.create_thread("Existing question") + await session.commit() + + asyncio.run(_seed()) + + # Restore with additional data (merge mode) + new_thread_id = str(uuid.uuid4()) + backup_file = _make_json_backup(tmp_path, tables={ + "users": [], + "threads": [{ + "id": new_thread_id, + "question": "New merged question", + "status": "active", + "created_at": "2026-01-01T00:00:00+00:00", + "updated_at": "2026-01-01T00:00:00+00:00", + }], + "turns": [], + "contributions": [], + "turn_summaries": [], + "thread_summaries": [], + "decisions": [], + "outcomes": [], + "subtasks": [], + "votes": [], + "api_keys": [], + }) + + async def _restore() -> dict[str, int]: + async with factory() as session: + return await restore_json(session, backup_file, merge=True) + + counts = asyncio.run(_restore()) + assert counts["threads"] == 1 # 1 new record processed + + # Verify both old and new data exist + async def _verify() -> None: + from sqlalchemy import select + + from duh.memory.models import Thread + + async with factory() as session: + result = await session.execute(select(Thread)) + threads = result.scalars().all() + assert len(threads) == 2 + questions = {t.question for t in threads} + assert "Existing question" in questions + assert "New merged question" in questions + + asyncio.run(_verify()) + asyncio.run(engine.dispose()) + + def test_restore_validates_structure(self, tmp_path: Path) -> None: + """Missing 'tables' key raises ValueError.""" + bad_backup = tmp_path / "bad.json" + bad_backup.write_text(json.dumps({"version": "0.5.0"}), encoding="utf-8") + + factory, engine = _make_async_session() + + async def _run() -> dict[str, int]: + async with factory() as session: + return await restore_json(session, bad_backup) + + with pytest.raises(ValueError, match="missing 'tables'"): + asyncio.run(_run()) + + asyncio.run(engine.dispose()) + + +# ── restore_sqlite ───────────────────────────────────────────── + + +class TestRestoreSqlite: + def test_copies_file(self, tmp_path: Path) -> None: + """SQLite restore replaces the DB file.""" + # Create a backup SQLite file + backup_db = tmp_path / "backup.db" + conn = sqlite3.connect(str(backup_db)) + conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") + conn.execute("INSERT INTO test VALUES (1, 'restored')") + conn.commit() + conn.close() + + # Create target DB path + target_db = tmp_path / "target" / "duh.db" + target_db.parent.mkdir(parents=True, exist_ok=True) + # Create an empty target + conn2 = sqlite3.connect(str(target_db)) + conn2.execute("CREATE TABLE empty (id INTEGER)") + conn2.commit() + conn2.close() + + db_url = f"sqlite:///{target_db}" + asyncio.run(restore_sqlite(backup_db, db_url)) + + # Verify the restored data + conn3 = sqlite3.connect(str(target_db)) + rows = conn3.execute("SELECT * FROM test").fetchall() + conn3.close() + assert rows == [(1, "restored")] + + def test_memory_db_raises(self, tmp_path: Path) -> None: + backup_db = tmp_path / "backup.db" + backup_db.write_bytes(b"") + with pytest.raises(ValueError, match="in-memory"): + asyncio.run(restore_sqlite(backup_db, "sqlite+aiosqlite:///:memory:")) + + def test_no_triple_slash_raises(self, tmp_path: Path) -> None: + backup_db = tmp_path / "backup.db" + backup_db.write_bytes(b"") + with pytest.raises(ValueError, match="Cannot extract"): + asyncio.run(restore_sqlite(backup_db, "sqlite://badurl")) + + +# ── CLI command ──────────────────────────────────────────────── + + +class TestRestoreCli: + @pytest.fixture() + def runner(self) -> CliRunner: + return CliRunner() + + def test_help(self, runner: CliRunner) -> None: + result = runner.invoke(cli, ["restore", "--help"]) + assert result.exit_code == 0 + assert "PATH" in result.output + assert "--merge" in result.output + + def test_restore_json_via_cli(self, runner: CliRunner, tmp_path: Path) -> None: + """Use CliRunner to test the restore command with JSON backup.""" + factory, engine = _make_async_session() + from duh.config.schema import DatabaseConfig, DuhConfig + + config = DuhConfig( + database=DatabaseConfig(url="sqlite+aiosqlite://"), + ) + + backup_file = _make_json_backup(tmp_path, tables={ + "threads": [], + "turns": [], + "contributions": [], + "turn_summaries": [], + "thread_summaries": [], + "decisions": [], + "outcomes": [], + "subtasks": [], + "votes": [], + "api_keys": [], + }) + + with ( + patch("duh.cli.app.load_config", return_value=config), + patch( + "duh.cli.app._create_db", + new_callable=AsyncMock, + return_value=(factory, engine), + ), + ): + result = runner.invoke(cli, ["restore", str(backup_file)]) + + assert result.exit_code == 0, result.output + assert "Restored" in result.output + asyncio.run(engine.dispose()) + + def test_restore_sqlite_pg_errors( + self, runner: CliRunner, tmp_path: Path + ) -> None: + """Cannot restore a SQLite backup into a PostgreSQL database.""" + from duh.config.schema import DatabaseConfig, DuhConfig + + config = DuhConfig( + database=DatabaseConfig(url="postgresql+asyncpg://user:pass@host/db"), + ) + + backup_db = tmp_path / "backup.db" + conn = sqlite3.connect(str(backup_db)) + conn.execute("CREATE TABLE t (id INTEGER)") + conn.commit() + conn.close() + + with patch("duh.cli.app.load_config", return_value=config): + result = runner.invoke(cli, ["restore", str(backup_db)]) + + assert result.exit_code != 0 diff --git a/tests/unit/test_smoke.py b/tests/unit/test_smoke.py index 32d0573..f424738 100644 --- a/tests/unit/test_smoke.py +++ b/tests/unit/test_smoke.py @@ -7,14 +7,14 @@ def test_version_string(): - assert __version__ == "0.4.0" + assert __version__ == "0.5.0" def test_cli_version(): runner = CliRunner() result = runner.invoke(cli, ["--version"]) assert result.exit_code == 0 - assert "0.4.0" in result.output + assert "0.5.0" in result.output def test_cli_help(): diff --git a/tests/unit/test_user_model.py b/tests/unit/test_user_model.py new file mode 100644 index 0000000..68dd646 --- /dev/null +++ b/tests/unit/test_user_model.py @@ -0,0 +1,157 @@ +"""Tests for User model and user_id foreign keys.""" + +from __future__ import annotations + +import hashlib +from typing import TYPE_CHECKING + +import pytest +from sqlalchemy.exc import IntegrityError + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + +from duh.memory.models import APIKey, Decision, Thread, Turn, User + + +def _make_user( + email: str = "alice@example.com", + password_hash: str = "hashed_pw_placeholder", + display_name: str = "Alice", + **kwargs: object, +) -> User: + return User( + email=email, + password_hash=password_hash, + display_name=display_name, + **kwargs, + ) + + +def _make_thread(question: str = "What is AI?", **kwargs: object) -> Thread: + return Thread(question=question, **kwargs) + + +def _make_turn( + thread: Thread, round_number: int = 1, state: str = "propose" +) -> Turn: + return Turn(thread=thread, round_number=round_number, state=state) + + +# ── User Creation ──────────────────────────────────────────────── + + +class TestUserCreation: + async def test_user_creation(self, db_session: AsyncSession) -> None: + user = _make_user() + db_session.add(user) + await db_session.commit() + + assert user.id is not None + assert len(user.id) == 36 + assert user.email == "alice@example.com" + assert user.password_hash == "hashed_pw_placeholder" + assert user.display_name == "Alice" + assert user.created_at is not None + assert user.updated_at is not None + + async def test_user_defaults(self, db_session: AsyncSession) -> None: + user = _make_user() + db_session.add(user) + await db_session.commit() + + assert user.role == "contributor" + assert user.is_active is True + + async def test_user_email_unique(self, db_session: AsyncSession) -> None: + user1 = _make_user(email="dup@example.com") + db_session.add(user1) + await db_session.commit() + + user2 = _make_user(email="dup@example.com", display_name="Bob") + db_session.add(user2) + with pytest.raises(IntegrityError): + await db_session.flush() + await db_session.rollback() + + +# ── Relationships ──────────────────────────────────────────────── + + +class TestUserRelationships: + async def test_thread_user_relationship(self, db_session: AsyncSession) -> None: + user = _make_user() + thread = _make_thread(user=user) + db_session.add(thread) + await db_session.commit() + + assert thread.user_id == user.id + assert thread.user is user + assert thread in user.threads + + async def test_decision_user_id(self, db_session: AsyncSession) -> None: + user = _make_user() + db_session.add(user) + await db_session.flush() + + thread = _make_thread() + turn = _make_turn(thread) + decision = Decision( + turn=turn, + thread=thread, + content="Answer", + confidence=0.8, + user_id=user.id, + ) + db_session.add(decision) + await db_session.commit() + + assert decision.user_id == user.id + + async def test_api_key_user_id(self, db_session: AsyncSession) -> None: + user = _make_user() + db_session.add(user) + await db_session.flush() + + key_hash = hashlib.sha256(b"test-key").hexdigest() + api_key = APIKey( + key_hash=key_hash, + name="test-key", + user_id=user.id, + ) + db_session.add(api_key) + await db_session.commit() + + assert api_key.user_id == user.id + + +# ── Backward Compatibility ─────────────────────────────────────── + + +class TestNullableUserId: + async def test_thread_without_user(self, db_session: AsyncSession) -> None: + thread = _make_thread() + db_session.add(thread) + await db_session.commit() + + assert thread.user_id is None + assert thread.user is None + + async def test_decision_without_user(self, db_session: AsyncSession) -> None: + thread = _make_thread() + turn = _make_turn(thread) + decision = Decision( + turn=turn, thread=thread, content="Answer", confidence=0.8 + ) + db_session.add(decision) + await db_session.commit() + + assert decision.user_id is None + + async def test_api_key_without_user(self, db_session: AsyncSession) -> None: + key_hash = hashlib.sha256(b"orphan-key").hexdigest() + api_key = APIKey(key_hash=key_hash, name="orphan") + db_session.add(api_key) + await db_session.commit() + + assert api_key.user_id is None diff --git a/uv.lock b/uv.lock index 47290b0..d16b8a4 100644 --- a/uv.lock +++ b/uv.lock @@ -75,6 +75,54 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592 }, ] +[[package]] +name = "asyncpg" +version = "0.31.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/cc/d18065ce2380d80b1bcce927c24a2642efd38918e33fd724bc4bca904877/asyncpg-0.31.0.tar.gz", hash = "sha256:c989386c83940bfbd787180f2b1519415e2d3d6277a70d9d0f0145ac73500735", size = 993667 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/17/cc02bc49bc350623d050fa139e34ea512cd6e020562f2a7312a7bcae4bc9/asyncpg-0.31.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:eee690960e8ab85063ba93af2ce128c0f52fd655fdff9fdb1a28df01329f031d", size = 643159 }, + { url = "https://files.pythonhosted.org/packages/a4/62/4ded7d400a7b651adf06f49ea8f73100cca07c6df012119594d1e3447aa6/asyncpg-0.31.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2657204552b75f8288de08ca60faf4a99a65deef3a71d1467454123205a88fab", size = 638157 }, + { url = "https://files.pythonhosted.org/packages/d6/5b/4179538a9a72166a0bf60ad783b1ef16efb7960e4d7b9afe9f77a5551680/asyncpg-0.31.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a429e842a3a4b4ea240ea52d7fe3f82d5149853249306f7ff166cb9948faa46c", size = 2918051 }, + { url = "https://files.pythonhosted.org/packages/e6/35/c27719ae0536c5b6e61e4701391ffe435ef59539e9360959240d6e47c8c8/asyncpg-0.31.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c0807be46c32c963ae40d329b3a686356e417f674c976c07fa49f1b30303f109", size = 2972640 }, + { url = "https://files.pythonhosted.org/packages/43/f4/01ebb9207f29e645a64699b9ce0eefeff8e7a33494e1d29bb53736f7766b/asyncpg-0.31.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e5d5098f63beeae93512ee513d4c0c53dc12e9aa2b7a1af5a81cddf93fe4e4da", size = 2851050 }, + { url = "https://files.pythonhosted.org/packages/3e/f4/03ff1426acc87be0f4e8d40fa2bff5c3952bef0080062af9efc2212e3be8/asyncpg-0.31.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:37fc6c00a814e18eef51833545d1891cac9aa69140598bb076b4cd29b3e010b9", size = 2962574 }, + { url = "https://files.pythonhosted.org/packages/c7/39/cc788dfca3d4060f9d93e67be396ceec458dfc429e26139059e58c2c244d/asyncpg-0.31.0-cp311-cp311-win32.whl", hash = "sha256:5a4af56edf82a701aece93190cc4e094d2df7d33f6e915c222fb09efbb5afc24", size = 521076 }, + { url = "https://files.pythonhosted.org/packages/28/fc/735af5384c029eb7f1ca60ccb8fa95521dbdaeef788edf4cecfc604c3cab/asyncpg-0.31.0-cp311-cp311-win_amd64.whl", hash = "sha256:480c4befbdf079c14c9ca43c8c5e1fe8b6296c96f1f927158d4f1e750aacc047", size = 584980 }, + { url = "https://files.pythonhosted.org/packages/2a/a6/59d0a146e61d20e18db7396583242e32e0f120693b67a8de43f1557033e2/asyncpg-0.31.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b44c31e1efc1c15188ef183f287c728e2046abb1d26af4d20858215d50d91fad", size = 662042 }, + { url = "https://files.pythonhosted.org/packages/36/01/ffaa189dcb63a2471720615e60185c3f6327716fdc0fc04334436fbb7c65/asyncpg-0.31.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0c89ccf741c067614c9b5fc7f1fc6f3b61ab05ae4aaa966e6fd6b93097c7d20d", size = 638504 }, + { url = "https://files.pythonhosted.org/packages/9f/62/3f699ba45d8bd24c5d65392190d19656d74ff0185f42e19d0bbd973bb371/asyncpg-0.31.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:12b3b2e39dc5470abd5e98c8d3373e4b1d1234d9fbdedf538798b2c13c64460a", size = 3426241 }, + { url = "https://files.pythonhosted.org/packages/8c/d1/a867c2150f9c6e7af6462637f613ba67f78a314b00db220cd26ff559d532/asyncpg-0.31.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:aad7a33913fb8bcb5454313377cc330fbb19a0cd5faa7272407d8a0c4257b671", size = 3520321 }, + { url = "https://files.pythonhosted.org/packages/7a/1a/cce4c3f246805ecd285a3591222a2611141f1669d002163abef999b60f98/asyncpg-0.31.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3df118d94f46d85b2e434fd62c84cb66d5834d5a890725fe625f498e72e4d5ec", size = 3316685 }, + { url = "https://files.pythonhosted.org/packages/40/ae/0fc961179e78cc579e138fad6eb580448ecae64908f95b8cb8ee2f241f67/asyncpg-0.31.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bd5b6efff3c17c3202d4b37189969acf8927438a238c6257f66be3c426beba20", size = 3471858 }, + { url = "https://files.pythonhosted.org/packages/52/b2/b20e09670be031afa4cbfabd645caece7f85ec62d69c312239de568e058e/asyncpg-0.31.0-cp312-cp312-win32.whl", hash = "sha256:027eaa61361ec735926566f995d959ade4796f6a49d3bde17e5134b9964f9ba8", size = 527852 }, + { url = "https://files.pythonhosted.org/packages/b5/f0/f2ed1de154e15b107dc692262395b3c17fc34eafe2a78fc2115931561730/asyncpg-0.31.0-cp312-cp312-win_amd64.whl", hash = "sha256:72d6bdcbc93d608a1158f17932de2321f68b1a967a13e014998db87a72ed3186", size = 597175 }, + { url = "https://files.pythonhosted.org/packages/95/11/97b5c2af72a5d0b9bc3fa30cd4b9ce22284a9a943a150fdc768763caf035/asyncpg-0.31.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c204fab1b91e08b0f47e90a75d1b3c62174dab21f670ad6c5d0f243a228f015b", size = 661111 }, + { url = "https://files.pythonhosted.org/packages/1b/71/157d611c791a5e2d0423f09f027bd499935f0906e0c2a416ce712ba51ef3/asyncpg-0.31.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:54a64f91839ba59008eccf7aad2e93d6e3de688d796f35803235ea1c4898ae1e", size = 636928 }, + { url = "https://files.pythonhosted.org/packages/2e/fc/9e3486fb2bbe69d4a867c0b76d68542650a7ff1574ca40e84c3111bb0c6e/asyncpg-0.31.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0e0822b1038dc7253b337b0f3f676cadc4ac31b126c5d42691c39691962e403", size = 3424067 }, + { url = "https://files.pythonhosted.org/packages/12/c6/8c9d076f73f07f995013c791e018a1cd5f31823c2a3187fc8581706aa00f/asyncpg-0.31.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bef056aa502ee34204c161c72ca1f3c274917596877f825968368b2c33f585f4", size = 3518156 }, + { url = "https://files.pythonhosted.org/packages/ae/3b/60683a0baf50fbc546499cfb53132cb6835b92b529a05f6a81471ab60d0c/asyncpg-0.31.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0bfbcc5b7ffcd9b75ab1558f00db2ae07db9c80637ad1b2469c43df79d7a5ae2", size = 3319636 }, + { url = "https://files.pythonhosted.org/packages/50/dc/8487df0f69bd398a61e1792b3cba0e47477f214eff085ba0efa7eac9ce87/asyncpg-0.31.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:22bc525ebbdc24d1261ecbf6f504998244d4e3be1721784b5f64664d61fbe602", size = 3472079 }, + { url = "https://files.pythonhosted.org/packages/13/a1/c5bbeeb8531c05c89135cb8b28575ac2fac618bcb60119ee9696c3faf71c/asyncpg-0.31.0-cp313-cp313-win32.whl", hash = "sha256:f890de5e1e4f7e14023619399a471ce4b71f5418cd67a51853b9910fdfa73696", size = 527606 }, + { url = "https://files.pythonhosted.org/packages/91/66/b25ccb84a246b470eb943b0107c07edcae51804912b824054b3413995a10/asyncpg-0.31.0-cp313-cp313-win_amd64.whl", hash = "sha256:dc5f2fa9916f292e5c5c8b2ac2813763bcd7f58e130055b4ad8a0531314201ab", size = 596569 }, + { url = "https://files.pythonhosted.org/packages/3c/36/e9450d62e84a13aea6580c83a47a437f26c7ca6fa0f0fd40b6670793ea30/asyncpg-0.31.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:f6b56b91bb0ffc328c4e3ed113136cddd9deefdf5f79ab448598b9772831df44", size = 660867 }, + { url = "https://files.pythonhosted.org/packages/82/4b/1d0a2b33b3102d210439338e1beea616a6122267c0df459ff0265cd5807a/asyncpg-0.31.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:334dec28cf20d7f5bb9e45b39546ddf247f8042a690bff9b9573d00086e69cb5", size = 638349 }, + { url = "https://files.pythonhosted.org/packages/41/aa/e7f7ac9a7974f08eff9183e392b2d62516f90412686532d27e196c0f0eeb/asyncpg-0.31.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:98cc158c53f46de7bb677fd20c417e264fc02b36d901cc2a43bd6cb0dc6dbfd2", size = 3410428 }, + { url = "https://files.pythonhosted.org/packages/6f/de/bf1b60de3dede5c2731e6788617a512bc0ebd9693eac297ee74086f101d7/asyncpg-0.31.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9322b563e2661a52e3cdbc93eed3be7748b289f792e0011cb2720d278b366ce2", size = 3471678 }, + { url = "https://files.pythonhosted.org/packages/46/78/fc3ade003e22d8bd53aaf8f75f4be48f0b460fa73738f0391b9c856a9147/asyncpg-0.31.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19857a358fc811d82227449b7ca40afb46e75b33eb8897240c3839dd8b744218", size = 3313505 }, + { url = "https://files.pythonhosted.org/packages/bf/e9/73eb8a6789e927816f4705291be21f2225687bfa97321e40cd23055e903a/asyncpg-0.31.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:ba5f8886e850882ff2c2ace5732300e99193823e8107e2c53ef01c1ebfa1e85d", size = 3434744 }, + { url = "https://files.pythonhosted.org/packages/08/4b/f10b880534413c65c5b5862f79b8e81553a8f364e5238832ad4c0af71b7f/asyncpg-0.31.0-cp314-cp314-win32.whl", hash = "sha256:cea3a0b2a14f95834cee29432e4ddc399b95700eb1d51bbc5bfee8f31fa07b2b", size = 532251 }, + { url = "https://files.pythonhosted.org/packages/d3/2d/7aa40750b7a19efa5d66e67fc06008ca0f27ba1bd082e457ad82f59aba49/asyncpg-0.31.0-cp314-cp314-win_amd64.whl", hash = "sha256:04d19392716af6b029411a0264d92093b6e5e8285ae97a39957b9a9c14ea72be", size = 604901 }, + { url = "https://files.pythonhosted.org/packages/ce/fe/b9dfe349b83b9dee28cc42360d2c86b2cdce4cb551a2c2d27e156bcac84d/asyncpg-0.31.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:bdb957706da132e982cc6856bb2f7b740603472b54c3ebc77fe60ea3e57e1bd2", size = 702280 }, + { url = "https://files.pythonhosted.org/packages/6a/81/e6be6e37e560bd91e6c23ea8a6138a04fd057b08cf63d3c5055c98e81c1d/asyncpg-0.31.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:6d11b198111a72f47154fa03b85799f9be63701e068b43f84ac25da0bda9cb31", size = 682931 }, + { url = "https://files.pythonhosted.org/packages/a6/45/6009040da85a1648dd5bc75b3b0a062081c483e75a1a29041ae63a0bf0dc/asyncpg-0.31.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:18c83b03bc0d1b23e6230f5bf8d4f217dc9bc08644ce0502a9d91dc9e634a9c7", size = 3581608 }, + { url = "https://files.pythonhosted.org/packages/7e/06/2e3d4d7608b0b2b3adbee0d0bd6a2d29ca0fc4d8a78f8277df04e2d1fd7b/asyncpg-0.31.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e009abc333464ff18b8f6fd146addffd9aaf63e79aa3bb40ab7a4c332d0c5e9e", size = 3498738 }, + { url = "https://files.pythonhosted.org/packages/7d/aa/7d75ede780033141c51d83577ea23236ba7d3a23593929b32b49db8ed36e/asyncpg-0.31.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:3b1fbcb0e396a5ca435a8826a87e5c2c2cc0c8c68eb6fadf82168056b0e53a8c", size = 3401026 }, + { url = "https://files.pythonhosted.org/packages/ba/7a/15e37d45e7f7c94facc1e9148c0e455e8f33c08f0b8a0b1deb2c5171771b/asyncpg-0.31.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8df714dba348efcc162d2adf02d213e5fab1bd9f557e1305633e851a61814a7a", size = 3429426 }, + { url = "https://files.pythonhosted.org/packages/13/d5/71437c5f6ae5f307828710efbe62163974e71237d5d46ebd2869ea052d10/asyncpg-0.31.0-cp314-cp314t-win32.whl", hash = "sha256:1b41f1afb1033f2b44f3234993b15096ddc9cd71b21a42dbd87fc6a57b43d65d", size = 614495 }, + { url = "https://files.pythonhosted.org/packages/3c/d7/8fb3044eaef08a310acfe23dae9a8e2e07d305edc29a53497e52bc76eca7/asyncpg-0.31.0-cp314-cp314t-win_amd64.whl", hash = "sha256:bd4107bb7cdd0e9e65fae66a62afd3a249663b844fa34d479f6d5b3bef9c04c3", size = 706062 }, +] + [[package]] name = "attrs" version = "25.4.0" @@ -107,6 +155,76 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/e3/a4fa1946722c4c7b063cc25043a12d9ce9b4323777f89643be74cef2993c/backrefs-6.1-py39-none-any.whl", hash = "sha256:a9e99b8a4867852cad177a6430e31b0f6e495d65f8c6c134b68c14c3c95bf4b0", size = 381058 }, ] +[[package]] +name = "bcrypt" +version = "5.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d4/36/3329e2518d70ad8e2e5817d5a4cac6bba05a47767ec416c7d020a965f408/bcrypt-5.0.0.tar.gz", hash = "sha256:f748f7c2d6fd375cc93d3fba7ef4a9e3a092421b8dbf34d8d4dc06be9492dfdd", size = 25386 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/13/85/3e65e01985fddf25b64ca67275bb5bdb4040bd1a53b66d355c6c37c8a680/bcrypt-5.0.0-cp313-cp313t-macosx_10_12_universal2.whl", hash = "sha256:f3c08197f3039bec79cee59a606d62b96b16669cff3949f21e74796b6e3cd2be", size = 481806 }, + { url = "https://files.pythonhosted.org/packages/44/dc/01eb79f12b177017a726cbf78330eb0eb442fae0e7b3dfd84ea2849552f3/bcrypt-5.0.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:200af71bc25f22006f4069060c88ed36f8aa4ff7f53e67ff04d2ab3f1e79a5b2", size = 268626 }, + { url = "https://files.pythonhosted.org/packages/8c/cf/e82388ad5959c40d6afd94fb4743cc077129d45b952d46bdc3180310e2df/bcrypt-5.0.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:baade0a5657654c2984468efb7d6c110db87ea63ef5a4b54732e7e337253e44f", size = 271853 }, + { url = "https://files.pythonhosted.org/packages/ec/86/7134b9dae7cf0efa85671651341f6afa695857fae172615e960fb6a466fa/bcrypt-5.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:c58b56cdfb03202b3bcc9fd8daee8e8e9b6d7e3163aa97c631dfcfcc24d36c86", size = 269793 }, + { url = "https://files.pythonhosted.org/packages/cc/82/6296688ac1b9e503d034e7d0614d56e80c5d1a08402ff856a4549cb59207/bcrypt-5.0.0-cp313-cp313t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:4bfd2a34de661f34d0bda43c3e4e79df586e4716ef401fe31ea39d69d581ef23", size = 289930 }, + { url = "https://files.pythonhosted.org/packages/d1/18/884a44aa47f2a3b88dd09bc05a1e40b57878ecd111d17e5bba6f09f8bb77/bcrypt-5.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:ed2e1365e31fc73f1825fa830f1c8f8917ca1b3ca6185773b349c20fd606cec2", size = 272194 }, + { url = "https://files.pythonhosted.org/packages/0e/8f/371a3ab33c6982070b674f1788e05b656cfbf5685894acbfef0c65483a59/bcrypt-5.0.0-cp313-cp313t-manylinux_2_34_aarch64.whl", hash = "sha256:83e787d7a84dbbfba6f250dd7a5efd689e935f03dd83b0f919d39349e1f23f83", size = 269381 }, + { url = "https://files.pythonhosted.org/packages/b1/34/7e4e6abb7a8778db6422e88b1f06eb07c47682313997ee8a8f9352e5a6f1/bcrypt-5.0.0-cp313-cp313t-manylinux_2_34_x86_64.whl", hash = "sha256:137c5156524328a24b9fac1cb5db0ba618bc97d11970b39184c1d87dc4bf1746", size = 271750 }, + { url = "https://files.pythonhosted.org/packages/c0/1b/54f416be2499bd72123c70d98d36c6cd61a4e33d9b89562c22481c81bb30/bcrypt-5.0.0-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:38cac74101777a6a7d3b3e3cfefa57089b5ada650dce2baf0cbdd9d65db22a9e", size = 303757 }, + { url = "https://files.pythonhosted.org/packages/13/62/062c24c7bcf9d2826a1a843d0d605c65a755bc98002923d01fd61270705a/bcrypt-5.0.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:d8d65b564ec849643d9f7ea05c6d9f0cd7ca23bdd4ac0c2dbef1104ab504543d", size = 306740 }, + { url = "https://files.pythonhosted.org/packages/d5/c8/1fdbfc8c0f20875b6b4020f3c7dc447b8de60aa0be5faaf009d24242aec9/bcrypt-5.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:741449132f64b3524e95cd30e5cd3343006ce146088f074f31ab26b94e6c75ba", size = 334197 }, + { url = "https://files.pythonhosted.org/packages/a6/c1/8b84545382d75bef226fbc6588af0f7b7d095f7cd6a670b42a86243183cd/bcrypt-5.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:212139484ab3207b1f0c00633d3be92fef3c5f0af17cad155679d03ff2ee1e41", size = 352974 }, + { url = "https://files.pythonhosted.org/packages/10/a6/ffb49d4254ed085e62e3e5dd05982b4393e32fe1e49bb1130186617c29cd/bcrypt-5.0.0-cp313-cp313t-win32.whl", hash = "sha256:9d52ed507c2488eddd6a95bccee4e808d3234fa78dd370e24bac65a21212b861", size = 148498 }, + { url = "https://files.pythonhosted.org/packages/48/a9/259559edc85258b6d5fc5471a62a3299a6aa37a6611a169756bf4689323c/bcrypt-5.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:f6984a24db30548fd39a44360532898c33528b74aedf81c26cf29c51ee47057e", size = 145853 }, + { url = "https://files.pythonhosted.org/packages/2d/df/9714173403c7e8b245acf8e4be8876aac64a209d1b392af457c79e60492e/bcrypt-5.0.0-cp313-cp313t-win_arm64.whl", hash = "sha256:9fffdb387abe6aa775af36ef16f55e318dcda4194ddbf82007a6f21da29de8f5", size = 139626 }, + { url = "https://files.pythonhosted.org/packages/f8/14/c18006f91816606a4abe294ccc5d1e6f0e42304df5a33710e9e8e95416e1/bcrypt-5.0.0-cp314-cp314t-macosx_10_12_universal2.whl", hash = "sha256:4870a52610537037adb382444fefd3706d96d663ac44cbb2f37e3919dca3d7ef", size = 481862 }, + { url = "https://files.pythonhosted.org/packages/67/49/dd074d831f00e589537e07a0725cf0e220d1f0d5d8e85ad5bbff251c45aa/bcrypt-5.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:48f753100931605686f74e27a7b49238122aa761a9aefe9373265b8b7aa43ea4", size = 268544 }, + { url = "https://files.pythonhosted.org/packages/f5/91/50ccba088b8c474545b034a1424d05195d9fcbaaf802ab8bfe2be5a4e0d7/bcrypt-5.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f70aadb7a809305226daedf75d90379c397b094755a710d7014b8b117df1ebbf", size = 271787 }, + { url = "https://files.pythonhosted.org/packages/aa/e7/d7dba133e02abcda3b52087a7eea8c0d4f64d3e593b4fffc10c31b7061f3/bcrypt-5.0.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:744d3c6b164caa658adcb72cb8cc9ad9b4b75c7db507ab4bc2480474a51989da", size = 269753 }, + { url = "https://files.pythonhosted.org/packages/33/fc/5b145673c4b8d01018307b5c2c1fc87a6f5a436f0ad56607aee389de8ee3/bcrypt-5.0.0-cp314-cp314t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a28bc05039bdf3289d757f49d616ab3efe8cf40d8e8001ccdd621cd4f98f4fc9", size = 289587 }, + { url = "https://files.pythonhosted.org/packages/27/d7/1ff22703ec6d4f90e62f1a5654b8867ef96bafb8e8102c2288333e1a6ca6/bcrypt-5.0.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:7f277a4b3390ab4bebe597800a90da0edae882c6196d3038a73adf446c4f969f", size = 272178 }, + { url = "https://files.pythonhosted.org/packages/c8/88/815b6d558a1e4d40ece04a2f84865b0fef233513bd85fd0e40c294272d62/bcrypt-5.0.0-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:79cfa161eda8d2ddf29acad370356b47f02387153b11d46042e93a0a95127493", size = 269295 }, + { url = "https://files.pythonhosted.org/packages/51/8c/e0db387c79ab4931fc89827d37608c31cc57b6edc08ccd2386139028dc0d/bcrypt-5.0.0-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a5393eae5722bcef046a990b84dff02b954904c36a194f6cfc817d7dca6c6f0b", size = 271700 }, + { url = "https://files.pythonhosted.org/packages/06/83/1570edddd150f572dbe9fc00f6203a89fc7d4226821f67328a85c330f239/bcrypt-5.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7f4c94dec1b5ab5d522750cb059bb9409ea8872d4494fd152b53cca99f1ddd8c", size = 334034 }, + { url = "https://files.pythonhosted.org/packages/c9/f2/ea64e51a65e56ae7a8a4ec236c2bfbdd4b23008abd50ac33fbb2d1d15424/bcrypt-5.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0cae4cb350934dfd74c020525eeae0a5f79257e8a201c0c176f4b84fdbf2a4b4", size = 352766 }, + { url = "https://files.pythonhosted.org/packages/d7/d4/1a388d21ee66876f27d1a1f41287897d0c0f1712ef97d395d708ba93004c/bcrypt-5.0.0-cp314-cp314t-win32.whl", hash = "sha256:b17366316c654e1ad0306a6858e189fc835eca39f7eb2cafd6aaca8ce0c40a2e", size = 152449 }, + { url = "https://files.pythonhosted.org/packages/3f/61/3291c2243ae0229e5bca5d19f4032cecad5dfb05a2557169d3a69dc0ba91/bcrypt-5.0.0-cp314-cp314t-win_amd64.whl", hash = "sha256:92864f54fb48b4c718fc92a32825d0e42265a627f956bc0361fe869f1adc3e7d", size = 149310 }, + { url = "https://files.pythonhosted.org/packages/3e/89/4b01c52ae0c1a681d4021e5dd3e45b111a8fb47254a274fa9a378d8d834b/bcrypt-5.0.0-cp314-cp314t-win_arm64.whl", hash = "sha256:dd19cf5184a90c873009244586396a6a884d591a5323f0e8a5922560718d4993", size = 143761 }, + { url = "https://files.pythonhosted.org/packages/84/29/6237f151fbfe295fe3e074ecc6d44228faa1e842a81f6d34a02937ee1736/bcrypt-5.0.0-cp38-abi3-macosx_10_12_universal2.whl", hash = "sha256:fc746432b951e92b58317af8e0ca746efe93e66555f1b40888865ef5bf56446b", size = 494553 }, + { url = "https://files.pythonhosted.org/packages/45/b6/4c1205dde5e464ea3bd88e8742e19f899c16fa8916fb8510a851fae985b5/bcrypt-5.0.0-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c2388ca94ffee269b6038d48747f4ce8df0ffbea43f31abfa18ac72f0218effb", size = 275009 }, + { url = "https://files.pythonhosted.org/packages/3b/71/427945e6ead72ccffe77894b2655b695ccf14ae1866cd977e185d606dd2f/bcrypt-5.0.0-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:560ddb6ec730386e7b3b26b8b4c88197aaed924430e7b74666a586ac997249ef", size = 278029 }, + { url = "https://files.pythonhosted.org/packages/17/72/c344825e3b83c5389a369c8a8e58ffe1480b8a699f46c127c34580c4666b/bcrypt-5.0.0-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d79e5c65dcc9af213594d6f7f1fa2c98ad3fc10431e7aa53c176b441943efbdd", size = 275907 }, + { url = "https://files.pythonhosted.org/packages/0b/7e/d4e47d2df1641a36d1212e5c0514f5291e1a956a7749f1e595c07a972038/bcrypt-5.0.0-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2b732e7d388fa22d48920baa267ba5d97cca38070b69c0e2d37087b381c681fd", size = 296500 }, + { url = "https://files.pythonhosted.org/packages/0f/c3/0ae57a68be2039287ec28bc463b82e4b8dc23f9d12c0be331f4782e19108/bcrypt-5.0.0-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0c8e093ea2532601a6f686edbc2c6b2ec24131ff5c52f7610dd64fa4553b5464", size = 278412 }, + { url = "https://files.pythonhosted.org/packages/45/2b/77424511adb11e6a99e3a00dcc7745034bee89036ad7d7e255a7e47be7d8/bcrypt-5.0.0-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:5b1589f4839a0899c146e8892efe320c0fa096568abd9b95593efac50a87cb75", size = 275486 }, + { url = "https://files.pythonhosted.org/packages/43/0a/405c753f6158e0f3f14b00b462d8bca31296f7ecfc8fc8bc7919c0c7d73a/bcrypt-5.0.0-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:89042e61b5e808b67daf24a434d89bab164d4de1746b37a8d173b6b14f3db9ff", size = 277940 }, + { url = "https://files.pythonhosted.org/packages/62/83/b3efc285d4aadc1fa83db385ec64dcfa1707e890eb42f03b127d66ac1b7b/bcrypt-5.0.0-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:e3cf5b2560c7b5a142286f69bde914494b6d8f901aaa71e453078388a50881c4", size = 310776 }, + { url = "https://files.pythonhosted.org/packages/95/7d/47ee337dacecde6d234890fe929936cb03ebc4c3a7460854bbd9c97780b8/bcrypt-5.0.0-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:f632fd56fc4e61564f78b46a2269153122db34988e78b6be8b32d28507b7eaeb", size = 312922 }, + { url = "https://files.pythonhosted.org/packages/d6/3a/43d494dfb728f55f4e1cf8fd435d50c16a2d75493225b54c8d06122523c6/bcrypt-5.0.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:801cad5ccb6b87d1b430f183269b94c24f248dddbbc5c1f78b6ed231743e001c", size = 341367 }, + { url = "https://files.pythonhosted.org/packages/55/ab/a0727a4547e383e2e22a630e0f908113db37904f58719dc48d4622139b5c/bcrypt-5.0.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3cf67a804fc66fc217e6914a5635000259fbbbb12e78a99488e4d5ba445a71eb", size = 359187 }, + { url = "https://files.pythonhosted.org/packages/1b/bb/461f352fdca663524b4643d8b09e8435b4990f17fbf4fea6bc2a90aa0cc7/bcrypt-5.0.0-cp38-abi3-win32.whl", hash = "sha256:3abeb543874b2c0524ff40c57a4e14e5d3a66ff33fb423529c88f180fd756538", size = 153752 }, + { url = "https://files.pythonhosted.org/packages/41/aa/4190e60921927b7056820291f56fc57d00d04757c8b316b2d3c0d1d6da2c/bcrypt-5.0.0-cp38-abi3-win_amd64.whl", hash = "sha256:35a77ec55b541e5e583eb3436ffbbf53b0ffa1fa16ca6782279daf95d146dcd9", size = 150881 }, + { url = "https://files.pythonhosted.org/packages/54/12/cd77221719d0b39ac0b55dbd39358db1cd1246e0282e104366ebbfb8266a/bcrypt-5.0.0-cp38-abi3-win_arm64.whl", hash = "sha256:cde08734f12c6a4e28dc6755cd11d3bdfea608d93d958fffbe95a7026ebe4980", size = 144931 }, + { url = "https://files.pythonhosted.org/packages/5d/ba/2af136406e1c3839aea9ecadc2f6be2bcd1eff255bd451dd39bcf302c47a/bcrypt-5.0.0-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:0c418ca99fd47e9c59a301744d63328f17798b5947b0f791e9af3c1c499c2d0a", size = 495313 }, + { url = "https://files.pythonhosted.org/packages/ac/ee/2f4985dbad090ace5ad1f7dd8ff94477fe089b5fab2040bd784a3d5f187b/bcrypt-5.0.0-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddb4e1500f6efdd402218ffe34d040a1196c072e07929b9820f363a1fd1f4191", size = 275290 }, + { url = "https://files.pythonhosted.org/packages/e4/6e/b77ade812672d15cf50842e167eead80ac3514f3beacac8902915417f8b7/bcrypt-5.0.0-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7aeef54b60ceddb6f30ee3db090351ecf0d40ec6e2abf41430997407a46d2254", size = 278253 }, + { url = "https://files.pythonhosted.org/packages/36/c4/ed00ed32f1040f7990dac7115f82273e3c03da1e1a1587a778d8cea496d8/bcrypt-5.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:f0ce778135f60799d89c9693b9b398819d15f1921ba15fe719acb3178215a7db", size = 276084 }, + { url = "https://files.pythonhosted.org/packages/e7/c4/fa6e16145e145e87f1fa351bbd54b429354fd72145cd3d4e0c5157cf4c70/bcrypt-5.0.0-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a71f70ee269671460b37a449f5ff26982a6f2ba493b3eabdd687b4bf35f875ac", size = 297185 }, + { url = "https://files.pythonhosted.org/packages/24/b4/11f8a31d8b67cca3371e046db49baa7c0594d71eb40ac8121e2fc0888db0/bcrypt-5.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f8429e1c410b4073944f03bd778a9e066e7fad723564a52ff91841d278dfc822", size = 278656 }, + { url = "https://files.pythonhosted.org/packages/ac/31/79f11865f8078e192847d2cb526e3fa27c200933c982c5b2869720fa5fce/bcrypt-5.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:edfcdcedd0d0f05850c52ba3127b1fce70b9f89e0fe5ff16517df7e81fa3cbb8", size = 275662 }, + { url = "https://files.pythonhosted.org/packages/d4/8d/5e43d9584b3b3591a6f9b68f755a4da879a59712981ef5ad2a0ac1379f7a/bcrypt-5.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:611f0a17aa4a25a69362dcc299fda5c8a3d4f160e2abb3831041feb77393a14a", size = 278240 }, + { url = "https://files.pythonhosted.org/packages/89/48/44590e3fc158620f680a978aafe8f87a4c4320da81ed11552f0323aa9a57/bcrypt-5.0.0-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:db99dca3b1fdc3db87d7c57eac0c82281242d1eabf19dcb8a6b10eb29a2e72d1", size = 311152 }, + { url = "https://files.pythonhosted.org/packages/5f/85/e4fbfc46f14f47b0d20493669a625da5827d07e8a88ee460af6cd9768b44/bcrypt-5.0.0-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:5feebf85a9cefda32966d8171f5db7e3ba964b77fdfe31919622256f80f9cf42", size = 313284 }, + { url = "https://files.pythonhosted.org/packages/25/ae/479f81d3f4594456a01ea2f05b132a519eff9ab5768a70430fa1132384b1/bcrypt-5.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:3ca8a166b1140436e058298a34d88032ab62f15aae1c598580333dc21d27ef10", size = 341643 }, + { url = "https://files.pythonhosted.org/packages/df/d2/36a086dee1473b14276cd6ea7f61aef3b2648710b5d7f1c9e032c29b859f/bcrypt-5.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:61afc381250c3182d9078551e3ac3a41da14154fbff647ddf52a769f588c4172", size = 359698 }, + { url = "https://files.pythonhosted.org/packages/c0/f6/688d2cd64bfd0b14d805ddb8a565e11ca1fb0fd6817175d58b10052b6d88/bcrypt-5.0.0-cp39-abi3-win32.whl", hash = "sha256:64d7ce196203e468c457c37ec22390f1a61c85c6f0b8160fd752940ccfb3a683", size = 153725 }, + { url = "https://files.pythonhosted.org/packages/9f/b9/9d9a641194a730bda138b3dfe53f584d61c58cd5230e37566e83ec2ffa0d/bcrypt-5.0.0-cp39-abi3-win_amd64.whl", hash = "sha256:64ee8434b0da054d830fa8e89e1c8bf30061d539044a39524ff7dec90481e5c2", size = 150912 }, + { url = "https://files.pythonhosted.org/packages/27/44/d2ef5e87509158ad2187f4dd0852df80695bb1ee0cfe0a684727b01a69e0/bcrypt-5.0.0-cp39-abi3-win_arm64.whl", hash = "sha256:f2347d3534e76bf50bca5500989d6c1d05ed64b440408057a37673282c654927", size = 144953 }, + { url = "https://files.pythonhosted.org/packages/8a/75/4aa9f5a4d40d762892066ba1046000b329c7cd58e888a6db878019b282dc/bcrypt-5.0.0-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:7edda91d5ab52b15636d9c30da87d2cc84f426c72b9dba7a9b4fe142ba11f534", size = 271180 }, + { url = "https://files.pythonhosted.org/packages/54/79/875f9558179573d40a9cc743038ac2bf67dfb79cecb1e8b5d70e88c94c3d/bcrypt-5.0.0-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:046ad6db88edb3c5ece4369af997938fb1c19d6a699b9c1b27b0db432faae4c4", size = 273791 }, + { url = "https://files.pythonhosted.org/packages/bc/fe/975adb8c216174bf70fc17535f75e85ac06ed5252ea077be10d9cff5ce24/bcrypt-5.0.0-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:dcd58e2b3a908b5ecc9b9df2f0085592506ac2d5110786018ee5e160f28e0911", size = 270746 }, + { url = "https://files.pythonhosted.org/packages/e4/f8/972c96f5a2b6c4b3deca57009d93e946bbdbe2241dca9806d502f29dd3ee/bcrypt-5.0.0-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:6b8f520b61e8781efee73cba14e3e8c9556ccfb375623f4f97429544734545b4", size = 273375 }, +] + [[package]] name = "certifi" version = "2026.1.4" @@ -477,12 +595,14 @@ wheels = [ [[package]] name = "duh" -version = "0.4.0" +version = "0.5.0" source = { editable = "." } dependencies = [ { name = "aiosqlite" }, { name = "alembic" }, { name = "anthropic" }, + { name = "asyncpg" }, + { name = "bcrypt" }, { name = "click" }, { name = "duckduckgo-search" }, { name = "fastapi" }, @@ -493,6 +613,7 @@ dependencies = [ { name = "openai" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "pyjwt" }, { name = "rich" }, { name = "sqlalchemy", extra = ["asyncio"] }, { name = "uvicorn", extra = ["standard"] }, @@ -518,6 +639,8 @@ requires-dist = [ { name = "aiosqlite", specifier = ">=0.20.0" }, { name = "alembic", specifier = ">=1.13" }, { name = "anthropic", specifier = ">=0.40.0" }, + { name = "asyncpg", specifier = ">=0.29" }, + { name = "bcrypt", specifier = ">=4.0" }, { name = "click", specifier = ">=8.1" }, { name = "duckduckgo-search", specifier = ">=7.0" }, { name = "fastapi", specifier = ">=0.115" }, @@ -530,6 +653,7 @@ requires-dist = [ { name = "openai", specifier = ">=1.50.0" }, { name = "pydantic", specifier = ">=2.0" }, { name = "pydantic-settings", specifier = ">=2.0" }, + { name = "pyjwt", specifier = ">=2.8" }, { name = "rich", specifier = ">=13.0" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.30" }, diff --git a/web/e2e/consensus.spec.ts b/web/e2e/consensus.spec.ts new file mode 100644 index 0000000..f786abd --- /dev/null +++ b/web/e2e/consensus.spec.ts @@ -0,0 +1,56 @@ +import { test, expect } from '@playwright/test'; + +test.describe('Consensus Page', () => { + test.beforeEach(async ({ page }) => { + await page.goto('/'); + }); + + test('has question textarea', async ({ page }) => { + const textarea = page.locator('textarea'); + await expect(textarea).toBeVisible(); + await expect(textarea).toHaveAttribute( + 'placeholder', + 'Ask a question to reach consensus...', + ); + }); + + test('has rounds selector', async ({ page }) => { + const roundsSelect = page.locator('select').first(); + await expect(roundsSelect).toBeVisible(); + + // Verify round options 1-5 are available + const options = roundsSelect.locator('option'); + await expect(options).toHaveCount(5); + }); + + test('has protocol selector', async ({ page }) => { + const selects = page.locator('select'); + // Second select is the protocol selector + const protocolSelect = selects.nth(1); + await expect(protocolSelect).toBeVisible(); + + const options = protocolSelect.locator('option'); + await expect(options).toHaveCount(3); + await expect(options.nth(0)).toHaveText('consensus'); + await expect(options.nth(1)).toHaveText('voting'); + await expect(options.nth(2)).toHaveText('auto'); + }); + + test('has Ask submit button', async ({ page }) => { + const button = page.getByRole('button', { name: /ask/i }); + await expect(button).toBeVisible(); + }); + + test('Ask button is disabled when textarea is empty', async ({ page }) => { + const button = page.getByRole('button', { name: /ask/i }); + await expect(button).toBeDisabled(); + }); + + test('Ask button is enabled when question is entered', async ({ page }) => { + const textarea = page.locator('textarea'); + await textarea.fill('Should we use TypeScript?'); + + const button = page.getByRole('button', { name: /ask/i }); + await expect(button).toBeEnabled(); + }); +}); diff --git a/web/e2e/decision-space.spec.ts b/web/e2e/decision-space.spec.ts new file mode 100644 index 0000000..f77e8db --- /dev/null +++ b/web/e2e/decision-space.spec.ts @@ -0,0 +1,26 @@ +import { test, expect } from '@playwright/test'; + +test.describe('Decision Space', () => { + test('page loads without error', async ({ page }) => { + await page.goto('/space'); + await expect(page.locator('body')).toBeVisible(); + }); + + test('no console errors on load', async ({ page }) => { + const errors: string[] = []; + page.on('console', (msg) => { + if (msg.type() === 'error') { + errors.push(msg.text()); + } + }); + + await page.goto('/space'); + await page.waitForTimeout(1000); + + // Filter out expected errors (e.g. WebSocket connection failures in test env) + const unexpectedErrors = errors.filter( + (e) => !e.includes('WebSocket') && !e.includes('ERR_CONNECTION_REFUSED'), + ); + expect(unexpectedErrors).toEqual([]); + }); +}); diff --git a/web/e2e/navigation.spec.ts b/web/e2e/navigation.spec.ts new file mode 100644 index 0000000..e46c3b5 --- /dev/null +++ b/web/e2e/navigation.spec.ts @@ -0,0 +1,66 @@ +import { test, expect } from '@playwright/test'; + +test.describe('Navigation', () => { + test('home page loads with correct title', async ({ page }) => { + await page.goto('/'); + await expect(page).toHaveTitle(/duh/i); + }); + + test('sidebar navigation is visible on desktop', async ({ page }) => { + // Shell renders a