diff --git a/CLAUDE.md b/CLAUDE.md index 47dc3e3d..bcd4050c 120000 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1 +1,150 @@ -AGENTS.md \ No newline at end of file +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +LTX Desktop is an open-source Electron app for AI video generation using LTX models. It supports local generation on Windows NVIDIA GPUs (32GB+ VRAM) and API-only mode for unsupported hardware and macOS. + +Three-layer architecture: + +``` +Renderer (React + TS) --HTTP: localhost:8000--> Backend (FastAPI + Python) +Renderer (React + TS) --IPC: window.electronAPI--> Electron main (TS) +Electron main --> OS integration (files, dialogs, ffmpeg, process mgmt) +Backend --> Local models + GPU | External APIs (when API-backed) +``` + +- **Frontend** (`frontend/`): React 18 + TypeScript + Tailwind CSS renderer +- **Electron** (`electron/`): Main process managing app lifecycle, IPC, Python backend process, ffmpeg export. Renderer is sandboxed (`contextIsolation: true`, `nodeIntegration: false`). +- **Backend** (`backend/`): Python FastAPI server (port 8000) handling ML model orchestration and generation + +## Common Commands + +| Command | Purpose | +|---|---| +| `pnpm dev` | Start dev server (Vite + Electron + Python backend) | +| `pnpm dev:debug` | Dev with Electron inspector (port 9229) + Python debugpy | +| `pnpm typecheck` | Run TypeScript (`tsc --noEmit`) and Python (`pyright`) type checks | +| `pnpm typecheck:ts` | TypeScript only | +| `pnpm typecheck:py` | Python pyright only (`cd backend && uv run pyright`) | +| `pnpm backend:test` | Run Python pytest tests (`cd backend && uv sync --frozen --extra test --extra dev && uv run pytest -v --tb=short`) | +| `pnpm build:frontend` | Vite frontend build only | +| `pnpm build:win` / `pnpm build:mac` | Full platform builds (installer) | +| `pnpm build:fast:win` / `pnpm build:fast:mac` | Unpacked build, skip Python bundling | +| `pnpm setup:dev:win` / `pnpm setup:dev:mac` | One-time dev environment setup | + +Run a single backend test: `cd backend && uv run pytest tests/test_generation.py -v --tb=short` + +Run a single test function: `cd backend && uv run pytest tests/test_generation.py::test_name -v --tb=short` + +## CI Checks + +PRs must pass: `pnpm typecheck` + `pnpm backend:test` + frontend Vite build. + +## Frontend Architecture + +- **Path alias**: `@/*` maps to `frontend/*` (configured in `tsconfig.json` and `vite.config.ts`) +- **State management**: React contexts only (`ProjectContext`, `AppSettingsContext`, `KeyboardShortcutsContext`) — no Redux/Zustand +- **Routing**: View-based via `ProjectContext` with views: `home`, `project`, `playground`, plus library views (`Gallery`, `PromptLibrary`, `Characters`, `Styles`, `References`, `Wildcards`) +- **IPC bridge**: All Electron communication through `window.electronAPI` (defined in `electron/preload.ts`). Key methods: `getBackendUrl`, `readLocalFile`, `checkGpu`, `getAppInfo`, `exportVideo`, `showSaveDialog`, `showItemInFolder` +- **Backend calls**: Frontend calls `http://localhost:8000` directly +- **Styling**: Tailwind with custom semantic color tokens via CSS variables; utilities from `class-variance-authority` + `clsx` + `tailwind-merge` +- **Views**: `Home.tsx`, `GenSpace.tsx`, `Project.tsx`, `Playground.tsx`, `VideoEditor.tsx` (largest frontend file), `editor/` subdirectory, plus library views (`Gallery.tsx`, `PromptLibrary.tsx`, `Characters.tsx`, `Styles.tsx`, `References.tsx`, `Wildcards.tsx`) +- **Generation hook**: `useGeneration()` manages the full generate → poll → complete lifecycle. Submits jobs to `/api/queue/submit`, polls `/api/queue/status` every 500ms, maps backend phases to user-facing status messages. +- **LoRA support**: `GenerationSettings` includes `loraPath`, `loraWeight`, `loraTriggerPhrase`, and `loraTriggerMode` (`'prepend' | 'append' | 'off'`). Trigger phrase is applied client-side before submission. +- **No frontend tests** currently exist + +## Backend Architecture + +Request flow: `_routes/* (thin) -> AppHandler -> handlers/* (logic) -> services/* (side effects) + state/* (mutations)` + +Key patterns: +- **Routes** (`_routes/`): Thin plumbing only — parse input, call handler, return typed output. No business logic. +- **AppHandler** (`app_handler.py`): Single composition root owning all sub-handlers, state, and lock. Sub-handlers accessed as `handler.health`, `handler.models`, `handler.downloads`, etc. +- **State** (`state/`): Centralized `AppState` using discriminated union types for state machines (e.g., `GenerationState = GenerationRunning | GenerationComplete | GenerationError | GenerationCancelled`) +- **Services** (`services/`): Protocol interfaces with real implementations and fake test implementations. The test boundary for heavy side effects (GPU, network). +- **Concurrency**: Thread pool with shared `RLock`. Pattern: lock -> read/validate -> unlock -> heavy work -> lock -> write. Never hold lock during heavy compute/IO. Use `handlers.base.with_state_lock` decorator. +- **Exception handling**: Boundary-owned traceback policy. Handlers raise `HTTPError` with `from exc` chaining; `app_factory.py` owns logging. Don't `logger.exception()` then rethrow. +- **Naming**: `*Payload` for DTOs/TypedDicts, `*Like` for structural wrappers, `Fake*` for test implementations + +### Job Queue System + +Generation requests flow through a persistent job queue rather than direct handler calls: + +``` +Frontend POST /api/queue/submit → JobQueue.submit() → QueueWorker.tick() → JobExecutor.execute() +Frontend polls GET /api/queue/status for progress updates +``` + +- **JobQueue** (`state/job_queue.py`): Persistent dataclass-based queue with JSON file backing. Jobs have `slot` (`gpu` | `api`) determining which executor runs them. On app restart, any `running` jobs are marked `error`. +- **QueueWorker** (`handlers/queue_worker.py`): Ticks on a timer, dispatches one job per slot concurrently via daemon threads. Two independent slots: `gpu` (local models) and `api` (cloud APIs). +- **JobExecutor** (`handlers/job_executors.py`): Protocol with `execute(job) -> list[str]`. GPU executor delegates to `VideoGenerationHandler`/`ImageGenerationHandler`; API executor calls external APIs. +- **Phase reporting**: Handlers report granular phases (`preparing_gpu`, `unloading_video_model`, `cleaning_gpu`, `loading_image_model`, `loading_lora`, `inference`, `decoding`, etc.) via `on_phase` callbacks through the pipeline chain. Frontend maps these to user-facing messages. + +### Pipeline Lifecycle + +GPU is shared between video and image models. Only one model type loaded at a time: + +- `PipelinesHandler` manages swap lifecycle: unload current → clean VRAM → load new +- ZIT (image model) can be parked on CPU when video model needs GPU, then restored +- `load_zit_to_gpu(on_phase=...)` and `load_gpu_pipeline(model_type, on_phase=...)` accept phase callbacks for progress reporting during model swaps + +### Backend Composition Roots + +- `ltx2_server.py`: Runtime bootstrap (logging, `RuntimeConfig`, `AppHandler`, `uvicorn`) +- `app_factory.py`: FastAPI app factory (routers, DI init, exception handling) — importable from tests +- `state/deps.py`: FastAPI dependency hook (`get_state_service()` returns shared `AppHandler`; tests override via `set_state_service_for_tests()`) + +### Backend Testing + +- Integration-first using Starlette `TestClient` against real FastAPI app +- **No mocks**: `test_no_mock_usage.py` enforces no `unittest.mock`. Swap services via `ServiceBundle` fakes only. +- Fakes live in `tests/fakes/`; `conftest.py` wires fresh `AppHandler` per test +- Pyright strict mode is also enforced as a test (`test_pyright.py`) + +### Backend Route Domains + +Core: `health`, `settings`, `models`, `generation`, `image_gen`, `queue` +Video modes: `retake`, `ic_lora` +Library/content: `gallery`, `library`, `prompts`, `style_guide`, `contact_sheet`, `enhance_prompt` +Integration: `sync` (Palette cloud sync), `receive_job` (incoming cloud jobs) + +### Adding a Backend Feature + +1. Define request/response models in `api_types.py` +2. Add endpoint in `_routes/.py` delegating to handler +3. Implement logic in `handlers/_handler.py` with lock-aware state transitions +4. If new heavy side effect needed, add service in `services/` with Protocol + real + fake implementations +5. Add integration test in `tests/` using fake services + +## TypeScript Config + +- Strict mode with `noUnusedLocals`, `noUnusedParameters` +- Frontend: ES2020 target, React JSX +- Electron main process: ESNext, compiled to `dist-electron/` +- Preload script must be CommonJS (configured in `vite.config.ts` rollup output) + +## Python Config + +- Python 3.12+ required (`.python-version` pins 3.13), managed with `uv` +- Pyright strict mode (`backend/pyrightconfig.json`) — tests are excluded from pyright +- Dependencies in `backend/pyproject.toml`, lock in `backend/uv.lock` +- PyTorch uses CUDA 12.8 index on Windows/Linux (`tool.uv.sources`) + +## Key File Locations + +- Backend architecture doc: `backend/architecture.md` +- Default app settings schema: `settings.json` +- Electron builder config: `electron-builder.yml` +- Video editor (largest frontend file): `frontend/views/VideoEditor.tsx` +- Project types: `frontend/types/project.ts` +- IPC API surface: `electron/preload.ts` +- Python backend entry: `backend/ltx2_server.py` +- Build/setup scripts: `scripts/` (platform-specific `.sh` and `.ps1` variants) +- Job queue: `backend/state/job_queue.py` +- Queue worker: `backend/handlers/queue_worker.py` +- Job executors: `backend/handlers/job_executors.py` +- Queue routes: `backend/_routes/queue.py` +- Generation hook: `frontend/hooks/use-generation.ts` +- Hero banner video: `public/hero-video.mp4` (2092x480, ~4.36:1, 30fps, 11s, H.264; CSS gradient overlay in `Home.tsx`) diff --git a/README.md b/README.md index d6eec206..8a5d5db9 100644 --- a/README.md +++ b/README.md @@ -19,12 +19,43 @@ LTX Desktop is an open-source desktop app for generating videos with LTX models ## Features -- Text-to-video generation -- Image-to-video generation -- Audio-to-video generation -- Video edit generation (Retake) -- Video Editor Interface -- Video Editing Projects +### Generation +- **Text-to-video** — generate video clips from text prompts +- **Image-to-video** — animate a still image into video +- **Audio-to-video** — drive video generation from an audio track +- **Image generation** — create images with ZIT (local) or fal API +- **Image editing (img2img)** — edit existing images with ZIT Edit +- **Video Retake** — re-generate portions of an existing video +- **IC-LoRA** — identity-consistent generation with LoRA weights +- **Video Extend** — continue generating from the last frame of a video +- **Prompt Enhancement** — AI-powered prompt rewriting (via LTX API or Palette) + +### Batch Generation +- **Batch Builder** — queue multiple generation jobs at once +- **List mode** — add prompts one-by-one with per-job settings +- **Import mode** — bulk-import prompts from CSV, JSON, or plain text files +- **Grid Sweep mode** — combinatorial parameter sweeps (prompts × seeds × models) +- **Timeline import** — import an edited timeline as a batch to re-generate all segments + +### Library & Organization +- **Gallery** — browse, filter, and manage all generated images and videos +- **Prompt Library** — save, tag, and reuse favorite prompts +- **Characters** — store character descriptions for consistent generation +- **Styles** — save and apply visual style presets +- **References** — manage reference images for guided generation +- **Wildcards** — define placeholder tokens that expand to random values + +### Palette Cloud Integration +- **Directors Palette sync** — connect to [Directors Palette](https://directorspal.com) for cloud-synced library content +- **Email/password login** — authenticate directly or via deep link +- **Credit balance & cost tracking** — view remaining credits in the header, see estimated cost on Generate buttons before submitting +- **Automatic credit deduction** — API-slot jobs automatically deduct credits after successful generation +- **Seedance video generation** — generate videos via Seedance 1.5 Pro through the Replicate API + +### Editor & Export +- **Video Editor** — multi-track timeline editor with clips, transitions, and keyframes +- **Video Projects** — save and reopen editing sessions +- **FFmpeg export** — export final videos with configurable codec and quality settings ## Local vs API mode @@ -37,6 +68,27 @@ LTX Desktop is an open-source desktop app for generating videos with LTX models In API-only mode, available resolutions/durations may be limited to what the API supports. +## Custom Video Models + +Directors Desktop supports multiple LTX 2.3 model formats, so you can run on GPUs with less VRAM. + +| Your GPU VRAM | Recommended Format | File Size | +|---------------|-------------------|-----------| +| 32 GB+ | BF16 (auto-downloaded) | ~43 GB | +| 20-31 GB | [FP8 Checkpoint](https://huggingface.co/Lightricks/LTX-Video-2.3-22b-distilled) | ~22 GB | +| 16-19 GB | [GGUF Q5_K](https://huggingface.co/city96/LTX-Video-2.3-22b-0.9.7-dev-gguf) | ~15 GB | +| 10-15 GB | [GGUF Q4_K](https://huggingface.co/city96/LTX-Video-2.3-22b-0.9.7-dev-gguf) | ~12 GB | + +### Setup + +1. Download the model file for your GPU from the links above +2. Open **Settings → Models** and set your model folder +3. If using GGUF or NF4, also download the [distilled LoRA](https://huggingface.co/Lightricks/LTX-Video-2.3-22b-distilled) +4. Select your model from the dropdown +5. Generate! + +The app also has a built-in **Model Guide** (Settings → Models → Open Model Guide) that detects your GPU and recommends the best format automatically. + ## System requirements ### Windows (local generation) @@ -96,10 +148,22 @@ Used for Z Image Turbo text-to-image generation in API mode. When enabled, image Create an API key in the [fal dashboard](https://fal.ai/dashboard/keys). +### Replicate API key (optional) + +Used for Seedance 1.5 Pro video generation. When enabled, video generation requests are sent to Replicate. + +Create an API key in the [Replicate dashboard](https://replicate.com/account/api-tokens). + ### Gemini API key (optional) Used for AI prompt suggestions. When enabled, prompt context and frames may be sent to Google Gemini. +### Directors Palette account (optional) + +Connect to [Directors Palette](https://directorspal.com) to sync library content (characters, styles, references), use cloud-based prompt enhancement, and track credit usage. Sign in via email/password in **Settings > Palette Connection**. + +Credits are consumed when generating via API-backed models (cloud video, Seedance, cloud image). Local GPU generations are free. Credit balance and per-generation costs are displayed in the UI. + ## Architecture LTX Desktop is split into three main layers: @@ -113,6 +177,7 @@ LTX Desktop is split into three main layers: - **Backend (`backend/`)**: Python + FastAPI local server. - Orchestrates generation, model downloads, and GPU execution. - Calls external APIs only when API-backed features are used. + - Output files follow the naming convention `dd_{model}_{prompt_slug}_{timestamp}.{ext}` for easy identification. ```mermaid graph TD diff --git a/backend/_routes/batch.py b/backend/_routes/batch.py new file mode 100644 index 00000000..55d3637e --- /dev/null +++ b/backend/_routes/batch.py @@ -0,0 +1,44 @@ +"""Route handlers for /api/queue/batch/* — batch generation endpoints.""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends + +from api_types import BatchSubmitRequest, BatchSubmitResponse, BatchStatusResponse +from state import get_state_service +from app_handler import AppHandler + +router = APIRouter(prefix="/api/queue", tags=["batch"]) + + +@router.post("/submit-batch", response_model=BatchSubmitResponse) +def route_batch_submit( + req: BatchSubmitRequest, + handler: AppHandler = Depends(get_state_service), +) -> BatchSubmitResponse: + return handler.batch.submit_batch(req, handler.job_queue) + + +@router.get("/batch/{batch_id}/status", response_model=BatchStatusResponse) +def route_batch_status( + batch_id: str, + handler: AppHandler = Depends(get_state_service), +) -> BatchStatusResponse: + return handler.batch.get_batch_status(batch_id, handler.job_queue) + + +@router.post("/batch/{batch_id}/cancel") +def route_batch_cancel( + batch_id: str, + handler: AppHandler = Depends(get_state_service), +) -> dict[str, str]: + handler.batch.cancel_batch(batch_id, handler.job_queue) + return {"status": "cancelled"} + + +@router.post("/batch/{batch_id}/retry-failed", response_model=BatchSubmitResponse) +def route_batch_retry( + batch_id: str, + handler: AppHandler = Depends(get_state_service), +) -> BatchSubmitResponse: + return handler.batch.retry_failed(batch_id, handler.job_queue) diff --git a/backend/_routes/contact_sheet.py b/backend/_routes/contact_sheet.py new file mode 100644 index 00000000..444ef0c2 --- /dev/null +++ b/backend/_routes/contact_sheet.py @@ -0,0 +1,18 @@ +"""Route for contact sheet generation.""" +from __future__ import annotations + +from fastapi import APIRouter, Depends + +from api_types import GenerateContactSheetRequest, GenerateContactSheetResponse +from app_handler import AppHandler +from state import get_state_service + +router = APIRouter(prefix="/api/contact-sheet", tags=["contact-sheet"]) + + +@router.post("/generate", response_model=GenerateContactSheetResponse) +def route_generate_contact_sheet( + req: GenerateContactSheetRequest, + handler: AppHandler = Depends(get_state_service), +) -> GenerateContactSheetResponse: + return handler.contact_sheet.generate(req) diff --git a/backend/_routes/enhance_prompt.py b/backend/_routes/enhance_prompt.py new file mode 100644 index 00000000..53a5f80a --- /dev/null +++ b/backend/_routes/enhance_prompt.py @@ -0,0 +1,27 @@ +"""Prompt enhancement route.""" +from __future__ import annotations + +from fastapi import APIRouter, Depends +from pydantic import BaseModel + +from app_handler import AppHandler +from state import get_state_service + +router = APIRouter(tags=["enhance"]) + + +class EnhancePromptRequest(BaseModel): + prompt: str = "" + mode: str = "text-to-video" + model: str = "ltx-fast" + imagePath: str | None = None + + +@router.post("/api/enhance-prompt") +def enhance_prompt( + req: EnhancePromptRequest, + handler: AppHandler = Depends(get_state_service), +): + return handler.enhance_prompt.enhance( + req.prompt, req.mode, req.model, image_path=req.imagePath, + ) diff --git a/backend/_routes/gallery.py b/backend/_routes/gallery.py new file mode 100644 index 00000000..e35af9d3 --- /dev/null +++ b/backend/_routes/gallery.py @@ -0,0 +1,49 @@ +"""Route handlers for /api/gallery/local endpoints.""" + +from __future__ import annotations + +from pathlib import Path + +from fastapi import APIRouter, Depends +from fastapi.responses import FileResponse + +from api_types import GalleryListResponse, StatusResponse +from app_handler import AppHandler +from _routes._errors import HTTPError +from state import get_state_service + +router = APIRouter(prefix="/api/gallery", tags=["gallery"]) + + +@router.get("/local", response_model=GalleryListResponse) +def route_list_local_assets( + page: int = 1, + per_page: int = 50, + type: str = "all", + handler: AppHandler = Depends(get_state_service), +) -> GalleryListResponse: + return handler.gallery.list_local_assets( + page=page, + per_page=per_page, + asset_type=type, + ) + + +@router.get("/local/file/{filename:path}") +def route_serve_local_file( + filename: str, + handler: AppHandler = Depends(get_state_service), +) -> FileResponse: + file_path = handler.config.outputs_dir / Path(filename).name + if not file_path.is_file(): + raise HTTPError(404, f"File not found: {filename}") + return FileResponse(path=str(file_path)) + + +@router.delete("/local/{asset_id}", response_model=StatusResponse) +def route_delete_local_asset( + asset_id: str, + handler: AppHandler = Depends(get_state_service), +) -> StatusResponse: + handler.gallery.delete_local_asset(asset_id) + return StatusResponse(status="ok") diff --git a/backend/_routes/generation.py b/backend/_routes/generation.py index ac525f5c..9d3b8a8c 100644 --- a/backend/_routes/generation.py +++ b/backend/_routes/generation.py @@ -6,6 +6,8 @@ from api_types import ( CancelResponse, + GenerateLongVideoRequest, + GenerateLongVideoResponse, GenerateVideoRequest, GenerateVideoResponse, GenerationProgressResponse, @@ -25,6 +27,36 @@ def route_generate( return handler.video_generation.generate(req) +@router.post("/generate/long", response_model=GenerateLongVideoResponse) +def route_generate_long( + req: GenerateLongVideoRequest, + handler: AppHandler = Depends(get_state_service), +) -> GenerateLongVideoResponse: + """POST /api/generate/long — chain-extend I2V to target duration.""" + try: + video_path = handler.video_generation.generate_long_video( + prompt=req.prompt, + image_path=req.imagePath, + target_duration=req.targetDuration, + resolution=req.resolution, + aspect_ratio=req.aspectRatio, + fps=req.fps, + segment_duration=req.segmentDuration, + camera_motion=req.cameraMotion, + lora_path=req.loraPath, + lora_weight=req.loraWeight, + ) + segments = max(1, (req.targetDuration + req.segmentDuration - 1) // req.segmentDuration) + return GenerateLongVideoResponse( + status="complete", video_path=video_path, + segments=segments, total_duration=req.targetDuration, + ) + except Exception as e: + if "cancelled" in str(e).lower(): + return GenerateLongVideoResponse(status="cancelled") + raise + + @router.post("/generate/cancel", response_model=CancelResponse) def route_generate_cancel(handler: AppHandler = Depends(get_state_service)) -> CancelResponse: """POST /api/generate/cancel.""" diff --git a/backend/_routes/library.py b/backend/_routes/library.py new file mode 100644 index 00000000..9ee0ed15 --- /dev/null +++ b/backend/_routes/library.py @@ -0,0 +1,179 @@ +"""Route handlers for /api/library/* (characters, styles, references).""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query + +from api_types import ( + CharacterCreate, + CharacterListResponse, + CharacterResponse, + CharacterUpdate, + LibraryReferenceCategory, + ReferenceCreate, + ReferenceListResponse, + ReferenceResponse, + StatusResponse, + StyleCreate, + StyleListResponse, + StyleResponse, +) +from app_handler import AppHandler +from state import get_state_service +from state.library_store import Character, Reference, Style + +router = APIRouter(prefix="/api/library", tags=["library"]) + + +def _character_response(c: Character) -> CharacterResponse: + return CharacterResponse( + id=c.id, + name=c.name, + role=c.role, + description=c.description, + reference_image_paths=c.reference_image_paths, + created_at=c.created_at, + ) + + +def _style_response(s: Style) -> StyleResponse: + return StyleResponse( + id=s.id, + name=s.name, + description=s.description, + reference_image_path=s.reference_image_path, + created_at=s.created_at, + ) + + +def _reference_response(r: Reference) -> ReferenceResponse: + return ReferenceResponse( + id=r.id, + name=r.name, + category=r.category, + image_path=r.image_path, + created_at=r.created_at, + ) + + +# ------------------------------------------------------------------ +# Characters +# ------------------------------------------------------------------ + + +@router.get("/characters", response_model=CharacterListResponse) +def route_list_characters( + handler: AppHandler = Depends(get_state_service), +) -> CharacterListResponse: + items = handler.library.list_characters() + return CharacterListResponse(characters=[_character_response(c) for c in items]) + + +@router.post("/characters", response_model=CharacterResponse) +def route_create_character( + req: CharacterCreate, + handler: AppHandler = Depends(get_state_service), +) -> CharacterResponse: + result = handler.library.create_character( + name=req.name, + role=req.role, + description=req.description, + reference_image_paths=req.reference_image_paths, + ) + return _character_response(result) + + +@router.put("/characters/{character_id}", response_model=CharacterResponse) +def route_update_character( + character_id: str, + req: CharacterUpdate, + handler: AppHandler = Depends(get_state_service), +) -> CharacterResponse: + result = handler.library.update_character( + character_id, + name=req.name, + role=req.role, + description=req.description, + reference_image_paths=req.reference_image_paths, + ) + return _character_response(result) + + +@router.delete("/characters/{character_id}", response_model=StatusResponse) +def route_delete_character( + character_id: str, + handler: AppHandler = Depends(get_state_service), +) -> StatusResponse: + handler.library.delete_character(character_id) + return StatusResponse(status="ok") + + +# ------------------------------------------------------------------ +# Styles +# ------------------------------------------------------------------ + + +@router.get("/styles", response_model=StyleListResponse) +def route_list_styles( + handler: AppHandler = Depends(get_state_service), +) -> StyleListResponse: + items = handler.library.list_styles() + return StyleListResponse(styles=[_style_response(s) for s in items]) + + +@router.post("/styles", response_model=StyleResponse) +def route_create_style( + req: StyleCreate, + handler: AppHandler = Depends(get_state_service), +) -> StyleResponse: + result = handler.library.create_style( + name=req.name, + description=req.description, + reference_image_path=req.reference_image_path, + ) + return _style_response(result) + + +@router.delete("/styles/{style_id}", response_model=StatusResponse) +def route_delete_style( + style_id: str, + handler: AppHandler = Depends(get_state_service), +) -> StatusResponse: + handler.library.delete_style(style_id) + return StatusResponse(status="ok") + + +# ------------------------------------------------------------------ +# References +# ------------------------------------------------------------------ + + +@router.get("/references", response_model=ReferenceListResponse) +def route_list_references( + category: LibraryReferenceCategory | None = Query(default=None), + handler: AppHandler = Depends(get_state_service), +) -> ReferenceListResponse: + items = handler.library.list_references(category) + return ReferenceListResponse(references=[_reference_response(r) for r in items]) + + +@router.post("/references", response_model=ReferenceResponse) +def route_create_reference( + req: ReferenceCreate, + handler: AppHandler = Depends(get_state_service), +) -> ReferenceResponse: + result = handler.library.create_reference( + name=req.name, + category=req.category, + image_path=req.image_path, + ) + return _reference_response(result) + + +@router.delete("/references/{reference_id}", response_model=StatusResponse) +def route_delete_reference( + reference_id: str, + handler: AppHandler = Depends(get_state_service), +) -> StatusResponse: + handler.library.delete_reference(reference_id) + return StatusResponse(status="ok") diff --git a/backend/_routes/lora.py b/backend/_routes/lora.py new file mode 100644 index 00000000..fe1738eb --- /dev/null +++ b/backend/_routes/lora.py @@ -0,0 +1,156 @@ +"""Routes for /api/lora — CivitAI search, download, and local library management.""" + +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Depends +from pydantic import BaseModel + +from app_handler import AppHandler +from _routes._errors import HTTPError +from state import get_state_service + +router = APIRouter(prefix="/api/lora", tags=["lora"]) + + +# ── Request/Response Models ───────────────────────────────────────── + + +class LoraSearchRequest(BaseModel): + query: str = "" + baseModel: str = "" + sort: str = "Most Downloaded" + limit: int = 20 + page: int = 1 + nsfw: bool = False + + +class LoraDownloadRequest(BaseModel): + downloadUrl: str + fileName: str + name: str + thumbnailUrl: str = "" + triggerPhrase: str = "" + baseModel: str = "" + civitaiModelId: int | None = None + civitaiVersionId: int | None = None + description: str = "" + + +class LoraImportRequest(BaseModel): + filePath: str + name: str = "" + triggerPhrase: str = "" + thumbnailPath: str = "" + + +# ── Search ────────────────────────────────────────────────────────── + + +@router.post("/search") +def route_search_civitai( + body: LoraSearchRequest, + handler: AppHandler = Depends(get_state_service), +) -> dict[str, Any]: + try: + return handler.lora.search_civitai( + query=body.query, + base_model=body.baseModel, + sort=body.sort, + limit=body.limit, + page=body.page, + nsfw=body.nsfw, + ) + except Exception as exc: + raise HTTPError(502, f"CivitAI search failed: {exc}") from exc + + +# ── Download ──────────────────────────────────────────────────────── + + +@router.post("/download") +def route_download_lora( + body: LoraDownloadRequest, + handler: AppHandler = Depends(get_state_service), +) -> dict[str, Any]: + try: + entry = handler.lora.download_lora( + download_url=body.downloadUrl, + file_name=body.fileName, + name=body.name, + thumbnail_url=body.thumbnailUrl, + trigger_phrase=body.triggerPhrase, + base_model=body.baseModel, + civitai_model_id=body.civitaiModelId, + civitai_version_id=body.civitaiVersionId, + description=body.description, + ) + return {"status": "ok", "entry": _entry_to_dict(entry)} + except Exception as exc: + raise HTTPError(500, f"Download failed: {exc}") from exc + + +# ── Library ───────────────────────────────────────────────────────── + + +@router.get("/library") +def route_list_library( + handler: AppHandler = Depends(get_state_service), +) -> dict[str, Any]: + return {"entries": handler.lora.list_library()} + + +@router.delete("/library/{lora_id}") +def route_delete_lora( + lora_id: str, + handler: AppHandler = Depends(get_state_service), +) -> dict[str, Any]: + deleted = handler.lora.delete_lora(lora_id) + if not deleted: + raise HTTPError(404, f"LoRA not found: {lora_id}") + return {"status": "ok"} + + +@router.post("/import") +def route_import_lora( + body: LoraImportRequest, + handler: AppHandler = Depends(get_state_service), +) -> dict[str, Any]: + try: + entry = handler.lora.import_local_lora( + file_path=body.filePath, + name=body.name, + trigger_phrase=body.triggerPhrase, + thumbnail_path=body.thumbnailPath, + ) + return {"status": "ok", "entry": _entry_to_dict(entry)} + except FileNotFoundError as exc: + raise HTTPError(404, str(exc)) from exc + + +# ── Thumbnail serving ─────────────────────────────────────────────── + + +@router.get("/thumbnail/{lora_id}") +def route_serve_thumbnail( + lora_id: str, + handler: AppHandler = Depends(get_state_service), +) -> Any: + from pathlib import Path + from fastapi.responses import FileResponse + + entry = handler.lora.get_entry(lora_id) + if entry is None: + raise HTTPError(404, "LoRA not found") + + thumb = entry.get("thumbnail_url", "") + if thumb and Path(thumb).exists(): + return FileResponse(path=thumb) + + raise HTTPError(404, "No thumbnail available") + + +def _entry_to_dict(entry: Any) -> dict[str, Any]: + from dataclasses import asdict + return asdict(entry) diff --git a/backend/_routes/models.py b/backend/_routes/models.py index f5de347f..2386374a 100644 --- a/backend/_routes/models.py +++ b/backend/_routes/models.py @@ -12,7 +12,10 @@ ModelDownloadStartResponse, ModelInfo, ModelsStatusResponse, + SelectModelRequest, TextEncoderDownloadResponse, + VideoModelGuideResponse, + VideoModelScanResponse, ) from _routes._errors import HTTPError from state import get_state_service @@ -74,3 +77,23 @@ def route_text_encoder_download(handler: AppHandler = Depends(get_state_service) return TextEncoderDownloadResponse(status="started", message="Text encoder download started") raise HTTPError(400, "Failed to start download") + + +@router.get("/models/video/scan", response_model=VideoModelScanResponse) +def route_video_model_scan(handler: AppHandler = Depends(get_state_service)) -> VideoModelScanResponse: + return handler.models.scan_video_models() + + +@router.post("/models/video/select") +def route_video_model_select( + req: SelectModelRequest, + handler: AppHandler = Depends(get_state_service), +) -> dict[str, str]: + handler.models.select_video_model(req.model) + handler.settings.save_settings() + return {"status": "ok"} + + +@router.get("/models/video/guide", response_model=VideoModelGuideResponse) +def route_video_model_guide(handler: AppHandler = Depends(get_state_service)) -> VideoModelGuideResponse: + return handler.models.video_model_guide() diff --git a/backend/_routes/prompts.py b/backend/_routes/prompts.py new file mode 100644 index 00000000..18579501 --- /dev/null +++ b/backend/_routes/prompts.py @@ -0,0 +1,155 @@ +"""Route handlers for prompt library and wildcard endpoints.""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends + +from _routes._errors import HTTPError +from api_types import ( + CreateWildcardRequest, + ExpandWildcardsRequest, + ExpandWildcardsResponse, + IncrementUsageResponse, + PromptListResponse, + SavedPromptResponse, + SavePromptRequest, + StatusResponse, + UpdateWildcardRequest, + WildcardListResponse, + WildcardResponse, +) +from app_handler import AppHandler +from state import get_state_service + +router = APIRouter(prefix="/api", tags=["prompts"]) + + +# ------------------------------------------------------------------ +# Prompts +# ------------------------------------------------------------------ + + +@router.get("/prompts", response_model=PromptListResponse) +def route_list_prompts( + search: str | None = None, + tag: str | None = None, + sort_by: str | None = None, + handler: AppHandler = Depends(get_state_service), +) -> PromptListResponse: + prompts = handler.prompts.list_prompts(search=search, tag=tag, sort_by=sort_by) + return PromptListResponse( + prompts=[ + SavedPromptResponse( + id=p.id, + text=p.text, + tags=p.tags, + category=p.category, + used_count=p.used_count, + created_at=p.created_at, + last_used_at=p.last_used_at, + ) + for p in prompts + ] + ) + + +@router.post("/prompts", response_model=SavedPromptResponse) +def route_save_prompt( + req: SavePromptRequest, + handler: AppHandler = Depends(get_state_service), +) -> SavedPromptResponse: + p = handler.prompts.save_prompt(text=req.text, tags=req.tags, category=req.category) + return SavedPromptResponse( + id=p.id, + text=p.text, + tags=p.tags, + category=p.category, + used_count=p.used_count, + created_at=p.created_at, + last_used_at=p.last_used_at, + ) + + +@router.delete("/prompts/{prompt_id}", response_model=StatusResponse) +def route_delete_prompt( + prompt_id: str, + handler: AppHandler = Depends(get_state_service), +) -> StatusResponse: + deleted = handler.prompts.delete_prompt(prompt_id) + if not deleted: + raise HTTPError(404, f"Prompt {prompt_id} not found") + return StatusResponse(status="ok") + + +@router.post("/prompts/{prompt_id}/usage", response_model=IncrementUsageResponse) +def route_increment_usage( + prompt_id: str, + handler: AppHandler = Depends(get_state_service), +) -> IncrementUsageResponse: + result = handler.prompts.increment_usage(prompt_id) + if result is None: + raise HTTPError(404, f"Prompt {prompt_id} not found") + return IncrementUsageResponse(status="ok", used_count=result.used_count) + + +# ------------------------------------------------------------------ +# Wildcards +# ------------------------------------------------------------------ + + +@router.get("/wildcards", response_model=WildcardListResponse) +def route_list_wildcards( + handler: AppHandler = Depends(get_state_service), +) -> WildcardListResponse: + wildcards = handler.prompts.list_wildcards() + return WildcardListResponse( + wildcards=[ + WildcardResponse( + id=w.id, name=w.name, values=w.values, created_at=w.created_at, + ) + for w in wildcards + ] + ) + + +@router.post("/wildcards", response_model=WildcardResponse) +def route_create_wildcard( + req: CreateWildcardRequest, + handler: AppHandler = Depends(get_state_service), +) -> WildcardResponse: + w = handler.prompts.create_wildcard(name=req.name, values=req.values) + return WildcardResponse(id=w.id, name=w.name, values=w.values, created_at=w.created_at) + + +@router.put("/wildcards/{wildcard_id}", response_model=WildcardResponse) +def route_update_wildcard( + wildcard_id: str, + req: UpdateWildcardRequest, + handler: AppHandler = Depends(get_state_service), +) -> WildcardResponse: + w = handler.prompts.update_wildcard(wildcard_id, values=req.values) + if w is None: + raise HTTPError(404, f"Wildcard {wildcard_id} not found") + return WildcardResponse(id=w.id, name=w.name, values=w.values, created_at=w.created_at) + + +@router.delete("/wildcards/{wildcard_id}", response_model=StatusResponse) +def route_delete_wildcard( + wildcard_id: str, + handler: AppHandler = Depends(get_state_service), +) -> StatusResponse: + deleted = handler.prompts.delete_wildcard(wildcard_id) + if not deleted: + raise HTTPError(404, f"Wildcard {wildcard_id} not found") + return StatusResponse(status="ok") + + +@router.post("/wildcards/expand", response_model=ExpandWildcardsResponse) +def route_expand_wildcards( + req: ExpandWildcardsRequest, + handler: AppHandler = Depends(get_state_service), +) -> ExpandWildcardsResponse: + expanded = handler.prompts.expand_wildcards( + prompt=req.prompt, mode=req.mode, count=req.count, + ) + return ExpandWildcardsResponse(expanded=expanded) diff --git a/backend/_routes/queue.py b/backend/_routes/queue.py new file mode 100644 index 00000000..006062d5 --- /dev/null +++ b/backend/_routes/queue.py @@ -0,0 +1,63 @@ +"""Route handlers for /api/queue/*.""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends + +from api_types import QueueSubmitRequest, QueueSubmitResponse, QueueStatusResponse, QueueJobResponse +from state import get_state_service +from app_handler import AppHandler + +router = APIRouter(prefix="/api/queue", tags=["queue"]) + + +@router.post("/submit", response_model=QueueSubmitResponse) +def route_queue_submit( + req: QueueSubmitRequest, + handler: AppHandler = Depends(get_state_service), +) -> QueueSubmitResponse: + slot = handler.determine_slot(req.model) + job = handler.job_queue.submit( + job_type=req.type, + model=req.model, + params={k: v for k, v in req.params.items()}, + slot=slot, + ) + return QueueSubmitResponse(id=job.id, status=job.status) + + +@router.get("/status", response_model=QueueStatusResponse) +def route_queue_status( + handler: AppHandler = Depends(get_state_service), +) -> QueueStatusResponse: + jobs = handler.job_queue.get_all_jobs() + return QueueStatusResponse(jobs=[ + QueueJobResponse( + id=j.id, type=j.type, model=j.model, params={k: v for k, v in j.params.items()}, + status=j.status, slot=j.slot, progress=j.progress, phase=j.phase, + result_paths=j.result_paths, error=j.error, created_at=j.created_at, + batch_id=j.batch_id, batch_index=j.batch_index, tags=j.tags, + ) + for j in jobs + ]) + + +@router.post("/cancel/{job_id}", response_model=QueueSubmitResponse) +def route_queue_cancel( + job_id: str, + handler: AppHandler = Depends(get_state_service), +) -> QueueSubmitResponse: + handler.job_queue.cancel_job(job_id) + # Also cancel via GenerationHandler so running pipelines stop + handler.generation.cancel_generation() + job = handler.job_queue.get_job(job_id) + status = job.status if job else "not_found" + return QueueSubmitResponse(id=job_id, status=status) + + +@router.post("/clear", response_model=QueueStatusResponse) +def route_queue_clear( + handler: AppHandler = Depends(get_state_service), +) -> QueueStatusResponse: + handler.job_queue.clear_finished() + return route_queue_status(handler) diff --git a/backend/_routes/receive_job.py b/backend/_routes/receive_job.py new file mode 100644 index 00000000..918c6fc3 --- /dev/null +++ b/backend/_routes/receive_job.py @@ -0,0 +1,18 @@ +"""Route for receiving jobs from Director's Palette.""" +from __future__ import annotations + +from fastapi import APIRouter, Depends + +from api_types import ReceiveJobRequest, ReceiveJobResponse +from app_handler import AppHandler +from state import get_state_service + +router = APIRouter(prefix="/api/sync", tags=["sync"]) + + +@router.post("/receive-job", response_model=ReceiveJobResponse) +def route_receive_job( + req: ReceiveJobRequest, + handler: AppHandler = Depends(get_state_service), +) -> ReceiveJobResponse: + return handler.receive_job_handler.receive_job(req) diff --git a/backend/_routes/settings.py b/backend/_routes/settings.py index b20c404b..4dbfbcac 100644 --- a/backend/_routes/settings.py +++ b/backend/_routes/settings.py @@ -34,4 +34,8 @@ def route_post_settings( ", ".join(sorted(changed_roots)) if changed_roots else "none", ) + # Sync CivitAI API key to LoRA handler + if "civitai_api_key" in changed_roots: + handler.lora.set_api_key(_after.civitai_api_key) + return StatusResponse(status="ok") diff --git a/backend/_routes/style_guide.py b/backend/_routes/style_guide.py new file mode 100644 index 00000000..f35c52b5 --- /dev/null +++ b/backend/_routes/style_guide.py @@ -0,0 +1,18 @@ +"""Route for style guide grid generation.""" +from __future__ import annotations + +from fastapi import APIRouter, Depends + +from api_types import GenerateStyleGuideRequest, GenerateStyleGuideResponse +from app_handler import AppHandler +from state import get_state_service + +router = APIRouter(prefix="/api/style-guide", tags=["style-guide"]) + + +@router.post("/generate", response_model=GenerateStyleGuideResponse) +def route_generate_style_guide( + req: GenerateStyleGuideRequest, + handler: AppHandler = Depends(get_state_service), +) -> GenerateStyleGuideResponse: + return handler.style_guide.generate(req) diff --git a/backend/_routes/sync.py b/backend/_routes/sync.py new file mode 100644 index 00000000..9e608d2a --- /dev/null +++ b/backend/_routes/sync.py @@ -0,0 +1,127 @@ +"""Palette sync routes.""" +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Depends, Query +from pydantic import BaseModel + +from app_handler import AppHandler +from state import get_state_service + +router = APIRouter(prefix="/api/sync", tags=["sync"]) + + +class ConnectRequest(BaseModel): + token: str + + +class LoginRequest(BaseModel): + email: str + password: str + + +class EnhancePromptRequest(BaseModel): + prompt: str + level: str = "2x" + + +@router.get("/status") +def sync_status(handler: AppHandler = Depends(get_state_service)) -> dict[str, Any]: + return handler.sync.get_status() + + +@router.post("/connect") +def sync_connect(body: ConnectRequest, handler: AppHandler = Depends(get_state_service)) -> dict[str, Any]: + result = handler.sync.connect(body.token) + if result.get("connected"): + handler.settings.save_settings() + return result + + +@router.post("/login") +def sync_login(body: LoginRequest, handler: AppHandler = Depends(get_state_service)) -> dict[str, Any]: + result = handler.sync.login(body.email, body.password) + if result.get("connected"): + handler.settings.save_settings() + return result + + +@router.post("/disconnect") +def sync_disconnect(handler: AppHandler = Depends(get_state_service)) -> dict[str, Any]: + result = handler.sync.disconnect() + handler.settings.save_settings() + return result + + +class CheckCreditsRequest(BaseModel): + generation_type: str + count: int = 1 + + +class DeductCreditsRequest(BaseModel): + generation_type: str + count: int = 1 + metadata: dict[str, Any] | None = None + + +@router.get("/credits") +def sync_credits(handler: AppHandler = Depends(get_state_service)) -> dict[str, Any]: + return handler.sync.get_credits() + + +@router.post("/credits/check") +def sync_check_credits( + body: CheckCreditsRequest, + handler: AppHandler = Depends(get_state_service), +) -> dict[str, Any]: + return handler.sync.check_credits(body.generation_type, body.count) + + +@router.post("/credits/deduct") +def sync_deduct_credits( + body: DeductCreditsRequest, + handler: AppHandler = Depends(get_state_service), +) -> dict[str, Any]: + return handler.sync.deduct_credits(body.generation_type, body.count, body.metadata) + + +@router.get("/gallery") +def sync_gallery( + page: int = Query(default=1, ge=1), + per_page: int = Query(default=50, ge=1, le=100), + type: str = Query(default="all"), + handler: AppHandler = Depends(get_state_service), +) -> dict[str, Any]: + return handler.sync.list_gallery(page=page, per_page=per_page, asset_type=type) + + +@router.get("/library/characters") +def sync_characters(handler: AppHandler = Depends(get_state_service)) -> dict[str, Any]: + return handler.sync.list_characters() + + +@router.get("/library/styles") +def sync_styles(handler: AppHandler = Depends(get_state_service)) -> dict[str, Any]: + return handler.sync.list_styles() + + +@router.get("/library/references") +def sync_references( + category: str | None = Query(default=None), + handler: AppHandler = Depends(get_state_service), +) -> dict[str, Any]: + return handler.sync.list_references(category=category) + + +@router.post("/prompt/enhance") +def sync_enhance_prompt( + body: EnhancePromptRequest, + handler: AppHandler = Depends(get_state_service), +) -> dict[str, Any]: + return handler.sync.enhance_prompt(prompt=body.prompt, level=body.level) + + +@router.post("/library/sync-loras") +def sync_loras(handler: AppHandler = Depends(get_state_service)) -> dict[str, Any]: + return handler.sync.sync_loras() diff --git a/backend/api_types.py b/backend/api_types.py index 5bc87108..6794574f 100644 --- a/backend/api_types.py +++ b/backend/api_types.py @@ -45,7 +45,7 @@ class ModelDownloadState(TypedDict): files_completed: int total_files: int error: str | None - speed_mbps: int + speed_bytes_per_sec: float JsonObject: TypeAlias = dict[str, object] @@ -157,7 +157,7 @@ class DownloadProgressResponse(BaseModel): filesCompleted: int totalFiles: int error: str | None - speedMbps: int + speedBytesPerSec: float class IcLoraModel(BaseModel): @@ -237,6 +237,58 @@ class ErrorResponse(BaseModel): message: str | None = None +class QueueJobResponse(BaseModel): + id: str + type: str + model: str + params: dict[str, object] = {} + status: str + slot: str + progress: int + phase: str + result_paths: list[str] = [] + error: str | None = None + created_at: str = "" + batch_id: str | None = None + batch_index: int = 0 + tags: list[str] = [] + + +class QueueStatusResponse(BaseModel): + jobs: list[QueueJobResponse] + + +class QueueSubmitResponse(BaseModel): + id: str + status: str + + +# ============================================================ +# Gallery Models +# ============================================================ + +GalleryAssetType = Literal["image", "video"] + + +class GalleryAsset(BaseModel): + id: str + filename: str + path: str + url: str + type: GalleryAssetType + size_bytes: int + created_at: str + model_name: str | None = None + + +class GalleryListResponse(BaseModel): + items: list[GalleryAsset] + total: int + page: int + per_page: int + total_pages: int + + # ============================================================ # Request Models # ============================================================ @@ -253,7 +305,30 @@ class GenerateVideoRequest(BaseModel): audio: str = "false" imagePath: str | None = None audioPath: str | None = None + lastFramePath: str | None = None + aspectRatio: Literal["16:9", "9:16"] = "16:9" + loraPath: str | None = None + loraWeight: float = 1.0 + + +class GenerateLongVideoRequest(BaseModel): + prompt: NonEmptyPrompt + imagePath: str + targetDuration: int = 20 + resolution: str = "512p" aspectRatio: Literal["16:9", "9:16"] = "16:9" + fps: int = 24 + segmentDuration: int = 4 + cameraMotion: VideoCameraMotion = "none" + loraPath: str | None = None + loraWeight: float = 1.0 + + +class GenerateLongVideoResponse(BaseModel): + status: str + video_path: str | None = None + segments: int = 0 + total_duration: int = 0 class GenerateImageRequest(BaseModel): @@ -262,6 +337,89 @@ class GenerateImageRequest(BaseModel): height: int = 1024 numSteps: int = 4 numImages: int = 1 + loraPath: str | None = None + loraWeight: float = 1.0 + sourceImagePath: str | None = None + strength: float = 0.65 + + +class QueueSubmitRequest(BaseModel): + type: Literal["video", "image", "long_video"] + model: str + params: dict[str, object] = {} + + +# --- Batch Generation Types --- + + +class BatchJobItem(BaseModel): + type: Literal["video", "image"] + model: str + params: dict[str, object] = {} + + +class SweepAxis(BaseModel): + param: str + values: list[object] + mode: Literal["replace", "search_replace"] = "replace" + search: str | None = None + + +class SweepDefinition(BaseModel): + base_type: Literal["video", "image"] + base_model: str + base_params: dict[str, object] = {} + axes: list[SweepAxis] + + +class PipelineStep(BaseModel): + type: Literal["video", "image"] + model: str + params: dict[str, object] = {} + auto_prompt: bool = False + + +class PipelineDefinition(BaseModel): + steps: list[PipelineStep] + + +class BatchSubmitRequest(BaseModel): + mode: Literal["list", "sweep", "pipeline"] + target: Literal["local", "cloud"] + jobs: list[BatchJobItem] | None = None + sweep: SweepDefinition | None = None + pipeline: PipelineDefinition | None = None + + +class BatchSubmitResponse(BaseModel): + batch_id: str + job_ids: list[str] + total_jobs: int + + +class BatchReport(BaseModel): + batch_id: str + total: int + succeeded: int + failed: int + cancelled: int + duration_seconds: float + avg_job_seconds: float + result_paths: list[str] + failed_indices: list[int] + sweep_axes: list[str] | None = None + + +class BatchStatusResponse(BaseModel): + batch_id: str + total: int + completed: int + failed: int + running: int + queued: int + cancelled: int = 0 + jobs: list[QueueJobResponse] + report: BatchReport | None = None class ModelDownloadRequest(BaseModel): @@ -321,3 +479,241 @@ class IcLoraGenerateRequest(BaseModel): cfg_guidance_scale: float = 1.0 negative_prompt: str = "" images: list[IcLoraImageInput] = Field(default_factory=_default_ic_lora_images) + + +# ============================================================ +# Receive Job (from Palette) +# ============================================================ + + +class ReceiveJobSettings(BaseModel): + resolution: str = "512p" + duration: str = "2" + fps: str = "24" + aspect_ratio: Literal["16:9", "9:16"] = "16:9" + + +class ReceiveJobRequest(BaseModel): + prompt: NonEmptyPrompt + model: str = "ltx-fast" + settings: ReceiveJobSettings = Field(default_factory=ReceiveJobSettings) + character_id: str | None = None + first_frame_url: str | None = None + last_frame_url: str | None = None + priority: int = 0 + + +class ReceiveJobResponse(BaseModel): + id: str + status: str + + +# ============================================================ +# Contact Sheet +# ============================================================ + + +class GenerateContactSheetRequest(BaseModel): + reference_image_path: str + subject_description: NonEmptyPrompt + style: str = "" + + +class GenerateContactSheetResponse(BaseModel): + job_ids: list[str] + + +# ============================================================ +# Style Guide Grid +# ============================================================ + + +class GenerateStyleGuideRequest(BaseModel): + style_name: NonEmptyPrompt + style_description: str = "" + reference_image_path: str | None = None + + +class GenerateStyleGuideResponse(BaseModel): + job_ids: list[str] + + +# ============================================================ +# Library Models (Characters, Styles, References) +# ============================================================ + +LibraryReferenceCategory = Literal["people", "places", "props", "other"] + + +class CharacterCreate(BaseModel): + name: str + role: str = "" + description: str = "" + reference_image_paths: list[str] = Field(default_factory=list) + + +class CharacterUpdate(BaseModel): + name: str | None = None + role: str | None = None + description: str | None = None + reference_image_paths: list[str] | None = None + + +class CharacterResponse(BaseModel): + id: str + name: str + role: str + description: str + reference_image_paths: list[str] + created_at: str + + +class CharacterListResponse(BaseModel): + characters: list[CharacterResponse] + + +class StyleCreate(BaseModel): + name: str + description: str = "" + reference_image_path: str = "" + + +class StyleResponse(BaseModel): + id: str + name: str + description: str + reference_image_path: str + created_at: str + + +class StyleListResponse(BaseModel): + styles: list[StyleResponse] + + +class ReferenceCreate(BaseModel): + name: str + category: LibraryReferenceCategory + image_path: str = "" + + +class ReferenceResponse(BaseModel): + id: str + name: str + category: LibraryReferenceCategory + image_path: str + created_at: str + + +class ReferenceListResponse(BaseModel): + references: list[ReferenceResponse] + + +# ============================================================ +# Prompt Library Models +# ============================================================ + + +class SavedPromptResponse(BaseModel): + id: str + text: str + tags: list[str] + category: str + used_count: int + created_at: str + last_used_at: str | None + + +class PromptListResponse(BaseModel): + prompts: list[SavedPromptResponse] + + +class SavePromptRequest(BaseModel): + text: NonEmptyPrompt + tags: list[str] = Field(default_factory=list) + category: str = "" + + +class IncrementUsageResponse(BaseModel): + status: str + used_count: int + + +class WildcardResponse(BaseModel): + id: str + name: str + values: list[str] + created_at: str + + +class WildcardListResponse(BaseModel): + wildcards: list[WildcardResponse] + + +class CreateWildcardRequest(BaseModel): + name: str = Field(min_length=1) + values: list[str] = Field(min_length=1) + + +class UpdateWildcardRequest(BaseModel): + values: list[str] = Field(min_length=1) + + +class ExpandWildcardsRequest(BaseModel): + prompt: NonEmptyPrompt + mode: Literal["all", "random"] = "random" + count: int = Field(default=1, ge=1, le=1000) + + +class ExpandWildcardsResponse(BaseModel): + expanded: list[str] + + +# ============================================================ +# Video Model Scanner Types +# ============================================================ + + +class DetectedModel(BaseModel): + filename: str + path: str + model_format: str # "bf16" | "fp8" | "gguf" | "nf4" + quant_type: str | None = None + size_bytes: int + size_gb: float + is_distilled: bool + display_name: str + + +class ModelFormatInfo(BaseModel): + id: str + name: str + size_gb: float + min_vram_gb: int + quality_tier: str + needs_distilled_lora: bool + download_url: str + description: str + + +class DistilledLoraInfo(BaseModel): + name: str + size_gb: float + download_url: str + description: str + + +class VideoModelScanResponse(BaseModel): + models: list[DetectedModel] + distilled_lora_found: bool + + +class VideoModelGuideResponse(BaseModel): + gpu_name: str | None + vram_gb: int | None + recommended_format: str + formats: list[ModelFormatInfo] + distilled_lora: DistilledLoraInfo + + +class SelectModelRequest(BaseModel): + model: str diff --git a/backend/app_factory.py b/backend/app_factory.py index bd9651e4..d48d2c96 100644 --- a/backend/app_factory.py +++ b/backend/app_factory.py @@ -15,10 +15,21 @@ from _routes.ic_lora import router as ic_lora_router from _routes.image_gen import router as image_gen_router from _routes.models import router as models_router +from _routes.enhance_prompt import router as enhance_prompt_router from _routes.suggest_gap_prompt import router as suggest_gap_prompt_router from _routes.retake import router as retake_router +from _routes.queue import router as queue_router from _routes.runtime_policy import router as runtime_policy_router from _routes.settings import router as settings_router +from _routes.sync import router as sync_router +from _routes.receive_job import router as receive_job_router +from _routes.contact_sheet import router as contact_sheet_router +from _routes.style_guide import router as style_guide_router +from _routes.gallery import router as gallery_router +from _routes.library import router as library_router +from _routes.prompts import router as prompts_router +from _routes.batch import router as batch_router +from _routes.lora import router as lora_router from logging_policy import log_http_error, log_unhandled_exception from state import init_state_service @@ -73,8 +84,19 @@ async def _route_generic_error_handler(request: Request, exc: Exception) -> JSON app.include_router(settings_router) app.include_router(image_gen_router) app.include_router(suggest_gap_prompt_router) + app.include_router(enhance_prompt_router) app.include_router(retake_router) app.include_router(ic_lora_router) app.include_router(runtime_policy_router) + app.include_router(queue_router) + app.include_router(sync_router) + app.include_router(receive_job_router) + app.include_router(contact_sheet_router) + app.include_router(style_guide_router) + app.include_router(gallery_router) + app.include_router(library_router) + app.include_router(prompts_router) + app.include_router(batch_router) + app.include_router(lora_router) return app diff --git a/backend/app_handler.py b/backend/app_handler.py index 6ceef0e5..b3459352 100644 --- a/backend/app_handler.py +++ b/backend/app_handler.py @@ -21,12 +21,22 @@ TextHandler, VideoGenerationHandler, ) +from handlers.gallery_handler import GalleryHandler +from handlers.library_handler import LibraryHandler from runtime_config.runtime_config import RuntimeConfig +from handlers.contact_sheet_handler import ContactSheetHandler +from handlers.enhance_prompt_handler import EnhancePromptHandler +from handlers.prompt_handler import PromptHandler +from handlers.receive_job_handler import ReceiveJobHandler +from handlers.style_guide_handler import StyleGuideHandler +from handlers.sync_handler import SyncHandler from services.interfaces import ( A2VPipeline, FastVideoPipeline, - ZitAPIClient, + ImageAPIClient, ImageGenerationPipeline, + PaletteSyncClient, + VideoAPIClient, GpuCleaner, GpuInfo, HTTPClient, @@ -39,6 +49,7 @@ TextEncoder, VideoProcessor, ) +from services.model_scanner.model_scanner import ModelScanner from state.app_state_types import AppState, StartupPending, TextEncoderState @@ -57,13 +68,19 @@ def __init__( text_encoder: TextEncoder, task_runner: TaskRunner, ltx_api_client: LTXAPIClient, - zit_api_client: ZitAPIClient, + image_api_client: ImageAPIClient, + video_api_client: VideoAPIClient, + palette_sync_client: PaletteSyncClient, fast_video_pipeline_class: type[FastVideoPipeline], + gguf_video_pipeline_class: type[FastVideoPipeline] | None, + nf4_video_pipeline_class: type[FastVideoPipeline] | None, image_generation_pipeline_class: type[ImageGenerationPipeline], + flux_klein_pipeline_class: type[ImageGenerationPipeline] | None, ic_lora_pipeline_class: type[IcLoraPipeline], a2v_pipeline_class: type[A2VPipeline], retake_pipeline_class: type[RetakePipeline], ic_lora_model_downloader: IcLoraModelDownloader, + model_scanner: ModelScanner, ) -> None: self.config = config @@ -75,9 +92,14 @@ def __init__( self.video_processor = video_processor self.task_runner = task_runner self.ltx_api_client = ltx_api_client - self.zit_api_client = zit_api_client + self.image_api_client = image_api_client + self.video_api_client = video_api_client + self.palette_sync_client = palette_sync_client self.fast_video_pipeline_class = fast_video_pipeline_class + self.gguf_video_pipeline_class = gguf_video_pipeline_class + self.nf4_video_pipeline_class = nf4_video_pipeline_class self.image_generation_pipeline_class = image_generation_pipeline_class + self.flux_klein_pipeline_class = flux_klein_pipeline_class self.ic_lora_pipeline_class = ic_lora_pipeline_class self.a2v_pipeline_class = a2v_pipeline_class self.retake_pipeline_class = retake_pipeline_class @@ -91,6 +113,7 @@ def __init__( "upsampler": None, "text_encoder": None, "zit": None, + "flux_klein": None, }, downloading_session=None, gpu_slot=None, @@ -116,6 +139,8 @@ def __init__( state=self.state, lock=self._lock, config=config, + model_scanner=model_scanner, + gpu_info_service=gpu_info, ) self.downloads = DownloadHandler( @@ -139,7 +164,10 @@ def __init__( text_handler=self.text, gpu_cleaner=gpu_cleaner, fast_video_pipeline_class=fast_video_pipeline_class, + gguf_video_pipeline_class=gguf_video_pipeline_class, + nf4_video_pipeline_class=nf4_video_pipeline_class, image_generation_pipeline_class=image_generation_pipeline_class, + flux_klein_pipeline_class=flux_klein_pipeline_class, ic_lora_pipeline_class=ic_lora_pipeline_class, a2v_pipeline_class=a2v_pipeline_class, retake_pipeline_class=retake_pipeline_class, @@ -157,6 +185,7 @@ def __init__( pipelines_handler=self.pipelines, text_handler=self.text, ltx_api_client=ltx_api_client, + video_api_client=video_api_client, outputs_dir=config.outputs_dir, config=config, camera_motion_prompts=config.camera_motion_prompts, @@ -170,7 +199,7 @@ def __init__( pipelines_handler=self.pipelines, outputs_dir=config.outputs_dir, config=config, - zit_api_client=zit_api_client, + image_api_client=image_api_client, ) self.health = HealthHandler( @@ -191,6 +220,12 @@ def __init__( http=http, ) + self.enhance_prompt = EnhancePromptHandler( + state=self.state, + lock=self._lock, + http=http, + ) + self.retake = RetakeHandler( state=self.state, lock=self._lock, @@ -214,9 +249,87 @@ def __init__( outputs_dir=config.outputs_dir, ) + from state.lora_library import LoraLibraryStore + from handlers.lora_handler import LoraHandler + lora_store = LoraLibraryStore(config.models_dir / "loras") + self.lora = LoraHandler( + store=lora_store, + civitai_api_key=default_settings.civitai_api_key, + ) + + self.sync = SyncHandler( + state=self.state, + palette_sync_client=palette_sync_client, + http=http, + lora_store=lora_store, + loras_dir=config.models_dir / "loras", + ) + + self.gallery = GalleryHandler(outputs_dir=config.outputs_dir) + self.downloads.cleanup_downloading_dir() + + from state.job_queue import JobQueue + self.job_queue = JobQueue(persistence_path=config.settings_file.parent / "job_queue.json") + + from state.library_store import LibraryStore + library_store = LibraryStore(config.settings_file.parent / "library") + self.library = LibraryHandler(store=library_store) + + self.prompts = PromptHandler( + state=self.state, + lock=self._lock, + store_path=config.settings_file.parent / "prompt_store.json", + ) + + self.receive_job_handler = ReceiveJobHandler( + state=self.state, + http=http, + job_queue=self.job_queue, + ) + + self.contact_sheet = ContactSheetHandler(job_queue=self.job_queue) + self.style_guide = StyleGuideHandler(job_queue=self.job_queue) + + from handlers.batch_handler import BatchHandler + self.batch = BatchHandler() + + from handlers.job_executors import GpuJobExecutor, ApiJobExecutor + from handlers.queue_worker import QueueWorker + self._queue_worker = QueueWorker( + queue=self.job_queue, + gpu_executor=GpuJobExecutor(self), + api_executor=ApiJobExecutor(self), + gpu_cleaner=gpu_cleaner, + credit_deductor=self.sync, + ) + self._queue_stop = threading.Event() + self._queue_thread = threading.Thread(target=self._queue_loop, daemon=True) + self._queue_thread.start() + self.models.refresh_available_files() + def _queue_loop(self) -> None: + """Background loop that ticks the queue worker every second.""" + import logging + _logger = logging.getLogger(__name__) + _logger.info("Queue worker background thread started") + while not self._queue_stop.is_set(): + try: + self._queue_worker.tick() + except Exception as exc: + _logger.error("Queue worker tick error: %s", exc) + self._queue_stop.wait(1.0) + + def determine_slot(self, model: str) -> str: + """Determine whether a job should use the gpu or api slot.""" + always_api_models = {"seedance-1.5-pro", "nano-banana-2"} + if model in always_api_models: + return "api" + if self.config.force_api_generations: + return "api" + return "gpu" + @dataclass class ServiceBundle: @@ -228,19 +341,28 @@ class ServiceBundle: text_encoder: TextEncoder task_runner: TaskRunner ltx_api_client: LTXAPIClient - zit_api_client: ZitAPIClient + image_api_client: ImageAPIClient + video_api_client: VideoAPIClient + palette_sync_client: PaletteSyncClient fast_video_pipeline_class: type[FastVideoPipeline] + gguf_video_pipeline_class: type[FastVideoPipeline] | None + nf4_video_pipeline_class: type[FastVideoPipeline] | None image_generation_pipeline_class: type[ImageGenerationPipeline] + flux_klein_pipeline_class: type[ImageGenerationPipeline] | None ic_lora_pipeline_class: type[IcLoraPipeline] a2v_pipeline_class: type[A2VPipeline] retake_pipeline_class: type[RetakePipeline] ic_lora_model_downloader: IcLoraModelDownloader + model_scanner: ModelScanner def build_default_service_bundle(config: RuntimeConfig) -> ServiceBundle: """Build real runtime services with lazy heavy imports isolated from tests.""" from services.fast_video_pipeline.ltx_fast_video_pipeline import LTXFastVideoPipeline - from services.zit_api_client.zit_api_client_impl import ZitAPIClientImpl + from services.fast_video_pipeline.gguf_fast_video_pipeline import GGUFFastVideoPipeline + from services.fast_video_pipeline.nf4_fast_video_pipeline import NF4FastVideoPipeline + from services.image_api_client.replicate_client_impl import ReplicateImageClientImpl + from services.video_api_client.replicate_video_client_impl import ReplicateVideoClientImpl from services.gpu_cleaner.torch_cleaner import TorchCleaner from services.gpu_info.gpu_info_impl import GpuInfoImpl from services.http_client.http_client_impl import HTTPClientImpl @@ -248,11 +370,14 @@ def build_default_service_bundle(config: RuntimeConfig) -> ServiceBundle: from services.a2v_pipeline.ltx_a2v_pipeline import LTXa2vPipeline from services.ic_lora_pipeline.ltx_ic_lora_pipeline import LTXIcLoraPipeline from services.image_generation_pipeline.zit_image_generation_pipeline import ZitImageGenerationPipeline + from services.image_generation_pipeline.flux_klein_pipeline import FluxKleinImagePipeline from services.ltx_api_client.ltx_api_client_impl import LTXAPIClientImpl from services.model_downloader.hugging_face_downloader import HuggingFaceDownloader + from services.model_scanner.model_scanner_impl import ModelScannerImpl from services.retake_pipeline.ltx_retake_pipeline import LTXRetakePipeline from services.task_runner.threading_runner import ThreadingRunner from services.text_encoder.ltx_text_encoder import LTXTextEncoder + from services.palette_sync_client.palette_sync_client_impl import PaletteSyncClientImpl from services.video_processor.video_processor_impl import VideoProcessorImpl http = HTTPClientImpl() @@ -270,13 +395,19 @@ def build_default_service_bundle(config: RuntimeConfig) -> ServiceBundle: ), task_runner=ThreadingRunner(), ltx_api_client=LTXAPIClientImpl(http=http, ltx_api_base_url=config.ltx_api_base_url), - zit_api_client=ZitAPIClientImpl(http=http), + image_api_client=ReplicateImageClientImpl(http=http), + video_api_client=ReplicateVideoClientImpl(http=http), + palette_sync_client=PaletteSyncClientImpl(http=http), fast_video_pipeline_class=LTXFastVideoPipeline, + gguf_video_pipeline_class=GGUFFastVideoPipeline, + nf4_video_pipeline_class=NF4FastVideoPipeline, image_generation_pipeline_class=ZitImageGenerationPipeline, + flux_klein_pipeline_class=FluxKleinImagePipeline, ic_lora_pipeline_class=LTXIcLoraPipeline, a2v_pipeline_class=LTXa2vPipeline, retake_pipeline_class=LTXRetakePipeline, ic_lora_model_downloader=IcLoraModelDownloaderImpl(), + model_scanner=ModelScannerImpl(), ) @@ -298,11 +429,17 @@ def build_initial_state( text_encoder=bundle.text_encoder, task_runner=bundle.task_runner, ltx_api_client=bundle.ltx_api_client, - zit_api_client=bundle.zit_api_client, + image_api_client=bundle.image_api_client, + video_api_client=bundle.video_api_client, + palette_sync_client=bundle.palette_sync_client, fast_video_pipeline_class=bundle.fast_video_pipeline_class, + gguf_video_pipeline_class=bundle.gguf_video_pipeline_class, + nf4_video_pipeline_class=bundle.nf4_video_pipeline_class, image_generation_pipeline_class=bundle.image_generation_pipeline_class, + flux_klein_pipeline_class=bundle.flux_klein_pipeline_class, ic_lora_pipeline_class=bundle.ic_lora_pipeline_class, a2v_pipeline_class=bundle.a2v_pipeline_class, retake_pipeline_class=bundle.retake_pipeline_class, ic_lora_model_downloader=bundle.ic_lora_model_downloader, + model_scanner=bundle.model_scanner, ) diff --git a/backend/handlers/__init__.py b/backend/handlers/__init__.py index 32da0178..46877dee 100644 --- a/backend/handlers/__init__.py +++ b/backend/handlers/__init__.py @@ -1,6 +1,7 @@ """State handler exports.""" from handlers.download_handler import DownloadHandler +from handlers.gallery_handler import GalleryHandler from handlers.generation_handler import GenerationHandler from handlers.health_handler import HealthHandler from handlers.ic_lora_handler import IcLoraHandler @@ -24,6 +25,7 @@ "VideoGenerationHandler", "ImageGenerationHandler", "HealthHandler", + "GalleryHandler", "SuggestGapPromptHandler", "RetakeHandler", "RuntimePolicyHandler", diff --git a/backend/handlers/batch_handler.py b/backend/handlers/batch_handler.py new file mode 100644 index 00000000..431ca8ff --- /dev/null +++ b/backend/handlers/batch_handler.py @@ -0,0 +1,225 @@ +"""Batch generation handler — expands batch definitions into individual queue jobs.""" + +from __future__ import annotations + +import itertools +import uuid +from typing import Any + +from api_types import ( + BatchJobItem, + BatchSubmitRequest, + BatchSubmitResponse, + BatchStatusResponse, + BatchReport, + QueueJobResponse, + SweepDefinition, + PipelineDefinition, +) +from state.job_queue import JobQueue, QueueJob + + +class BatchHandler: + def submit_batch(self, request: BatchSubmitRequest, queue: JobQueue) -> BatchSubmitResponse: + batch_id = uuid.uuid4().hex[:8] + slot: str = "api" if request.target == "cloud" else "gpu" + + if request.mode == "list": + job_defs = self._expand_list(request.jobs or [], batch_id, slot) + tags = [f"batch:{batch_id}"] + elif request.mode == "sweep": + if request.sweep is None: + raise ValueError("sweep mode requires sweep definition") + job_defs = self._expand_sweep(request.sweep, batch_id, slot) + tags = [f"batch:{batch_id}"] + [f"sweep:{a.param}" for a in request.sweep.axes] + elif request.mode == "pipeline": + if request.pipeline is None: + raise ValueError("pipeline mode requires pipeline definition") + return self._submit_pipeline(request.pipeline, batch_id, slot, queue) + else: + raise ValueError(f"Unknown batch mode: {request.mode}") + + job_ids: list[str] = [] + for job_def in job_defs: + job = queue.submit( + job_type=str(job_def["type"]), + model=str(job_def["model"]), + params=dict(job_def["params"]), # type: ignore[arg-type] + slot=slot, + batch_id=batch_id, + batch_index=int(job_def["batch_index"]), # type: ignore[arg-type] + tags=list(tags), + ) + job_ids.append(job.id) + + return BatchSubmitResponse(batch_id=batch_id, job_ids=job_ids, total_jobs=len(job_ids)) + + def get_batch_status(self, batch_id: str, queue: JobQueue) -> BatchStatusResponse: + jobs = queue.jobs_for_batch(batch_id) + if not jobs: + raise ValueError(f"Batch {batch_id} not found") + + completed = sum(1 for j in jobs if j.status == "complete") + failed = sum(1 for j in jobs if j.status == "error") + running = sum(1 for j in jobs if j.status == "running") + cancelled = sum(1 for j in jobs if j.status == "cancelled") + queued = sum(1 for j in jobs if j.status == "queued") + + report = None + if queued == 0 and running == 0: + report = self._build_report(batch_id, jobs) + + return BatchStatusResponse( + batch_id=batch_id, + total=len(jobs), + completed=completed, + failed=failed, + running=running, + queued=queued, + cancelled=cancelled, + jobs=[self._job_to_response(j) for j in jobs], + report=report, + ) + + def cancel_batch(self, batch_id: str, queue: JobQueue) -> None: + for job in queue.jobs_for_batch(batch_id): + if job.status == "queued": + queue.update_job(job.id, status="cancelled") + + def retry_failed(self, batch_id: str, queue: JobQueue) -> BatchSubmitResponse: + failed_jobs = [j for j in queue.jobs_for_batch(batch_id) if j.status == "error"] + new_ids: list[str] = [] + for job in failed_jobs: + new_job = queue.submit( + job_type=job.type, + model=job.model, + params=job.params, + slot=job.slot, + batch_id=batch_id, + batch_index=job.batch_index, + tags=job.tags, + ) + new_ids.append(new_job.id) + return BatchSubmitResponse(batch_id=batch_id, job_ids=new_ids, total_jobs=len(new_ids)) + + def _expand_list( + self, items: list[BatchJobItem], batch_id: str, slot: str + ) -> list[dict[str, object]]: + return [ + { + "type": item.type, + "model": item.model, + "params": dict(item.params), + "batch_index": i, + } + for i, item in enumerate(items) + ] + + def _expand_sweep( + self, sweep: SweepDefinition, batch_id: str, slot: str + ) -> list[dict[str, object]]: + axis_values: list[list[tuple[str, object]]] = [] + for axis in sweep.axes: + pairs: list[tuple[str, object]] = [] + for val in axis.values: + pairs.append((axis.param, val)) + axis_values.append(pairs) + + combos = list(itertools.product(*axis_values)) + results: list[dict[str, object]] = [] + for i, combo in enumerate(combos): + params: dict[str, Any] = dict(sweep.base_params) + for param_name, value in combo: + axis_def = next(a for a in sweep.axes if a.param == param_name) + if axis_def.mode == "search_replace" and axis_def.search and param_name in params: + current = str(params[param_name]) + params[param_name] = current.replace(axis_def.search, str(value)) + else: + params[param_name] = value + results.append({ + "type": sweep.base_type, + "model": sweep.base_model, + "params": params, + "batch_index": i, + }) + + return results + + def _submit_pipeline( + self, pipeline: PipelineDefinition, batch_id: str, slot: str, queue: JobQueue + ) -> BatchSubmitResponse: + job_ids: list[str] = [] + prev_job_id: str | None = None + tags = [f"batch:{batch_id}"] + + for i, step in enumerate(pipeline.steps): + auto_params: dict[str, str] = {} + depends_on: str | None = None + if prev_job_id is not None: + depends_on = prev_job_id + auto_params["imagePath"] = "$dep.result_paths[0]" + if step.auto_prompt: + auto_params["auto_prompt"] = "true" + + job = queue.submit( + job_type=step.type, + model=step.model, + params=dict(step.params), + slot=slot, + batch_id=batch_id, + batch_index=i, + depends_on=depends_on, + auto_params=auto_params, + tags=list(tags), + ) + job_ids.append(job.id) + prev_job_id = job.id + + return BatchSubmitResponse(batch_id=batch_id, job_ids=job_ids, total_jobs=len(job_ids)) + + def _build_report(self, batch_id: str, jobs: list[QueueJob]) -> BatchReport: + succeeded = sum(1 for j in jobs if j.status == "complete") + failed = sum(1 for j in jobs if j.status == "error") + cancelled = sum(1 for j in jobs if j.status == "cancelled") + result_paths: list[str] = [] + failed_indices: list[int] = [] + for j in jobs: + result_paths.extend(j.result_paths) + if j.status == "error": + failed_indices.append(j.batch_index) + + sweep_axes: list[str] | None = None + sweep_tags = [t for t in (jobs[0].tags if jobs else []) if t.startswith("sweep:")] + if sweep_tags: + sweep_axes = [t.split(":", 1)[1] for t in sweep_tags] + + return BatchReport( + batch_id=batch_id, + total=len(jobs), + succeeded=succeeded, + failed=failed, + cancelled=cancelled, + duration_seconds=0.0, + avg_job_seconds=0.0, + result_paths=result_paths, + failed_indices=failed_indices, + sweep_axes=sweep_axes, + ) + + def _job_to_response(self, job: QueueJob) -> QueueJobResponse: + return QueueJobResponse( + id=job.id, + type=job.type, + model=job.model, + params=job.params, + status=job.status, + slot=job.slot, + progress=job.progress, + phase=job.phase, + result_paths=job.result_paths, + error=job.error, + created_at=job.created_at, + batch_id=job.batch_id, + batch_index=job.batch_index, + tags=job.tags, + ) diff --git a/backend/handlers/contact_sheet_handler.py b/backend/handlers/contact_sheet_handler.py new file mode 100644 index 00000000..c5dfa188 --- /dev/null +++ b/backend/handlers/contact_sheet_handler.py @@ -0,0 +1,50 @@ +"""Handler for contact sheet generation (3x3 cinematic angle grid).""" +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from api_types import GenerateContactSheetRequest, GenerateContactSheetResponse + +if TYPE_CHECKING: + from state.job_queue import JobQueue + +logger = logging.getLogger(__name__) + +CAMERA_ANGLES: list[str] = [ + "Close-up portrait, tight framing on the face", + "Medium shot, waist-up framing", + "Full body shot, head to toe framing", + "Over-the-shoulder shot, looking past one subject", + "Low angle hero shot, looking up dramatically", + "High angle bird's eye view, looking down from above", + "Profile side view, lateral perspective", + "Three-quarter view, angled between front and side", + "Wide establishing shot, showing full environment", +] + + +class ContactSheetHandler: + def __init__(self, job_queue: "JobQueue") -> None: + self._job_queue = job_queue + + def generate(self, req: GenerateContactSheetRequest) -> GenerateContactSheetResponse: + job_ids: list[str] = [] + style_suffix = f", {req.style}" if req.style else "" + + for angle in CAMERA_ANGLES: + prompt = f"{req.subject_description}, {angle}{style_suffix}" + job = self._job_queue.submit( + job_type="image", + model="z-image-turbo", + params={ + "prompt": prompt, + "width": 1024, + "height": 1024, + "reference_image_path": req.reference_image_path, + }, + slot="api", + ) + job_ids.append(job.id) + + return GenerateContactSheetResponse(job_ids=job_ids) diff --git a/backend/handlers/download_handler.py b/backend/handlers/download_handler.py index 44ffd82a..8c85473b 100644 --- a/backend/handlers/download_handler.py +++ b/backend/handlers/download_handler.py @@ -50,13 +50,13 @@ def start_download(self, files: dict[ModelFileType, tuple[str, int]]) -> None: progress=0.0, downloaded_bytes=0, total_bytes=size, - speed_mbps=0.0, + speed_bytes_per_sec=0.0, ) for file_type, (target, size) in files.items() } @with_state_lock - def update_file_progress(self, file_type: ModelFileType, downloaded: int, total: int, speed_mbps: float) -> None: + def update_file_progress(self, file_type: ModelFileType, downloaded: int, total: int, speed_bytes_per_sec: float) -> None: match self.state.downloading_session: case dict() as files: if file_type not in files: @@ -66,7 +66,7 @@ def update_file_progress(self, file_type: ModelFileType, downloaded: int, total: running.downloaded_bytes = downloaded running.total_bytes = total running.progress = 0.0 if total == 0 else min(1.0, max(0.0, downloaded / total)) - running.speed_mbps = speed_mbps + running.speed_bytes_per_sec = speed_bytes_per_sec case FileDownloadCompleted(): return case _: @@ -86,12 +86,25 @@ def fail_download(self, error: str) -> None: self.state.downloading_session = DownloadError(error=error) def _make_progress_callback(self, file_type: ModelFileType) -> Callable[[int, int], None]: - start_time = time.monotonic() + last_sample_time = time.monotonic() + last_sample_bytes = 0 + smoothed_speed = 0.0 def on_progress(downloaded: int, total: int) -> None: - elapsed = time.monotonic() - start_time - speed_mbps = (downloaded / elapsed / (1024 * 1024)) if elapsed > 0 else 0.0 - self.update_file_progress(file_type, downloaded, total, speed_mbps) + nonlocal last_sample_time, last_sample_bytes, smoothed_speed + now = time.monotonic() + elapsed = now - last_sample_time + if elapsed >= 1.0: + instant_speed = (downloaded - last_sample_bytes) / elapsed + # EWMA: weight new sample at 30%, keep 70% of previous. + # On first sample (smoothed_speed == 0) use instant value. + if smoothed_speed == 0.0: + smoothed_speed = instant_speed + else: + smoothed_speed = 0.3 * instant_speed + 0.7 * smoothed_speed + last_sample_time = now + last_sample_bytes = downloaded + self.update_file_progress(file_type, downloaded, total, smoothed_speed) return on_progress @@ -103,7 +116,7 @@ def get_download_progress(self) -> DownloadProgressResponse: status = "idle" current_file = "" current_file_progress = 0 - speed_mbps = 0 + speed_bytes_per_sec: float = 0.0 downloaded_bytes = 0 total_bytes = 0 files_completed = 0 @@ -127,7 +140,7 @@ def get_download_progress(self) -> DownloadProgressResponse: case FileDownloadRunning() as running: current_file = file_type current_file_progress = int(running.progress * 100) - speed_mbps = int(running.speed_mbps) + speed_bytes_per_sec = running.speed_bytes_per_sec downloaded_bytes += running.downloaded_bytes case _: status = "idle" @@ -144,7 +157,7 @@ def get_download_progress(self) -> DownloadProgressResponse: filesCompleted=files_completed, totalFiles=total_files, error=error, - speedMbps=speed_mbps, + speedBytesPerSec=speed_bytes_per_sec, ) def _move_to_final(self, file_type: ModelFileType) -> None: diff --git a/backend/handlers/enhance_prompt_handler.py b/backend/handlers/enhance_prompt_handler.py new file mode 100644 index 00000000..420b1c41 --- /dev/null +++ b/backend/handlers/enhance_prompt_handler.py @@ -0,0 +1,427 @@ +"""Prompt enhancement handler (Palette proxy or Gemini fallback). + +Two modes: +- **Generate** (empty prompt): create a random cinematic prompt from scratch +- **Enhance** (has prompt): expand a rough prompt into a detailed, model-optimized one + +System prompts are tailored per model family (LTX-Video, Seedance, image models). +""" +from __future__ import annotations + +import logging +from threading import RLock +from typing import Any + +from _routes._errors import HTTPError +from handlers.base import StateHandlerBase +from pydantic import BaseModel, Field, ValidationError +from services.interfaces import HTTPClient, HttpTimeoutError, JSONValue +from state.app_state_types import AppState + +logger = logging.getLogger(__name__) + +PALETTE_BASE_URL = "https://directorspalette.com" + +# --------------------------------------------------------------------------- +# Gemini response parsing +# --------------------------------------------------------------------------- + +class _GeminiPart(BaseModel): + text: str + + +class _GeminiContent(BaseModel): + parts: list[_GeminiPart] = Field(min_length=1) + + +class _GeminiCandidate(BaseModel): + content: _GeminiContent + + +class _GeminiResponsePayload(BaseModel): + candidates: list[_GeminiCandidate] = Field(min_length=1) + + +def _extract_gemini_text(payload: object) -> str: + try: + parsed = _GeminiResponsePayload.model_validate(payload) + except ValidationError: + raise HTTPError(500, "GEMINI_PARSE_ERROR") + return parsed.candidates[0].content.parts[0].text + + +# --------------------------------------------------------------------------- +# Palette response parsing +# --------------------------------------------------------------------------- + +class _PaletteResponsePayload(BaseModel): + enhanced_prompt: str | None = Field(None, alias="enhanced_prompt") + expanded_prompt: str | None = Field(None, alias="expandedPrompt") + + +class _OpenRouterMessage(BaseModel): + content: str + + +class _OpenRouterChoice(BaseModel): + message: _OpenRouterMessage + + +class _OpenRouterResponse(BaseModel): + choices: list[_OpenRouterChoice] = Field(min_length=1) + + +def _extract_openrouter_text(payload: object) -> str: + """Parse OpenAI-compatible chat completion response.""" + try: + parsed = _OpenRouterResponse.model_validate(payload) + except ValidationError: + raise HTTPError(500, "OPENROUTER_PARSE_ERROR") + text = parsed.choices[0].message.content.strip() + if not text: + raise HTTPError(500, "OPENROUTER_PARSE_ERROR") + return text + + +def _extract_palette_text(payload: object) -> str: + try: + parsed = _PaletteResponsePayload.model_validate(payload) + except ValidationError: + raise HTTPError(500, "PALETTE_PARSE_ERROR") + text = parsed.enhanced_prompt or parsed.expanded_prompt or "" + if not text.strip(): + raise HTTPError(500, "PALETTE_PARSE_ERROR") + return text.strip() + + +# --------------------------------------------------------------------------- +# Model-specific system prompts +# --------------------------------------------------------------------------- + +_LTX_VIDEO_RULES = ( + "LTX-2.3 PROMPTING RULES (follow strictly):\n" + "- Direct the scene like a director: specify spatial layout (left/right, " + "foreground/background, facing toward/away)\n" + "- Use specific action verbs for motion: who moves, what moves, how they move, " + "what the camera does\n" + "- Describe texture and material: fabric types, hair texture, surface finish, " + "environmental wear\n" + "- Avoid static photo-like descriptions — include movement to reduce frozen outputs\n" + "- Specify camera motion explicitly: 'camera slowly pushes forward', " + "'camera tracks right', 'camera holds steady'\n" + "- Be specific about the number of subjects and their positions\n" + "- Describe lighting: source direction, color temperature, quality (soft/hard)\n" +) + +_LTX_I2V_EXTRA = ( + "- IMPORTANT for image-to-video: describe only motion and camera movement. " + "Do NOT redescribe what is already visible in the image. Focus on what CHANGES: " + "who moves, what moves, camera direction. Explicitly state if subjects should " + "remain still. State 'no other people enter the frame' if the scene should stay " + "contained.\n" +) + +_SEEDANCE_RULES = ( + "SEEDANCE PROMPTING RULES:\n" + "- Seedance excels at dance, human motion, and dynamic physical movement\n" + "- Describe body movement in detail: gestures, dance styles, footwork\n" + "- Specify camera angle and movement\n" + "- Include environment and lighting details\n" + "- Describe clothing and how it moves with the body\n" +) + +_IMAGE_RULES = ( + "IMAGE GENERATION RULES:\n" + "- Expand vague descriptions into specific, vivid details\n" + "- Add lighting direction, color temperature, mood, and atmosphere\n" + "- Describe texture, material, and surface detail precisely\n" + "- Specify camera angle, lens characteristics (shallow DOF, wide angle, etc)\n" + "- Include color palette and tonal mood\n" +) + + +def _get_system_prompt(*, model: str, mode: str, is_generate: bool, has_image: bool = False) -> str: + """Build a system prompt tailored to the model, mode, and action.""" + is_image = mode in ("text-to-image", "t2i") + is_i2v = mode in ("image-to-video", "i2v") or has_image + is_ltx = model.startswith("ltx") + is_seedance = "seedance" in model + + if is_generate and has_image: + action = ( + "The user has provided a starting image. Analyze what you see in the image — " + "the subjects, setting, lighting, mood, composition — and create a motion-aware " + "prompt that will animate this image into a cinematic video. Focus on describing " + "MOTION: what should move, how the camera should move, what stays still. " + "Do NOT just describe the image statically. Direct the scene.\n\n" + ) + elif is_generate: + action = ( + "The user wants you to invent a creative, cinematic prompt from scratch. " + "Come up with something visually stunning and unexpected. Vary your ideas — " + "don't default to the same themes. Mix genres: sci-fi, nature, fashion, " + "documentary, abstract, horror, comedy, noir, fantasy. " + "Write only the prompt, nothing else.\n\n" + ) + elif has_image: + action = ( + "The user has provided a starting image and a rough prompt. Look at the image, " + "understand the scene, and enhance the prompt into a detailed, motion-aware " + "description. Keep the user's intent but add specificity about motion, camera, " + "and what should happen in the scene.\n\n" + ) + else: + action = ( + "The user provides a rough prompt. Your job is to enhance it into a " + "detailed, production-ready description while keeping the core intent.\n\n" + ) + + if is_image: + rules = _IMAGE_RULES + elif is_seedance: + rules = _SEEDANCE_RULES + elif is_ltx: + rules = _LTX_VIDEO_RULES + (_LTX_I2V_EXTRA if is_i2v else "") + else: + rules = _LTX_VIDEO_RULES # default to LTX rules for unknown video models + + return ( + f"You are a creative director's assistant specializing in AI " + f"{'image' if is_image else 'video'} generation.\n\n" + + action + + rules + + "\nOutput format:\n" + "- Write 2-4 sentences max\n" + "- Write ONLY the prompt, no labels, explanations, or quotation marks\n" + ) + + +# --------------------------------------------------------------------------- +# Handler +# --------------------------------------------------------------------------- + +class EnhancePromptHandler(StateHandlerBase): + def __init__(self, state: AppState, lock: RLock, http: HTTPClient) -> None: + super().__init__(state, lock) + self._http = http + + def enhance( + self, prompt: str, mode: str, model: str = "ltx-fast", *, image_path: str | None = None, + ) -> dict[str, str]: + """Enhance an existing prompt or generate one from scratch (empty prompt). + + When *image_path* is provided the handler asks the AI to describe the + image and craft a motion-aware prompt based on what it sees. + """ + palette_api_key = self.state.app_settings.palette_api_key + if palette_api_key: + return self._enhance_via_palette(prompt, palette_api_key, mode, model) + + gemini_api_key = self.state.app_settings.gemini_api_key + if gemini_api_key: + return self._enhance_via_gemini(prompt, mode, model, gemini_api_key, image_path=image_path) + + openrouter_api_key = self.state.app_settings.openrouter_api_key + if openrouter_api_key: + return self._enhance_via_openrouter(prompt, mode, model, openrouter_api_key, image_path=image_path) + + raise HTTPError(400, "NO_AI_SERVICE_CONFIGURED") + + # --- Provider: Palette API --- + + def enhance_i2v_motion(self, image_path: str) -> str: + """Generate an i2v motion prompt (used by queue worker for extend chains).""" + result = self.enhance("", "image-to-video", "ltx-fast") + return result.get("enhancedPrompt", "") + + def _enhance_via_palette( + self, prompt: str, api_key: str, mode: str = "text-to-video", model: str = "ltx-fast", + ) -> dict[str, str]: + """Proxy to Director's Palette /api/prompt-expander endpoint.""" + url = f"{PALETTE_BASE_URL}/api/prompt-expander" + payload: dict[str, Any] = { + "prompt": prompt, + "level": "2x", + "mode": mode, + "model": model, + } + if not prompt.strip(): + payload["action"] = "generate" + + try: + response = self._http.post( + url, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + }, + json_payload=payload, + timeout=30, + ) + except HttpTimeoutError as exc: + raise HTTPError(504, "Palette prompt expander request timed out") from exc + except Exception as exc: + raise HTTPError(500, str(exc)) from exc + + if response.status_code != 200: + # Fall through to Gemini if Palette can't handle generate mode + gemini_api_key = self.state.app_settings.gemini_api_key + if gemini_api_key: + return self._enhance_via_gemini(prompt, mode, model, gemini_api_key) + raise HTTPError(response.status_code, f"Palette API error: {response.text}") + + enhanced = _extract_palette_text(response.json()) + return {"enhancedPrompt": enhanced.strip()} + + # --- Provider: Gemini --- + + def _enhance_via_gemini( + self, + prompt: str, + mode: str, + model: str, + gemini_api_key: str, + *, + image_path: str | None = None, + ) -> dict[str, str]: + """Enhance or generate prompt using Gemini API (supports multimodal with image).""" + has_image = bool(image_path) + is_generate = not prompt.strip() + system_text = _get_system_prompt( + model=model, mode=mode, is_generate=is_generate, has_image=has_image, + ) + + # Build user content parts (text + optional image) + user_parts: list[dict[str, JSONValue]] = [] + + if has_image and image_path: + image_b64 = self._read_image_as_base64(image_path) + if image_b64: + inline: dict[str, JSONValue] = {"mime_type": "image/jpeg", "data": image_b64} + user_parts.append({"inline_data": inline}) + + if is_generate and has_image: + user_parts.append({"text": "Look at this image and create a cinematic motion prompt for it."}) + elif is_generate: + user_parts.append({"text": "Generate a creative, cinematic prompt."}) + elif has_image: + user_parts.append({"text": f"Look at this image and enhance this prompt for it: {prompt}"}) + else: + user_parts.append({"text": f"Enhance this prompt: {prompt}"}) + + gemini_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent" + parts_list: JSONValue = user_parts # type: ignore[assignment] + gemini_payload: dict[str, JSONValue] = { + "contents": [{"role": "user", "parts": parts_list}], + "systemInstruction": {"parts": [{"text": system_text}]}, + "generationConfig": { + "temperature": 0.9 if is_generate else 0.7, + "maxOutputTokens": 256, + }, + } + + try: + response = self._http.post( + gemini_url, + headers={"Content-Type": "application/json", "x-goog-api-key": gemini_api_key}, + json_payload=gemini_payload, + timeout=30, + ) + except HttpTimeoutError as exc: + raise HTTPError(504, "Gemini API request timed out") from exc + except Exception as exc: + raise HTTPError(500, str(exc)) from exc + + if response.status_code != 200: + raise HTTPError(response.status_code, f"Gemini API error: {response.text}") + + enhanced = _extract_gemini_text(response.json()).strip() + return {"enhancedPrompt": enhanced} + + # --- Provider: OpenRouter (OpenAI-compatible) --- + + def _enhance_via_openrouter( + self, + prompt: str, + mode: str, + model: str, + openrouter_api_key: str, + *, + image_path: str | None = None, + ) -> dict[str, str]: + """Enhance or generate prompt using OpenRouter (OpenAI chat completions API).""" + has_image = bool(image_path) + is_generate = not prompt.strip() + system_text = _get_system_prompt( + model=model, mode=mode, is_generate=is_generate, has_image=has_image, + ) + + # Build user message content (text + optional image) + content: list[dict[str, JSONValue]] = [] + + if has_image and image_path: + image_b64 = self._read_image_as_base64(image_path) + if image_b64: + content.append({ + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}, + }) + + if is_generate and has_image: + content.append({"type": "text", "text": "Look at this image and create a cinematic motion prompt for it."}) + elif is_generate: + content.append({"type": "text", "text": "Generate a creative, cinematic prompt."}) + elif has_image: + content.append({"type": "text", "text": f"Look at this image and enhance this prompt for it: {prompt}"}) + else: + content.append({"type": "text", "text": f"Enhance this prompt: {prompt}"}) + + # Use a vision model when image is provided, otherwise a fast text model + or_model = "google/gemini-2.0-flash-001" if has_image else "google/gemini-2.0-flash-001" + + user_msg_content: JSONValue = content # type: ignore[assignment] + messages: JSONValue = [ + {"role": "system", "content": system_text}, + {"role": "user", "content": user_msg_content}, + ] + openrouter_payload: dict[str, JSONValue] = { + "model": or_model, + "messages": messages, + "temperature": 0.9 if is_generate else 0.7, + "max_tokens": 256, + } + + try: + response = self._http.post( + "https://openrouter.ai/api/v1/chat/completions", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {openrouter_api_key}", + }, + json_payload=openrouter_payload, + timeout=30, + ) + except HttpTimeoutError as exc: + raise HTTPError(504, "OpenRouter request timed out") from exc + except Exception as exc: + raise HTTPError(500, str(exc)) from exc + + if response.status_code != 200: + raise HTTPError(response.status_code, f"OpenRouter API error: {response.text}") + + enhanced = _extract_openrouter_text(response.json()) + return {"enhancedPrompt": enhanced} + + @staticmethod + def _read_image_as_base64(image_path: str) -> str | None: + """Read an image file and return base64-encoded string.""" + import base64 + from pathlib import Path + p = Path(image_path) + if not p.exists() or not p.is_file(): + return None + try: + raw = p.read_bytes() + return base64.b64encode(raw).decode("ascii") + except Exception: + return None diff --git a/backend/handlers/gallery_handler.py b/backend/handlers/gallery_handler.py new file mode 100644 index 00000000..b8ea0b84 --- /dev/null +++ b/backend/handlers/gallery_handler.py @@ -0,0 +1,166 @@ +"""Gallery handler for listing and managing local generated assets.""" + +from __future__ import annotations + +import hashlib +import logging +import math +import os +from pathlib import Path + +from _routes._errors import HTTPError +from api_types import GalleryAsset, GalleryAssetType, GalleryListResponse + +logger = logging.getLogger(__name__) + +IMAGE_EXTENSIONS = frozenset({".png", ".jpg", ".jpeg"}) +VIDEO_EXTENSIONS = frozenset({".mp4", ".webm"}) +ALL_EXTENSIONS = IMAGE_EXTENSIONS | VIDEO_EXTENSIONS + +# Known filename prefixes that indicate which model produced the file. +# Legacy prefixes kept for backwards-compat with existing output files. +_LEGACY_MODEL_PREFIXES: list[tuple[str, str]] = [ + ("zit_edit_", "zit-edit"), + ("zit_image_", "zit"), + ("api_image_", "api"), + ("ltx_fast_", "ltx-fast"), + ("ltx2_", "ltx2"), + ("ltx_", "ltx"), + ("api_video_", "api-video"), + ("seedance_", "seedance"), + ("nano_banana_", "nano-banana"), + ("ic_lora_", "ic-lora"), + ("retake_", "retake"), +] + +_DD_PREFIX = "dd_" + + +def _parse_model_name(filename: str) -> str | None: + """Extract model name from filename. + + New format: ``dd_{model}_{prompt}_{timestamp}.ext`` + Legacy format: ``{prefix}{timestamp}_{id}.ext`` + """ + lower = filename.lower() + # New dd_ naming: dd_{model}_{prompt_slug}_{timestamp}.ext + if lower.startswith(_DD_PREFIX): + rest = lower[len(_DD_PREFIX):] + # Model name is the segment before the next underscore + parts = rest.split("_", 1) + if parts: + return parts[0] + return None + # Legacy prefixes + for prefix, model in _LEGACY_MODEL_PREFIXES: + if lower.startswith(prefix): + return model + return None + + +def _classify_file(ext: str) -> GalleryAssetType | None: + """Return 'image' or 'video' based on extension, or None if unsupported.""" + if ext in IMAGE_EXTENSIONS: + return "image" + if ext in VIDEO_EXTENSIONS: + return "video" + return None + + +def _asset_id(filepath: Path) -> str: + """Produce a stable, short identifier from the file path.""" + return hashlib.sha256(str(filepath).encode()).hexdigest()[:16] + + +class GalleryHandler: + """Scans an output directory for generated image/video files.""" + + def __init__(self, outputs_dir: Path) -> None: + self._outputs_dir = outputs_dir + + def list_local_assets( + self, + page: int = 1, + per_page: int = 50, + asset_type: str = "all", + ) -> GalleryListResponse: + """List generated assets with pagination and optional type filtering.""" + if page < 1: + page = 1 + if per_page < 1: + per_page = 50 + + allowed_exts: frozenset[str] + if asset_type == "image": + allowed_exts = IMAGE_EXTENSIONS + elif asset_type == "video": + allowed_exts = VIDEO_EXTENSIONS + else: + allowed_exts = ALL_EXTENSIONS + + assets: list[GalleryAsset] = [] + + if self._outputs_dir.is_dir(): + for entry in self._outputs_dir.iterdir(): + if not entry.is_file(): + continue + ext = entry.suffix.lower() + if ext not in allowed_exts: + continue + + file_type = _classify_file(ext) + if file_type is None: + continue + + stat = entry.stat() + # Use st_ctime_ns for cross-platform compatibility (st_ctime + # property is deprecated on Python 3.12+). + created_at_secs = stat.st_ctime_ns / 1_000_000_000 + assets.append( + GalleryAsset( + id=_asset_id(entry), + filename=entry.name, + path=str(entry), + url=f"/api/gallery/local/file/{entry.name}", + type=file_type, + size_bytes=stat.st_size, + created_at=str(created_at_secs), + model_name=_parse_model_name(entry.name), + ) + ) + + # Sort by created_at descending (newest first). + assets.sort(key=lambda a: float(a.created_at), reverse=True) + + total = len(assets) + total_pages = max(1, math.ceil(total / per_page)) + start = (page - 1) * per_page + end = start + per_page + + return GalleryListResponse( + items=assets[start:end], + total=total, + page=page, + per_page=per_page, + total_pages=total_pages, + ) + + def delete_local_asset(self, asset_id: str) -> None: + """Delete a local asset file by its ID.""" + if not self._outputs_dir.is_dir(): + raise HTTPError(404, f"Asset {asset_id} not found") + + for entry in self._outputs_dir.iterdir(): + if not entry.is_file(): + continue + ext = entry.suffix.lower() + if ext not in ALL_EXTENSIONS: + continue + if _asset_id(entry) == asset_id: + try: + os.remove(entry) + except OSError as exc: + raise HTTPError(500, f"Failed to delete asset: {exc}") from exc + return + + raise HTTPError(404, f"Asset {asset_id} not found") diff --git a/backend/handlers/health_handler.py b/backend/handlers/health_handler.py index 11d0427d..37cce57a 100644 --- a/backend/handlers/health_handler.py +++ b/backend/handlers/health_handler.py @@ -119,6 +119,8 @@ def default_warmup(self) -> None: case _: pass + # Only preload ZIT to CPU if it's downloaded — Flux Klein (now default) + # doesn't support CPU parking, it rebuilds each time on demand. zit_models_path = self._config.model_path("zit") zit_exists = zit_models_path.exists() and any(zit_models_path.iterdir()) if zit_exists: diff --git a/backend/handlers/ic_lora_handler.py b/backend/handlers/ic_lora_handler.py index 43559b6e..2d9f1a0e 100644 --- a/backend/handlers/ic_lora_handler.py +++ b/backend/handlers/ic_lora_handler.py @@ -5,10 +5,10 @@ import base64 import logging import uuid -from datetime import datetime from pathlib import Path from threading import RLock +from server_utils.output_naming import make_output_path from api_types import ( IcLoraDownloadRequest, IcLoraDownloadResponse, @@ -160,7 +160,7 @@ def generate(self, req: IcLoraGenerateRequest) -> IcLoraGenerateResponse: height = round(req.height / 64) * 64 width = round(req.width / 64) * 64 - output_path = self._outputs_dir / f"ic_lora_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}.mp4" + output_path = make_output_path(self._outputs_dir, model="ic-lora", prompt=req.prompt, ext="mp4") ic_state.pipeline.generate( prompt=req.prompt, diff --git a/backend/handlers/image_generation_handler.py b/backend/handlers/image_generation_handler.py index 1e95e9ab..95f99932 100644 --- a/backend/handlers/image_generation_handler.py +++ b/backend/handlers/image_generation_handler.py @@ -5,17 +5,19 @@ import logging import time import uuid -from datetime import datetime from pathlib import Path from threading import RLock from typing import TYPE_CHECKING +from PIL import Image + from _routes._errors import HTTPError from api_types import GenerateImageRequest, GenerateImageResponse from handlers.base import StateHandlerBase from handlers.generation_handler import GenerationHandler from handlers.pipelines_handler import PipelinesHandler -from services.interfaces import ZitAPIClient +from server_utils.output_naming import make_output_path +from services.interfaces import ImageAPIClient from state.app_state_types import AppState if TYPE_CHECKING: @@ -33,14 +35,14 @@ def __init__( pipelines_handler: PipelinesHandler, outputs_dir: Path, config: RuntimeConfig, - zit_api_client: ZitAPIClient, + image_api_client: ImageAPIClient, ) -> None: super().__init__(state, lock) self._generation = generation_handler self._pipelines = pipelines_handler self._outputs_dir = outputs_dir self._config = config - self._zit_api_client = zit_api_client + self._image_api_client = image_api_client def generate(self, req: GenerateImageRequest) -> GenerateImageResponse: if self._generation.is_generation_running(): @@ -69,7 +71,8 @@ def generate(self, req: GenerateImageRequest) -> GenerateImageResponse: ) try: - self._pipelines.load_zit_to_gpu() + image_model = settings.image_model + self._pipelines.load_image_model_to_gpu(image_model) self._generation.start_generation(generation_id) output_paths = self.generate_image( prompt=req.prompt, @@ -78,6 +81,11 @@ def generate(self, req: GenerateImageRequest) -> GenerateImageResponse: num_inference_steps=req.numSteps, seed=seed, num_images=num_images, + lora_path=req.loraPath, + lora_weight=req.loraWeight, + source_image_path=req.sourceImagePath, + strength=req.strength, + image_model=image_model, ) self._generation.complete_generation(output_paths) return GenerateImageResponse(status="complete", image_paths=output_paths) @@ -96,19 +104,52 @@ def generate_image( num_inference_steps: int, seed: int | None, num_images: int, + lora_path: str | None = None, + lora_weight: float = 1.0, + source_image_path: str | None = None, + strength: float = 0.65, + image_model: str = "flux-klein-9b", ) -> list[str]: if self._generation.is_generation_cancelled(): raise RuntimeError("Generation was cancelled") - self._generation.update_progress("loading_model", 5, 0, num_inference_steps) - zit = self._pipelines.load_zit_to_gpu() + self._generation.update_progress("preparing_gpu", 3, 0, num_inference_steps) + pipeline = self._pipelines.load_image_model_to_gpu( + model_name=image_model, + on_phase=lambda phase: self._generation.update_progress(phase, 5, 0, num_inference_steps), + ) + + if lora_path: + logger.info("Loading LoRA: %s (weight=%.2f)", lora_path, lora_weight) + self._generation.update_progress("loading_lora", 10, 0, num_inference_steps) + pipeline.load_lora(lora_path, weight=lora_weight) + else: + pipeline.unload_lora() + + # Load and prepare source image for img2img + source_image = None + if source_image_path: + self._generation.update_progress("encoding_image", 12, 0, num_inference_steps) + source_image = Image.open(source_image_path).convert("RGB") + width = (source_image.width // 16) * 16 + height = (source_image.height // 16) * 16 + source_image = source_image.resize((width, height), Image.Resampling.LANCZOS) + self._generation.update_progress("inference", 15, 0, num_inference_steps) if seed is None: seed = int(time.time()) % 2147483647 + is_flux = "flux" in image_model + is_edit = source_image is not None + if is_flux: + model_label = "flux-klein-edit" if is_edit else "flux-klein" + else: + model_label = "zit-edit" if is_edit else "zit" + + # FLUX Klein uses guidance_scale=4.0; ZIT uses 0.0 + guidance = 4.0 if is_flux else 0.0 outputs: list[str] = [] - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") for i in range(num_images): if self._generation.is_generation_cancelled(): @@ -117,16 +158,28 @@ def generate_image( progress = 15 + int((i / num_images) * 80) self._generation.update_progress("inference", progress, i, num_images) - result = zit.generate( - prompt=prompt, - height=height, - width=width, - guidance_scale=0.0, - num_inference_steps=num_inference_steps, - seed=seed + i, - ) + if source_image is not None: + result = pipeline.img2img( + prompt=prompt, + image=source_image, + strength=strength, + height=height, + width=width, + guidance_scale=guidance, + num_inference_steps=num_inference_steps, + seed=seed + i, + ) + else: + result = pipeline.generate( + prompt=prompt, + height=height, + width=width, + guidance_scale=guidance, + num_inference_steps=num_inference_steps, + seed=seed + i, + ) - output_path = self._outputs_dir / f"zit_image_{timestamp}_{uuid.uuid4().hex[:8]}.png" + output_path = make_output_path(self._outputs_dir, model=model_label, prompt=prompt, ext="png") result.images[0].save(str(output_path)) outputs.append(str(output_path)) @@ -148,15 +201,14 @@ def _generate_via_api( ) -> GenerateImageResponse: generation_id = uuid.uuid4().hex[:8] output_paths: list[Path] = [] - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") settings = self.state.app_settings.model_copy(deep=True) try: self._generation.start_api_generation(generation_id) self._generation.update_progress("validating_request", 5, None, None) - if not settings.fal_api_key.strip(): - raise HTTPError(500, "FAL_API_KEY_NOT_CONFIGURED") + if not settings.replicate_api_key.strip(): + raise HTTPError(500, "REPLICATE_API_KEY_NOT_CONFIGURED") for idx in range(num_images): if self._generation.is_generation_cancelled(): @@ -164,8 +216,9 @@ def _generate_via_api( inference_progress = 15 + int((idx / num_images) * 60) self._generation.update_progress("inference", inference_progress, None, None) - image_bytes = self._zit_api_client.generate_text_to_image( - api_key=settings.fal_api_key, + image_bytes = self._image_api_client.generate_text_to_image( + api_key=settings.replicate_api_key, + model=settings.image_model, prompt=prompt, width=width, height=height, @@ -179,7 +232,7 @@ def _generate_via_api( download_progress = 75 + int(((idx + 1) / num_images) * 20) self._generation.update_progress("downloading_output", download_progress, None, None) - output_path = self._outputs_dir / f"zit_api_image_{timestamp}_{uuid.uuid4().hex[:8]}.png" + output_path = make_output_path(self._outputs_dir, model=settings.image_model, prompt=prompt, ext="png") output_path.write_bytes(image_bytes) output_paths.append(output_path) diff --git a/backend/handlers/job_executors.py b/backend/handlers/job_executors.py new file mode 100644 index 00000000..b2ec957f --- /dev/null +++ b/backend/handlers/job_executors.py @@ -0,0 +1,218 @@ +"""Job executors that bridge the queue system to existing generation handlers.""" + +from __future__ import annotations + +import logging +import threading +from pathlib import Path +from typing import TYPE_CHECKING + +from api_types import GenerateImageRequest, GenerateVideoRequest + +if TYPE_CHECKING: + from app_handler import AppHandler + from state.job_queue import QueueJob + +logger = logging.getLogger(__name__) + + +class _ProgressSyncer: + """Copies progress from GenerationHandler to the job queue in a background thread.""" + + def __init__(self, handler: AppHandler, job_id: str) -> None: + self._handler = handler + self._job_id = job_id + self._stop = threading.Event() + self._thread = threading.Thread(target=self._loop, daemon=True) + + def start(self) -> None: + self._thread.start() + + def stop(self) -> None: + self._stop.set() + self._thread.join(timeout=2) + + def _loop(self) -> None: + while not self._stop.is_set(): + try: + prog = self._handler.generation.get_generation_progress() + if prog.phase and prog.phase not in ("", "idle"): + self._handler.job_queue.update_job( + self._job_id, + phase=prog.phase, + progress=prog.progress, + ) + except Exception: + pass + self._stop.wait(0.3) + + +class GpuJobExecutor: + """Executes GPU-slot jobs using existing generation handlers.""" + + def __init__(self, handler: AppHandler) -> None: + self._handler = handler + + def execute(self, job: QueueJob) -> list[str]: + syncer = _ProgressSyncer(self._handler, job.id) + syncer.start() + try: + if job.type == "image": + result = self._execute_image(job) + elif job.type == "video": + result = self._execute_video(job) + elif job.type == "long_video": + result = self._execute_long_video(job) + else: + raise ValueError(f"Unknown job type: {job.type}") + self._try_upload_to_r2(job, result) + return result + finally: + syncer.stop() + + def _try_upload_to_r2(self, job: QueueJob, result_paths: list[str]) -> None: + """Upload results to R2 if configured.""" + settings = self._handler.state.app_settings + if not settings.auto_upload_to_r2: + return + if not (settings.r2_access_key_id and settings.r2_endpoint): + return + + from services.r2_client.r2_client_impl import R2ClientImpl + + client = R2ClientImpl( + access_key_id=settings.r2_access_key_id, + secret_access_key=settings.r2_secret_access_key, + endpoint=settings.r2_endpoint, + bucket=settings.r2_bucket, + public_url=settings.r2_public_url, + ) + + for path in result_paths: + try: + ext = Path(path).suffix + content_type = "video/mp4" if ext == ".mp4" else "image/png" + remote_key = f"videos/{job.id}{ext}" + client.upload_file(local_path=path, remote_key=remote_key, content_type=content_type) + except Exception as exc: + logger.warning("R2 upload failed for %s: %s", path, exc) + + def _execute_image(self, job: QueueJob) -> list[str]: + params = job.params + req = GenerateImageRequest( + prompt=str(params.get("prompt", "")), + width=int(params.get("width", 1024)), + height=int(params.get("height", 1024)), + numSteps=int(params.get("numSteps", 4)), + numImages=int(params.get("numImages", 1)), + loraPath=str(params.get("loraPath")) if params.get("loraPath") else None, + loraWeight=float(params.get("loraWeight", 1.0)), + sourceImagePath=str(params.get("sourceImagePath")) if params.get("sourceImagePath") else None, + strength=float(params.get("strength", 0.65)), + ) + result = self._handler.image_generation.generate(req) + if result.status == "cancelled": + raise RuntimeError("Generation was cancelled") + return result.image_paths or [] + + def _execute_video(self, job: QueueJob) -> list[str]: + params = job.params + req = GenerateVideoRequest( + prompt=str(params.get("prompt", "")), + resolution=str(params.get("resolution", "512p")), + model=job.model, + cameraMotion=str(params.get("cameraMotion", "none")), # type: ignore[arg-type] + duration=str(params.get("duration", "2")), + fps=str(params.get("fps", "24")), + audio=str(params.get("audio", "false")), + imagePath=str(params.get("imagePath")) if params.get("imagePath") else None, + audioPath=str(params.get("audioPath")) if params.get("audioPath") else None, + lastFramePath=str(params.get("lastFramePath")) if params.get("lastFramePath") else None, + aspectRatio=str(params.get("aspectRatio", "16:9")), # type: ignore[arg-type] + loraPath=str(params.get("loraPath")) if params.get("loraPath") else None, + loraWeight=float(params.get("loraWeight", 1.0)), + ) + result = self._handler.video_generation.generate(req) + if result.status == "cancelled": + raise RuntimeError("Generation was cancelled") + if result.video_path is None: + return [] + return [result.video_path] + + def _execute_long_video(self, job: QueueJob) -> list[str]: + params = job.params + video_path = self._handler.video_generation.generate_long_video( + prompt=str(params.get("prompt", "")), + image_path=str(params.get("imagePath", "")), + target_duration=int(params.get("targetDuration", 20)), + resolution=str(params.get("resolution", "512p")), + aspect_ratio=str(params.get("aspectRatio", "16:9")), + fps=int(params.get("fps", 24)), + segment_duration=int(params.get("segmentDuration", 4)), + camera_motion=str(params.get("cameraMotion", "none")), # type: ignore[arg-type] + lora_path=str(params.get("loraPath")) if params.get("loraPath") else None, + lora_weight=float(params.get("loraWeight", 1.0)), + ) + return [video_path] + + +class ApiJobExecutor: + """Executes API-slot jobs using existing generation handlers.""" + + def __init__(self, handler: AppHandler) -> None: + self._handler = handler + + def execute(self, job: QueueJob) -> list[str]: + syncer = _ProgressSyncer(self._handler, job.id) + syncer.start() + try: + if job.type == "image": + return self._execute_image(job) + elif job.type == "video": + return self._execute_video(job) + else: + raise ValueError(f"Unknown job type: {job.type}") + finally: + syncer.stop() + + def _execute_image(self, job: QueueJob) -> list[str]: + params = job.params + req = GenerateImageRequest( + prompt=str(params.get("prompt", "")), + width=int(params.get("width", 1024)), + height=int(params.get("height", 1024)), + numSteps=int(params.get("numSteps", 4)), + numImages=int(params.get("numImages", 1)), + loraPath=str(params.get("loraPath")) if params.get("loraPath") else None, + loraWeight=float(params.get("loraWeight", 1.0)), + sourceImagePath=str(params.get("sourceImagePath")) if params.get("sourceImagePath") else None, + strength=float(params.get("strength", 0.65)), + ) + result = self._handler.image_generation.generate(req) + if result.status == "cancelled": + raise RuntimeError("Generation was cancelled") + return result.image_paths or [] + + def _execute_video(self, job: QueueJob) -> list[str]: + params = job.params + req = GenerateVideoRequest( + prompt=str(params.get("prompt", "")), + resolution=str(params.get("resolution", "512p")), + model=job.model, + cameraMotion=str(params.get("cameraMotion", "none")), # type: ignore[arg-type] + duration=str(params.get("duration", "2")), + fps=str(params.get("fps", "24")), + audio=str(params.get("audio", "false")), + imagePath=str(params.get("imagePath")) if params.get("imagePath") else None, + audioPath=str(params.get("audioPath")) if params.get("audioPath") else None, + lastFramePath=str(params.get("lastFramePath")) if params.get("lastFramePath") else None, + aspectRatio=str(params.get("aspectRatio", "16:9")), # type: ignore[arg-type] + loraPath=str(params.get("loraPath")) if params.get("loraPath") else None, + loraWeight=float(params.get("loraWeight", 1.0)), + ) + result = self._handler.video_generation.generate(req) + if result.status == "cancelled": + raise RuntimeError("Generation was cancelled") + if result.video_path is None: + return [] + return [result.video_path] diff --git a/backend/handlers/library_handler.py b/backend/handlers/library_handler.py new file mode 100644 index 00000000..16a93490 --- /dev/null +++ b/backend/handlers/library_handler.py @@ -0,0 +1,118 @@ +"""Library handler for characters, styles, and references.""" + +from __future__ import annotations + +from _routes._errors import HTTPError +from state.library_store import Character, LibraryStore, Reference, ReferenceCategory, Style + + +class LibraryHandler: + """Business logic for the local library (characters, styles, references).""" + + def __init__(self, store: LibraryStore) -> None: + self._store = store + + # ------------------------------------------------------------------ + # Characters + # ------------------------------------------------------------------ + + def list_characters(self) -> list[Character]: + return self._store.list_characters() + + def create_character( + self, + *, + name: str, + role: str, + description: str, + reference_image_paths: list[str] | None = None, + ) -> Character: + if not name.strip(): + raise HTTPError(400, "Character name must not be empty") + return self._store.create_character( + name=name, + role=role, + description=description, + reference_image_paths=reference_image_paths, + ) + + def update_character( + self, + character_id: str, + *, + name: str | None = None, + role: str | None = None, + description: str | None = None, + reference_image_paths: list[str] | None = None, + ) -> Character: + if name is not None and not name.strip(): + raise HTTPError(400, "Character name must not be empty") + character = self._store.update_character( + character_id, + name=name, + role=role, + description=description, + reference_image_paths=reference_image_paths, + ) + if character is None: + raise HTTPError(404, f"Character {character_id} not found") + return character + + def delete_character(self, character_id: str) -> None: + deleted = self._store.delete_character(character_id) + if not deleted: + raise HTTPError(404, f"Character {character_id} not found") + + # ------------------------------------------------------------------ + # Styles + # ------------------------------------------------------------------ + + def list_styles(self) -> list[Style]: + return self._store.list_styles() + + def create_style( + self, + *, + name: str, + description: str, + reference_image_path: str = "", + ) -> Style: + if not name.strip(): + raise HTTPError(400, "Style name must not be empty") + return self._store.create_style( + name=name, + description=description, + reference_image_path=reference_image_path, + ) + + def delete_style(self, style_id: str) -> None: + deleted = self._store.delete_style(style_id) + if not deleted: + raise HTTPError(404, f"Style {style_id} not found") + + # ------------------------------------------------------------------ + # References + # ------------------------------------------------------------------ + + def list_references(self, category: ReferenceCategory | None = None) -> list[Reference]: + return self._store.list_references(category) + + def create_reference( + self, + *, + name: str, + category: ReferenceCategory, + image_path: str = "", + ) -> Reference: + if not name.strip(): + raise HTTPError(400, "Reference name must not be empty") + return self._store.create_reference( + name=name, + category=category, + image_path=image_path, + ) + + def delete_reference(self, reference_id: str) -> None: + deleted = self._store.delete_reference(reference_id) + if not deleted: + raise HTTPError(404, f"Reference {reference_id} not found") diff --git a/backend/handlers/lora_handler.py b/backend/handlers/lora_handler.py new file mode 100644 index 00000000..2c58d8ab --- /dev/null +++ b/backend/handlers/lora_handler.py @@ -0,0 +1,273 @@ +"""Handler for LoRA library operations — search CivitAI, download, manage local catalog.""" + +from __future__ import annotations + +import hashlib +import logging +import shutil +from dataclasses import asdict +from pathlib import Path +from typing import Any + +import requests + +from state.lora_library import LoraEntry, LoraLibraryStore + +_logger = logging.getLogger(__name__) + +CIVITAI_API_BASE = "https://civitai.com/api/v1" + + +class LoraHandler: + """Search CivitAI, download LoRAs, and manage the local catalog.""" + + def __init__(self, store: LoraLibraryStore, civitai_api_key: str = "") -> None: + self._store = store + self._civitai_api_key = civitai_api_key + + def set_api_key(self, key: str) -> None: + self._civitai_api_key = key + + # ── CivitAI Search ────────────────────────────────────────────── + + def search_civitai( + self, + query: str = "", + base_model: str = "", + sort: str = "Most Downloaded", + limit: int = 20, + page: int = 1, + nsfw: bool = False, + ) -> dict[str, Any]: + """Search CivitAI for LORA models. Returns raw API response.""" + params: dict[str, str | int | bool] = { + "types": "LORA", + "limit": limit, + "page": page, + "sort": sort, + "nsfw": nsfw, + } + if query: + params["query"] = query + if base_model: + params["baseModels"] = base_model + + headers: dict[str, str] = {} + if self._civitai_api_key: + headers["Authorization"] = f"Bearer {self._civitai_api_key}" + + resp = requests.get( + f"{CIVITAI_API_BASE}/models", + params=params, # type: ignore[arg-type] + headers=headers, + timeout=15, + ) + resp.raise_for_status() + data: dict[str, Any] = resp.json() + + # Normalize to a cleaner format for the frontend + items: list[dict[str, Any]] = [] + for model in data.get("items", []): + versions = model.get("modelVersions", []) + if not versions: + continue + latest = versions[0] + files = latest.get("files", []) + safetensors_file = next( + (f for f in files if f.get("name", "").endswith(".safetensors")), + files[0] if files else None, + ) + + # Get preview image + images = latest.get("images", []) + thumbnail = images[0].get("url", "") if images else "" + + # Get trigger words + trigger_words = latest.get("trainedWords", []) + + items.append({ + "civitaiModelId": model.get("id"), + "civitaiVersionId": latest.get("id"), + "name": model.get("name", "Unknown"), + "description": (model.get("description") or "")[:200], + "thumbnailUrl": thumbnail, + "triggerPhrase": ", ".join(trigger_words) if trigger_words else "", + "baseModel": latest.get("baseModel", ""), + "downloadUrl": safetensors_file.get("downloadUrl", "") if safetensors_file else "", + "fileSizeBytes": safetensors_file.get("sizeKB", 0) * 1024 if safetensors_file else 0, + "fileName": safetensors_file.get("name", "") if safetensors_file else "", + "stats": { + "downloadCount": model.get("stats", {}).get("downloadCount", 0), + "favoriteCount": model.get("stats", {}).get("favoriteCount", 0), + "thumbsUpCount": model.get("stats", {}).get("thumbsUpCount", 0), + "rating": model.get("stats", {}).get("rating", 0), + }, + # Check if already downloaded + "isDownloaded": self._is_downloaded(model.get("id"), latest.get("id")), + }) + + return { + "items": items, + "metadata": data.get("metadata", {}), + } + + def _is_downloaded(self, model_id: int | None, version_id: int | None) -> bool: + if model_id is None: + return False + for entry in self._store.list_all(): + if entry.civitai_model_id == model_id: + if version_id is None or entry.civitai_version_id == version_id: + return True + return False + + # ── Download ──────────────────────────────────────────────────── + + def download_lora( + self, + download_url: str, + file_name: str, + name: str, + thumbnail_url: str = "", + trigger_phrase: str = "", + base_model: str = "", + civitai_model_id: int | None = None, + civitai_version_id: int | None = None, + description: str = "", + on_progress: Any = None, + ) -> LoraEntry: + """Download a LoRA file and add it to the catalog.""" + loras_dir = self._store.loras_dir + dest = loras_dir / file_name + + headers: dict[str, str] = {} + if self._civitai_api_key: + headers["Authorization"] = f"Bearer {self._civitai_api_key}" + + _logger.info("Downloading LoRA %s to %s", name, dest) + resp = requests.get(download_url, headers=headers, stream=True, timeout=30) + resp.raise_for_status() + + total = int(resp.headers.get("content-length", 0)) + downloaded = 0 + + with open(dest, "wb") as f: + for chunk in resp.iter_content(chunk_size=1024 * 1024): + f.write(chunk) + downloaded += len(chunk) + if on_progress and total > 0: + on_progress(downloaded, total) + + # Download thumbnail locally + local_thumb = "" + if thumbnail_url: + try: + local_thumb = self._download_thumbnail(thumbnail_url, file_name) + except Exception: + _logger.warning("Failed to download thumbnail for %s", name, exc_info=True) + local_thumb = thumbnail_url # Fall back to remote URL + + lora_id = hashlib.sha256(f"{civitai_model_id}:{civitai_version_id}:{file_name}".encode()).hexdigest()[:16] + + entry = LoraEntry( + id=lora_id, + name=name, + file_path=str(dest), + file_size_bytes=dest.stat().st_size, + thumbnail_url=local_thumb, + trigger_phrase=trigger_phrase, + base_model=base_model, + civitai_model_id=civitai_model_id, + civitai_version_id=civitai_version_id, + description=description, + ) + self._store.add(entry) + _logger.info("LoRA %s downloaded and cataloged (id=%s)", name, lora_id) + return entry + + def _download_thumbnail(self, url: str, lora_filename: str) -> str: + """Download thumbnail image to loras/thumbnails/ and return local path.""" + thumb_dir = self._store.loras_dir / "thumbnails" + thumb_dir.mkdir(exist_ok=True) + + stem = Path(lora_filename).stem + # Detect extension from URL + ext = ".jpg" + if ".png" in url.lower(): + ext = ".png" + elif ".webp" in url.lower(): + ext = ".webp" + + thumb_path = thumb_dir / f"{stem}{ext}" + resp = requests.get(url, timeout=15) + resp.raise_for_status() + thumb_path.write_bytes(resp.content) + return str(thumb_path) + + # ── Local Library ─────────────────────────────────────────────── + + def get_entry(self, lora_id: str) -> dict[str, Any] | None: + entry = self._store.get(lora_id) + if entry is None: + return None + return asdict(entry) + + def list_library(self) -> list[dict[str, Any]]: + return [asdict(e) for e in self._store.list_all()] + + def delete_lora(self, lora_id: str) -> bool: + entry = self._store.get(lora_id) + if entry is None: + return False + + # Delete the file + lora_path = Path(entry.file_path) + if lora_path.exists(): + lora_path.unlink() + + # Delete thumbnail if local + if entry.thumbnail_url and Path(entry.thumbnail_url).exists(): + try: + Path(entry.thumbnail_url).unlink() + except Exception: + pass + + return self._store.remove(lora_id) + + def import_local_lora( + self, + file_path: str, + name: str = "", + trigger_phrase: str = "", + thumbnail_path: str = "", + ) -> LoraEntry: + """Import a LoRA from the local filesystem into the library.""" + src = Path(file_path) + if not src.exists(): + raise FileNotFoundError(f"LoRA file not found: {file_path}") + + # Copy to loras dir if not already there + dest = self._store.loras_dir / src.name + if dest != src: + shutil.copy2(src, dest) + + lora_id = hashlib.sha256(f"local:{src.name}".encode()).hexdigest()[:16] + + # Copy thumbnail if provided + local_thumb = "" + if thumbnail_path and Path(thumbnail_path).exists(): + thumb_dir = self._store.loras_dir / "thumbnails" + thumb_dir.mkdir(exist_ok=True) + thumb_dest = thumb_dir / f"{src.stem}{Path(thumbnail_path).suffix}" + shutil.copy2(thumbnail_path, thumb_dest) + local_thumb = str(thumb_dest) + + entry = LoraEntry( + id=lora_id, + name=name or src.stem, + file_path=str(dest), + file_size_bytes=dest.stat().st_size, + thumbnail_url=local_thumb, + trigger_phrase=trigger_phrase, + ) + self._store.add(entry) + return entry diff --git a/backend/handlers/models_handler.py b/backend/handlers/models_handler.py index 12244037..4d9393bf 100644 --- a/backend/handlers/models_handler.py +++ b/backend/handlers/models_handler.py @@ -6,13 +6,16 @@ from threading import RLock from typing import TYPE_CHECKING -from api_types import ModelFileStatus, ModelInfo, ModelsStatusResponse, TextEncoderStatus +from api_types import ModelFileStatus, ModelInfo, ModelsStatusResponse, TextEncoderStatus, VideoModelGuideResponse, VideoModelScanResponse from handlers.base import StateHandlerBase, with_state_lock from runtime_config.model_download_specs import MODEL_FILE_ORDER, resolve_required_model_types +from services.model_scanner.model_guide_data import DISTILLED_LORA_INFO, MODEL_FORMATS, recommend_format from state.app_state_types import AppState, AvailableFiles if TYPE_CHECKING: from runtime_config.runtime_config import RuntimeConfig + from services.gpu_info.gpu_info import GpuInfo + from services.model_scanner.model_scanner import ModelScanner class ModelsHandler(StateHandlerBase): @@ -21,9 +24,13 @@ def __init__( state: AppState, lock: RLock, config: RuntimeConfig, + model_scanner: ModelScanner, + gpu_info_service: GpuInfo, ) -> None: super().__init__(state, lock) self._config = config + self._model_scanner = model_scanner + self._gpu_info = gpu_info_service @staticmethod def _path_size(path: Path, is_folder: bool) -> int: @@ -107,6 +114,13 @@ def get_models_status(self, has_api_key: bool | None = None) -> ModelsStatusResp if model_type == "text_encoder": description += " (optional with API key)" if has_api_key else "" optional_reason = "Uses LTX API for text encoding" if has_api_key else None + if model_type == "text_encoder_abliterated": + description += " (optional)" + optional_reason = ( + "Uncensored text encoder variant. Removes safety alignment from Gemma-3 " + "so prompts are encoded without content filtering, giving more flexible " + "and faithful prompt interpretation for video generation." + ) models.append( ModelFileStatus( @@ -135,3 +149,70 @@ def get_models_status(self, has_api_key: bool | None = None) -> ModelsStatusResp text_encoder_status=self.get_text_encoder_status(), use_local_text_encoder=settings.use_local_text_encoder, ) + + def _get_video_models_dir(self) -> Path: + """Return the custom video model path if set, else the default models dir.""" + custom = self.state.app_settings.custom_video_model_path + if custom: + return Path(custom) + return self._config.models_dir + + def scan_video_models(self) -> VideoModelScanResponse: + folder = self._get_video_models_dir() + models = self._model_scanner.scan_video_models(folder) + distilled_lora_found = self._check_distilled_lora(folder) + return VideoModelScanResponse( + models=models, + distilled_lora_found=distilled_lora_found, + ) + + @with_state_lock + def select_video_model(self, model: str) -> None: + from _routes._errors import HTTPError + from state.app_state_types import GenerationRunning, GpuSlot + + match self.state.gpu_slot: + case GpuSlot(generation=GenerationRunning()): + raise HTTPError(409, "Cannot change model while generation is running") + case _: + pass + + folder = self._get_video_models_dir() + model_path = folder / model + if not model_path.exists(): + raise HTTPError(400, f"Model file not found: {model}") + + self.state.app_settings.selected_video_model = model + + def video_model_guide(self) -> VideoModelGuideResponse: + vram_gb: int | None = None + gpu_name: str | None = None + try: + vram_gb = self._gpu_info.get_vram_total_gb() + gpu_name = self._gpu_info.get_device_name() + except Exception: + pass + + return VideoModelGuideResponse( + recommended_format=recommend_format(vram_gb), + formats=MODEL_FORMATS, + distilled_lora=DISTILLED_LORA_INFO, + vram_gb=vram_gb, + gpu_name=gpu_name, + ) + + @staticmethod + def _check_distilled_lora(folder: Path) -> bool: + if not folder.exists(): + return False + search_dirs = [folder, folder / "loras"] + for search_dir in search_dirs: + if not search_dir.exists(): + continue + try: + for entry in search_dir.iterdir(): + if entry.is_file() and "distill" in entry.name.lower(): + return True + except OSError: + continue + return False diff --git a/backend/handlers/pipelines_handler.py b/backend/handlers/pipelines_handler.py index 21b50059..51fe5f2b 100644 --- a/backend/handlers/pipelines_handler.py +++ b/backend/handlers/pipelines_handler.py @@ -5,6 +5,7 @@ import logging from pathlib import Path from threading import RLock +from collections.abc import Callable from typing import TYPE_CHECKING import torch @@ -20,6 +21,8 @@ RetakePipeline, VideoPipelineModelType, ) +from services.gpu_optimizations.ffn_chunking import patch_ffn_chunking +from services.gpu_optimizations.tea_cache import install_tea_cache_patch from services.services_utils import device_supports_fp8, get_device_type from state.app_state_types import ( A2VPipelineState, @@ -47,7 +50,10 @@ def __init__( text_handler: TextHandler, gpu_cleaner: GpuCleaner, fast_video_pipeline_class: type[FastVideoPipeline], + gguf_video_pipeline_class: type[FastVideoPipeline] | None, + nf4_video_pipeline_class: type[FastVideoPipeline] | None, image_generation_pipeline_class: type[ImageGenerationPipeline], + flux_klein_pipeline_class: type[ImageGenerationPipeline] | None, ic_lora_pipeline_class: type[IcLoraPipeline], a2v_pipeline_class: type[A2VPipeline], retake_pipeline_class: type[RetakePipeline], @@ -59,7 +65,10 @@ def __init__( self._text_handler = text_handler self._gpu_cleaner = gpu_cleaner self._fast_video_pipeline_class = fast_video_pipeline_class + self._gguf_video_pipeline_class = gguf_video_pipeline_class + self._nf4_video_pipeline_class = nf4_video_pipeline_class self._image_generation_pipeline_class = image_generation_pipeline_class + self._flux_klein_pipeline_class = flux_klein_pipeline_class self._ic_lora_pipeline_class = ic_lora_pipeline_class self._a2v_pipeline_class = a2v_pipeline_class self._retake_pipeline_class = retake_pipeline_class @@ -67,6 +76,8 @@ def __init__( self._outputs_dir = outputs_dir self._device = device self._runtime_device = get_device_type(device) + # Track which image model is currently loaded (zit or flux_klein) + self._loaded_image_model: str | None = None def _ensure_no_running_generation(self) -> None: match self.state.gpu_slot: @@ -117,25 +128,76 @@ def _compile_if_enabled(self, state: VideoPipelineState) -> VideoPipelineState: logger.warning("Failed to compile transformer: %s", exc, exc_info=True) return state - def _create_video_pipeline(self, model_type: VideoPipelineModelType) -> VideoPipelineState: + def _create_video_pipeline( + self, + model_type: VideoPipelineModelType, + lora_path: str | None = None, + lora_weight: float = 1.0, + ) -> VideoPipelineState: gemma_root = self._text_handler.resolve_gemma_root() - checkpoint_path = str(self._config.model_path("checkpoint")) + # Determine checkpoint path and pipeline class based on selected model + selected = self.state.app_settings.selected_video_model + custom_dir = self.state.app_settings.custom_video_model_path + pipeline_class = self._fast_video_pipeline_class + + if selected: + base_dir = Path(custom_dir) if custom_dir else self._config.models_dir + model_path = base_dir / selected + checkpoint_path = str(model_path) + + # Validate model file/folder still exists on disk + if not model_path.exists(): + raise FileNotFoundError( + f"Selected model not found: {checkpoint_path}. " + "Go to Settings → Models to select a different model." + ) + + if selected.endswith(".gguf") and self._gguf_video_pipeline_class is not None: + pipeline_class = self._gguf_video_pipeline_class + elif model_path.is_dir() and self._nf4_video_pipeline_class is not None: + # NF4 models are folders + pipeline_class = self._nf4_video_pipeline_class + # else: safetensors files use default pipeline + else: + checkpoint_path = str(self._config.model_path("checkpoint")) + upsampler_path = str(self._config.model_path("upsampler")) - pipeline = self._fast_video_pipeline_class.create( + pipeline = pipeline_class.create( checkpoint_path, gemma_root, upsampler_path, self._device, + lora_path=lora_path, + lora_weight=lora_weight, ) state = VideoPipelineState( pipeline=pipeline, warmth=VideoPipelineWarmth.COLD, is_compiled=False, + lora_path=lora_path, ) - return self._compile_if_enabled(state) + state = self._compile_if_enabled(state) + + # Apply FFN chunking if enabled and torch.compile is not active + chunk_count = self.state.app_settings.ffn_chunk_count + if chunk_count > 0 and not state.is_compiled: + try: + transformer: torch.nn.Module = state.pipeline.pipeline.model_ledger.transformer() # type: ignore[union-attr] + patch_ffn_chunking(transformer, num_chunks=chunk_count) # pyright: ignore[reportUnknownArgumentType] + except AttributeError: + logger.debug("FFN chunking skipped — pipeline has no model_ledger") + + # Install TeaCache denoising loop patch + tea_threshold = self.state.app_settings.tea_cache_threshold + try: + install_tea_cache_patch(tea_threshold) + except (ImportError, AttributeError): + logger.debug("TeaCache skipped — ltx_pipelines not available") + + return state def unload_gpu_pipeline(self) -> None: with self._lock: @@ -170,13 +232,26 @@ def park_zit_on_cpu(self) -> None: self.state.cpu_slot = CpuSlot(active_pipeline=zit) self._assert_invariants() - def load_zit_to_gpu(self) -> ImageGenerationPipeline: + def load_zit_to_gpu( + self, + on_phase: "Callable[[str], None] | None" = None, + ) -> ImageGenerationPipeline: + def _report(phase: str) -> None: + if on_phase is not None: + on_phase(phase) + with self._lock: if self.state.gpu_slot is not None: active = self.state.gpu_slot.active_pipeline if not isinstance(active, (VideoPipelineState, ICLoraState, A2VPipelineState, RetakePipelineState)): return active self._ensure_no_running_generation() + # Unload the video/other pipeline from GPU before loading ZIT + _report("unloading_video_model") + self.state.gpu_slot = None + self._assert_invariants() + _report("cleaning_gpu") + self._gpu_cleaner.cleanup() zit_service: ImageGenerationPipeline | None = None @@ -188,6 +263,7 @@ def load_zit_to_gpu(self) -> ImageGenerationPipeline: case _: zit_service = None + _report("loading_image_model") if zit_service is None: zit_path = self._config.model_path("zit") if not (zit_path.exists() and any(zit_path.iterdir())): @@ -196,14 +272,73 @@ def load_zit_to_gpu(self) -> ImageGenerationPipeline: else: zit_service.to(self._runtime_device) - self._gpu_cleaner.cleanup() - with self._lock: self.state.gpu_slot = GpuSlot(active_pipeline=zit_service, generation=None) + self._loaded_image_model = "zit" self._assert_invariants() return zit_service + def load_image_model_to_gpu( + self, + model_name: str = "flux-klein-9b", + on_phase: "Callable[[str], None] | None" = None, + ) -> ImageGenerationPipeline: + """Load the requested image model to GPU. + + Supports 'zit' (Z-Image-Turbo) and 'flux_klein' (FLUX.2 Klein 9B). + If the requested model is already loaded, returns it immediately. + """ + if model_name == "flux-klein-9b" or model_name == "flux_klein": + return self._load_flux_klein_to_gpu(on_phase=on_phase) + # Default: load ZIT + return self.load_zit_to_gpu(on_phase=on_phase) + + def _load_flux_klein_to_gpu( + self, + on_phase: "Callable[[str], None] | None" = None, + ) -> ImageGenerationPipeline: + """Load FLUX.2 Klein 9B to GPU.""" + def _report(phase: str) -> None: + if on_phase is not None: + on_phase(phase) + + # If FLUX Klein is already on GPU, return it (unless pipeline was + # destroyed after previous generation to avoid Windows VAE segfault) + with self._lock: + if self.state.gpu_slot is not None and self._loaded_image_model == "flux_klein": + active = self.state.gpu_slot.active_pipeline + if not isinstance(active, (VideoPipelineState, ICLoraState, A2VPipelineState, RetakePipelineState)): + if hasattr(active, "pipeline"): + return active + # Pipeline was destroyed after last gen; fall through to reload + + # Evict whatever is on GPU + if self.state.gpu_slot is not None: + self._ensure_no_running_generation() + _report("unloading_image_model") + self.state.gpu_slot = None + self.state.cpu_slot = None # Don't cache ZIT when switching to FLUX + self._assert_invariants() + + _report("cleaning_gpu") + self._gpu_cleaner.cleanup() + + _report("loading_image_model") + if self._flux_klein_pipeline_class is None: + raise RuntimeError("FLUX.2 Klein pipeline class not configured") + flux_path = self._config.model_path("flux_klein") + if not (flux_path.exists() and any(flux_path.iterdir())): + raise RuntimeError("FLUX.2 Klein 9B model not downloaded. Please download it from the Model Status menu.") + flux_service = self._flux_klein_pipeline_class.create(str(flux_path), self._runtime_device) + + with self._lock: + self.state.gpu_slot = GpuSlot(active_pipeline=flux_service, generation=None) + self._loaded_image_model = "flux_klein" + self._assert_invariants() + + return flux_service + def preload_zit_to_cpu(self) -> ImageGenerationPipeline: with self._lock: match self.state.cpu_slot: @@ -246,21 +381,38 @@ def _evict_gpu_pipeline_for_swap(self) -> None: elif should_cleanup: self._gpu_cleaner.cleanup() - def load_gpu_pipeline(self, model_type: VideoPipelineModelType, should_warm: bool = False) -> VideoPipelineState: + def load_gpu_pipeline( + self, + model_type: VideoPipelineModelType, + should_warm: bool = False, + on_phase: Callable[[str], None] | None = None, + lora_path: str | None = None, + lora_weight: float = 1.0, + ) -> VideoPipelineState: self._install_text_patches_if_needed() + def _report(phase: str) -> None: + if on_phase is not None: + on_phase(phase) + state: VideoPipelineState | None = None with self._lock: if self._pipeline_matches_model_type(model_type): match self.state.gpu_slot: case GpuSlot(active_pipeline=VideoPipelineState() as existing_state): - state = existing_state + # Reload if the LoRA changed + if existing_state.lora_path == lora_path: + state = existing_state case _: pass if state is None: + _report("unloading_image_model") self._evict_gpu_pipeline_for_swap() - state = self._create_video_pipeline(model_type) + if lora_path: + _report("loading_lora") + _report("loading_video_model") + state = self._create_video_pipeline(model_type, lora_path=lora_path, lora_weight=lora_weight) with self._lock: self.state.gpu_slot = GpuSlot(active_pipeline=state, generation=None) self._assert_invariants() diff --git a/backend/handlers/prompt_handler.py b/backend/handlers/prompt_handler.py new file mode 100644 index 00000000..4f8d4f7b --- /dev/null +++ b/backend/handlers/prompt_handler.py @@ -0,0 +1,86 @@ +"""Prompt library and wildcard management handler.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from threading import RLock +from typing import Literal + +from handlers.base import StateHandlerBase, with_state_lock +from services.wildcard_parser import WildcardDef, expand_prompt, expand_random +from state.app_state_types import AppState +from state.prompt_store import PromptStore, SavedPrompt, WildcardEntry + +logger = logging.getLogger(__name__) + + +class PromptHandler(StateHandlerBase): + """Domain handler for saved prompts and wildcard definitions.""" + + def __init__(self, state: AppState, lock: RLock, store_path: Path) -> None: + super().__init__(state, lock) + self._store = PromptStore(store_path) + + # ------------------------------------------------------------------ + # Prompts + # ------------------------------------------------------------------ + + @with_state_lock + def list_prompts( + self, + search: str | None = None, + tag: str | None = None, + sort_by: str | None = None, + ) -> list[SavedPrompt]: + return self._store.list_prompts(search=search, tag=tag, sort_by=sort_by) + + @with_state_lock + def save_prompt(self, text: str, tags: list[str], category: str) -> SavedPrompt: + return self._store.save_prompt(text, tags, category) + + @with_state_lock + def delete_prompt(self, prompt_id: str) -> bool: + return self._store.delete_prompt(prompt_id) + + @with_state_lock + def increment_usage(self, prompt_id: str) -> SavedPrompt | None: + return self._store.increment_usage(prompt_id) + + # ------------------------------------------------------------------ + # Wildcards + # ------------------------------------------------------------------ + + @with_state_lock + def list_wildcards(self) -> list[WildcardEntry]: + return self._store.list_wildcards() + + @with_state_lock + def create_wildcard(self, name: str, values: list[str]) -> WildcardEntry: + return self._store.create_wildcard(name, values) + + @with_state_lock + def update_wildcard(self, wildcard_id: str, values: list[str]) -> WildcardEntry | None: + return self._store.update_wildcard(wildcard_id, values) + + @with_state_lock + def delete_wildcard(self, wildcard_id: str) -> bool: + return self._store.delete_wildcard(wildcard_id) + + # ------------------------------------------------------------------ + # Wildcard expansion + # ------------------------------------------------------------------ + + @with_state_lock + def expand_wildcards( + self, + prompt: str, + mode: Literal["all", "random"] = "random", + count: int = 1, + ) -> list[str]: + """Expand wildcards in *prompt* using stored wildcard definitions.""" + entries = self._store.list_wildcards() + defs = [WildcardDef(name=e.name, values=e.values) for e in entries] + if mode == "all": + return expand_prompt(prompt, defs) + return expand_random(prompt, defs, count=count) diff --git a/backend/handlers/queue_worker.py b/backend/handlers/queue_worker.py new file mode 100644 index 00000000..7981596c --- /dev/null +++ b/backend/handlers/queue_worker.py @@ -0,0 +1,208 @@ +"""Background queue worker that processes jobs from the job queue.""" + +from __future__ import annotations + +import logging +import threading +from typing import Callable, Protocol + +from services.interfaces import GpuCleaner +from state.job_queue import JobQueue, QueueJob + +logger = logging.getLogger(__name__) + + +class JobExecutor(Protocol): + def execute(self, job: QueueJob) -> list[str]: + ... + + +class EnhancePromptProvider(Protocol): + def enhance_i2v_motion(self, image_path: str) -> str: + ... + + +class CreditDeductor(Protocol): + def deduct_credits( + self, generation_type: str, count: int, + metadata: dict[str, object] | None, + ) -> dict[str, object]: + ... + + +def _credit_type_for_job(job: QueueJob) -> str | None: + """Map a completed API job to its Palette credit type. Returns None for local GPU jobs.""" + model = job.model.lower() + if "seedance" in model: + return "video_seedance" + if job.type == "image": + return "image" + has_image = bool(job.params.get("imagePath")) + if has_image: + return "video_i2v" + return "video_t2v" + + +class QueueWorker: + def __init__( + self, + *, + queue: JobQueue, + gpu_executor: JobExecutor, + api_executor: JobExecutor, + gpu_cleaner: GpuCleaner | None = None, + on_batch_complete: Callable[[str, list[QueueJob]], None] | None = None, + enhance_handler: EnhancePromptProvider | None = None, + credit_deductor: CreditDeductor | None = None, + ) -> None: + self._queue = queue + self._gpu_executor = gpu_executor + self._api_executor = api_executor + self._gpu_cleaner = gpu_cleaner + self._gpu_busy = False + self._api_busy = False + self._lock = threading.Lock() + self._on_batch_complete = on_batch_complete + self._enhance_handler = enhance_handler + self._credit_deductor = credit_deductor + self._notified_batches: set[str] = set() + + def tick(self) -> None: + """Process one round: pick up available jobs for each free slot. + + Non-blocking — spawns daemon threads for each job so the tick loop + can keep checking for new jobs on other slots. + """ + # First, fail any jobs whose dependencies have errored/cancelled + self._fail_orphaned_dependents() + + # Recover from stuck busy flags — if slot is busy but no job is actually + # running for that slot, reset the flag. This handles cases where a job + # was cancelled externally while the executor thread was still working. + self._recover_stuck_slots() + + gpu_job: QueueJob | None = None + api_job: QueueJob | None = None + + with self._lock: + if not self._gpu_busy: + gpu_job = self._next_ready_job("gpu") + if gpu_job is not None: + self._gpu_busy = True + self._queue.update_job(gpu_job.id, status="running", phase="starting") + + if not self._api_busy: + api_job = self._next_ready_job("api") + if api_job is not None: + self._api_busy = True + self._queue.update_job(api_job.id, status="running", phase="starting") + + if gpu_job is not None: + t = threading.Thread(target=self._run_job, args=(gpu_job, self._gpu_executor, "gpu"), daemon=True) + t.start() + + if api_job is not None: + t = threading.Thread(target=self._run_job, args=(api_job, self._api_executor, "api"), daemon=True) + t.start() + + self._check_batch_completions() + + def _recover_stuck_slots(self) -> None: + """Reset busy flags when no job is actually running for that slot. + + This handles the case where a job is cancelled via the API while the + executor thread is mid-work. The thread eventually finishes (or the + process was restarted), but _gpu_busy / _api_busy stayed True. + """ + has_running_gpu = any( + j.status == "running" and j.slot == "gpu" for j in self._queue.all_jobs() + ) + has_running_api = any( + j.status == "running" and j.slot == "api" for j in self._queue.all_jobs() + ) + with self._lock: + if self._gpu_busy and not has_running_gpu: + logger.info("Recovering stuck GPU slot — no running GPU jobs found") + self._gpu_busy = False + if self._api_busy and not has_running_api: + logger.info("Recovering stuck API slot — no running API jobs found") + self._api_busy = False + + def _next_ready_job(self, slot: str) -> QueueJob | None: + for job in self._queue.queued_jobs_for_slot(slot): + if job.depends_on is None: + return job + dep = self._queue.get_job(job.depends_on) + if dep is None: + return job # Dependency missing, run anyway + if dep.status == "complete": + self._resolve_auto_params(job, dep) + return job + # dep still queued/running or already handled by _fail_orphaned_dependents + continue + return None + + def _fail_orphaned_dependents(self) -> None: + for job in self._queue.all_jobs(): + if job.status != "queued" or job.depends_on is None: + continue + dep = self._queue.get_job(job.depends_on) + if dep is not None and dep.status in ("error", "cancelled"): + self._queue.update_job( + job.id, + status="error", + error=f"Upstream job {dep.id} failed: {dep.error or dep.status}", + ) + + def _resolve_auto_params(self, job: QueueJob, dep: QueueJob) -> None: + for key, template in list(job.auto_params.items()): + if template == "$dep.result_paths[0]" and dep.result_paths: + job.params[key] = dep.result_paths[0] + + if job.auto_params.get("auto_prompt") == "true" and self._enhance_handler: + image_path = job.params.get("imagePath", dep.result_paths[0] if dep.result_paths else "") + if image_path: + motion_prompt = self._enhance_handler.enhance_i2v_motion(str(image_path)) + job.params["prompt"] = motion_prompt + + def _check_batch_completions(self) -> None: + seen: set[str] = set() + for job in self._queue.all_jobs(): + if job.batch_id and job.batch_id not in self._notified_batches: + seen.add(job.batch_id) + for batch_id in seen: + jobs = self._queue.jobs_for_batch(batch_id) + if all(j.status in ("complete", "error", "cancelled") for j in jobs): + self._notified_batches.add(batch_id) + if self._on_batch_complete: + self._on_batch_complete(batch_id, jobs) + + def _run_job(self, job: QueueJob, executor: JobExecutor, slot: str) -> None: + try: + result_paths = executor.execute(job) + self._queue.update_job(job.id, status="complete", progress=100, phase="complete", result_paths=result_paths) + # Deduct credits for API-slot jobs (local GPU jobs are free) + if slot == "api" and self._credit_deductor is not None: + credit_type = _credit_type_for_job(job) + if credit_type: + try: + self._credit_deductor.deduct_credits( + credit_type, 1, + {"model": job.model, "job_id": job.id}, + ) + except Exception as exc: + logger.warning("Credit deduction failed for job %s: %s", job.id, exc) + except Exception as exc: + logger.error("Job %s failed: %s", job.id, exc) + self._queue.update_job(job.id, status="error", error=str(exc)) + finally: + if slot == "gpu" and self._gpu_cleaner is not None: + try: + self._gpu_cleaner.deep_cleanup() + except Exception: + pass + with self._lock: + if slot == "gpu": + self._gpu_busy = False + else: + self._api_busy = False diff --git a/backend/handlers/receive_job_handler.py b/backend/handlers/receive_job_handler.py new file mode 100644 index 00000000..ee6e0be1 --- /dev/null +++ b/backend/handlers/receive_job_handler.py @@ -0,0 +1,78 @@ +"""Handler for receiving generation jobs from Director's Palette.""" +from __future__ import annotations + +import logging +import tempfile +from typing import TYPE_CHECKING + +from _routes._errors import HTTPError +from api_types import ReceiveJobRequest, ReceiveJobResponse + +if TYPE_CHECKING: + from services.interfaces import HTTPClient + from state.app_state_types import AppState + from state.job_queue import JobQueue + +logger = logging.getLogger(__name__) + + +class ReceiveJobHandler: + def __init__( + self, + state: AppState, + http: "HTTPClient", + job_queue: "JobQueue", + ) -> None: + self._state = state + self._http = http + self._job_queue = job_queue + + def receive_job(self, req: ReceiveJobRequest) -> ReceiveJobResponse: + api_key = self._state.app_settings.palette_api_key + if not api_key: + raise HTTPError(403, "Not connected to Director's Palette. Set a palette API key first.") + + params: dict[str, object] = { + "prompt": req.prompt, + "resolution": req.settings.resolution, + "duration": req.settings.duration, + "fps": req.settings.fps, + "aspectRatio": req.settings.aspect_ratio, + } + + if req.character_id is not None: + params["character_id"] = req.character_id + + if req.first_frame_url is not None: + local_path = self._download_remote_image(req.first_frame_url, "first_frame") + params["imagePath"] = local_path + + if req.last_frame_url is not None: + local_path = self._download_remote_image(req.last_frame_url, "last_frame") + params["lastFramePath"] = local_path + + job = self._job_queue.submit( + job_type="video", + model=req.model, + params=params, + slot="api", + ) + + return ReceiveJobResponse(id=job.id, status=job.status) + + def _download_remote_image(self, url: str, prefix: str) -> str: + try: + resp = self._http.get(url, timeout=30) + if resp.status_code != 200: + raise HTTPError(400, f"Failed to download image from {url}: HTTP {resp.status_code}") + suffix = ".png" + if ".jpg" in url or ".jpeg" in url: + suffix = ".jpg" + tmp = tempfile.NamedTemporaryFile(prefix=f"{prefix}_", suffix=suffix, delete=False) + tmp.write(resp.content) + tmp.close() + return tmp.name + except HTTPError: + raise + except Exception as exc: + raise HTTPError(400, f"Failed to download image from {url}") from exc diff --git a/backend/handlers/retake_handler.py b/backend/handlers/retake_handler.py index ebc7856e..f309af65 100644 --- a/backend/handlers/retake_handler.py +++ b/backend/handlers/retake_handler.py @@ -3,7 +3,6 @@ from __future__ import annotations import uuid -from datetime import datetime from pathlib import Path from threading import RLock import time @@ -15,6 +14,7 @@ from handlers.pipelines_handler import PipelinesHandler from handlers.text_handler import TextHandler from runtime_config.runtime_config import RuntimeConfig +from server_utils.output_naming import make_output_path from services.ltx_api_client.ltx_api_client import LTXAPIClientError from services.interfaces import LTXAPIClient from state.app_state_types import AppState @@ -103,7 +103,7 @@ def _run_api_retake( raise HTTPError(exc.status_code, exc.detail) from exc if result.video_bytes is not None: - output = self._outputs_dir / f"retake_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}.mp4" + output = make_output_path(self._outputs_dir, model="retake", prompt=prompt, ext="mp4") with open(output, "wb") as out: out.write(result.video_bytes) return RetakeResponse(status="complete", video_path=str(output)) @@ -138,7 +138,7 @@ def _run_local_retake( generation_id = uuid.uuid4().hex[:8] seed = self._resolve_seed() - output_path = self._outputs_dir / f"retake_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{generation_id}.mp4" + output_path = make_output_path(self._outputs_dir, model="retake", prompt=prompt, ext="mp4") regenerate_video, regenerate_audio = self._resolve_retake_mode(mode) try: diff --git a/backend/handlers/settings_handler.py b/backend/handlers/settings_handler.py index 27704d60..a30b19e2 100644 --- a/backend/handlers/settings_handler.py +++ b/backend/handlers/settings_handler.py @@ -63,7 +63,7 @@ def get_settings_snapshot(self) -> AppSettings: def update_settings(self, patch: UpdateSettingsRequest) -> tuple[AppSettings, AppSettings, set[str]]: patch_payload = strip_none_values(ensure_json_object(patch.model_dump(by_alias=False, exclude_unset=True))) - for key_field in ("ltx_api_key", "gemini_api_key", "fal_api_key"): + for key_field in ("ltx_api_key", "gemini_api_key", "replicate_api_key", "palette_api_key"): if key_field in patch_payload and patch_payload[key_field] == "": del patch_payload[key_field] diff --git a/backend/handlers/style_guide_handler.py b/backend/handlers/style_guide_handler.py new file mode 100644 index 00000000..7b2a2ef4 --- /dev/null +++ b/backend/handlers/style_guide_handler.py @@ -0,0 +1,53 @@ +"""Handler for style guide grid generation (3x3 style-across-subjects grid).""" +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from api_types import GenerateStyleGuideRequest, GenerateStyleGuideResponse + +if TYPE_CHECKING: + from state.job_queue import JobQueue + +logger = logging.getLogger(__name__) + +STYLE_SUBJECTS: list[str] = [ + "Portrait of a person", + "Cityscape", + "Nature landscape", + "Interior room", + "Food still life", + "Vehicle on a road", + "Animal in its habitat", + "Architecture detail", + "Abstract pattern", +] + + +class StyleGuideHandler: + def __init__(self, job_queue: "JobQueue") -> None: + self._job_queue = job_queue + + def generate(self, req: GenerateStyleGuideRequest) -> GenerateStyleGuideResponse: + job_ids: list[str] = [] + description_suffix = f": {req.style_description}" if req.style_description else "" + + for subject in STYLE_SUBJECTS: + prompt = f"{subject}, in the style of {req.style_name}{description_suffix}" + params: dict[str, object] = { + "prompt": prompt, + "width": 1024, + "height": 1024, + } + if req.reference_image_path is not None: + params["reference_image_path"] = req.reference_image_path + + job = self._job_queue.submit( + job_type="image", + model="z-image-turbo", + params=params, + slot="api", + ) + job_ids.append(job.id) + + return GenerateStyleGuideResponse(job_ids=job_ids) diff --git a/backend/handlers/sync_handler.py b/backend/handlers/sync_handler.py new file mode 100644 index 00000000..fcb84c91 --- /dev/null +++ b/backend/handlers/sync_handler.py @@ -0,0 +1,328 @@ +"""Handler for Palette sync operations.""" +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from services.http_client.http_client import HTTPClient +from services.palette_sync_client.palette_sync_client import PaletteSyncClient +from state.app_state_types import AppState +from state.lora_library import LoraEntry, LoraLibraryStore + +logger = logging.getLogger(__name__) + +# Fallback pricing (cents) when the Palette credits endpoint doesn't return it. +# Values sourced from the live Palette /api/desktop/credits/check endpoint. +_DEFAULT_PRICING: dict[str, int] = { + "video_t2v": 10, + "video_i2v": 16, + "video_seedance": 5, + "image": 6, + "image_edit": 20, + "audio": 15, + "text_enhance": 3, +} + +# Known FLUX Klein 9B LoRA weights URLs from Palette's built-in library. +# The Palette API may not include download URLs for all LoRAs, so we map +# known IDs to their weights URLs as a fallback. +_KNOWN_LORA_WEIGHTS: dict[str, tuple[str, float]] = { + "claymation-k9b": ( + "https://huuezdiitpmafkljkvui.supabase.co/storage/v1/object/public/loras/community/claymation_flux_lora_v1.safetensors", + 1.0, + ), + "inflate-k9b": ( + "https://huuezdiitpmafkljkvui.supabase.co/storage/v1/object/public/loras/community/inflate_it.safetensors", + 1.0, + ), + "disney-golden-age-k9b": ( + "https://huuezdiitpmafkljkvui.supabase.co/storage/v1/object/public/loras/community/disney_golden_age.safetensors", + 1.0, + ), + "nava-k9b": ( + "https://v3.fal.media/files/monkey/oF3DkwBOmrzohIKhCfNie_pytorch_lora_weights.safetensors", + 1.0, + ), + "dcau-k9b": ( + "https://huuezdiitpmafkljkvui.supabase.co/storage/v1/object/public/loras/community/jRB4slNlO3KYd18ROU5Up_pytorch_lora_weights_comfy_converted.safetensors", + 1.0, + ), + "cinematic-filmstill-k9b": ( + "https://huuezdiitpmafkljkvui.supabase.co/storage/v1/object/public/loras/community/cinematic_filmstill.safetensors", + 1.0, + ), + "consistency-k9b": ( + "https://pub-060813fba4064da4815db04b08604ce7.r2.dev/consistency_lora_v3.safetensors", + 0.8, + ), +} + + +class SyncHandler: + def __init__( + self, + state: AppState, + palette_sync_client: PaletteSyncClient, + http: HTTPClient, + lora_store: LoraLibraryStore, + loras_dir: Path, + ) -> None: + self._state = state + self._client = palette_sync_client + self._http = http + self._lora_store = lora_store + self._loras_dir = loras_dir + self._cached_user: dict[str, Any] | None = None + + def _try_refresh(self) -> dict[str, Any] | None: + """Attempt to refresh an expired JWT. Returns user info or None.""" + refresh_token = self._state.app_settings.palette_refresh_token + if not refresh_token: + return None + try: + result = self._client.refresh_access_token(refresh_token=refresh_token) + self._state.app_settings.palette_api_key = result["access_token"] + self._state.app_settings.palette_refresh_token = result["refresh_token"] + self._cached_user = result["user"] + return result["user"] + except Exception: + return None + + def get_status(self) -> dict[str, Any]: + api_key = self._state.app_settings.palette_api_key + if not api_key: + return {"connected": False, "user": None} + if self._cached_user is not None: + return {"connected": True, "user": self._cached_user} + try: + user = self._client.validate_connection(api_key=api_key) + self._cached_user = user + return {"connected": True, "user": user} + except Exception as exc: + # JWT might be expired — try refreshing + user = self._try_refresh() + if user is not None: + return {"connected": True, "user": user} + self._cached_user = None + return {"connected": False, "user": None, "error": str(exc)} + + def connect(self, token: str) -> dict[str, Any]: + """Store an auth token and validate it. Returns status.""" + try: + user = self._client.validate_connection(api_key=token) + except Exception as exc: + return {"connected": False, "error": str(exc)} + self._state.app_settings.palette_api_key = token + self._cached_user = user + return {"connected": True, "user": user} + + def login(self, email: str, password: str) -> dict[str, Any]: + """Sign in with email/password and store the session tokens.""" + try: + result = self._client.sign_in_with_email(email=email, password=password) + except Exception as exc: + return {"connected": False, "error": str(exc)} + self._state.app_settings.palette_api_key = result["access_token"] + self._state.app_settings.palette_refresh_token = result["refresh_token"] + self._cached_user = result["user"] + return {"connected": True, "user": result["user"]} + + def disconnect(self) -> dict[str, Any]: + """Clear the stored auth token and cached user.""" + self._state.app_settings.palette_api_key = "" + self._state.app_settings.palette_refresh_token = "" + self._cached_user = None + return {"connected": False, "user": None} + + def get_credits(self) -> dict[str, Any]: + api_key = self._state.app_settings.palette_api_key + if not api_key: + return {"connected": False, "balance_cents": None, "pricing": None} + try: + credits = self._client.get_credits(api_key=api_key) + result: dict[str, Any] = {"connected": True, **credits} + except Exception: + result = {"connected": True, "balance_cents": None, "pricing": None} + + # If the credits endpoint didn't return balance_cents, fall back + # to the check endpoint which reliably includes the balance. + if result.get("balance_cents") is None: + try: + check = self._client.check_credits( + api_key=api_key, generation_type="image", count=1, + ) + result["balance_cents"] = check.get("balance_cents") + except Exception: + pass + + # Ensure pricing is present — fall back to known defaults if the + # credits endpoint didn't provide it. + if not result.get("pricing"): + result["pricing"] = _DEFAULT_PRICING + + return result + + def check_credits(self, generation_type: str, count: int = 1) -> dict[str, Any]: + api_key = self._state.app_settings.palette_api_key + if not api_key: + return {"connected": False, "can_afford": True} + try: + return {"connected": True, **self._client.check_credits( + api_key=api_key, generation_type=generation_type, count=count, + )} + except Exception as exc: + logger.warning("Credit check failed: %s", exc) + # Fail open — don't block generation if credit check is unavailable + return {"connected": False, "can_afford": True} + + def deduct_credits( + self, generation_type: str, count: int = 1, + metadata: dict[str, Any] | None = None, + ) -> dict[str, Any]: + api_key = self._state.app_settings.palette_api_key + if not api_key: + return {"deducted": False} + try: + result = self._client.deduct_credits( + api_key=api_key, generation_type=generation_type, + count=count, metadata=metadata, + ) + return {"deducted": True, **result} + except Exception as exc: + logger.warning("Credit deduction failed: %s", exc) + return {"deducted": False, "error": str(exc)} + + def list_gallery( + self, page: int = 1, per_page: int = 50, asset_type: str = "all", + ) -> dict[str, Any]: + api_key = self._state.app_settings.palette_api_key + if not api_key: + return {"connected": False, "items": []} + try: + return { + "connected": True, + **self._client.list_gallery( + api_key=api_key, page=page, per_page=per_page, asset_type=asset_type, + ), + } + except Exception as exc: + logger.warning("Palette gallery list failed: %s", exc) + return {"connected": False, "items": [], "error": str(exc)} + + def list_characters(self) -> dict[str, Any]: + api_key = self._state.app_settings.palette_api_key + if not api_key: + return {"connected": False, "characters": []} + try: + return {"connected": True, **self._client.list_characters(api_key=api_key)} + except Exception as exc: + logger.warning("Palette characters list failed: %s", exc) + return {"connected": False, "characters": [], "error": str(exc)} + + def list_styles(self) -> dict[str, Any]: + api_key = self._state.app_settings.palette_api_key + if not api_key: + return {"connected": False, "styles": []} + try: + return {"connected": True, **self._client.list_styles(api_key=api_key)} + except Exception as exc: + logger.warning("Palette styles list failed: %s", exc) + return {"connected": False, "styles": [], "error": str(exc)} + + def list_references(self, category: str | None = None) -> dict[str, Any]: + api_key = self._state.app_settings.palette_api_key + if not api_key: + return {"connected": False, "references": []} + try: + return { + "connected": True, + **self._client.list_references(api_key=api_key, category=category), + } + except Exception as exc: + logger.warning("Palette references list failed: %s", exc) + return {"connected": False, "references": [], "error": str(exc)} + + def enhance_prompt(self, prompt: str, level: str = "2x") -> dict[str, Any]: + api_key = self._state.app_settings.palette_api_key + if not api_key: + return {"error": "Not connected to Palette"} + try: + return self._client.enhance_prompt(api_key=api_key, prompt=prompt, level=level) + except Exception as exc: + logger.warning("Palette prompt enhance failed: %s", exc) + return {"error": str(exc)} + + def sync_loras(self) -> dict[str, Any]: + """Fetch LoRA catalog from Palette and download any new LoRAs locally. + + Returns {"synced": N, "skipped": N, "failed": N}. + """ + api_key = self._state.app_settings.palette_api_key + if not api_key: + return {"connected": False, "error": "Not connected to Palette"} + + try: + data = self._client.list_loras(api_key=api_key) + except Exception as exc: + logger.warning("Palette LoRA list failed: %s", exc) + return {"connected": False, "error": str(exc)} + + palette_loras = data.get("loras", []) + existing_ids = {e.id for e in self._lora_store.list_all()} + + synced = 0 + skipped = 0 + failed = 0 + + for lora in palette_loras: + lora_id = lora.get("id", "") + catalog_id = f"palette:{lora_id}" + if catalog_id in existing_ids: + skipped += 1 + continue + + # Get download URL — either from API response or hardcoded map + weights_url = lora.get("weights_url") or lora.get("download_url") or "" + if not weights_url: + weights_url = _KNOWN_LORA_WEIGHTS.get(lora_id, ("", 1.0))[0] + if not weights_url: + logger.debug("Skipping LoRA %s — no download URL available", lora_id) + skipped += 1 + continue + + try: + self._download_and_register_lora(lora, catalog_id, weights_url) + synced += 1 + except Exception: + logger.warning("Failed to sync LoRA %s", lora_id, exc_info=True) + failed += 1 + + return {"connected": True, "synced": synced, "skipped": skipped, "failed": failed} + + def _download_and_register_lora( + self, lora: dict[str, Any], catalog_id: str, weights_url: str, + ) -> None: + """Download LoRA weights and register in local catalog.""" + lora_id = lora.get("id", "unknown") + filename = f"palette_{lora_id}.safetensors" + dest = self._loras_dir / filename + + if not dest.exists(): + logger.info("Downloading LoRA %s from %s", lora_id, weights_url) + resp = self._http.get(weights_url, timeout=300) + if resp.status_code != 200: + raise RuntimeError(f"Download failed: HTTP {resp.status_code}") + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(resp.content) + + entry = LoraEntry( + id=catalog_id, + name=f"[Palette] {lora.get('name', lora_id)}", + file_path=str(dest), + file_size_bytes=dest.stat().st_size, + thumbnail_url=lora.get("thumbnail_url", ""), + trigger_phrase=lora.get("trigger_word", ""), + base_model="flux-klein-9b", + ) + self._lora_store.add(entry) diff --git a/backend/handlers/text_handler.py b/backend/handlers/text_handler.py index 31b9d015..3146447a 100644 --- a/backend/handlers/text_handler.py +++ b/backend/handlers/text_handler.py @@ -97,6 +97,11 @@ def prepare_text_encoding(self, prompt: str, enhance_prompt: bool) -> None: def resolve_gemma_root(self) -> str | None: if not self.should_use_local_encoding(): return None + settings = self.state.app_settings.model_copy(deep=True) + if settings.use_abliterated_text_encoder: + abliterated_dir = self._config.model_path("text_encoder_abliterated") + if abliterated_dir.exists() and any(abliterated_dir.iterdir()): + return str(abliterated_dir) text_encoder_dir = self._config.model_path("text_encoder") return str(text_encoder_dir) diff --git a/backend/handlers/video_generation_handler.py b/backend/handlers/video_generation_handler.py index 3147f2ee..986bbf87 100644 --- a/backend/handlers/video_generation_handler.py +++ b/backend/handlers/video_generation_handler.py @@ -7,7 +7,6 @@ import tempfile import time import uuid -from datetime import datetime from pathlib import Path from threading import RLock from typing import TYPE_CHECKING @@ -25,10 +24,13 @@ validate_audio_file, validate_image_file, ) -from services.interfaces import LTXAPIClient +from server_utils.output_naming import make_output_path +from services.interfaces import LTXAPIClient, VideoAPIClient from state.app_state_types import AppState from state.app_settings import should_video_generate_with_ltx_api +REPLICATE_VIDEO_MODELS = {"seedance-1.5-pro"} + if TYPE_CHECKING: from runtime_config.runtime_config import RuntimeConfig @@ -63,6 +65,7 @@ def __init__( pipelines_handler: PipelinesHandler, text_handler: TextHandler, ltx_api_client: LTXAPIClient, + video_api_client: VideoAPIClient, outputs_dir: Path, config: RuntimeConfig, camera_motion_prompts: dict[str, str], @@ -73,12 +76,16 @@ def __init__( self._pipelines = pipelines_handler self._text = text_handler self._ltx_api_client = ltx_api_client + self._video_api_client = video_api_client self._outputs_dir = outputs_dir self._config = config self._camera_motion_prompts = camera_motion_prompts self._default_negative_prompt = default_negative_prompt def generate(self, req: GenerateVideoRequest) -> GenerateVideoResponse: + if req.model in REPLICATE_VIDEO_MODELS: + return self._generate_via_replicate(req) + if should_video_generate_with_ltx_api( force_api_generations=self._config.force_api_generations, settings=self.state.app_settings, @@ -126,16 +133,26 @@ def get_9_16_size(res: str) -> tuple[int, int]: image = self._prepare_image(image_path, width, height) logger.info("Image: %s -> %sx%s", image_path, width, height) + last_frame_image = None + last_frame_path = normalize_optional_path(req.lastFramePath) + if last_frame_path: + last_frame_image = self._prepare_image(last_frame_path, width, height) + logger.info("Last frame: %s -> %sx%s", last_frame_path, width, height) + generation_id = self._make_generation_id() seed = self._resolve_seed() try: - self._pipelines.load_gpu_pipeline("fast", should_warm=False) + self._pipelines.load_gpu_pipeline( + "fast", should_warm=False, + lora_path=req.loraPath, lora_weight=req.loraWeight, + ) self._generation.start_generation(generation_id) output_path = self.generate_video( prompt=req.prompt, image=image, + last_frame_image=last_frame_image, height=height, width=width, num_frames=num_frames, @@ -143,6 +160,8 @@ def get_9_16_size(res: str) -> tuple[int, int]: seed=seed, camera_motion=req.cameraMotion, negative_prompt=req.negativePrompt, + lora_path=req.loraPath, + lora_weight=req.loraWeight, ) self._generation.complete_generation(output_path) @@ -167,6 +186,9 @@ def generate_video( seed: int, camera_motion: VideoCameraMotion, negative_prompt: str, + last_frame_image: Image.Image | None = None, + lora_path: str | None = None, + lora_weight: float = 1.0, ) -> str: t_total_start = time.perf_counter() gen_mode = "i2v" if image is not None else "t2v" @@ -180,9 +202,15 @@ def generate_video( total_steps = 8 - self._generation.update_progress("loading_model", 5, 0, total_steps) + self._generation.update_progress("preparing_gpu", 3, 0, total_steps) t_load_start = time.perf_counter() - pipeline_state = self._pipelines.load_gpu_pipeline("fast", should_warm=False) + pipeline_state = self._pipelines.load_gpu_pipeline( + "fast", + should_warm=False, + on_phase=lambda phase: self._generation.update_progress(phase, 5, 0, total_steps), + lora_path=lora_path, + lora_weight=lora_weight, + ) t_load_end = time.perf_counter() logger.info("[%s] Pipeline load: %.2fs", gen_mode, t_load_end - t_load_start) @@ -192,12 +220,21 @@ def generate_video( images: list[ImageConditioningInput] = [] temp_image_path: str | None = None + temp_last_frame_path: str | None = None if image is not None: temp_image_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name image.save(temp_image_path) - images = [ImageConditioningInput(path=temp_image_path, frame_idx=0, strength=1.0)] - - output_path = self._make_output_path() + images.append(ImageConditioningInput(path=temp_image_path, frame_idx=0, strength=1.0)) + if last_frame_image is not None: + temp_last_frame_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name + last_frame_image.save(temp_last_frame_path) + # Use frame_idx=0 (first-frame conditioning) for extend. + # The distilled pipeline doesn't support last-frame conditioning + # (frame_idx=num_frames-1), but first-frame works identically — + # the new video continues from this frame. + images.append(ImageConditioningInput(path=temp_last_frame_path, frame_idx=0, strength=1.0)) + + output_path = self._make_output_path(model="ltx-fast", prompt=prompt) try: settings = self.state.app_settings @@ -248,6 +285,179 @@ def generate_video( self._text.clear_api_embeddings() if temp_image_path and os.path.exists(temp_image_path): os.unlink(temp_image_path) + if temp_last_frame_path and os.path.exists(temp_last_frame_path): + os.unlink(temp_last_frame_path) + + def generate_long_video( + self, + prompt: str, + image_path: str, + target_duration: int, + resolution: str = "512p", + aspect_ratio: str = "16:9", + fps: int = 24, + segment_duration: int = 4, + camera_motion: VideoCameraMotion = "none", + lora_path: str | None = None, + lora_weight: float = 1.0, + ) -> str: + """Generate a long video by chaining I2V + extend segments. + + 1. Generate initial segment from source image (I2V) + 2. Extract last frame, generate next segment conditioned on it + 3. Repeat until target_duration is reached + 4. Concatenate all segments (trimming first frame of extensions) + """ + RESOLUTION_MAP_16_9: dict[str, tuple[int, int]] = { + "512p": (960, 544), "540p": (960, 544), + "720p": (1280, 704), "1080p": (1920, 1088), + } + + def get_size(res: str, ar: str) -> tuple[int, int]: + w, h = RESOLUTION_MAP_16_9.get(res, (960, 544)) + return (h, w) if ar == "9:16" else (w, h) + + width, height = get_size(resolution, aspect_ratio) + num_segments = max(1, (target_duration + segment_duration - 1) // segment_duration) + logger.info("[long] Starting %ds video: %d segments of %ds (%dx%d)", + target_duration, num_segments, segment_duration, width, height) + + ffmpeg_path = self._find_ffmpeg() + segment_paths: list[str] = [] + temp_files: list[str] = [] + + try: + # --- Segment 1: I2V from source image --- + image = self._prepare_image(image_path, width, height) + num_frames = self._compute_num_frames(segment_duration, fps) + seed = self._resolve_seed() + + generation_id = self._make_generation_id() + self._pipelines.load_gpu_pipeline( + "fast", should_warm=False, + lora_path=lora_path, lora_weight=lora_weight, + ) + self._generation.start_generation(generation_id) + + try: + self._generation.update_progress("generating_segment", 5, 1, num_segments) + seg1_path = self.generate_video( + prompt=prompt, image=image, height=height, width=width, + num_frames=num_frames, fps=float(fps), seed=seed, + camera_motion=camera_motion, negative_prompt="", + lora_path=lora_path, lora_weight=lora_weight, + ) + segment_paths.append(seg1_path) + logger.info("[long] Segment 1/%d complete: %s", num_segments, seg1_path) + except Exception: + self._generation.fail_generation("Segment 1 failed") + raise + + # --- Segments 2..N: extend from last frame --- + for seg_idx in range(2, num_segments + 1): + if self._generation.is_generation_cancelled(): + raise RuntimeError("Generation was cancelled") + + prev_path = segment_paths[-1] + last_frame = self._extract_last_frame(prev_path, ffmpeg_path) + temp_files.append(last_frame) + + last_frame_image = Image.open(last_frame).convert("RGB") + seed = self._resolve_seed() + + self._generation.update_progress( + "generating_segment", int(15 + 70 * seg_idx / num_segments), + seg_idx, num_segments, + ) + + seg_path = self.generate_video( + prompt=prompt, image=None, last_frame_image=last_frame_image, + height=height, width=width, num_frames=num_frames, fps=float(fps), + seed=seed, camera_motion=camera_motion, negative_prompt="", + lora_path=lora_path, lora_weight=lora_weight, + ) + segment_paths.append(seg_path) + logger.info("[long] Segment %d/%d complete: %s", seg_idx, num_segments, seg_path) + + # --- Concatenate segments --- + self._generation.update_progress("concatenating", 90, num_segments, num_segments) + output_path = self._make_output_path(model="ltx-fast-long", prompt=prompt) + + self._concatenate_segments( + segment_paths, str(output_path), ffmpeg_path, fps, + ) + logger.info("[long] Final video: %s (%d segments)", output_path, len(segment_paths)) + + self._generation.complete_generation(str(output_path)) + return str(output_path) + + except Exception as e: + if "cancelled" not in str(e).lower(): + self._generation.fail_generation(str(e)) + raise + finally: + for f in temp_files: + if os.path.exists(f): + os.unlink(f) + + @staticmethod + def _find_ffmpeg() -> str: + """Find ffmpeg binary — bundled with imageio-ffmpeg.""" + try: + import imageio_ffmpeg + return imageio_ffmpeg.get_ffmpeg_exe() + except Exception: + pass + raise RuntimeError("ffmpeg not found. Install imageio-ffmpeg.") + + @staticmethod + def _extract_last_frame(video_path: str, ffmpeg_path: str) -> str: + """Extract the last frame of a video to a temp PNG.""" + import subprocess + + out = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + out.close() + subprocess.run( + [ffmpeg_path, "-y", "-sseof", "-0.05", "-i", video_path, + "-frames:v", "1", "-update", "1", out.name], + capture_output=True, check=True, + ) + return out.name + + @staticmethod + def _concatenate_segments( + segment_paths: list[str], output_path: str, ffmpeg_path: str, fps: int, + ) -> None: + """Concatenate video segments into one file.""" + import subprocess + + if len(segment_paths) == 1: + import shutil + shutil.copy2(segment_paths[0], output_path) + return + + # Use concat demuxer — simple and reliable. + concat_file = tempfile.NamedTemporaryFile( + mode="w", suffix=".txt", delete=False, + ) + try: + for seg in segment_paths: + # ffmpeg concat demuxer needs forward slashes or escaped backslashes + safe_path = seg.replace("\\", "/") + concat_file.write(f"file '{safe_path}'\n") + concat_file.close() + + cmd = [ + ffmpeg_path, "-y", + "-f", "concat", "-safe", "0", + "-i", concat_file.name, + "-c", "copy", + output_path, + ] + subprocess.run(cmd, capture_output=True, check=True) + finally: + if os.path.exists(concat_file.name): + os.unlink(concat_file.name) def _generate_a2v( self, req: GenerateVideoRequest, duration: int, fps: int, *, audio_path: str @@ -268,10 +478,16 @@ def _generate_a2v( image = None temp_image_path: str | None = None + temp_last_frame_path: str | None = None image_path = normalize_optional_path(req.imagePath) if image_path: image = self._prepare_image(image_path, width, height) + last_frame_image = None + last_frame_path = normalize_optional_path(req.lastFramePath) + if last_frame_path: + last_frame_image = self._prepare_image(last_frame_path, width, height) + seed = self._resolve_seed() generation_id = self._make_generation_id() @@ -287,9 +503,13 @@ def _generate_a2v( if image is not None: temp_image_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name image.save(temp_image_path) - images = [ImageConditioningInput(path=temp_image_path, frame_idx=0, strength=1.0)] + images.append(ImageConditioningInput(path=temp_image_path, frame_idx=0, strength=1.0)) + if last_frame_image is not None: + temp_last_frame_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name + last_frame_image.save(temp_last_frame_path) + images.append(ImageConditioningInput(path=temp_last_frame_path, frame_idx=0, strength=1.0)) - output_path = self._make_output_path() + output_path = self._make_output_path(model="ltx-pro", prompt=req.prompt) total_steps = 11 # distilled: 8 steps (stage 1) + 3 steps (stage 2) @@ -340,6 +560,8 @@ def _generate_a2v( self._text.clear_api_embeddings() if temp_image_path and os.path.exists(temp_image_path): os.unlink(temp_image_path) + if temp_last_frame_path and os.path.exists(temp_last_frame_path): + os.unlink(temp_last_frame_path) def _prepare_image(self, image_path: str, width: int, height: int) -> Image.Image: validated_path = validate_image_file(image_path) @@ -377,9 +599,72 @@ def _resolve_seed(self) -> int: return settings.locked_seed return int(time.time()) % 2147483647 - def _make_output_path(self) -> Path: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - return self._outputs_dir / f"ltx2_video_{timestamp}_{self._make_generation_id()}.mp4" + def _make_output_path(self, *, model: str, prompt: str) -> Path: + return make_output_path(self._outputs_dir, model=model, prompt=prompt, ext="mp4") + + def _generate_via_replicate(self, req: GenerateVideoRequest) -> GenerateVideoResponse: + if self._generation.is_generation_running(): + raise HTTPError(409, "Generation already in progress") + + generation_id = self._make_generation_id() + self._generation.start_api_generation(generation_id) + + try: + self._generation.update_progress("validating_request", 5, None, None) + + api_key = self.state.app_settings.replicate_api_key.strip() + if not api_key: + raise HTTPError(400, "REPLICATE_API_KEY_NOT_CONFIGURED") + + duration = self._parse_forced_numeric_field(req.duration, "INVALID_FORCED_API_DURATION") + aspect_ratio = req.aspectRatio.strip() if req.aspectRatio else "16:9" + resolution = req.resolution or "720p" + generate_audio = self._parse_audio_flag(req.audio) + + if self._generation.is_generation_cancelled(): + raise RuntimeError("Generation was cancelled") + + # Support image-to-video for Seedance via last_frame + last_frame_url: str | None = None + image_path = normalize_optional_path(req.imagePath) + if image_path is not None: + validated = validate_image_file(image_path) + import base64 + raw = validated.read_bytes() + b64 = base64.b64encode(raw).decode("ascii") + ext = validated.suffix.lstrip(".") + mime = "image/png" if ext == "png" else "image/jpeg" + last_frame_url = f"data:{mime};base64,{b64}" + + self._generation.update_progress("inference", 20, None, None) + video_bytes = self._video_api_client.generate_text_to_video( + api_key=api_key, + model=req.model, + prompt=req.prompt, + duration=duration, + resolution=resolution, + aspect_ratio=aspect_ratio, + generate_audio=generate_audio, + last_frame=last_frame_url, + ) + self._generation.update_progress("downloading_output", 85, None, None) + + if self._generation.is_generation_cancelled(): + raise RuntimeError("Generation was cancelled") + + output_path = self._write_forced_api_video(video_bytes, model=req.model, prompt=req.prompt) + self._generation.update_progress("complete", 100, None, None) + self._generation.complete_generation(str(output_path)) + return GenerateVideoResponse(status="complete", video_path=str(output_path)) + except HTTPError as e: + self._generation.fail_generation(e.detail) + raise + except Exception as e: + self._generation.fail_generation(str(e)) + if "cancelled" in str(e).lower(): + logger.info("Generation cancelled by user") + return GenerateVideoResponse(status="cancelled") + raise HTTPError(500, str(e)) from e def _generate_forced_api(self, req: GenerateVideoRequest) -> GenerateVideoResponse: if self._generation.is_generation_running(): @@ -390,6 +675,7 @@ def _generate_forced_api(self, req: GenerateVideoRequest) -> GenerateVideoRespon audio_path = normalize_optional_path(req.audioPath) image_path = normalize_optional_path(req.imagePath) + last_frame_path = normalize_optional_path(req.lastFramePath) has_input_audio = bool(audio_path) has_input_image = bool(image_path) @@ -472,11 +758,20 @@ def _generate_forced_api(self, req: GenerateVideoRequest) -> GenerateVideoRespon api_key=api_key, file_path=str(validated_image_path), ) + last_frame_uri: str | None = None + if last_frame_path is not None: + validated_last_frame_path = validate_image_file(last_frame_path) + self._generation.update_progress("uploading_last_frame", 35, None, None) + last_frame_uri = self._ltx_api_client.upload_file( + api_key=api_key, + file_path=str(validated_last_frame_path), + ) self._generation.update_progress("inference", 55, None, None) video_bytes = self._ltx_api_client.generate_image_to_video( api_key=api_key, prompt=prompt, image_uri=image_uri, + last_frame_uri=last_frame_uri, model=api_model_id, resolution=api_resolution, duration=float(duration), @@ -494,10 +789,19 @@ def _generate_forced_api(self, req: GenerateVideoRequest) -> GenerateVideoRespon raise HTTPError(400, "INVALID_FORCED_API_DURATION") generate_audio = self._parse_audio_flag(req.audio) + t2v_last_frame_uri: str | None = None + if last_frame_path is not None: + validated_last_frame_path = validate_image_file(last_frame_path) + self._generation.update_progress("uploading_last_frame", 20, None, None) + t2v_last_frame_uri = self._ltx_api_client.upload_file( + api_key=api_key, + file_path=str(validated_last_frame_path), + ) self._generation.update_progress("inference", 55, None, None) video_bytes = self._ltx_api_client.generate_text_to_video( api_key=api_key, prompt=prompt, + last_frame_uri=t2v_last_frame_uri, model=api_model_id, resolution=api_resolution, duration=float(duration), @@ -510,7 +814,8 @@ def _generate_forced_api(self, req: GenerateVideoRequest) -> GenerateVideoRespon if self._generation.is_generation_cancelled(): raise RuntimeError("Generation was cancelled") - output_path = self._write_forced_api_video(video_bytes) + api_model_label = f"ltx-{requested_model}" + output_path = self._write_forced_api_video(video_bytes, model=api_model_label, prompt=prompt) if self._generation.is_generation_cancelled(): output_path.unlink(missing_ok=True) raise RuntimeError("Generation was cancelled") @@ -528,8 +833,8 @@ def _generate_forced_api(self, req: GenerateVideoRequest) -> GenerateVideoRespon return GenerateVideoResponse(status="cancelled") raise HTTPError(500, str(e)) from e - def _write_forced_api_video(self, video_bytes: bytes) -> Path: - output_path = self._make_output_path() + def _write_forced_api_video(self, video_bytes: bytes, *, model: str, prompt: str) -> Path: + output_path = self._make_output_path(model=model, prompt=prompt) output_path.write_bytes(video_bytes) return output_path diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 797e3bf8..a8b95d5b 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "transformers>=4.52,<5", "sentencepiece>=0.1.99", "sageattention>=1.0.0; sys_platform != 'darwin'", + "gguf>=0.10.0", "opencv-python-headless>=4.8.0", "fastapi>=0.115.0", "uvicorn[standard]>=0.30.0", diff --git a/backend/pyrightconfig.json b/backend/pyrightconfig.json index e54dd0d3..117ea971 100644 --- a/backend/pyrightconfig.json +++ b/backend/pyrightconfig.json @@ -6,6 +6,7 @@ "venv": ".venv", "exclude": [ "tests", + "scripts", ".venv", "node_modules", "dist" diff --git a/backend/requirements-dist.txt b/backend/requirements-dist.txt new file mode 100644 index 00000000..5d0be396 Binary files /dev/null and b/backend/requirements-dist.txt differ diff --git a/backend/runtime_config/model_download_specs.py b/backend/runtime_config/model_download_specs.py index c808bbc2..84e87a4f 100644 --- a/backend/runtime_config/model_download_specs.py +++ b/backend/runtime_config/model_download_specs.py @@ -26,7 +26,9 @@ def name(self) -> str: "checkpoint", "upsampler", "text_encoder", + "text_encoder_abliterated", "zit", + "flux_klein", ) @@ -52,6 +54,13 @@ def name(self) -> str: repo_id="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized", description="Gemma text encoder (bfloat16)", ), + "text_encoder_abliterated": ModelFileDownloadSpec( + relative_path=Path("gemma-3-12b-it-abliterated"), + expected_size_bytes=24_400_000_000, + is_folder=True, + repo_id="mlabonne/gemma-3-12b-it-abliterated", + description="Abliterated Gemma text encoder (~24.4 GB)", + ), "zit": ModelFileDownloadSpec( relative_path=Path("Z-Image-Turbo"), expected_size_bytes=31_000_000_000, @@ -59,11 +68,18 @@ def name(self) -> str: repo_id="Tongyi-MAI/Z-Image-Turbo", description="Z-Image-Turbo model for text-to-image generation", ), + "flux_klein": ModelFileDownloadSpec( + relative_path=Path("FLUX.2-klein-base-9B"), + expected_size_bytes=50_000_000_000, + is_folder=True, + repo_id="black-forest-labs/FLUX.2-klein-base-9B", + description="FLUX.2 Klein 9B Base — text-to-image with LoRA support", + ), } DEFAULT_REQUIRED_MODEL_TYPES: frozenset[ModelFileType] = frozenset( - {"checkpoint", "upsampler", "zit"} + {"checkpoint", "upsampler", "flux_klein"} ) diff --git a/backend/runtime_config/runtime_policy.py b/backend/runtime_config/runtime_policy.py index cf64b66b..c6e20eff 100644 --- a/backend/runtime_config/runtime_policy.py +++ b/backend/runtime_config/runtime_policy.py @@ -13,7 +13,7 @@ def decide_force_api_generations(system: str, cuda_available: bool, vram_gb: int return True if vram_gb is None: return True - return vram_gb < 31 + return vram_gb < 20 # Fail closed for non-target platforms unless explicitly relaxed. return True diff --git a/backend/scripts/test_flux_klein_nf4.py b/backend/scripts/test_flux_klein_nf4.py new file mode 100644 index 00000000..bcf9edc7 --- /dev/null +++ b/backend/scripts/test_flux_klein_nf4.py @@ -0,0 +1,133 @@ +"""Test FLUX.2 Klein 9B with bitsandbytes NF4 quantization + LoRA. + +Run: cd backend && uv run python scripts/test_flux_klein_nf4.py +""" + +import time +import gc +import torch +from pathlib import Path + +MODEL_PATH = "C:/Users/taskm/AppData/Local/LTXDesktop/models/FLUX.2-klein-base-9B" +LORA_PATH = "C:/Users/taskm/AppData/Local/LTXDesktop/models/loras/jRB4slNlO3KYd18ROU5Up_pytorch_lora_weights_comfy_converted.safetensors" +OUTPUT_PATH = "D:/git/directors-desktop/backend/outputs/test_nf4_lora.png" + +def main(): + from diffusers import BitsAndBytesConfig, Flux2KleinPipeline, AutoencoderKL, PipelineQuantizationConfig + import numpy as np + from PIL import Image + + print("=" * 60) + print("FLUX.2 Klein 9B — NF4 Quantization + LoRA Test") + print("=" * 60) + + # Step 1: Load pipeline with NF4 quantization + print("\n[1/5] Loading pipeline with NF4 quantization...") + t0 = time.time() + + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + + pipeline_quant_config = PipelineQuantizationConfig( + quant_mapping={"transformer": nf4_config}, + ) + + pipe = Flux2KleinPipeline.from_pretrained( + MODEL_PATH, + quantization_config=pipeline_quant_config, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + + # NF4 transformer (~5GB) fits on GPU, but T5-XXL text_encoder (~9GB bf16) + # pushes total to ~21GB leaving no room for activations at 1024x1024. + # CPU offload moves text_encoder off GPU after encoding, so the NF4 + # transformer gets full VRAM for inference. + pipe.enable_model_cpu_offload() + + load_time = time.time() - t0 + print(f" Pipeline loaded in {load_time:.1f}s") + + # Check VRAM usage after loading + vram_mb = torch.cuda.memory_allocated() / 1024**2 + print(f" VRAM after load: {vram_mb:.0f} MB") + + # Step 2: Load LoRA + print("\n[2/5] Loading LoRA...") + t1 = time.time() + pipe.load_lora_weights(LORA_PATH, adapter_name="user_lora") + pipe.set_adapters(["user_lora"], adapter_weights=[1.0]) + lora_time = time.time() - t1 + print(f" LoRA loaded in {lora_time:.1f}s") + + vram_mb = torch.cuda.memory_allocated() / 1024**2 + print(f" VRAM after LoRA: {vram_mb:.0f} MB") + + # Step 3: Generate (latent output to avoid VAE segfault) + print("\n[3/5] Generating image (1024x1024, 28 steps)...") + prompt = "DC animation style,with bold outlines,cel-shaded & muted color palette, A powerful superhero standing on a city rooftop at sunset, dramatic lighting, cape flowing in the wind, Gotham-style cityscape in background" + + generator = torch.Generator(device="cpu").manual_seed(42) + + t2 = time.time() + output = pipe( + prompt=prompt, + height=1024, + width=1024, + guidance_scale=4.0, + num_inference_steps=28, + generator=generator, + output_type="latent", + return_dict=True, + ) + gen_time = time.time() - t2 + print(f" Inference completed in {gen_time:.1f}s") + + latents = output.images.to("cpu") + + # Step 4: Destroy pipeline, decode with fresh VAE + print("\n[4/5] Decoding latents with fresh VAE on CPU...") + del pipe + gc.collect() + torch.cuda.empty_cache() + + t3 = time.time() + vae_path = str(Path(MODEL_PATH) / "vae") + vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float32) + vae = vae.to("cpu") + vae.eval() + + latents_f32 = latents.to(dtype=torch.float32) + with torch.no_grad(): + decoded = vae.decode(latents_f32, return_dict=False)[0] + + decoded = (decoded / 2 + 0.5).clamp(0, 1) + arr = decoded[0].permute(1, 2, 0).numpy() + pil_img = Image.fromarray((arr * 255).astype(np.uint8)) + decode_time = time.time() - t3 + print(f" VAE decode completed in {decode_time:.1f}s") + + # Step 5: Save + pil_img.save(OUTPUT_PATH) + print(f"\n[5/5] Saved to {OUTPUT_PATH}") + + del vae, latents, latents_f32, decoded + gc.collect() + + total = load_time + lora_time + gen_time + decode_time + print("\n" + "=" * 60) + print(f"RESULTS:") + print(f" Pipeline load (NF4): {load_time:.1f}s") + print(f" LoRA load: {lora_time:.1f}s") + print(f" Inference (28 steps): {gen_time:.1f}s") + print(f" VAE decode (CPU): {decode_time:.1f}s") + print(f" TOTAL: {total:.1f}s") + print(f" Peak VRAM: {torch.cuda.max_memory_allocated() / 1024**2:.0f} MB") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/backend/server_utils/output_naming.py b/backend/server_utils/output_naming.py new file mode 100644 index 00000000..1ce5aed0 --- /dev/null +++ b/backend/server_utils/output_naming.py @@ -0,0 +1,53 @@ +"""Shared output filename generation for all generation handlers.""" + +from __future__ import annotations + +import re +from datetime import datetime +from pathlib import Path + + +def _slugify_prompt(prompt: str, max_words: int = 5) -> str: + """Turn a prompt into a short filesystem-safe slug. + + Takes the first *max_words* words, lowercases, strips non-alphanumeric + characters, and joins with hyphens. Returns ``"untitled"`` for empty input. + """ + cleaned = re.sub(r"[^a-zA-Z0-9\s]", "", prompt).strip() + words = cleaned.lower().split()[:max_words] + slug = "-".join(words) + return slug or "untitled" + + +def make_output_filename( + *, + model: str, + prompt: str, + ext: str = "mp4", +) -> str: + """Build a descriptive output filename. + + Pattern: ``dd_{model}_{prompt_slug}_{timestamp}.{ext}`` + + Examples:: + + dd_ltx-fast_elegant-woman-luxury-handbag_20260309_144341.mp4 + dd_seedance_confident-woman-walks-runway_20260309_144342.mp4 + dd_zit_cyberpunk-cityscape-neon-rain_20260309_144343.png + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + slug = _slugify_prompt(prompt) + # Sanitise model name for filesystem + safe_model = re.sub(r"[^a-zA-Z0-9_-]", "", model) or "unknown" + return f"dd_{safe_model}_{slug}_{timestamp}.{ext}" + + +def make_output_path( + outputs_dir: Path, + *, + model: str, + prompt: str, + ext: str = "mp4", +) -> Path: + """Build a full output path under *outputs_dir*.""" + return outputs_dir / make_output_filename(model=model, prompt=prompt, ext=ext) diff --git a/backend/services/__init__.py b/backend/services/__init__.py index 4ed822b8..b3c419ff 100644 --- a/backend/services/__init__.py +++ b/backend/services/__init__.py @@ -2,7 +2,7 @@ from services.interfaces import ( FastVideoPipeline, - ZitAPIClient, + ImageAPIClient, ImageGenerationPipeline, GpuCleaner, GpuInfo, @@ -31,7 +31,7 @@ "TextEncoder", "VideoPipelineModelType", "FastVideoPipeline", - "ZitAPIClient", + "ImageAPIClient", "ImageGenerationPipeline", "IcLoraPipeline", "IcLoraModelDownloader", diff --git a/backend/services/fast_video_pipeline/fast_video_pipeline.py b/backend/services/fast_video_pipeline/fast_video_pipeline.py index e415524a..be521828 100644 --- a/backend/services/fast_video_pipeline/fast_video_pipeline.py +++ b/backend/services/fast_video_pipeline/fast_video_pipeline.py @@ -19,6 +19,8 @@ def create( gemma_root: str | None, upsampler_path: str, device: torch.device, + lora_path: str | None = None, + lora_weight: float = 1.0, ) -> "FastVideoPipeline": ... diff --git a/backend/services/fast_video_pipeline/gguf_fast_video_pipeline.py b/backend/services/fast_video_pipeline/gguf_fast_video_pipeline.py new file mode 100644 index 00000000..b59224ad --- /dev/null +++ b/backend/services/fast_video_pipeline/gguf_fast_video_pipeline.py @@ -0,0 +1,336 @@ +"""GGUF quantized LTX video pipeline. + +Loads LTX-Video transformer weights from a GGUF file using the diffusers +GGUFLinear layers (on-the-fly dequantization), while loading VAE, text +encoder, audio models, and upsampler from the base BF16 checkpoint. + +The base BF16 checkpoint must be present in the same directory as the +GGUF file (auto-discovered) because the GGUF file only contains the +transformer weights. +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Iterator +from pathlib import Path +from typing import Any, Final, cast + +import numpy as np +import torch + +from api_types import ImageConditioningInput +from services.ltx_pipeline_common import default_tiling_config, encode_video_output, video_chunks_number +from services.services_utils import AudioOrNone, TilingConfigType + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_BASE_CHECKPOINT_NAME = "ltx-2.3-22b-distilled.safetensors" + + +def _find_base_checkpoint(gguf_path: str, upsampler_path: str) -> str: + """Locate the base BF16 safetensors checkpoint for non-transformer components. + + Searches in: + 1. Same directory as the GGUF file + 2. Parent directory of the upsampler (models_dir) + """ + search_dirs = [ + Path(gguf_path).parent, + Path(upsampler_path).parent, + ] + for d in search_dirs: + candidate = d / _BASE_CHECKPOINT_NAME + if candidate.is_file(): + return str(candidate) + + # Fallback: largest safetensors file in the models directory + for d in search_dirs: + safetensors = sorted(d.glob("*.safetensors"), key=lambda p: p.stat().st_size, reverse=True) + for sf in safetensors: + if sf.stat().st_size > 10_000_000_000: # > 10 GB — likely the base model + return str(sf) + + raise FileNotFoundError( + f"Cannot find base BF16 checkpoint ({_BASE_CHECKPOINT_NAME}) near {gguf_path}. " + "The base model is required for VAE, audio, and text encoder components. " + "Download it from the Models tab in Settings." + ) + + +def _read_gguf_state_dict(gguf_path: str) -> dict[str, torch.nn.Parameter]: + """Read a GGUF file and return a state dict of GGUFParameter/Tensor objects.""" + import gguf + from diffusers.quantizers.gguf.utils import GGUFParameter + + UNQUANTIZED = { + gguf.GGMLQuantizationType.F32, + gguf.GGMLQuantizationType.F16, + gguf.GGMLQuantizationType.BF16, + } + + logger.info("Reading GGUF file: %s", gguf_path) + reader = gguf.GGUFReader(gguf_path) + + state_dict: dict[str, torch.nn.Parameter] = {} + for tensor in reader.tensors: + name = str(tensor.name) + data: np.ndarray[Any, Any] = tensor.data # type: ignore[assignment] + quant_type = tensor.tensor_type + + if quant_type in UNQUANTIZED: + # Unquantized — convert to proper dtype + if quant_type == gguf.GGMLQuantizationType.F32: + t = torch.from_numpy(data.copy()).to(torch.float32) + elif quant_type == gguf.GGMLQuantizationType.F16: + t = torch.from_numpy(data.copy()).to(torch.float16) + else: # BF16 + t = torch.from_numpy(data.copy()).view(torch.bfloat16) + # Reshape to original shape + shape = tuple(int(s) for s in tensor.shape) + if shape: + t = t.reshape(shape) + state_dict[name] = torch.nn.Parameter(t, requires_grad=False) + else: + # Quantized — wrap as GGUFParameter for on-the-fly dequantization + t = torch.from_numpy(data.copy()) + param = GGUFParameter(t, requires_grad=False, quant_type=quant_type) + state_dict[name] = param + + logger.info("Read %d tensors from GGUF file", len(state_dict)) + return state_dict + + +def _remap_gguf_keys(state_dict: dict[str, torch.nn.Parameter]) -> dict[str, torch.nn.Parameter]: + """Strip common prefixes from GGUF tensor names to match ltx_core model keys. + + City96 GGUF files use ``model.diffusion_model.`` prefix; ltx_core expects + keys without that prefix. Also handles ``diffusion_model.`` prefix. + """ + PREFIXES_TO_STRIP = [ + "model.diffusion_model.", + "diffusion_model.", + ] + + remapped: dict[str, torch.nn.Parameter] = {} + for key, value in state_dict.items(): + new_key = key + for prefix in PREFIXES_TO_STRIP: + if new_key.startswith(prefix): + new_key = new_key[len(prefix):] + break + remapped[new_key] = value + + return remapped + + +def _create_gguf_transformer( + state_dict: dict[str, torch.nn.Parameter], + device: torch.device, + compute_dtype: torch.dtype = torch.bfloat16, +) -> Any: + """Create an X0Model transformer with GGUFLinear layers and load quantized weights.""" + from diffusers.quantizers.gguf.utils import GGUFParameter, _replace_with_gguf_linear + from ltx_core.model.transformer import LTXModelConfigurator, X0Model + + # Build model config with defaults for LTX-Video 2.3 22B + # The GGUF file doesn't contain model config metadata, so we use defaults + # which match the 22B distilled model architecture. + default_config: dict[str, Any] = {} + + with torch.device("meta"): + model = LTXModelConfigurator.from_config(default_config) + + # Replace nn.Linear with GGUFLinear where state dict has GGUFParameter + _replace_with_gguf_linear(model, compute_dtype, state_dict) + + # Load state dict — assign=True replaces meta tensors with real data + missing, unexpected = model.load_state_dict( + {k: v for k, v in state_dict.items()}, + strict=False, + assign=True, + ) + if missing: + logger.warning("GGUF transformer load — missing keys: %s", missing[:10]) + if unexpected: + logger.debug("GGUF transformer load — unexpected keys: %s", unexpected[:10]) + + x0_model = X0Model(model) + return x0_model.to(device).eval() + + +# --------------------------------------------------------------------------- +# Pipeline class +# --------------------------------------------------------------------------- + + +class GGUFFastVideoPipeline: + """FastVideoPipeline that loads the transformer from a GGUF file. + + Uses the ``DistilledPipeline`` from ltx_pipelines for two-stage video + generation (half-res → upsample+refine), but overrides the transformer + loading to use GGUF quantized weights with on-the-fly dequantization + via diffusers' ``GGUFLinear`` layers. + + Requires the base BF16 checkpoint to be present for VAE, audio, and + text encoder components. + + Note: LoRA fusion with GGUF quantized weights is not yet supported. + The distilled LoRA must be baked into the GGUF file. + """ + + pipeline_kind: Final = "fast" + + @staticmethod + def create( + checkpoint_path: str, + gemma_root: str | None, + upsampler_path: str, + device: torch.device, + lora_path: str | None = None, + lora_weight: float = 1.0, + ) -> "GGUFFastVideoPipeline": + return GGUFFastVideoPipeline( + checkpoint_path=checkpoint_path, + gemma_root=gemma_root, + upsampler_path=upsampler_path, + device=device, + lora_path=lora_path, + lora_weight=lora_weight, + ) + + def __init__( + self, + checkpoint_path: str, + gemma_root: str | None, + upsampler_path: str, + device: torch.device, + lora_path: str | None = None, + lora_weight: float = 1.0, + ) -> None: + try: + import gguf # noqa: F401 # pyright: ignore[reportUnusedImport] + except ImportError: + raise RuntimeError( + "GGUF model support requires the 'gguf' package. " + "Install it with: pip install gguf>=0.10.0" + ) from None + + gguf_path = checkpoint_path + base_checkpoint = _find_base_checkpoint(gguf_path, upsampler_path) + logger.info("GGUF pipeline: transformer from %s, base model from %s", gguf_path, base_checkpoint) + + # Pre-read and remap the GGUF state dict (stays on CPU, cached for reuse) + raw_sd = _read_gguf_state_dict(gguf_path) + self._gguf_state_dict = _remap_gguf_keys(raw_sd) + self._device = device + + # Build DistilledPipeline with the base BF16 checkpoint for all non-transformer components. + # LoRA is NOT passed here because it can't be fused with GGUF weights. + from ltx_pipelines.distilled import DistilledPipeline + + if lora_path: + logger.warning( + "LoRA fusion with GGUF quantized weights is not supported. " + "The LoRA at %s will be ignored. Use a GGUF file that includes the LoRA baked in.", + lora_path, + ) + + self.pipeline = DistilledPipeline( + distilled_checkpoint_path=base_checkpoint, + gemma_root=cast(str, gemma_root), + spatial_upsampler_path=upsampler_path, + loras=[], + device=device, + quantization=None, + ) + + # Override model_ledger.transformer() to load from GGUF instead of safetensors + self.pipeline.model_ledger.transformer = self._build_gguf_transformer # type: ignore[assignment] + + def _build_gguf_transformer(self) -> Any: + """Build the GGUF transformer model on each call (matches ModelLedger contract).""" + return _create_gguf_transformer( + state_dict=self._gguf_state_dict, + device=self._device, + ) + + def _run_inference( + self, + prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + images: list[ImageConditioningInput], + tiling_config: TilingConfigType, + ) -> tuple[torch.Tensor | Iterator[torch.Tensor], AudioOrNone]: + from ltx_pipelines.utils.args import ImageConditioningInput as _LtxImageInput + + return self.pipeline( + prompt=prompt, + seed=seed, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + images=[_LtxImageInput(img.path, img.frame_idx, img.strength) for img in images], + tiling_config=tiling_config, + ) + + @torch.inference_mode() + def generate( + self, + prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + images: list[ImageConditioningInput], + output_path: str, + ) -> None: + tiling_config = default_tiling_config() + video, audio = self._run_inference( + prompt=prompt, + seed=seed, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + images=images, + tiling_config=tiling_config, + ) + chunks = video_chunks_number(num_frames, tiling_config) + encode_video_output(video=video, audio=audio, fps=int(frame_rate), output_path=output_path, video_chunks_number_value=chunks) + + @torch.inference_mode() + def warmup(self, output_path: str) -> None: + warmup_frames = 9 + tiling_config = default_tiling_config() + + try: + video, audio = self._run_inference( + prompt="test warmup", + seed=42, + height=256, + width=384, + num_frames=warmup_frames, + frame_rate=8, + images=[], + tiling_config=tiling_config, + ) + chunks = video_chunks_number(warmup_frames, tiling_config) + encode_video_output(video=video, audio=audio, fps=8, output_path=output_path, video_chunks_number_value=chunks) + finally: + if os.path.exists(output_path): + os.unlink(output_path) + + def compile_transformer(self) -> None: + logger.info("Skipping torch.compile for GGUF pipeline — not supported with quantized weights") diff --git a/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py b/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py index d24c5c32..0236c2e3 100644 --- a/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py +++ b/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py @@ -4,7 +4,7 @@ from collections.abc import Iterator import os -from typing import Final, cast +from typing import Any, Final, cast import torch @@ -22,26 +22,57 @@ def create( gemma_root: str | None, upsampler_path: str, device: torch.device, + lora_path: str | None = None, + lora_weight: float = 1.0, ) -> "LTXFastVideoPipeline": return LTXFastVideoPipeline( checkpoint_path=checkpoint_path, gemma_root=gemma_root, upsampler_path=upsampler_path, device=device, + lora_path=lora_path, + lora_weight=lora_weight, ) - def __init__(self, checkpoint_path: str, gemma_root: str | None, upsampler_path: str, device: torch.device) -> None: + def __init__( + self, + checkpoint_path: str, + gemma_root: str | None, + upsampler_path: str, + device: torch.device, + lora_path: str | None = None, + lora_weight: float = 1.0, + ) -> None: from ltx_core.quantization import QuantizationPolicy from ltx_pipelines.distilled import DistilledPipeline + lora_entries: list[Any] = [] + if lora_path: + from ltx_core.loader.primitives import LoraPathStrengthAndSDOps # pyright: ignore[reportMissingImports] + + sd_ops: Any = None + try: + import importlib + _ser = importlib.import_module("ltx_core.loader.serialization") + sd_ops = getattr(_ser, "LTXV_LORA_COMFY_RENAMING_MAP", None) + except (ImportError, AttributeError): + pass + + lora_entries = [LoraPathStrengthAndSDOps( + path=lora_path, + strength=lora_weight, + sd_ops=sd_ops, + )] + self.pipeline = DistilledPipeline( distilled_checkpoint_path=checkpoint_path, gemma_root=cast(str, gemma_root), spatial_upsampler_path=upsampler_path, - loras=[], + loras=lora_entries, device=device, quantization=QuantizationPolicy.fp8_cast() if device_supports_fp8(device) else None, ) + self.lora_path = lora_path def _run_inference( self, diff --git a/backend/services/fast_video_pipeline/nf4_fast_video_pipeline.py b/backend/services/fast_video_pipeline/nf4_fast_video_pipeline.py new file mode 100644 index 00000000..a5a40936 --- /dev/null +++ b/backend/services/fast_video_pipeline/nf4_fast_video_pipeline.py @@ -0,0 +1,295 @@ +"""NF4 (4-bit BitsAndBytes) quantized LTX video pipeline. + +Uses BitsAndBytes NF4 quantization to reduce the LTX transformer's VRAM +footprint from ~22 GB (BF16) to ~6 GB (NF4). The base BF16 checkpoint +is loaded and the transformer's Linear layers are replaced with +``bitsandbytes.nn.Linear4bit`` before being moved to GPU — the +quantization happens automatically during the ``.to(device)`` call. + +Requires: + - ``bitsandbytes`` package (``pip install bitsandbytes``) + - Base BF16 checkpoint in the same directory or auto-discoverable +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Iterator +from pathlib import Path +from typing import Any, Final, cast + +import torch + +from api_types import ImageConditioningInput +from services.ltx_pipeline_common import default_tiling_config, encode_video_output, video_chunks_number +from services.services_utils import AudioOrNone, TilingConfigType + +logger = logging.getLogger(__name__) + +_BASE_CHECKPOINT_NAME = "ltx-2.3-22b-distilled.safetensors" + + +def _find_base_checkpoint(nf4_path: str, upsampler_path: str) -> str: + """Locate the base BF16 safetensors checkpoint for all components. + + NF4 runtime quantization starts from the BF16 weights and quantizes + on the fly. Searches same directory as the NF4 folder, then models_dir. + """ + search_dirs = [ + Path(nf4_path).parent, + Path(upsampler_path).parent, + ] + for d in search_dirs: + candidate = d / _BASE_CHECKPOINT_NAME + if candidate.is_file(): + return str(candidate) + + # Fallback: largest safetensors in models dir + for d in search_dirs: + safetensors = sorted(d.glob("*.safetensors"), key=lambda p: p.stat().st_size, reverse=True) + for sf in safetensors: + if sf.stat().st_size > 10_000_000_000: + return str(sf) + + raise FileNotFoundError( + f"Cannot find base BF16 checkpoint ({_BASE_CHECKPOINT_NAME}) near {nf4_path}. " + "The base model is required for NF4 runtime quantization. " + "Download it from the Models tab in Settings." + ) + + +def _apply_nf4_quantization(model: torch.nn.Module) -> None: + """Replace all ``nn.Linear`` layers in *model* with ``bnb.nn.Linear4bit`` (NF4). + + Weights are stored as ``Params4bit`` objects that quantize to NF4 when + moved to a CUDA device. The replacement happens in-place. + """ + import bitsandbytes as bnb # type: ignore[import-untyped] + + replacements: list[tuple[torch.nn.Module, str, torch.nn.Module]] = [] + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + nf4_linear = bnb.nn.Linear4bit( + module.in_features, + module.out_features, + bias=module.bias is not None, + compute_dtype=torch.bfloat16, + quant_type="nf4", + ) + # Store BF16 weights as Params4bit — quantized on .to(cuda) + nf4_linear.weight = bnb.nn.Params4bit( + module.weight.data.to(torch.bfloat16), + requires_grad=False, + quant_type="nf4", + ) + if module.bias is not None: + nf4_linear.bias = module.bias + + # Find parent module to set attribute + parts = name.rsplit(".", 1) + if len(parts) == 2: + parent_name, child_name = parts + else: + parent_name, child_name = "", parts[0] + + parent = model.get_submodule(parent_name) if parent_name else model + replacements.append((parent, child_name, nf4_linear)) + + for parent, child_name, new_module in replacements: + setattr(parent, child_name, new_module) + + n_replaced = len(replacements) + logger.info("Replaced %d Linear layers with NF4 Linear4bit", n_replaced) + + +class NF4FastVideoPipeline: + """FastVideoPipeline that quantizes the transformer to NF4 at runtime. + + Uses the ``DistilledPipeline`` for two-stage video generation but + overrides the transformer loading to: + 1. Build the BF16 model on CPU + 2. Replace all ``nn.Linear`` with ``bnb.nn.Linear4bit`` (NF4) + 3. Move to GPU — this triggers 4-bit quantization + + Peak CPU RAM: ~43 GB (full BF16 model during build) + Peak GPU VRAM: ~6 GB (NF4 transformer) + ~3 GB (VAE) = ~9 GB + """ + + pipeline_kind: Final = "fast" + + @staticmethod + def create( + checkpoint_path: str, + gemma_root: str | None, + upsampler_path: str, + device: torch.device, + lora_path: str | None = None, + lora_weight: float = 1.0, + ) -> "NF4FastVideoPipeline": + return NF4FastVideoPipeline( + checkpoint_path=checkpoint_path, + gemma_root=gemma_root, + upsampler_path=upsampler_path, + device=device, + lora_path=lora_path, + lora_weight=lora_weight, + ) + + def __init__( + self, + checkpoint_path: str, + gemma_root: str | None, + upsampler_path: str, + device: torch.device, + lora_path: str | None = None, + lora_weight: float = 1.0, + ) -> None: + try: + import bitsandbytes # noqa: F401 # pyright: ignore[reportUnusedImport,reportMissingImports] + except ImportError: + raise RuntimeError( + "NF4 quantization requires the 'bitsandbytes' package. " + "Install it with: pip install bitsandbytes" + ) from None + + # checkpoint_path is the NF4 folder (from model scanner selection). + # We need the base BF16 model for actual loading + runtime quantization. + base_checkpoint = _find_base_checkpoint(checkpoint_path, upsampler_path) + logger.info("NF4 pipeline: base model from %s, runtime NF4 quantization", base_checkpoint) + + self._device = device + + # Build DistilledPipeline with the base BF16 model. + # LoRA is passed normally — the BF16 model builds with fused LoRA, + # then we quantize the result to NF4. + from ltx_pipelines.distilled import DistilledPipeline + + lora_entries: list[Any] = [] + if lora_path: + from ltx_core.loader.primitives import LoraPathStrengthAndSDOps # pyright: ignore[reportMissingImports] + + sd_ops: Any = None + try: + import importlib + _ser = importlib.import_module("ltx_core.loader.serialization") + sd_ops = getattr(_ser, "LTXV_LORA_COMFY_RENAMING_MAP", None) + except (ImportError, AttributeError): + pass + + lora_entries = [LoraPathStrengthAndSDOps( + path=lora_path, + strength=lora_weight, + sd_ops=sd_ops, + )] + + self.pipeline = DistilledPipeline( + distilled_checkpoint_path=base_checkpoint, + gemma_root=cast(str, gemma_root), + spatial_upsampler_path=upsampler_path, + loras=lora_entries, + device=device, + quantization=None, # We handle quantization ourselves + ) + + # Override model_ledger.transformer() to build BF16 on CPU then quantize to NF4 + self._original_transformer = self.pipeline.model_ledger.transformer + self.pipeline.model_ledger.transformer = self._build_nf4_transformer # type: ignore[assignment] + + def _build_nf4_transformer(self) -> Any: + """Build the transformer with NF4 quantization.""" + # Temporarily redirect model building to CPU + original_device = self.pipeline.model_ledger.device + self.pipeline.model_ledger.device = torch.device("cpu") + + try: + # Build full BF16 model on CPU (uses ~43GB CPU RAM) + model = self._original_transformer() + logger.info("Built BF16 transformer on CPU, applying NF4 quantization...") + finally: + self.pipeline.model_ledger.device = original_device + + # Replace Linear layers with NF4 + # model is an X0Model wrapping the actual transformer + inner_model = model.model if hasattr(model, "model") else model + _apply_nf4_quantization(inner_model) + + # Move to GPU — triggers NF4 quantization of Params4bit + logger.info("Moving NF4 transformer to %s", self._device) + return model.to(self._device).eval() + + def _run_inference( + self, + prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + images: list[ImageConditioningInput], + tiling_config: TilingConfigType, + ) -> tuple[torch.Tensor | Iterator[torch.Tensor], AudioOrNone]: + from ltx_pipelines.utils.args import ImageConditioningInput as _LtxImageInput + + return self.pipeline( + prompt=prompt, + seed=seed, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + images=[_LtxImageInput(img.path, img.frame_idx, img.strength) for img in images], + tiling_config=tiling_config, + ) + + @torch.inference_mode() + def generate( + self, + prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + images: list[ImageConditioningInput], + output_path: str, + ) -> None: + tiling_config = default_tiling_config() + video, audio = self._run_inference( + prompt=prompt, + seed=seed, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + images=images, + tiling_config=tiling_config, + ) + chunks = video_chunks_number(num_frames, tiling_config) + encode_video_output(video=video, audio=audio, fps=int(frame_rate), output_path=output_path, video_chunks_number_value=chunks) + + @torch.inference_mode() + def warmup(self, output_path: str) -> None: + warmup_frames = 9 + tiling_config = default_tiling_config() + + try: + video, audio = self._run_inference( + prompt="test warmup", + seed=42, + height=256, + width=384, + num_frames=warmup_frames, + frame_rate=8, + images=[], + tiling_config=tiling_config, + ) + chunks = video_chunks_number(warmup_frames, tiling_config) + encode_video_output(video=video, audio=audio, fps=8, output_path=output_path, video_chunks_number_value=chunks) + finally: + if os.path.exists(output_path): + os.unlink(output_path) + + def compile_transformer(self) -> None: + logger.info("Skipping torch.compile for NF4 pipeline — not supported with quantized weights") diff --git a/backend/services/gpu_cleaner/gpu_cleaner.py b/backend/services/gpu_cleaner/gpu_cleaner.py index 6910ae30..10a7e519 100644 --- a/backend/services/gpu_cleaner/gpu_cleaner.py +++ b/backend/services/gpu_cleaner/gpu_cleaner.py @@ -8,3 +8,6 @@ class GpuCleaner(Protocol): def cleanup(self) -> None: ... + + def deep_cleanup(self) -> None: + ... diff --git a/backend/services/gpu_cleaner/torch_cleaner.py b/backend/services/gpu_cleaner/torch_cleaner.py index 3e22e73b..a814e4d4 100644 --- a/backend/services/gpu_cleaner/torch_cleaner.py +++ b/backend/services/gpu_cleaner/torch_cleaner.py @@ -18,3 +18,12 @@ def __init__(self, device: str | torch.device = "cpu") -> None: def cleanup(self) -> None: empty_device_cache(self._device) gc.collect() + + def deep_cleanup(self) -> None: + """Aggressive cleanup for after heavy GPU workloads.""" + gc.collect() + empty_device_cache(self._device) + gc.collect() + empty_device_cache(self._device) + if str(self._device) != "cpu" and torch.cuda.is_available(): + torch.cuda.synchronize() diff --git a/backend/services/gpu_optimizations/__init__.py b/backend/services/gpu_optimizations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/services/gpu_optimizations/ffn_chunking.py b/backend/services/gpu_optimizations/ffn_chunking.py new file mode 100644 index 00000000..07148a75 --- /dev/null +++ b/backend/services/gpu_optimizations/ffn_chunking.py @@ -0,0 +1,67 @@ +"""Chunked feedforward optimization for LTX transformer. + +Splits FeedForward.forward along the sequence dimension (dim=1) to reduce +peak VRAM. Output is mathematically identical to unchunked forward — +FeedForward is pointwise along the sequence dimension so chunking is lossless. + +Reference: RandomInternetPreson/ComfyUI_LTX-2_VRAM_Memory_Management (V3.1) +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +import torch + +logger = logging.getLogger(__name__) + +_MIN_SEQ_PER_CHUNK = 100 # skip chunking for short sequences + + +def _make_chunked_forward( + original_forward: Callable[[torch.Tensor], torch.Tensor], + num_chunks: int, +) -> Callable[[torch.Tensor], torch.Tensor]: + """Return a drop-in replacement for FeedForward.forward that chunks along dim=1.""" + + def chunked_forward(x: torch.Tensor) -> torch.Tensor: + if x.dim() != 3: + return original_forward(x) + + seq_len = x.shape[1] + if seq_len < num_chunks * _MIN_SEQ_PER_CHUNK: + return original_forward(x) + + chunk_size = (seq_len + num_chunks - 1) // num_chunks + outputs: list[torch.Tensor] = [] + for start in range(0, seq_len, chunk_size): + end = min(start + chunk_size, seq_len) + outputs.append(original_forward(x[:, start:end, :])) + return torch.cat(outputs, dim=1) + + return chunked_forward + + +def patch_ffn_chunking(model: torch.nn.Module, num_chunks: int = 8) -> int: + """Monkey-patch all FeedForward modules in *model* to use chunked forward. + + Returns the number of modules patched. + """ + patched = 0 + named: list[tuple[str, torch.nn.Module]] = list(model.named_modules()) # pyright: ignore[reportUnknownArgumentType] + for name, module in named: + net = getattr(module, "net", None) + if net is None: + continue + if not isinstance(net, torch.nn.Sequential): + continue + if not (name.endswith(".ff") or name.endswith(".audio_ff")): + continue + + original: Callable[[torch.Tensor], torch.Tensor] = module.forward # type: ignore[assignment] + module.forward = _make_chunked_forward(original, num_chunks) # type: ignore[assignment] + patched += 1 + + if patched: + logger.info("FFN chunking: patched %d feedforward modules (chunks=%d)", patched, num_chunks) + return patched diff --git a/backend/services/gpu_optimizations/tea_cache.py b/backend/services/gpu_optimizations/tea_cache.py new file mode 100644 index 00000000..e474d433 --- /dev/null +++ b/backend/services/gpu_optimizations/tea_cache.py @@ -0,0 +1,162 @@ +"""TeaCache: Timestep-Aware Caching for diffusion denoising loops. + +Wraps a denoising function to skip transformer forward passes when the +timestep embedding hasn't changed significantly from the previous step. +First and last steps are always computed. + +Reference: ali-vilab/TeaCache (TeaCache4LTX-Video) +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + +# Polynomial fitted to LTX-Video noise schedule for rescaling relative L1 distance +_RESCALE_COEFFICIENTS = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03] +_rescale_poly = np.poly1d(_RESCALE_COEFFICIENTS) + +_original_euler_loop: Any = None + + +@dataclass +class TeaCacheState: + """Mutable state held across denoising steps.""" + accumulated_distance: float = 0.0 + previous_residual: torch.Tensor | None = None + previous_modulated_input: torch.Tensor | None = None + step_count: int = 0 + skipped: int = 0 + computed: int = 0 + + +def wrap_denoise_fn_with_tea_cache( + denoise_fn: Any, + num_steps: int, + threshold: float, +) -> Any: + """Wrap a denoising function with TeaCache. + + The wrapped function has the same signature as the original: + denoise_fn(video_state, audio_state, sigmas, step_index) + -> (denoised_video, denoised_audio) + + When the relative L1 distance of the timestep-modulated input is below + *threshold*, the previous residual is reused instead of calling the + transformer. + + Args: + denoise_fn: Original denoising function. + num_steps: Total number of denoising steps (len(sigmas) - 1). + threshold: Caching threshold. 0 disables. 0.03 = balanced. 0.05 = aggressive. + """ + if threshold <= 0: + return denoise_fn + + state = TeaCacheState() + + def cached_denoise( + video_state: Any, + audio_state: Any, + sigmas: torch.Tensor, + step_index: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Always compute first and last steps + if step_index == 0 or step_index == num_steps - 1: + should_compute = True + state.accumulated_distance = 0.0 + elif state.previous_modulated_input is not None: + # Estimate change using video_state latent as proxy for modulated input + current = video_state.latent + prev = state.previous_modulated_input + rel_l1 = ((current - prev).abs().mean() / prev.abs().mean().clamp(min=1e-8)).item() + rescaled = float(_rescale_poly(rel_l1)) + state.accumulated_distance += rescaled + + if state.accumulated_distance < threshold: + should_compute = False + else: + should_compute = True + state.accumulated_distance = 0.0 + else: + should_compute = True + state.accumulated_distance = 0.0 + + state.previous_modulated_input = video_state.latent.clone() + state.step_count += 1 + + if not should_compute and state.previous_residual is not None: + # Reuse cached residual + cached_video = video_state.latent + state.previous_residual + state.skipped += 1 + return cached_video, audio_state.latent + else: + # Full computation + original_latent = video_state.latent.clone() + denoised_video, denoised_audio = denoise_fn(video_state, audio_state, sigmas, step_index) + state.previous_residual = denoised_video - original_latent + state.computed += 1 + return denoised_video, denoised_audio + + cached_denoise._tea_cache_state = state # type: ignore[attr-defined] + return cached_denoise + + +def install_tea_cache_patch(threshold: float) -> None: + """Monkey-patch euler_denoising_loop in ltx_pipelines to apply TeaCache. + + This patches the function at the module level so that DistilledPipeline + (which imports euler_denoising_loop inside __call__) picks it up + automatically on each generation. + """ + global _original_euler_loop + + import ltx_pipelines.utils.samplers as samplers_mod + + if _original_euler_loop is None: + _original_euler_loop = samplers_mod.euler_denoising_loop + + if threshold <= 0: + # Restore original + samplers_mod.euler_denoising_loop = _original_euler_loop + logger.info("TeaCache disabled — restored original euler_denoising_loop") + return + + original = _original_euler_loop + + def tea_cache_euler_loop( + sigmas: torch.Tensor, + video_state: Any, + audio_state: Any, + stepper: Any, + denoise_fn: Any, + **kwargs: Any, + ) -> Any: + num_steps = len(sigmas) - 1 + cached_fn = wrap_denoise_fn_with_tea_cache(denoise_fn, num_steps=num_steps, threshold=threshold) + result = original( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=cached_fn, + **kwargs, + ) + if hasattr(cached_fn, "_tea_cache_state"): + s = cached_fn._tea_cache_state + logger.info("TeaCache: computed %d, skipped %d of %d steps", s.computed, s.skipped, s.step_count) + return result + + samplers_mod.euler_denoising_loop = tea_cache_euler_loop # type: ignore[assignment] + logger.info("TeaCache installed (threshold=%.3f)", threshold) + + +def uninstall_tea_cache_patch() -> None: + """Restore the original euler_denoising_loop.""" + install_tea_cache_patch(0.0) diff --git a/backend/services/image_api_client/__init__.py b/backend/services/image_api_client/__init__.py new file mode 100644 index 00000000..18912c7d --- /dev/null +++ b/backend/services/image_api_client/__init__.py @@ -0,0 +1,6 @@ +"""Image API client exports.""" + +from services.image_api_client.image_api_client import ImageAPIClient +from services.image_api_client.replicate_client_impl import ReplicateImageClientImpl + +__all__ = ["ImageAPIClient", "ReplicateImageClientImpl"] diff --git a/backend/services/zit_api_client/zit_api_client.py b/backend/services/image_api_client/image_api_client.py similarity index 72% rename from backend/services/zit_api_client/zit_api_client.py rename to backend/services/image_api_client/image_api_client.py index 9dddb1f1..8d8d0dad 100644 --- a/backend/services/zit_api_client/zit_api_client.py +++ b/backend/services/image_api_client/image_api_client.py @@ -1,15 +1,16 @@ -"""Z-Image Turbo API client protocol for FAL endpoints.""" +"""Image API client protocol for cloud image generation.""" from __future__ import annotations from typing import Protocol -class ZitAPIClient(Protocol): +class ImageAPIClient(Protocol): def generate_text_to_image( self, *, api_key: str, + model: str, prompt: str, width: int, height: int, diff --git a/backend/services/image_api_client/replicate_client_impl.py b/backend/services/image_api_client/replicate_client_impl.py new file mode 100644 index 00000000..c97f2473 --- /dev/null +++ b/backend/services/image_api_client/replicate_client_impl.py @@ -0,0 +1,213 @@ +"""Replicate API client implementation for cloud image generation.""" + +from __future__ import annotations + +import time +from typing import Any, cast + +from services.http_client.http_client import HTTPClient +from services.services_utils import JSONValue + +REPLICATE_API_BASE_URL = "https://api.replicate.com/v1" + +_MODEL_ROUTES: dict[str, str] = { + "z-image-turbo": "prunaai/z-image-turbo", + "nano-banana-2": "google/nano-banana-2", +} + +_NANO_BANANA_ASPECT_RATIOS = [ + (1, 1), + (2, 3), + (3, 2), + (3, 4), + (4, 3), + (4, 5), + (5, 4), + (9, 16), + (16, 9), +] + +_POLL_INTERVAL_SECONDS = 2 +_POLL_TIMEOUT_SECONDS = 120 + + +def _closest_aspect_ratio(width: int, height: int) -> str: + target = width / height + best: tuple[int, int] | None = None + best_diff = float("inf") + for w, h in _NANO_BANANA_ASPECT_RATIOS: + diff = abs(target - w / h) + if diff < best_diff: + best_diff = diff + best = (w, h) + assert best is not None + return f"{best[0]}:{best[1]}" + + +def _resolution_bucket(width: int, height: int) -> str: + largest = max(width, height) + if largest <= 512: + return "512px" + if largest <= 1024: + return "1K" + if largest <= 2048: + return "2K" + return "4K" + + +class ReplicateImageClientImpl: + def __init__(self, http: HTTPClient, *, api_base_url: str = REPLICATE_API_BASE_URL) -> None: + self._http = http + self._base_url = api_base_url.rstrip("/") + + def generate_text_to_image( + self, + *, + api_key: str, + model: str, + prompt: str, + width: int, + height: int, + seed: int, + num_inference_steps: int, + ) -> bytes: + replicate_model = _MODEL_ROUTES.get(model) + if replicate_model is None: + raise RuntimeError(f"Unknown image model: {model}") + + input_payload = self._build_input( + model=model, + prompt=prompt, + width=width, + height=height, + seed=seed, + num_inference_steps=num_inference_steps, + ) + + prediction = self._create_prediction( + api_key=api_key, + replicate_model=replicate_model, + input_payload=input_payload, + ) + + output_url = self._wait_for_output(api_key, prediction) + return self._download_image(output_url) + + def _build_input( + self, + *, + model: str, + prompt: str, + width: int, + height: int, + seed: int, + num_inference_steps: int, + ) -> dict[str, JSONValue]: + if model == "nano-banana-2": + return { + "prompt": prompt, + "aspect_ratio": _closest_aspect_ratio(width, height), + "resolution": _resolution_bucket(width, height), + "output_format": "png", + "seed": seed, + } + return { + "prompt": prompt, + "width": width, + "height": height, + "seed": seed, + "num_inference_steps": num_inference_steps, + } + + def _create_prediction( + self, + *, + api_key: str, + replicate_model: str, + input_payload: dict[str, JSONValue], + ) -> dict[str, Any]: + url = f"{self._base_url}/models/{replicate_model}/predictions" + payload: dict[str, JSONValue] = {"input": input_payload} + + response = self._http.post( + url, + headers=self._headers(api_key, prefer_wait=True), + json_payload=payload, + timeout=180, + ) + if response.status_code not in (200, 201): + detail = response.text[:500] if response.text else "Unknown error" + raise RuntimeError(f"Replicate prediction failed ({response.status_code}): {detail}") + + return self._json_object(response.json(), context="create prediction") + + def _wait_for_output(self, api_key: str, prediction: dict[str, Any]) -> str: + status = prediction.get("status", "") + if status == "succeeded": + return self._extract_output_url(prediction) + + if status in ("failed", "canceled"): + error = prediction.get("error", "Unknown error") + raise RuntimeError(f"Replicate prediction {status}: {error}") + + poll_url = prediction.get("urls", {}).get("get") + if not isinstance(poll_url, str) or not poll_url: + prediction_id = prediction.get("id", "") + poll_url = f"{self._base_url}/predictions/{prediction_id}" + + deadline = time.monotonic() + _POLL_TIMEOUT_SECONDS + while time.monotonic() < deadline: + time.sleep(_POLL_INTERVAL_SECONDS) + resp = self._http.get(poll_url, headers=self._headers(api_key), timeout=30) + if resp.status_code != 200: + detail = resp.text[:500] if resp.text else "Unknown error" + raise RuntimeError(f"Replicate poll failed ({resp.status_code}): {detail}") + + data = self._json_object(resp.json(), context="poll") + poll_status = data.get("status", "") + if poll_status == "succeeded": + return self._extract_output_url(data) + if poll_status in ("failed", "canceled"): + error = data.get("error", "Unknown error") + raise RuntimeError(f"Replicate prediction {poll_status}: {error}") + + raise RuntimeError("Replicate prediction timed out") + + def _download_image(self, url: str) -> bytes: + download = self._http.get(url, timeout=120) + if download.status_code != 200: + detail = download.text[:500] if download.text else "Unknown error" + raise RuntimeError(f"Replicate image download failed ({download.status_code}): {detail}") + if not download.content: + raise RuntimeError("Replicate image download returned empty body") + return download.content + + @staticmethod + def _headers(api_key: str, *, prefer_wait: bool = False) -> dict[str, str]: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + if prefer_wait: + headers["Prefer"] = "wait" + return headers + + @staticmethod + def _extract_output_url(prediction: dict[str, Any]) -> str: + output = prediction.get("output") + if isinstance(output, list) and output: + output_list = cast(list[object], output) + first = output_list[0] + if isinstance(first, str) and first: + return first + + if isinstance(output, str) and output: + return output + + raise RuntimeError("Replicate response missing output URL") + + @staticmethod + def _json_object(payload: object, *, context: str) -> dict[str, Any]: + if isinstance(payload, dict): + return cast(dict[str, Any], payload) + raise RuntimeError(f"Unexpected Replicate {context} response format") diff --git a/backend/services/image_generation_pipeline/flux_klein_pipeline.py b/backend/services/image_generation_pipeline/flux_klein_pipeline.py new file mode 100644 index 00000000..b46c2818 --- /dev/null +++ b/backend/services/image_generation_pipeline/flux_klein_pipeline.py @@ -0,0 +1,276 @@ +"""FLUX.2 Klein 9B Base image generation pipeline wrapper.""" + +from __future__ import annotations + +import gc +import logging +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, cast + +import numpy as np +import torch +from diffusers import AutoencoderKL, BitsAndBytesConfig, Flux2KleinPipeline, PipelineQuantizationConfig # type: ignore[reportUnknownVariableType] +from PIL import Image +from PIL.Image import Image as PILImage + +from services.services_utils import ImagePipelineOutputLike, PILImageType, get_device_type + +_logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class _FluxKleinOutput: + images: Sequence[PILImageType] + + +def _latents_to_pil(model_path: str, latents: torch.Tensor) -> list[PILImageType]: + """Decode latents to PIL via a fresh VAE loaded from disk. + + The pipeline's built-in AutoencoderKLFlux2 segfaults on Windows/CUDA + when accelerate model-cpu-offload hooks are active. Loading a clean + AutoencoderKL from the ``vae/`` subfolder and decoding on CPU in + float32 is the only reliable workaround. + """ + import pathlib + + latents_cpu = latents.to("cpu") + + gc.collect() + torch.cuda.empty_cache() + + vae_path = str(pathlib.Path(model_path) / "vae") + _logger.info("Loading fresh VAE from %s for decode", vae_path) + vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float32) # type: ignore[reportUnknownMemberType] + vae = vae.to("cpu") # type: ignore[reportUnknownMemberType] + vae.eval() # type: ignore[reportUnknownMemberType] + + latents_f32 = latents_cpu.to(dtype=torch.float32) + with torch.no_grad(): + decoded = vae.decode(latents_f32, return_dict=False)[0] # type: ignore[reportUnknownMemberType] + + decoded = (decoded / 2 + 0.5).clamp(0, 1) + images: list[PILImageType] = [] + for i in range(decoded.shape[0]): + arr = decoded[i].permute(1, 2, 0).numpy() + pil_img = Image.fromarray((arr * 255).astype(np.uint8)) + images.append(pil_img) + + del vae, latents_cpu, latents_f32, decoded + gc.collect() + + return images + + +class FluxKleinImagePipeline: + """FLUX.2 Klein 9B Base — text-to-image, img2img, and LoRA support. + + Loads the transformer with bitsandbytes NF4 quantization (~5GB instead + of ~18GB bf16). Still uses enable_model_cpu_offload() because the + T5-XXL text encoder (~9GB bf16) plus the NF4 transformer exceeds 24GB + VRAM when activation memory is included. + + After denoising, the pipeline is destroyed to release the accelerate + hooks that cause Windows/CUDA VAE segfaults, then latents are decoded + via a fresh VAE on CPU. + + The PipelinesHandler detects the destroyed pipeline and recreates it + on the next generation request (~10s rebuild with NF4). + """ + + @staticmethod + def create( + model_path: str, + device: str | None = None, + ) -> "FluxKleinImagePipeline": + return FluxKleinImagePipeline(model_path=model_path, device=device) + + def __init__(self, model_path: str, device: str | None = None) -> None: + self._device: str | None = None + self._model_offload_active = False + self._lora_loaded: str | None = None + self._model_path = model_path + + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + quant_config = PipelineQuantizationConfig( + quant_mapping={"transformer": nf4_config}, + ) + + self.pipeline = Flux2KleinPipeline.from_pretrained( # type: ignore[reportUnknownMemberType] + model_path, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + if device is not None: + self.to(device) + + def _resolve_generator_device(self) -> str: + if self._model_offload_active: + return "cpu" + if self._device is not None: + return self._device + execution_device = getattr(self.pipeline, "_execution_device", None) + return get_device_type(execution_device) + + @staticmethod + def _normalize_output(output: object) -> ImagePipelineOutputLike: + images = getattr(output, "images", None) + if not isinstance(images, Sequence): + raise RuntimeError("Unexpected FLUX Klein pipeline output format: missing images sequence") + + images_list = cast(Sequence[object], images) + validated_images: list[PILImageType] = [] + for image in images_list: + if not isinstance(image, PILImage): + raise RuntimeError("Unexpected FLUX Klein pipeline output: images must be PIL.Image instances") + validated_images.append(image) + + return _FluxKleinOutput(images=validated_images) + + @torch.inference_mode() + def generate( + self, + prompt: str, + height: int, + width: int, + guidance_scale: float, + num_inference_steps: int, + seed: int, + ) -> ImagePipelineOutputLike: + generator = torch.Generator(device=self._resolve_generator_device()).manual_seed(seed) + pipeline = cast(Any, self.pipeline) + + steps = num_inference_steps if num_inference_steps > 4 else 28 + gs = guidance_scale if guidance_scale > 0 else 4.0 + + if self._model_offload_active: + torch.cuda.empty_cache() + output = pipeline( + prompt=prompt, + height=height, + width=width, + guidance_scale=gs, + num_inference_steps=steps, + generator=generator, + output_type="latent", + return_dict=True, + ) + latents = output.images.to("cpu") + # Destroy pipeline to release accelerate hooks before VAE decode. + self._destroy_pipeline() + pil_images = _latents_to_pil(self._model_path, latents) + return _FluxKleinOutput(images=pil_images) + + output = pipeline( + prompt=prompt, + height=height, + width=width, + guidance_scale=gs, + num_inference_steps=steps, + generator=generator, + output_type="pil", + return_dict=True, + ) + return self._normalize_output(output) + + @torch.inference_mode() + def img2img( + self, + prompt: str, + image: PILImageType, + strength: float, + height: int, + width: int, + guidance_scale: float, + num_inference_steps: int, + seed: int, + ) -> ImagePipelineOutputLike: + generator = torch.Generator(device=self._resolve_generator_device()).manual_seed(seed) + pipeline = cast(Any, self.pipeline) + + base_steps = num_inference_steps if num_inference_steps > 4 else 28 + effective_steps = max(1, int(base_steps * strength)) + gs = guidance_scale if guidance_scale > 0 else 4.0 + + if self._model_offload_active: + torch.cuda.empty_cache() + output = pipeline( + prompt=prompt, + image=image, + height=height, + width=width, + guidance_scale=gs, + num_inference_steps=effective_steps, + generator=generator, + output_type="latent", + return_dict=True, + ) + latents = output.images.to("cpu") + self._destroy_pipeline() + pil_images = _latents_to_pil(self._model_path, latents) + return _FluxKleinOutput(images=pil_images) + + output = pipeline( + prompt=prompt, + image=image, + height=height, + width=width, + guidance_scale=gs, + num_inference_steps=effective_steps, + generator=generator, + output_type="pil", + return_dict=True, + ) + return self._normalize_output(output) + + def _destroy_pipeline(self) -> None: + """Destroy the diffusers pipeline to release accelerate hooks and VRAM. + + Required on Windows/CUDA: the accelerate model-cpu-offload hooks + cause a segfault during VAE decode. Destroying the pipeline before + loading a fresh VAE avoids the crash entirely. + + After calling this, the pipeline object is gone. The PipelinesHandler + will detect this and recreate the pipeline on the next generation. + """ + _logger.info("Destroying FLUX Klein pipeline to release accelerate hooks") + del self.pipeline + self._lora_loaded = None + self._model_offload_active = False + gc.collect() + torch.cuda.empty_cache() + + def to(self, device: str) -> None: + runtime_device = get_device_type(device) + if runtime_device in ("cuda", "mps"): + # Model-level CPU offload: moves whole sub-models to GPU one at + # a time. NF4 transformer (~5GB) + T5-XXL text_encoder (~9GB) + # + activations still exceeds 24GB VRAM at 1024x1024. + self.pipeline.enable_model_cpu_offload() # type: ignore[reportUnknownMemberType] + self._model_offload_active = True + else: + self._model_offload_active = False + self.pipeline.to(runtime_device) # type: ignore[reportUnknownMemberType] + self._device = runtime_device + + def load_lora(self, lora_path: str, weight: float = 1.0) -> None: + if self._lora_loaded == lora_path: + return + if self._lora_loaded is not None: + self.unload_lora() + pipeline = cast(Any, self.pipeline) + pipeline.load_lora_weights(lora_path, adapter_name="user_lora") + pipeline.set_adapters(["user_lora"], adapter_weights=[weight]) + self._lora_loaded = lora_path + + def unload_lora(self) -> None: + if self._lora_loaded is None: + return + pipeline = cast(Any, self.pipeline) + pipeline.unload_lora_weights() + self._lora_loaded = None diff --git a/backend/services/image_generation_pipeline/image_generation_pipeline.py b/backend/services/image_generation_pipeline/image_generation_pipeline.py index 0439569f..67a4b384 100644 --- a/backend/services/image_generation_pipeline/image_generation_pipeline.py +++ b/backend/services/image_generation_pipeline/image_generation_pipeline.py @@ -4,7 +4,7 @@ from typing import Protocol -from services.services_utils import ImagePipelineOutputLike +from services.services_utils import ImagePipelineOutputLike, PILImageType class ImageGenerationPipeline(Protocol): @@ -26,5 +26,24 @@ def generate( ) -> ImagePipelineOutputLike: ... + def img2img( + self, + prompt: str, + image: PILImageType, + strength: float, + height: int, + width: int, + guidance_scale: float, + num_inference_steps: int, + seed: int, + ) -> ImagePipelineOutputLike: + ... + def to(self, device: str) -> None: ... + + def load_lora(self, lora_path: str, weight: float = 1.0) -> None: + ... + + def unload_lora(self) -> None: + ... diff --git a/backend/services/image_generation_pipeline/zit_image_generation_pipeline.py b/backend/services/image_generation_pipeline/zit_image_generation_pipeline.py index 798ce130..1349f620 100644 --- a/backend/services/image_generation_pipeline/zit_image_generation_pipeline.py +++ b/backend/services/image_generation_pipeline/zit_image_generation_pipeline.py @@ -7,7 +7,7 @@ from typing import Any, cast import torch -from diffusers.pipelines.auto_pipeline import ZImagePipeline # type: ignore[reportUnknownVariableType] +from diffusers.pipelines.auto_pipeline import ZImageImg2ImgPipeline, ZImagePipeline # type: ignore[reportUnknownVariableType] from PIL.Image import Image as PILImage from services.services_utils import ImagePipelineOutputLike, PILImageType, get_device_type @@ -29,10 +29,20 @@ def create( def __init__(self, model_path: str, device: str | None = None) -> None: self._device: str | None = None self._cpu_offload_active = False + self._lora_loaded: str | None = None self.pipeline = ZImagePipeline.from_pretrained( # type: ignore[reportUnknownMemberType] model_path, torch_dtype=torch.bfloat16, ) + # Create img2img pipeline sharing the same model components — no extra VRAM. + pipeline_any = cast(Any, self.pipeline) + self._img2img_pipeline: Any = ZImageImg2ImgPipeline( # type: ignore[reportUnknownMemberType] + scheduler=pipeline_any.scheduler, + vae=pipeline_any.vae, + text_encoder=pipeline_any.text_encoder, + tokenizer=pipeline_any.tokenizer, + transformer=pipeline_any.transformer, + ) if device is not None: self.to(device) @@ -86,12 +96,59 @@ def generate( ) return self._normalize_output(output) + @torch.inference_mode() + def img2img( + self, + prompt: str, + image: PILImageType, + strength: float, + height: int, + width: int, + guidance_scale: float, + num_inference_steps: int, + seed: int, + ) -> ImagePipelineOutputLike: + _ = guidance_scale + generator = torch.Generator(device=self._resolve_generator_device()).manual_seed(seed) + output = self._img2img_pipeline( + prompt=prompt, + image=image, + strength=strength, + height=height, + width=width, + guidance_scale=0.0, + num_inference_steps=num_inference_steps, + generator=generator, + output_type="pil", + return_dict=True, + ) + return self._normalize_output(output) + def to(self, device: str) -> None: runtime_device = get_device_type(device) if runtime_device in ("cuda", "mps"): self.pipeline.enable_model_cpu_offload() # type: ignore[reportUnknownMemberType] + self._img2img_pipeline.enable_model_cpu_offload() # type: ignore[reportUnknownMemberType] self._cpu_offload_active = True else: self._cpu_offload_active = False self.pipeline.to(runtime_device) # type: ignore[reportUnknownMemberType] + self._img2img_pipeline.to(runtime_device) # type: ignore[reportUnknownMemberType] self._device = runtime_device + + def load_lora(self, lora_path: str, weight: float = 1.0) -> None: + if self._lora_loaded == lora_path: + return + if self._lora_loaded is not None: + self.unload_lora() + pipeline = cast(Any, self.pipeline) + pipeline.load_lora_weights(lora_path, adapter_name="user_lora") + pipeline.set_adapters(["user_lora"], adapter_weights=[weight]) + self._lora_loaded = lora_path + + def unload_lora(self) -> None: + if self._lora_loaded is None: + return + pipeline = cast(Any, self.pipeline) + pipeline.unload_lora_weights() + self._lora_loaded = None diff --git a/backend/services/interfaces.py b/backend/services/interfaces.py index cbb560c3..844ae8bf 100644 --- a/backend/services/interfaces.py +++ b/backend/services/interfaces.py @@ -6,7 +6,7 @@ from services.a2v_pipeline.a2v_pipeline import A2VPipeline from services.fast_video_pipeline.fast_video_pipeline import FastVideoPipeline -from services.zit_api_client.zit_api_client import ZitAPIClient +from services.image_api_client.image_api_client import ImageAPIClient from services.gpu_cleaner.gpu_cleaner import GpuCleaner from services.gpu_info.gpu_info import GpuInfo, GpuTelemetryPayload from services.http_client.http_client import HTTPClient, HttpResponseLike, HttpTimeoutError @@ -23,6 +23,9 @@ from services.services_utils import JSONScalar, JSONValue from services.task_runner.task_runner import TaskRunner from services.text_encoder.text_encoder import TextEncoder +from services.model_scanner.model_scanner import ModelScanner +from services.palette_sync_client.palette_sync_client import PaletteSyncClient +from services.video_api_client.video_api_client import VideoAPIClient from services.video_processor.video_processor import VideoInfoPayload, VideoProcessor VideoPipelineModelType = Literal["fast"] @@ -45,11 +48,14 @@ "TaskRunner", "VideoPipelineModelType", "FastVideoPipeline", - "ZitAPIClient", + "ImageAPIClient", "ImageGenerationPipeline", "IcLoraPipeline", "IcLoraModelDownloader", "LTXAPIClient", + "ModelScanner", + "PaletteSyncClient", "RetakePipeline", "TextEncoder", + "VideoAPIClient", ] diff --git a/backend/services/ltx_api_client/ltx_api_client.py b/backend/services/ltx_api_client/ltx_api_client.py index 69d43d5a..0d7a750a 100644 --- a/backend/services/ltx_api_client/ltx_api_client.py +++ b/backend/services/ltx_api_client/ltx_api_client.py @@ -42,6 +42,7 @@ def generate_text_to_video( fps: float, generate_audio: bool, camera_motion: VideoCameraMotion = "none", + last_frame_uri: str | None = None, ) -> bytes: ... @@ -57,6 +58,7 @@ def generate_image_to_video( fps: float, generate_audio: bool, camera_motion: VideoCameraMotion = "none", + last_frame_uri: str | None = None, ) -> bytes: ... diff --git a/backend/services/ltx_api_client/ltx_api_client_impl.py b/backend/services/ltx_api_client/ltx_api_client_impl.py index ec1d95e1..67de8aa6 100644 --- a/backend/services/ltx_api_client/ltx_api_client_impl.py +++ b/backend/services/ltx_api_client/ltx_api_client_impl.py @@ -68,6 +68,7 @@ def generate_text_to_video( fps: float, generate_audio: bool, camera_motion: VideoCameraMotion = "none", + last_frame_uri: str | None = None, ) -> bytes: payload: dict[str, JSONValue] = { "prompt": prompt, @@ -80,6 +81,8 @@ def generate_text_to_video( mapped_camera_motion = self._map_camera_motion(camera_motion) if mapped_camera_motion is not None: payload["camera_motion"] = mapped_camera_motion + if last_frame_uri is not None: + payload["last_frame_uri"] = last_frame_uri response = self._http.post( f"{self._base_url}/v1/text-to-video", headers=self._json_headers(api_key), @@ -100,6 +103,7 @@ def generate_image_to_video( fps: float, generate_audio: bool, camera_motion: VideoCameraMotion = "none", + last_frame_uri: str | None = None, ) -> bytes: payload: dict[str, JSONValue] = { "prompt": prompt, @@ -113,6 +117,8 @@ def generate_image_to_video( mapped_camera_motion = self._map_camera_motion(camera_motion) if mapped_camera_motion is not None: payload["camera_motion"] = mapped_camera_motion + if last_frame_uri is not None: + payload["last_frame_uri"] = last_frame_uri response = self._http.post( f"{self._base_url}/v1/image-to-video", headers=self._json_headers(api_key), diff --git a/backend/services/model_downloader/hugging_face_downloader.py b/backend/services/model_downloader/hugging_face_downloader.py index c675cbb8..ba2e8691 100644 --- a/backend/services/model_downloader/hugging_face_downloader.py +++ b/backend/services/model_downloader/hugging_face_downloader.py @@ -46,14 +46,15 @@ def update(self, n: float | int | None = 1) -> bool | None: # type: ignore[repo @contextlib.contextmanager -def _patch_http_get_progress(callback: Callable[[int, int], None]) -> Iterator[None]: +def _patch_download_progress(callback: Callable[[int, int], None]) -> Iterator[None]: """Temporarily monkey-patch ``huggingface_hub.file_download.http_get`` - to inject a custom tqdm bar that forwards progress to *callback*. + and ``xet_get`` to inject a custom tqdm bar that forwards progress to + *callback*. ``hf_hub_download`` does not expose a ``tqdm_class`` parameter (unlike - ``snapshot_download``), but its internal ``http_get`` accepts a private - ``_tqdm_bar`` kwarg. We wrap ``http_get`` to inject our own bar when - the caller hasn't already provided one. + ``snapshot_download``), but its internal ``http_get`` and ``xet_get`` + both accept a private ``_tqdm_bar`` kwarg. We wrap them to inject our + own bar when the caller hasn't already provided one. See ``test_http_get_accepts_tqdm_bar`` — if that test breaks after a huggingface_hub upgrade, this patch needs to be revisited. @@ -66,8 +67,20 @@ def _wrapped_http_get(*args: Any, **kwargs: Any) -> None: kwargs["_tqdm_bar"] = tqdm_cls(disable=True) return original_http_get(*args, **kwargs) + xet_get_fn: Callable[..., Any] | None = getattr(file_download, "xet_get", None) + + def _wrapped_xet_get(*args: Any, **kwargs: Any) -> None: + if kwargs.get("_tqdm_bar") is None: + kwargs["_tqdm_bar"] = tqdm_cls(disable=True) + assert xet_get_fn is not None + return xet_get_fn(*args, **kwargs) + with patch.object(file_download, "http_get", _wrapped_http_get): - yield + if xet_get_fn is not None: + with patch.object(file_download, "xet_get", _wrapped_xet_get): + yield + else: + yield class HuggingFaceDownloader: @@ -80,7 +93,7 @@ def download_file( local_dir: str, on_progress: Callable[[int, int], None] | None = None, ) -> Path: - ctx = _patch_http_get_progress(on_progress) if on_progress is not None else contextlib.nullcontext() + ctx = _patch_download_progress(on_progress) if on_progress is not None else contextlib.nullcontext() with ctx: path: str = hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir) return Path(path) @@ -91,7 +104,7 @@ def download_snapshot( local_dir: str, on_progress: Callable[[int, int], None] | None = None, ) -> Path: - ctx = _patch_http_get_progress(on_progress) if on_progress is not None else contextlib.nullcontext() + ctx = _patch_download_progress(on_progress) if on_progress is not None else contextlib.nullcontext() with ctx: path: str = snapshot_download(repo_id=repo_id, local_dir=local_dir) return Path(path) diff --git a/backend/services/model_scanner/__init__.py b/backend/services/model_scanner/__init__.py new file mode 100644 index 00000000..409a47d7 --- /dev/null +++ b/backend/services/model_scanner/__init__.py @@ -0,0 +1,4 @@ +from services.model_scanner.model_scanner import ModelScanner +from services.model_scanner.model_scanner_impl import ModelScannerImpl + +__all__ = ["ModelScanner", "ModelScannerImpl"] diff --git a/backend/services/model_scanner/model_guide_data.py b/backend/services/model_scanner/model_guide_data.py new file mode 100644 index 00000000..595fae32 --- /dev/null +++ b/backend/services/model_scanner/model_guide_data.py @@ -0,0 +1,118 @@ +"""Static metadata about available video model formats and download URLs.""" + +from __future__ import annotations + +from api_types import DistilledLoraInfo, ModelFormatInfo + +MODEL_FORMATS: list[ModelFormatInfo] = [ + ModelFormatInfo( + id="bf16", + name="Full Quality (BF16)", + size_gb=43, + min_vram_gb=32, + quality_tier="Best", + needs_distilled_lora=False, + download_url="https://huggingface.co/Lightricks/LTX-Video-2.3-22b-distilled", + description=( + "The original, uncompressed model. Best possible video quality but " + "needs a high-end GPU with 32 GB+ of video memory (e.g. RTX 4090, A100). " + "This is the default — it downloads automatically on first run." + ), + ), + ModelFormatInfo( + id="fp8", + name="Half-Size (FP8)", + size_gb=22, + min_vram_gb=20, + quality_tier="Excellent", + needs_distilled_lora=False, + download_url="https://huggingface.co/Lightricks/LTX-Video-2.3-22b-distilled", + description=( + "Same model, compressed to half the file size with almost no quality loss. " + "Great for GPUs with 20–31 GB of video memory (e.g. RTX 3090, RTX 4080). " + "Drop-in replacement — just swap the file." + ), + ), + ModelFormatInfo( + id="gguf_q8", + name="Compressed Q8 (GGUF)", + size_gb=22, + min_vram_gb=18, + quality_tier="Excellent", + needs_distilled_lora=True, + download_url="https://huggingface.co/city96/LTX-Video-2.3-22b-0.9.7-dev-gguf", + description=( + "High-quality compressed model. Very close to the original but uses " + "less video memory. Needs the Speed Boost LoRA file (see below). " + "Good for 20–24 GB GPUs." + ), + ), + ModelFormatInfo( + id="gguf_q5k", + name="Compressed Q5 (GGUF)", + size_gb=15, + min_vram_gb=13, + quality_tier="Very Good", + needs_distilled_lora=True, + download_url="https://huggingface.co/city96/LTX-Video-2.3-22b-0.9.7-dev-gguf", + description=( + "Nicely balanced — smaller file, still looks great. Best pick for " + "16 GB GPUs like the RTX 4060 Ti 16GB or RTX 3080. " + "Needs the Speed Boost LoRA file (see below)." + ), + ), + ModelFormatInfo( + id="gguf_q4k", + name="Compressed Q4 (GGUF)", + size_gb=12, + min_vram_gb=10, + quality_tier="Good", + needs_distilled_lora=True, + download_url="https://huggingface.co/city96/LTX-Video-2.3-22b-0.9.7-dev-gguf", + description=( + "Smallest file, runs on GPUs with as little as 10 GB of video memory " + "(e.g. RTX 3060 12GB). Some quality loss but still very usable. " + "Needs the Speed Boost LoRA file (see below)." + ), + ), + ModelFormatInfo( + id="nf4", + name="4-Bit Compressed (NF4)", + size_gb=12, + min_vram_gb=10, + quality_tier="Good", + needs_distilled_lora=True, + download_url="https://huggingface.co/Lightricks/LTX-Video-2.3-22b-distilled", + description=( + "Another way to run on 10–15 GB GPUs. Uses a different compression method " + "that needs extra software (bitsandbytes). Try the Q4 GGUF option first — " + "it's simpler to set up. Needs the Speed Boost LoRA file (see below)." + ), + ), +] + +DISTILLED_LORA_INFO = DistilledLoraInfo( + name="Speed Boost LoRA (Required for Compressed Models)", + size_gb=0.5, + download_url="https://huggingface.co/Lightricks/LTX-Video-2.3-22b-distilled", + description=( + "A small add-on file that makes compressed models generate videos fast. " + "You MUST download this if you're using any of the compressed (GGUF or NF4) models. " + "Just put it in the same folder as your model file." + ), +) + + +def recommend_format(vram_gb: int | None) -> str: + """Return the recommended format ID based on available VRAM.""" + if vram_gb is None: + return "bf16" + if vram_gb >= 32: + return "bf16" + if vram_gb >= 20: + return "fp8" + if vram_gb >= 16: + return "gguf_q5k" + if vram_gb >= 10: + return "gguf_q4k" + return "api_only" diff --git a/backend/services/model_scanner/model_scanner.py b/backend/services/model_scanner/model_scanner.py new file mode 100644 index 00000000..d139541c --- /dev/null +++ b/backend/services/model_scanner/model_scanner.py @@ -0,0 +1,14 @@ +"""Protocol for scanning model files in a directory.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Protocol + +from api_types import DetectedModel + + +class ModelScanner(Protocol): + def scan_video_models(self, folder: Path) -> list[DetectedModel]: + """Scan a folder for video model files and return structured results.""" + ... diff --git a/backend/services/model_scanner/model_scanner_impl.py b/backend/services/model_scanner/model_scanner_impl.py new file mode 100644 index 00000000..a8da1e08 --- /dev/null +++ b/backend/services/model_scanner/model_scanner_impl.py @@ -0,0 +1,203 @@ +"""Real implementation of ModelScanner — scans a folder for video model files.""" + +from __future__ import annotations + +import json +import struct +from pathlib import Path + +from api_types import DetectedModel + +# GGUF magic bytes (4-byte little-endian magic + 4-byte version) +_GGUF_MAGIC = b"GGUF" +_GGUF_MIN_FILE_SIZE = 8 # magic (4) + version (4) + +# Safetensors header sentinel (little-endian uint64 for header length) +_SAFETENSORS_HEADER_LENGTH_BYTES = 8 + + +class ModelScannerImpl: + """Scans a directory for video model files using file metadata (not size heuristics).""" + + def scan_video_models(self, folder: Path) -> list[DetectedModel]: + """Return all detected video models in *folder*. Returns [] if folder doesn't exist.""" + if not folder.exists() or not folder.is_dir(): + return [] + + results: list[DetectedModel] = [] + + for entry in sorted(folder.iterdir()): + try: + if entry.is_file(): + if entry.suffix.lower() == ".gguf": + model = self._scan_gguf(entry) + if model is not None: + results.append(model) + elif entry.suffix.lower() == ".safetensors": + model = self._scan_safetensors(entry) + if model is not None: + results.append(model) + elif entry.is_dir(): + model = self._scan_nf4_folder(entry) + if model is not None: + results.append(model) + except Exception: + continue # skip inaccessible entries + + return results + + # ------------------------------------------------------------------ + # GGUF + # ------------------------------------------------------------------ + + def _scan_gguf(self, path: Path) -> DetectedModel | None: + """Return a DetectedModel if *path* is a valid GGUF file, else None.""" + try: + with path.open("rb") as f: + header = f.read(_GGUF_MIN_FILE_SIZE) + if len(header) < _GGUF_MIN_FILE_SIZE: + return None + magic = header[:4] + if magic != _GGUF_MAGIC: + return None + version = struct.unpack_from(" str | None: + """Extract quant type like Q8_0, Q5_K_M, Q4_K_M from a GGUF filename.""" + name_upper = filename.upper() + # Common GGUF quant suffixes ordered from most to least specific + candidates = [ + "Q8_0", "Q5_K_M", "Q5_K_S", "Q4_K_M", "Q4_K_S", + "Q3_K_M", "Q3_K_S", "Q2_K", "F16", "F32", + ] + for cand in candidates: + if cand in name_upper: + return cand + return None + + def _gguf_display_name(self, filename: str, quant_type: str | None) -> str: + stem = Path(filename).stem + if quant_type: + return f"{stem} ({quant_type})" + return stem + + # ------------------------------------------------------------------ + # Safetensors + # ------------------------------------------------------------------ + + def _scan_safetensors(self, path: Path) -> DetectedModel | None: + """Return a DetectedModel if *path* is a valid .safetensors video model, else None.""" + try: + with path.open("rb") as f: + raw_len = f.read(_SAFETENSORS_HEADER_LENGTH_BYTES) + if len(raw_len) < _SAFETENSORS_HEADER_LENGTH_BYTES: + return None + except OSError: + return None + + fmt = self._detect_safetensors_format(path) + stat = path.stat() + size_bytes = stat.st_size + size_gb = round(size_bytes / (1024**3), 2) + + return DetectedModel( + filename=path.name, + path=str(path), + model_format=fmt, + quant_type=None, + size_bytes=size_bytes, + size_gb=size_gb, + is_distilled=False, + display_name=path.stem, + ) + + def _detect_safetensors_format(self, path: Path) -> str: + """Determine bf16 vs fp8 by inspecting companion config.json or safetensors header.""" + # Check sibling config.json for torch_dtype + config_path = path.parent / "config.json" + if config_path.exists(): + try: + data = json.loads(config_path.read_text(encoding="utf-8")) + dtype = str(data.get("torch_dtype", "")).lower() + if "fp8" in dtype or "float8" in dtype: + return "fp8" + if "bf16" in dtype or "bfloat16" in dtype: + return "bf16" + except (OSError, json.JSONDecodeError): + pass + + # Fall back to checking the safetensors header for dtype strings + try: + with path.open("rb") as f: + raw_len = f.read(8) + if len(raw_len) < 8: + return "bf16" + header_len = struct.unpack_from(" DetectedModel | None: + """Return a DetectedModel if *folder* contains a quantize_config.json with quant_type=nf4.""" + config_path = folder / "quantize_config.json" + if not config_path.exists(): + return None + + try: + data = json.loads(config_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + return None + + quant_type = str(data.get("quant_type", "")).lower() + if quant_type != "nf4": + return None + + # Sum up all files in the folder for size + size_bytes = sum( + f.stat().st_size + for f in folder.rglob("*") + if f.is_file() + ) + size_gb = round(size_bytes / (1024**3), 2) + + return DetectedModel( + filename=folder.name, + path=str(folder), + model_format="nf4", + quant_type="nf4", + size_bytes=size_bytes, + size_gb=size_gb, + is_distilled=False, + display_name=folder.name, + ) diff --git a/backend/services/palette_sync_client/__init__.py b/backend/services/palette_sync_client/__init__.py new file mode 100644 index 00000000..282637f0 --- /dev/null +++ b/backend/services/palette_sync_client/__init__.py @@ -0,0 +1,4 @@ +from services.palette_sync_client.palette_sync_client import PaletteSyncClient +from services.palette_sync_client.palette_sync_client_impl import PaletteSyncClientImpl + +__all__ = ["PaletteSyncClient", "PaletteSyncClientImpl"] diff --git a/backend/services/palette_sync_client/palette_sync_client.py b/backend/services/palette_sync_client/palette_sync_client.py new file mode 100644 index 00000000..579c4bdf --- /dev/null +++ b/backend/services/palette_sync_client/palette_sync_client.py @@ -0,0 +1,66 @@ +"""Protocol for communicating with Director's Palette cloud API.""" + +from __future__ import annotations + +from typing import Any, Protocol + + +class PaletteSyncClient(Protocol): + def validate_connection(self, *, api_key: str) -> dict[str, Any]: + """Validate API key and return user info. Raises on failure.""" + ... + + def sign_in_with_email(self, *, email: str, password: str) -> dict[str, Any]: + """Sign in with email/password. Returns access_token, refresh_token, user info.""" + ... + + def refresh_access_token(self, *, refresh_token: str) -> dict[str, Any]: + """Refresh an expired access token. Returns new access_token, refresh_token.""" + ... + + def get_credits(self, *, api_key: str) -> dict[str, Any]: + """Return credit balance and pricing for the authenticated user.""" + ... + + def check_credits( + self, *, api_key: str, generation_type: str, count: int, + ) -> dict[str, Any]: + """Check whether the user can afford a generation. Does not deduct.""" + ... + + def deduct_credits( + self, *, api_key: str, generation_type: str, count: int, + metadata: dict[str, Any] | None, + ) -> dict[str, Any]: + """Deduct credits after a successful generation.""" + ... + + def list_gallery( + self, *, api_key: str, page: int, per_page: int, asset_type: str, + ) -> dict[str, Any]: + """List cloud gallery items with pagination.""" + ... + + def list_characters(self, *, api_key: str) -> dict[str, Any]: + """List characters from all user storyboards.""" + ... + + def list_styles(self, *, api_key: str) -> dict[str, Any]: + """List user style guides and brands.""" + ... + + def list_references( + self, *, api_key: str, category: str | None, + ) -> dict[str, Any]: + """List reference images with optional category filter.""" + ... + + def list_loras(self, *, api_key: str) -> dict[str, Any]: + """List available LoRAs from the Palette library.""" + ... + + def enhance_prompt( + self, *, api_key: str, prompt: str, level: str, + ) -> dict[str, Any]: + """Enhance a prompt using Palette's prompt expander.""" + ... diff --git a/backend/services/palette_sync_client/palette_sync_client_impl.py b/backend/services/palette_sync_client/palette_sync_client_impl.py new file mode 100644 index 00000000..dba9e135 --- /dev/null +++ b/backend/services/palette_sync_client/palette_sync_client_impl.py @@ -0,0 +1,234 @@ +"""HTTP implementation of PaletteSyncClient.""" + +from __future__ import annotations + +import os +from typing import Any, cast + +from services.http_client.http_client import HTTPClient + +_DEFAULT_BASE = "https://directorspal.com" +_SUPABASE_URL = os.environ.get( + "PALETTE_SUPABASE_URL", + "https://tarohelkwuurakbxjyxm.supabase.co", +) +_SUPABASE_ANON_KEY = os.environ.get( + "PALETTE_SUPABASE_ANON_KEY", + # Supabase anon keys are designed for public/client-side use. + # This default is safe to ship but can be overridden via env var. + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + "eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6InRhcm9oZWxrd3V1cmFrYnhqeXhtIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NTU4OTQzMDYsImV4cCI6MjA3MTQ3MDMwNn0." + "uTeDJ0YYVGjP2FTL9oIpCRPeqXBbxNnh8y7CH0gIABs", +) + + +class PaletteSyncClientImpl: + def __init__(self, http: HTTPClient, base_url: str = _DEFAULT_BASE) -> None: + self._http = http + self._base_url = base_url + + def _headers(self, api_key: str) -> dict[str, str]: + return {"Authorization": f"Bearer {api_key}"} + + def _parse_supabase_user(self, user_data: dict[str, Any]) -> dict[str, Any]: + """Normalize a Supabase user response into a simple user dict.""" + metadata: dict[str, Any] = user_data.get("user_metadata", {}) + return { + "id": user_data.get("id"), + "email": user_data.get("email"), + "name": metadata.get("full_name") or metadata.get("name") or user_data.get("email"), + } + + def validate_connection(self, *, api_key: str) -> dict[str, Any]: + if api_key.startswith("dp_"): + # dp_ API keys must be validated by the Palette app, which has + # the api_keys table and hashing logic. + resp = self._http.get( + f"{self._base_url}/api/desktop/me", + headers=self._headers(api_key), + timeout=10, + ) + if resp.status_code == 404: + raise RuntimeError( + "dp_ API keys are not supported yet. " + "The Directors Palette team needs to deploy the /api/desktop endpoints. " + "Use 'Login with Email' instead." + ) + if resp.status_code != 200: + raise RuntimeError(f"Palette auth failed: {resp.status_code}") + return cast(dict[str, Any], resp.json()) + + # JWT token — validate directly with Supabase + resp = self._http.get( + f"{_SUPABASE_URL}/auth/v1/user", + headers={ + "Authorization": f"Bearer {api_key}", + "apikey": _SUPABASE_ANON_KEY, + }, + timeout=10, + ) + if resp.status_code != 200: + raise RuntimeError(f"Authentication failed (HTTP {resp.status_code}). Check your token.") + user_data = cast(dict[str, Any], resp.json()) + return self._parse_supabase_user(user_data) + + def sign_in_with_email(self, *, email: str, password: str) -> dict[str, Any]: + """Sign in with email/password via Supabase auth.""" + resp = self._http.post( + f"{_SUPABASE_URL}/auth/v1/token?grant_type=password", + headers={ + "apikey": _SUPABASE_ANON_KEY, + "Content-Type": "application/json", + }, + json_payload={"email": email, "password": password}, + timeout=15, + ) + if resp.status_code == 400: + data = cast(dict[str, Any], resp.json()) + msg = data.get("error_description") or data.get("msg") or "Invalid email or password" + raise RuntimeError(msg) + if resp.status_code != 200: + raise RuntimeError(f"Sign-in failed (HTTP {resp.status_code})") + data = cast(dict[str, Any], resp.json()) + user_data = cast(dict[str, Any], data.get("user", {})) + return { + "access_token": data["access_token"], + "refresh_token": data["refresh_token"], + "user": self._parse_supabase_user(user_data), + } + + def refresh_access_token(self, *, refresh_token: str) -> dict[str, Any]: + """Refresh an expired Supabase JWT.""" + resp = self._http.post( + f"{_SUPABASE_URL}/auth/v1/token?grant_type=refresh_token", + headers={ + "apikey": _SUPABASE_ANON_KEY, + "Content-Type": "application/json", + }, + json_payload={"refresh_token": refresh_token}, + timeout=10, + ) + if resp.status_code != 200: + raise RuntimeError("Session expired. Please log in again.") + data = cast(dict[str, Any], resp.json()) + user_data = cast(dict[str, Any], data.get("user", {})) + return { + "access_token": data["access_token"], + "refresh_token": data["refresh_token"], + "user": self._parse_supabase_user(user_data), + } + + def get_credits(self, *, api_key: str) -> dict[str, Any]: + try: + resp = self._http.get( + f"{self._base_url}/api/desktop/credits", + headers=self._headers(api_key), + timeout=10, + ) + if resp.status_code != 200: + return {"balance_cents": None} + return cast(dict[str, Any], resp.json()) + except Exception: + return {"balance_cents": None} + + def check_credits( + self, *, api_key: str, generation_type: str, count: int, + ) -> dict[str, Any]: + resp = self._http.post( + f"{self._base_url}/api/desktop/credits/check", + headers={**self._headers(api_key), "Content-Type": "application/json"}, + json_payload={"generation_type": generation_type, "count": count}, + timeout=10, + ) + if resp.status_code != 200: + raise RuntimeError(f"Credit check failed: {resp.status_code}") + return cast(dict[str, Any], resp.json()) + + def deduct_credits( + self, *, api_key: str, generation_type: str, count: int, + metadata: dict[str, Any] | None, + ) -> dict[str, Any]: + payload: dict[str, Any] = {"generation_type": generation_type, "count": count} + if metadata: + payload["metadata"] = metadata + resp = self._http.post( + f"{self._base_url}/api/desktop/credits/deduct", + headers={**self._headers(api_key), "Content-Type": "application/json"}, + json_payload=payload, + timeout=10, + ) + if resp.status_code == 402: + data = cast(dict[str, Any], resp.json()) + raise RuntimeError(f"Insufficient credits: balance={data.get('balance_cents')}") + if resp.status_code != 200: + raise RuntimeError(f"Credit deduction failed: {resp.status_code}") + return cast(dict[str, Any], resp.json()) + + def list_gallery( + self, *, api_key: str, page: int, per_page: int, asset_type: str, + ) -> dict[str, Any]: + params = f"?page={page}&per_page={per_page}&type={asset_type}" + resp = self._http.get( + f"{self._base_url}/api/desktop/gallery{params}", + headers=self._headers(api_key), + timeout=15, + ) + if resp.status_code != 200: + raise RuntimeError(f"Palette gallery failed: {resp.status_code}") + return cast(dict[str, Any], resp.json()) + + def list_characters(self, *, api_key: str) -> dict[str, Any]: + resp = self._http.get( + f"{self._base_url}/api/desktop/library/characters", + headers=self._headers(api_key), + timeout=10, + ) + if resp.status_code != 200: + raise RuntimeError(f"Palette characters failed: {resp.status_code}") + return cast(dict[str, Any], resp.json()) + + def list_styles(self, *, api_key: str) -> dict[str, Any]: + resp = self._http.get( + f"{self._base_url}/api/desktop/library/styles", + headers=self._headers(api_key), + timeout=10, + ) + if resp.status_code != 200: + raise RuntimeError(f"Palette styles failed: {resp.status_code}") + return cast(dict[str, Any], resp.json()) + + def list_references( + self, *, api_key: str, category: str | None, + ) -> dict[str, Any]: + params = f"?category={category}" if category else "" + resp = self._http.get( + f"{self._base_url}/api/desktop/library/references{params}", + headers=self._headers(api_key), + timeout=10, + ) + if resp.status_code != 200: + raise RuntimeError(f"Palette references failed: {resp.status_code}") + return cast(dict[str, Any], resp.json()) + + def list_loras(self, *, api_key: str) -> dict[str, Any]: + resp = self._http.get( + f"{self._base_url}/api/desktop/library/loras", + headers=self._headers(api_key), + timeout=15, + ) + if resp.status_code != 200: + raise RuntimeError(f"Palette loras failed: {resp.status_code}") + return cast(dict[str, Any], resp.json()) + + def enhance_prompt( + self, *, api_key: str, prompt: str, level: str, + ) -> dict[str, Any]: + resp = self._http.post( + f"{self._base_url}/api/desktop/prompt/enhance", + headers=self._headers(api_key), + json_payload={"prompt": prompt, "level": level}, + timeout=30, + ) + if resp.status_code != 200: + raise RuntimeError(f"Palette prompt enhance failed: {resp.status_code}") + return cast(dict[str, Any], resp.json()) diff --git a/backend/services/r2_client/__init__.py b/backend/services/r2_client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/services/r2_client/r2_client.py b/backend/services/r2_client/r2_client.py new file mode 100644 index 00000000..c4e149a4 --- /dev/null +++ b/backend/services/r2_client/r2_client.py @@ -0,0 +1,17 @@ +"""Protocol for R2/S3 compatible object storage.""" + +from __future__ import annotations + +from typing import Protocol + + +class R2Client(Protocol): + def upload_file( + self, *, local_path: str, remote_key: str, content_type: str, + ) -> str: + """Upload a local file. Returns the public URL.""" + ... + + def is_configured(self) -> bool: + """Return True if R2 credentials are set.""" + ... diff --git a/backend/services/r2_client/r2_client_impl.py b/backend/services/r2_client/r2_client_impl.py new file mode 100644 index 00000000..09ad93e2 --- /dev/null +++ b/backend/services/r2_client/r2_client_impl.py @@ -0,0 +1,54 @@ +"""Cloudflare R2 client using boto3 S3-compatible API.""" + +from __future__ import annotations + +import logging +import mimetypes +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +class R2ClientImpl: + def __init__( + self, + access_key_id: str, + secret_access_key: str, + endpoint: str, + bucket: str, + public_url: str, + ) -> None: + self._access_key_id = access_key_id + self._secret_access_key = secret_access_key + self._endpoint = endpoint + self._bucket = bucket + self._public_url = public_url.rstrip("/") + + def is_configured(self) -> bool: + return bool(self._access_key_id and self._secret_access_key and self._endpoint and self._bucket) + + def upload_file(self, *, local_path: str, remote_key: str, content_type: str) -> str: + if not self.is_configured(): + raise RuntimeError("R2 credentials not configured") + + boto3: Any = __import__("boto3") + + s3: Any = boto3.client( + "s3", + endpoint_url=self._endpoint, + aws_access_key_id=self._access_key_id, + aws_secret_access_key=self._secret_access_key, + ) + + content_type = content_type or mimetypes.guess_type(local_path)[0] or "application/octet-stream" + s3.upload_file( + local_path, + self._bucket, + remote_key, + ExtraArgs={"ContentType": content_type}, + ) + + public_url = f"{self._public_url}/{remote_key}" + logger.info("Uploaded %s -> %s", Path(local_path).name, public_url) + return public_url diff --git a/backend/services/video_api_client/__init__.py b/backend/services/video_api_client/__init__.py new file mode 100644 index 00000000..45b0c46f --- /dev/null +++ b/backend/services/video_api_client/__init__.py @@ -0,0 +1,4 @@ +from services.video_api_client.video_api_client import VideoAPIClient +from services.video_api_client.replicate_video_client_impl import ReplicateVideoClientImpl + +__all__ = ["VideoAPIClient", "ReplicateVideoClientImpl"] diff --git a/backend/services/video_api_client/replicate_video_client_impl.py b/backend/services/video_api_client/replicate_video_client_impl.py new file mode 100644 index 00000000..742ce4b3 --- /dev/null +++ b/backend/services/video_api_client/replicate_video_client_impl.py @@ -0,0 +1,155 @@ +"""Replicate API client implementation for cloud video generation.""" + +from __future__ import annotations + +import time +from typing import Any, cast + +from services.http_client.http_client import HTTPClient +from services.services_utils import JSONValue + +REPLICATE_API_BASE_URL = "https://api.replicate.com/v1" + +_MODEL_ROUTES: dict[str, str] = { + "seedance-1.5-pro": "bytedance/seedance-1.5-pro", +} + +_POLL_INTERVAL_SECONDS = 2 +_POLL_TIMEOUT_SECONDS = 300 + + +class ReplicateVideoClientImpl: + def __init__(self, http: HTTPClient, *, api_base_url: str = REPLICATE_API_BASE_URL) -> None: + self._http = http + self._base_url = api_base_url.rstrip("/") + + def generate_text_to_video( + self, + *, + api_key: str, + model: str, + prompt: str, + duration: int, + resolution: str, + aspect_ratio: str, + generate_audio: bool, + last_frame: str | None = None, + ) -> bytes: + replicate_model = _MODEL_ROUTES.get(model) + if replicate_model is None: + raise RuntimeError(f"Unknown video model: {model}") + + seed = int(time.time()) % 2_147_483_647 + + input_payload: dict[str, JSONValue] = { + "prompt": prompt, + "duration": duration, + "resolution": resolution, + "aspect_ratio": aspect_ratio, + "generate_audio": generate_audio, + "seed": seed, + } + if last_frame is not None: + input_payload["last_frame"] = last_frame + + prediction = self._create_prediction( + api_key=api_key, + replicate_model=replicate_model, + input_payload=input_payload, + ) + + output_url = self._wait_for_output(api_key, prediction) + return self._download_video(api_key, output_url) + + def _create_prediction( + self, + *, + api_key: str, + replicate_model: str, + input_payload: dict[str, JSONValue], + ) -> dict[str, Any]: + url = f"{self._base_url}/models/{replicate_model}/predictions" + payload: dict[str, JSONValue] = {"input": input_payload} + + response = self._http.post( + url, + headers=self._headers(api_key, prefer_wait=True), + json_payload=payload, + timeout=300, + ) + if response.status_code not in (200, 201): + detail = response.text[:500] if response.text else "Unknown error" + raise RuntimeError(f"Replicate prediction failed ({response.status_code}): {detail}") + + return self._json_object(response.json(), context="create prediction") + + def _wait_for_output(self, api_key: str, prediction: dict[str, Any]) -> str: + status = prediction.get("status", "") + if status == "succeeded": + return self._extract_output_url(prediction) + + if status in ("failed", "canceled"): + error = prediction.get("error", "Unknown error") + raise RuntimeError(f"Replicate prediction {status}: {error}") + + poll_url = prediction.get("urls", {}).get("get") + if not isinstance(poll_url, str) or not poll_url: + prediction_id = prediction.get("id", "") + poll_url = f"{self._base_url}/predictions/{prediction_id}" + + deadline = time.monotonic() + _POLL_TIMEOUT_SECONDS + while time.monotonic() < deadline: + time.sleep(_POLL_INTERVAL_SECONDS) + resp = self._http.get(poll_url, headers=self._headers(api_key), timeout=30) + if resp.status_code != 200: + detail = resp.text[:500] if resp.text else "Unknown error" + raise RuntimeError(f"Replicate poll failed ({resp.status_code}): {detail}") + + data = self._json_object(resp.json(), context="poll") + poll_status = data.get("status", "") + if poll_status == "succeeded": + return self._extract_output_url(data) + if poll_status in ("failed", "canceled"): + error = data.get("error", "Unknown error") + raise RuntimeError(f"Replicate prediction {poll_status}: {error}") + + raise RuntimeError("Replicate prediction timed out") + + def _download_video(self, api_key: str, url: str) -> bytes: + download = self._http.get(url, headers=self._headers(api_key), timeout=300) + if download.status_code != 200: + detail = download.text[:500] if download.text else "Unknown error" + raise RuntimeError(f"Replicate video download failed ({download.status_code}): {detail}") + if not download.content: + raise RuntimeError("Replicate video download returned empty body") + return download.content + + @staticmethod + def _headers(api_key: str, *, prefer_wait: bool = False) -> dict[str, str]: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + if prefer_wait: + headers["Prefer"] = "wait" + return headers + + @staticmethod + def _extract_output_url(prediction: dict[str, Any]) -> str: + output = prediction.get("output") + if isinstance(output, list) and output: + output_list = cast(list[object], output) + first = output_list[0] + if isinstance(first, str) and first: + return first + + if isinstance(output, str) and output: + return output + + raise RuntimeError("Replicate response missing output URL") + + @staticmethod + def _json_object(payload: object, *, context: str) -> dict[str, Any]: + if isinstance(payload, dict): + return cast(dict[str, Any], payload) + raise RuntimeError(f"Unexpected Replicate {context} response format") diff --git a/backend/services/video_api_client/video_api_client.py b/backend/services/video_api_client/video_api_client.py new file mode 100644 index 00000000..9bf1ef87 --- /dev/null +++ b/backend/services/video_api_client/video_api_client.py @@ -0,0 +1,21 @@ +"""Video API client protocol for cloud video generation.""" + +from __future__ import annotations + +from typing import Protocol + + +class VideoAPIClient(Protocol): + def generate_text_to_video( + self, + *, + api_key: str, + model: str, + prompt: str, + duration: int, + resolution: str, + aspect_ratio: str, + generate_audio: bool, + last_frame: str | None = None, + ) -> bytes: + ... diff --git a/backend/services/wildcard_parser.py b/backend/services/wildcard_parser.py new file mode 100644 index 00000000..42231561 --- /dev/null +++ b/backend/services/wildcard_parser.py @@ -0,0 +1,122 @@ +"""Wildcard expansion for prompt templates. + +Syntax: ``_wildcard_name_`` inside a prompt string is replaced by values +from a matching wildcard definition. The parser supports: + +* **All-combinations** expansion (Cartesian product of every wildcard slot). +* **Random-selection** expansion (pick *count* random fully-expanded prompts). +* **Nested wildcards** – a wildcard value may itself contain ``_other_`` + references which are recursively expanded. +""" + +from __future__ import annotations + +import random +import re +from dataclasses import dataclass + +# Pattern: underscore-delimited wildcard name, e.g. ``_color_`` +_WILDCARD_RE = re.compile(r"_([A-Za-z][A-Za-z0-9_]*)_") + + +@dataclass(frozen=True) +class WildcardDef: + """A single wildcard definition.""" + + name: str + values: list[str] + + +def _find_wildcard_names(prompt: str) -> list[str]: + """Return unique wildcard names found in *prompt*, preserving first-seen order.""" + seen: set[str] = set() + names: list[str] = [] + for m in _WILDCARD_RE.finditer(prompt): + name = m.group(1) + if name not in seen: + seen.add(name) + names.append(name) + return names + + +def _replace_single(prompt: str, name: str, value: str) -> str: + """Replace all occurrences of ``_name_`` in *prompt* with *value*.""" + return prompt.replace(f"_{name}_", value) + + +def _resolve_nested( + prompt: str, + lookup: dict[str, list[str]], + depth: int = 0, + max_depth: int = 10, +) -> list[str]: + """Recursively expand wildcards in *prompt*. + + Returns a list of all combinations produced by every wildcard slot. + """ + if depth > max_depth: + return [prompt] + + names = _find_wildcard_names(prompt) + if not names: + return [prompt] + + # Expand the first wildcard found, then recurse on each variant. + first = names[0] + values = lookup.get(first, [f"_{first}_"]) # keep literal if undefined + results: list[str] = [] + for val in values: + replaced = _replace_single(prompt, first, val) + results.extend(_resolve_nested(replaced, lookup, depth + 1, max_depth)) + return results + + +def _build_lookup(wildcards: list[WildcardDef]) -> dict[str, list[str]]: + return {w.name: w.values for w in wildcards} + + +def expand_prompt(prompt: str, wildcards: list[WildcardDef]) -> list[str]: + """Return **all** expanded combinations of *prompt* with the given wildcards. + + Example:: + + expand_prompt( + "A _color_ _animal_", + [WildcardDef("color", ["red", "blue"]), + WildcardDef("animal", ["cat", "dog"])], + ) + # → ["A red cat", "A red dog", "A blue cat", "A blue dog"] + """ + lookup = _build_lookup(wildcards) + return _resolve_nested(prompt, lookup) + + +def expand_random( + prompt: str, + wildcards: list[WildcardDef], + count: int = 1, + *, + rng: random.Random | None = None, +) -> list[str]: + """Return *count* randomly-expanded variants of *prompt*. + + Each returned string has every wildcard slot filled by a random value + drawn independently. Duplicates are possible when *count* exceeds the + number of unique combinations. + """ + _rng = rng or random.Random() + lookup = _build_lookup(wildcards) + + results: list[str] = [] + for _ in range(count): + text = prompt + for _ in range(20): # bounded recursion for nested wildcards + names = _find_wildcard_names(text) + if not names: + break + for name in names: + values = lookup.get(name) + if values: + text = _replace_single(text, name, _rng.choice(values)) + results.append(text) + return results diff --git a/backend/services/zit_api_client/__init__.py b/backend/services/zit_api_client/__init__.py deleted file mode 100644 index 17970ba7..00000000 --- a/backend/services/zit_api_client/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Z-Image Turbo API client exports.""" - -from services.zit_api_client.zit_api_client import ZitAPIClient -from services.zit_api_client.zit_api_client_impl import ZitAPIClientImpl - -__all__ = ["ZitAPIClient", "ZitAPIClientImpl"] diff --git a/backend/services/zit_api_client/zit_api_client_impl.py b/backend/services/zit_api_client/zit_api_client_impl.py deleted file mode 100644 index f5983bf8..00000000 --- a/backend/services/zit_api_client/zit_api_client_impl.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Z-Image Turbo API client implementation for FAL endpoints.""" - -from __future__ import annotations - -from typing import Any, cast - -from services.http_client.http_client import HTTPClient -from services.services_utils import JSONValue - -FAL_API_BASE_URL = "https://fal.run" -FAL_TEXT_TO_IMAGE_ENDPOINT = "/fal-ai/z-image/turbo" - -DEFAULT_OUTPUT_FORMAT = "png" -DEFAULT_ACCELERATION = "regular" -DEFAULT_ENABLE_SAFETY_CHECKER = True - - -class ZitAPIClientImpl: - def __init__(self, http: HTTPClient, *, fal_api_base_url: str = FAL_API_BASE_URL) -> None: - self._http = http - self._base_url = fal_api_base_url.rstrip("/") - - def generate_text_to_image( - self, - *, - api_key: str, - prompt: str, - width: int, - height: int, - seed: int, - num_inference_steps: int, - ) -> bytes: - payload: dict[str, JSONValue] = { - "prompt": prompt, - "image_size": {"width": width, "height": height}, - "num_inference_steps": num_inference_steps, - "seed": seed, - "num_images": 1, - "output_format": DEFAULT_OUTPUT_FORMAT, - "acceleration": DEFAULT_ACCELERATION, - "enable_safety_checker": DEFAULT_ENABLE_SAFETY_CHECKER, - } - return self._submit_and_download( - endpoint=FAL_TEXT_TO_IMAGE_ENDPOINT, - api_key=api_key, - payload=payload, - ) - - def _submit_and_download( - self, - *, - endpoint: str, - api_key: str, - payload: dict[str, JSONValue], - ) -> bytes: - response = self._http.post( - f"{self._base_url}{endpoint}", - headers=self._json_headers(api_key), - json_payload=payload, - timeout=180, - ) - if response.status_code != 200: - detail = response.text[:500] if response.text else "Unknown error" - raise RuntimeError(f"FAL submit failed ({response.status_code}): {detail}") - - response_payload = self._json_object(response.json(), context="submit") - image_url = self._extract_image_url(response_payload) - - download = self._http.get(image_url, timeout=120) - if download.status_code != 200: - detail = download.text[:500] if download.text else "Unknown error" - raise RuntimeError(f"FAL image download failed ({download.status_code}): {detail}") - if not download.content: - raise RuntimeError("FAL image download returned empty body") - return download.content - - @staticmethod - def _json_headers(api_key: str) -> dict[str, str]: - return { - "Authorization": f"Key {api_key}", - "Content-Type": "application/json", - } - - @staticmethod - def _extract_image_url(payload: dict[str, Any]) -> str: - images = payload.get("images") - if isinstance(images, list) and images: - images_list = cast(list[object], images) - first = images_list[0] - if isinstance(first, dict): - first_payload = cast(dict[str, Any], first) - url = first_payload.get("url") - if isinstance(url, str) and url: - return url - if isinstance(first, str) and first: - return first - - for key in ("image_url", "imageUrl", "url"): - url = payload.get(key) - if isinstance(url, str) and url: - return url - - raise RuntimeError("FAL response missing image url") - - @staticmethod - def _json_object(payload: object, *, context: str) -> dict[str, Any]: - if isinstance(payload, dict): - return cast(dict[str, Any], payload) - raise RuntimeError(f"Unexpected FAL {context} response format") diff --git a/backend/state/app_settings.py b/backend/state/app_settings.py index 5b3e67ad..bb0706c2 100644 --- a/backend/state/app_settings.py +++ b/backend/state/app_settings.py @@ -64,16 +64,39 @@ class AppSettings(SettingsBaseModel): load_on_startup: bool = False ltx_api_key: str = "" user_prefers_ltx_api_video_generations: bool = False - fal_api_key: str = "" + replicate_api_key: str = "" + palette_api_key: str = "" + palette_refresh_token: str = "" + image_model: str = "flux-klein-9b" + video_model: str = "ltx-fast" use_local_text_encoder: bool = False + use_abliterated_text_encoder: bool = False fast_model: FastModelSettings = Field(default_factory=FastModelSettings) pro_model: ProModelSettings = Field(default_factory=ProModelSettings) prompt_cache_size: int = 100 prompt_enhancer_enabled_t2v: bool = True prompt_enhancer_enabled_i2v: bool = False gemini_api_key: str = "" + openrouter_api_key: str = "" seed_locked: bool = False locked_seed: int = 42 + batch_sound_enabled: bool = True + ffn_chunk_count: int = 8 + tea_cache_threshold: float = 0.0 + r2_access_key_id: str = "" + r2_secret_access_key: str = "" + r2_endpoint: str = "" + r2_bucket: str = "" + r2_public_url: str = "" + auto_upload_to_r2: bool = False + civitai_api_key: str = "" + custom_video_model_path: str = "" + selected_video_model: str = "" + + @field_validator("ffn_chunk_count", mode="before") + @classmethod + def _clamp_ffn_chunk_count(cls, value: Any) -> int: + return _clamp_int(value, minimum=0, maximum=32, default=8) @field_validator("prompt_cache_size", mode="before") @classmethod @@ -135,26 +158,52 @@ class SettingsResponse(SettingsBaseModel): load_on_startup: bool = False has_ltx_api_key: bool = False user_prefers_ltx_api_video_generations: bool = False - has_fal_api_key: bool = False + has_replicate_api_key: bool = False + has_palette_api_key: bool = False + image_model: str = "flux-klein-9b" + video_model: str = "ltx-fast" use_local_text_encoder: bool = False + use_abliterated_text_encoder: bool = False fast_model: FastModelSettings = Field(default_factory=FastModelSettings) pro_model: ProModelSettings = Field(default_factory=ProModelSettings) prompt_cache_size: int = 100 prompt_enhancer_enabled_t2v: bool = True prompt_enhancer_enabled_i2v: bool = False has_gemini_api_key: bool = False + has_openrouter_api_key: bool = False seed_locked: bool = False locked_seed: int = 42 + batch_sound_enabled: bool = True + ffn_chunk_count: int = 8 + tea_cache_threshold: float = 0.0 + has_r2_credentials: bool = False + auto_upload_to_r2: bool = False + has_civitai_api_key: bool = False + custom_video_model_path: str = "" + selected_video_model: str = "" def to_settings_response(settings: AppSettings) -> SettingsResponse: data = settings.model_dump(by_alias=False) ltx_key = data.pop("ltx_api_key", "") - fal_key = data.pop("fal_api_key", "") + replicate_key = data.pop("replicate_api_key", "") + palette_key = data.pop("palette_api_key", "") + data.pop("palette_refresh_token", "") gemini_key = data.pop("gemini_api_key", "") + openrouter_key = data.pop("openrouter_api_key", "") data["has_ltx_api_key"] = bool(ltx_key) - data["has_fal_api_key"] = bool(fal_key) + data["has_replicate_api_key"] = bool(replicate_key) + data["has_palette_api_key"] = bool(palette_key) data["has_gemini_api_key"] = bool(gemini_key) + data["has_openrouter_api_key"] = bool(openrouter_key) + r2_key = data.pop("r2_access_key_id", "") + data.pop("r2_secret_access_key", "") + data.pop("r2_endpoint", "") + data.pop("r2_bucket", "") + data.pop("r2_public_url", "") + data["has_r2_credentials"] = bool(r2_key) + civitai_key = data.pop("civitai_api_key", "") + data["has_civitai_api_key"] = bool(civitai_key) return SettingsResponse.model_validate(data) diff --git a/backend/state/app_state_types.py b/backend/state/app_state_types.py index ba1f4afc..8ccd4508 100644 --- a/backend/state/app_state_types.py +++ b/backend/state/app_state_types.py @@ -24,7 +24,7 @@ # Model file availability (disk truth) # ============================================================ -ModelFileType = Literal["checkpoint", "upsampler", "text_encoder", "zit"] +ModelFileType = Literal["checkpoint", "upsampler", "text_encoder", "text_encoder_abliterated", "zit", "flux_klein"] # Availability and download are orthogonal concerns. AvailableFiles = dict[ModelFileType, Path | None] @@ -41,7 +41,7 @@ class FileDownloadRunning: progress: float downloaded_bytes: int total_bytes: int - speed_mbps: float + speed_bytes_per_sec: float @dataclass @@ -104,6 +104,7 @@ class VideoPipelineState: pipeline: FastVideoPipeline warmth: VideoPipelineWarmth is_compiled: bool + lora_path: str | None = None @dataclass diff --git a/backend/state/job_queue.py b/backend/state/job_queue.py new file mode 100644 index 00000000..698e4117 --- /dev/null +++ b/backend/state/job_queue.py @@ -0,0 +1,165 @@ +"""Persistent job queue for sequential generation processing.""" + +from __future__ import annotations + +import json +import uuid +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Literal + + +@dataclass +class QueueJob: + id: str + type: Literal["video", "image"] + model: str + params: dict[str, Any] + status: Literal["queued", "running", "complete", "error", "cancelled"] + slot: Literal["gpu", "api"] + progress: int = 0 + phase: str = "queued" + result_paths: list[str] = field(default_factory=lambda: list[str]()) + error: str | None = None + created_at: str = "" + # Batch fields + batch_id: str | None = None + batch_index: int = 0 + depends_on: str | None = None + auto_params: dict[str, str] = field(default_factory=lambda: dict[str, str]()) + tags: list[str] = field(default_factory=lambda: list[str]()) + + +class JobQueue: + def __init__(self, persistence_path: Path) -> None: + self._path = persistence_path + self._jobs: list[QueueJob] = [] + self._load() + + def submit( + self, + *, + job_type: str, + model: str, + params: dict[str, Any], + slot: str, + job_id: str | None = None, + batch_id: str | None = None, + batch_index: int = 0, + depends_on: str | None = None, + auto_params: dict[str, str] | None = None, + tags: list[str] | None = None, + ) -> QueueJob: + job = QueueJob( + id=job_id or uuid.uuid4().hex[:8], + type=job_type, # type: ignore[arg-type] + model=model, + params=params, + status="queued", + slot=slot, # type: ignore[arg-type] + progress=0, + phase="queued", + result_paths=[], + error=None, + created_at=datetime.now(timezone.utc).isoformat(), + batch_id=batch_id, + batch_index=batch_index, + depends_on=depends_on, + auto_params=auto_params or {}, + tags=tags or [], + ) + self._jobs.append(job) + self._save() + return job + + def get_all_jobs(self) -> list[QueueJob]: + return list(self._jobs) + + def get_job(self, job_id: str) -> QueueJob | None: + for job in self._jobs: + if job.id == job_id: + return job + return None + + def next_queued_for_slot(self, slot: str) -> QueueJob | None: + for job in self._jobs: + if job.status == "queued" and job.slot == slot: + return job + return None + + def update_job( + self, + job_id: str, + *, + status: str | None = None, + progress: int | None = None, + phase: str | None = None, + result_paths: list[str] | None = None, + error: str | None = None, + ) -> None: + job = self.get_job(job_id) + if job is None: + return + if status is not None: + job.status = status # type: ignore[assignment] + if progress is not None: + job.progress = progress + if phase is not None: + job.phase = phase + if result_paths is not None: + job.result_paths = result_paths + if error is not None: + job.error = error + self._save() + + def cancel_job(self, job_id: str) -> None: + self.update_job(job_id, status="cancelled", phase="cancelled") + + def jobs_for_batch(self, batch_id: str) -> list[QueueJob]: + return sorted( + [j for j in self._jobs if j.batch_id == batch_id], + key=lambda j: j.batch_index, + ) + + def active_batch_ids(self) -> list[str]: + batch_ids: set[str] = set() + for job in self._jobs: + if job.batch_id and job.status in ("queued", "running"): + batch_ids.add(job.batch_id) + return sorted(batch_ids) + + def all_jobs(self) -> list[QueueJob]: + return list(self._jobs) + + def queued_jobs_for_slot(self, slot: str) -> list[QueueJob]: + return [j for j in self._jobs if j.status == "queued" and j.slot == slot] + + def clear_finished(self) -> None: + self._jobs = [j for j in self._jobs if j.status not in ("complete", "error", "cancelled")] + self._save() + + def _save(self) -> None: + data = {"jobs": [asdict(j) for j in self._jobs]} + self._path.parent.mkdir(parents=True, exist_ok=True) + self._path.write_text(json.dumps(data, indent=2), encoding="utf-8") + + def _load(self) -> None: + if not self._path.exists(): + return + try: + raw = json.loads(self._path.read_text(encoding="utf-8")) + for item in raw.get("jobs", []): + # Backwards compat: provide defaults for batch fields + item.setdefault("batch_id", None) + item.setdefault("batch_index", 0) + item.setdefault("depends_on", None) + item.setdefault("auto_params", {}) + item.setdefault("tags", []) + job = QueueJob(**item) + if job.status == "running": + job.status = "error" + job.error = "Interrupted by app restart" + self._jobs.append(job) + except (json.JSONDecodeError, TypeError, KeyError): + self._jobs = [] diff --git a/backend/state/library_store.py b/backend/state/library_store.py new file mode 100644 index 00000000..27f13719 --- /dev/null +++ b/backend/state/library_store.py @@ -0,0 +1,239 @@ +"""Persistent local library store for characters, styles, and references.""" + +from __future__ import annotations + +import json +import uuid +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Literal, TypeVar + + +ReferenceCategory = Literal["people", "places", "props", "other"] + +_T = TypeVar("_T") + + +@dataclass +class Character: + id: str + name: str + role: str + description: str + reference_image_paths: list[str] = field(default_factory=lambda: list[str]()) + created_at: str = "" + + +@dataclass +class Style: + id: str + name: str + description: str + reference_image_path: str = "" + created_at: str = "" + + +@dataclass +class Reference: + id: str + name: str + category: ReferenceCategory + image_path: str = "" + created_at: str = "" + + +def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _new_id() -> str: + return uuid.uuid4().hex[:12] + + +def _load_json_list(path: Path, cls: type[_T]) -> list[_T]: + """Load a list of dataclass instances from a JSON file.""" + if not path.exists(): + return [] + try: + raw = json.loads(path.read_text(encoding="utf-8")) + return [cls(**item) for item in raw] # type: ignore[arg-type] + except (json.JSONDecodeError, TypeError, KeyError): + return [] + + +def _write_json(path: Path, data: list[dict[str, object]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, indent=2), encoding="utf-8") + + +class LibraryStore: + """Manages JSON file persistence for library entities.""" + + def __init__(self, library_dir: Path) -> None: + self._dir = library_dir + self._dir.mkdir(parents=True, exist_ok=True) + self._characters_file = self._dir / "characters.json" + self._styles_file = self._dir / "styles.json" + self._references_file = self._dir / "references.json" + + self._characters: list[Character] = _load_json_list(self._characters_file, Character) + self._styles: list[Style] = _load_json_list(self._styles_file, Style) + self._references: list[Reference] = _load_json_list(self._references_file, Reference) + + # ------------------------------------------------------------------ + # Characters + # ------------------------------------------------------------------ + + def list_characters(self) -> list[Character]: + return list(self._characters) + + def get_character(self, character_id: str) -> Character | None: + for c in self._characters: + if c.id == character_id: + return c + return None + + def create_character( + self, + *, + name: str, + role: str, + description: str, + reference_image_paths: list[str] | None = None, + ) -> Character: + character = Character( + id=_new_id(), + name=name, + role=role, + description=description, + reference_image_paths=reference_image_paths or [], + created_at=_now_iso(), + ) + self._characters.append(character) + self._save_characters() + return character + + def update_character( + self, + character_id: str, + *, + name: str | None = None, + role: str | None = None, + description: str | None = None, + reference_image_paths: list[str] | None = None, + ) -> Character | None: + character = self.get_character(character_id) + if character is None: + return None + if name is not None: + character.name = name + if role is not None: + character.role = role + if description is not None: + character.description = description + if reference_image_paths is not None: + character.reference_image_paths = reference_image_paths + self._save_characters() + return character + + def delete_character(self, character_id: str) -> bool: + before = len(self._characters) + self._characters = [c for c in self._characters if c.id != character_id] + if len(self._characters) < before: + self._save_characters() + return True + return False + + # ------------------------------------------------------------------ + # Styles + # ------------------------------------------------------------------ + + def list_styles(self) -> list[Style]: + return list(self._styles) + + def get_style(self, style_id: str) -> Style | None: + for s in self._styles: + if s.id == style_id: + return s + return None + + def create_style( + self, + *, + name: str, + description: str, + reference_image_path: str = "", + ) -> Style: + style = Style( + id=_new_id(), + name=name, + description=description, + reference_image_path=reference_image_path, + created_at=_now_iso(), + ) + self._styles.append(style) + self._save_styles() + return style + + def delete_style(self, style_id: str) -> bool: + before = len(self._styles) + self._styles = [s for s in self._styles if s.id != style_id] + if len(self._styles) < before: + self._save_styles() + return True + return False + + # ------------------------------------------------------------------ + # References + # ------------------------------------------------------------------ + + def list_references(self, category: ReferenceCategory | None = None) -> list[Reference]: + if category is None: + return list(self._references) + return [r for r in self._references if r.category == category] + + def get_reference(self, reference_id: str) -> Reference | None: + for r in self._references: + if r.id == reference_id: + return r + return None + + def create_reference( + self, + *, + name: str, + category: ReferenceCategory, + image_path: str = "", + ) -> Reference: + ref = Reference( + id=_new_id(), + name=name, + category=category, + image_path=image_path, + created_at=_now_iso(), + ) + self._references.append(ref) + self._save_references() + return ref + + def delete_reference(self, reference_id: str) -> bool: + before = len(self._references) + self._references = [r for r in self._references if r.id != reference_id] + if len(self._references) < before: + self._save_references() + return True + return False + + # ------------------------------------------------------------------ + # Persistence helpers + # ------------------------------------------------------------------ + + def _save_characters(self) -> None: + _write_json(self._characters_file, [asdict(c) for c in self._characters]) + + def _save_styles(self) -> None: + _write_json(self._styles_file, [asdict(s) for s in self._styles]) + + def _save_references(self) -> None: + _write_json(self._references_file, [asdict(r) for r in self._references]) diff --git a/backend/state/lora_library.py b/backend/state/lora_library.py new file mode 100644 index 00000000..08439477 --- /dev/null +++ b/backend/state/lora_library.py @@ -0,0 +1,98 @@ +"""Persistent LoRA library — tracks downloaded LoRAs with metadata.""" + +from __future__ import annotations + +import json +import logging +import threading +from dataclasses import asdict, dataclass, field +from pathlib import Path + +_logger = logging.getLogger(__name__) + + +@dataclass +class LoraEntry: + """A single LoRA in the local library.""" + + id: str + name: str + file_path: str + file_size_bytes: int = 0 + thumbnail_url: str = "" + trigger_phrase: str = "" + base_model: str = "" + civitai_model_id: int | None = None + civitai_version_id: int | None = None + description: str = "" + + +@dataclass +class LoraLibrary: + entries: list[LoraEntry] = field(default_factory=lambda: list[LoraEntry]()) + + +class LoraLibraryStore: + """Thread-safe, JSON-backed LoRA catalog.""" + + def __init__(self, loras_dir: Path) -> None: + self._loras_dir = loras_dir + self._loras_dir.mkdir(parents=True, exist_ok=True) + self._catalog_path = loras_dir / "catalog.json" + self._lock = threading.Lock() + self._library = self._load() + + @property + def loras_dir(self) -> Path: + return self._loras_dir + + def _load(self) -> LoraLibrary: + if not self._catalog_path.exists(): + return LoraLibrary() + try: + raw = json.loads(self._catalog_path.read_text(encoding="utf-8")) + entries = [LoraEntry(**e) for e in raw.get("entries", [])] + return LoraLibrary(entries=entries) + except Exception: + _logger.warning("Failed to load LoRA catalog, starting fresh", exc_info=True) + return LoraLibrary() + + def _save(self) -> None: + data = {"entries": [asdict(e) for e in self._library.entries]} + self._catalog_path.write_text(json.dumps(data, indent=2), encoding="utf-8") + + def list_all(self) -> list[LoraEntry]: + with self._lock: + return list(self._library.entries) + + def get(self, lora_id: str) -> LoraEntry | None: + with self._lock: + for entry in self._library.entries: + if entry.id == lora_id: + return entry + return None + + def add(self, entry: LoraEntry) -> None: + with self._lock: + # Replace if same ID exists + self._library.entries = [e for e in self._library.entries if e.id != entry.id] + self._library.entries.append(entry) + self._save() + + def remove(self, lora_id: str) -> bool: + with self._lock: + before = len(self._library.entries) + self._library.entries = [e for e in self._library.entries if e.id != lora_id] + if len(self._library.entries) < before: + self._save() + return True + return False + + def update_thumbnail(self, lora_id: str, thumbnail_url: str) -> bool: + with self._lock: + for entry in self._library.entries: + if entry.id == lora_id: + entry.thumbnail_url = thumbnail_url + self._save() + return True + return False diff --git a/backend/state/prompt_store.py b/backend/state/prompt_store.py new file mode 100644 index 00000000..a3295cf3 --- /dev/null +++ b/backend/state/prompt_store.py @@ -0,0 +1,193 @@ +"""JSON-file-backed persistence for saved prompts and wildcard definitions.""" + +from __future__ import annotations + +import json +import logging +import uuid +from dataclasses import dataclass, field, asdict +from datetime import datetime, timezone +from pathlib import Path + +logger = logging.getLogger(__name__) + + +@dataclass +class SavedPrompt: + id: str + text: str + tags: list[str] + category: str + used_count: int + created_at: str + last_used_at: str | None + + +@dataclass +class WildcardEntry: + id: str + name: str + values: list[str] + created_at: str + + +def _empty_prompts() -> list[SavedPrompt]: + return [] + + +def _empty_wildcards() -> list[WildcardEntry]: + return [] + + +@dataclass +class PromptStoreData: + prompts: list[SavedPrompt] = field(default_factory=_empty_prompts) + wildcards: list[WildcardEntry] = field(default_factory=_empty_wildcards) + + +def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _new_id() -> str: + return uuid.uuid4().hex[:12] + + +class PromptStore: + """Simple JSON-file persistence for prompts and wildcards. + + Not thread-safe on its own; callers (the handler) are expected to + hold the shared lock when calling mutating methods. + """ + + def __init__(self, path: Path) -> None: + self._path = path + self._data = PromptStoreData() + self._load() + + # ------------------------------------------------------------------ + # Persistence helpers + # ------------------------------------------------------------------ + + def _load(self) -> None: + if not self._path.exists(): + return + try: + raw = json.loads(self._path.read_text(encoding="utf-8")) + prompts = [SavedPrompt(**p) for p in raw.get("prompts", [])] + wildcards = [WildcardEntry(**w) for w in raw.get("wildcards", [])] + self._data = PromptStoreData(prompts=prompts, wildcards=wildcards) + except Exception as exc: + logger.warning("Could not load prompt store from %s: %s", self._path, exc) + + def _save(self) -> None: + try: + self._path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "prompts": [asdict(p) for p in self._data.prompts], + "wildcards": [asdict(w) for w in self._data.wildcards], + } + self._path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + except Exception as exc: + logger.warning("Could not save prompt store to %s: %s", self._path, exc) + + # ------------------------------------------------------------------ + # Prompt CRUD + # ------------------------------------------------------------------ + + def list_prompts( + self, + search: str | None = None, + tag: str | None = None, + sort_by: str | None = None, + ) -> list[SavedPrompt]: + results = list(self._data.prompts) + if search: + lower = search.lower() + results = [p for p in results if lower in p.text.lower()] + if tag: + results = [p for p in results if tag in p.tags] + if sort_by == "used_count": + results.sort(key=lambda p: p.used_count, reverse=True) + elif sort_by == "created_at": + results.sort(key=lambda p: p.created_at, reverse=True) + elif sort_by == "last_used_at": + results.sort(key=lambda p: p.last_used_at or "", reverse=True) + return results + + def save_prompt(self, text: str, tags: list[str], category: str) -> SavedPrompt: + prompt = SavedPrompt( + id=_new_id(), + text=text, + tags=tags, + category=category, + used_count=0, + created_at=_now_iso(), + last_used_at=None, + ) + self._data.prompts.append(prompt) + self._save() + return prompt + + def get_prompt(self, prompt_id: str) -> SavedPrompt | None: + for p in self._data.prompts: + if p.id == prompt_id: + return p + return None + + def delete_prompt(self, prompt_id: str) -> bool: + before = len(self._data.prompts) + self._data.prompts = [p for p in self._data.prompts if p.id != prompt_id] + if len(self._data.prompts) < before: + self._save() + return True + return False + + def increment_usage(self, prompt_id: str) -> SavedPrompt | None: + for p in self._data.prompts: + if p.id == prompt_id: + p.used_count += 1 + p.last_used_at = _now_iso() + self._save() + return p + return None + + # ------------------------------------------------------------------ + # Wildcard CRUD + # ------------------------------------------------------------------ + + def list_wildcards(self) -> list[WildcardEntry]: + return list(self._data.wildcards) + + def create_wildcard(self, name: str, values: list[str]) -> WildcardEntry: + entry = WildcardEntry( + id=_new_id(), + name=name, + values=values, + created_at=_now_iso(), + ) + self._data.wildcards.append(entry) + self._save() + return entry + + def get_wildcard(self, wildcard_id: str) -> WildcardEntry | None: + for w in self._data.wildcards: + if w.id == wildcard_id: + return w + return None + + def update_wildcard(self, wildcard_id: str, values: list[str]) -> WildcardEntry | None: + for w in self._data.wildcards: + if w.id == wildcard_id: + w.values = values + self._save() + return w + return None + + def delete_wildcard(self, wildcard_id: str) -> bool: + before = len(self._data.wildcards) + self._data.wildcards = [w for w in self._data.wildcards if w.id != wildcard_id] + if len(self._data.wildcards) < before: + self._save() + return True + return False diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 891524ef..1abdccca 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -74,13 +74,19 @@ def test_state(tmp_path: Path, fake_services: FakeServices): text_encoder=fake_services.text_encoder, task_runner=fake_services.task_runner, ltx_api_client=fake_services.ltx_api_client, - zit_api_client=fake_services.zit_api_client, + image_api_client=fake_services.image_api_client, + video_api_client=fake_services.video_api_client, + palette_sync_client=fake_services.palette_sync_client, fast_video_pipeline_class=type(fake_services.fast_video_pipeline), + gguf_video_pipeline_class=None, + nf4_video_pipeline_class=None, image_generation_pipeline_class=type(fake_services.image_generation_pipeline), + flux_klein_pipeline_class=None, ic_lora_pipeline_class=type(fake_services.ic_lora_pipeline), a2v_pipeline_class=type(fake_services.a2v_pipeline), retake_pipeline_class=type(fake_services.retake_pipeline), ic_lora_model_downloader=fake_services.ic_lora_model_downloader, + model_scanner=fake_services.model_scanner, ) handler = build_initial_state( @@ -108,7 +114,7 @@ def default_app_settings() -> AppSettings: @pytest.fixture def create_fake_model_files(test_state): - def _create(include_zit: bool = False): + def _create(include_zit: bool = False, include_flux_klein: bool = False): for path in ( test_state.config.model_path("checkpoint"), test_state.config.model_path("upsampler"), @@ -125,6 +131,14 @@ def _create(include_zit: bool = False): zit_dir = test_state.config.model_path("zit") zit_dir.mkdir(parents=True, exist_ok=True) (zit_dir / "model.safetensors").write_bytes(b"\x00" * 1024) + # Default is now flux-klein-9b; switch to ZIT when creating ZIT files + test_state.state.app_settings.image_model = "z-image-turbo" + + if include_flux_klein: + flux_dir = test_state.config.model_path("flux_klein") + flux_dir.mkdir(parents=True, exist_ok=True) + (flux_dir / "model_index.json").write_bytes(b"{}") + (flux_dir / "model.safetensors").write_bytes(b"\x00" * 1024) return _create diff --git a/backend/tests/fakes/services.py b/backend/tests/fakes/services.py index 1abf285a..7fc71295 100644 --- a/backend/tests/fakes/services.py +++ b/backend/tests/fakes/services.py @@ -8,7 +8,7 @@ from typing import Any, ClassVar from PIL import Image -from api_types import ImageConditioningInput, VideoCameraMotion +from api_types import DetectedModel, ImageConditioningInput, VideoCameraMotion from services.interfaces import IcLoraDownloadPayload, IcLoraModelPayload, VideoInfoPayload from services.ltx_api_client.ltx_api_client import LTXRetakeResult from tests.fakes.fake_gpu_info import FakeGpuInfo @@ -158,6 +158,7 @@ def generate_text_to_video( fps: float, generate_audio: bool, camera_motion: VideoCameraMotion = "none", + last_frame_uri: str | None = None, ) -> bytes: self.text_to_video_calls.append( { @@ -169,6 +170,7 @@ def generate_text_to_video( "fps": fps, "generate_audio": generate_audio, "camera_motion": camera_motion, + "last_frame_uri": last_frame_uri, } ) if self.raise_on_text_to_video is not None: @@ -187,6 +189,7 @@ def generate_image_to_video( fps: float, generate_audio: bool, camera_motion: VideoCameraMotion = "none", + last_frame_uri: str | None = None, ) -> bytes: self.image_to_video_calls.append( { @@ -199,6 +202,7 @@ def generate_image_to_video( "fps": fps, "generate_audio": generate_audio, "camera_motion": camera_motion, + "last_frame_uri": last_frame_uri, } ) if self.raise_on_image_to_video is not None: @@ -254,12 +258,12 @@ def retake( return self.retake_result -class FakeZitAPIClient: +class FakeImageAPIClient: def __init__(self) -> None: self.configured = True self.text_to_image_calls: list[dict[str, Any]] = [] self.raise_on_text_to_image: Exception | None = None - self.text_to_image_result = b"fake-zit-api-image" + self.text_to_image_result = b"fake-api-image" def is_configured(self) -> bool: return self.configured @@ -268,6 +272,7 @@ def generate_text_to_image( self, *, api_key: str, + model: str, prompt: str, width: int, height: int, @@ -277,6 +282,7 @@ def generate_text_to_image( self.text_to_image_calls.append( { "api_key": api_key, + "model": model, "prompt": prompt, "width": width, "height": height, @@ -289,6 +295,144 @@ def generate_text_to_image( return self.text_to_image_result +class FakeVideoAPIClient: + def __init__(self) -> None: + self.text_to_video_calls: list[dict[str, Any]] = [] + self.raise_on_text_to_video: Exception | None = None + self.text_to_video_result = b"fake-seedance-video" + + def generate_text_to_video( + self, + *, + api_key: str, + model: str, + prompt: str, + duration: int, + resolution: str, + aspect_ratio: str, + generate_audio: bool, + last_frame: str | None = None, + ) -> bytes: + self.text_to_video_calls.append({ + "api_key": api_key, + "model": model, + "prompt": prompt, + "duration": duration, + "resolution": resolution, + "aspect_ratio": aspect_ratio, + "generate_audio": generate_audio, + "last_frame": last_frame, + }) + if self.raise_on_text_to_video is not None: + raise self.raise_on_text_to_video + return self.text_to_video_result + + +class FakePaletteSyncClient: + def __init__(self) -> None: + self.validate_calls: list[str] = [] + self.credits_calls: list[str] = [] + self.login_calls: list[tuple[str, str]] = [] + self.raise_on_validate: Exception | None = None + self.raise_on_login: Exception | None = None + self.user_info: dict[str, Any] = {"id": "user-123", "email": "test@example.com", "name": "Test User"} + self.credits_info: dict[str, Any] = { + "balance_cents": 5000, + "lifetime_purchased_cents": 10000, + "lifetime_used_cents": 5000, + "pricing": { + "video_t2v": 40, "video_i2v": 40, "video_seedance": 80, + "image": 20, "image_edit": 20, "audio": 15, "text_enhance": 3, + }, + } + + def validate_connection(self, *, api_key: str) -> dict[str, Any]: + self.validate_calls.append(api_key) + if self.raise_on_validate is not None: + raise self.raise_on_validate + return self.user_info + + def sign_in_with_email(self, *, email: str, password: str) -> dict[str, Any]: + self.login_calls.append((email, password)) + if self.raise_on_login is not None: + raise self.raise_on_login + return { + "access_token": "fake-jwt-token", + "refresh_token": "fake-refresh-token", + "user": self.user_info, + } + + def refresh_access_token(self, *, refresh_token: str) -> dict[str, Any]: + return { + "access_token": "refreshed-jwt-token", + "refresh_token": "refreshed-refresh-token", + "user": self.user_info, + } + + def get_credits(self, *, api_key: str) -> dict[str, Any]: + self.credits_calls.append(api_key) + return self.credits_info + + def check_credits( + self, *, api_key: str, generation_type: str, count: int, + ) -> dict[str, Any]: + cost = {"video_t2v": 40, "video_i2v": 40, "video_seedance": 80, "image": 20, "text_enhance": 3}.get(generation_type, 40) + total = cost * count + balance = self.credits_info.get("balance_cents", 5000) + if not isinstance(balance, int): + balance = 5000 + return { + "can_afford": balance >= total, + "cost_cents": total, + "balance_cents": balance, + "balance_after_cents": balance - total, + } + + def deduct_credits( + self, *, api_key: str, generation_type: str, count: int, + metadata: dict[str, Any] | None, + ) -> dict[str, Any]: + cost = {"video_t2v": 40, "video_i2v": 40, "video_seedance": 80, "image": 20, "text_enhance": 3}.get(generation_type, 40) + total = cost * count + balance = self.credits_info.get("balance_cents", 5000) + if not isinstance(balance, int): + balance = 5000 + new_balance = balance - total + self.credits_info["balance_cents"] = new_balance + return {"deducted_cents": total, "balance_cents": new_balance} + + def list_gallery( + self, *, api_key: str, page: int, per_page: int, asset_type: str, + ) -> dict[str, Any]: + return {"items": [], "total": 0, "page": page, "per_page": per_page, "total_pages": 0} + + def list_characters(self, *, api_key: str) -> dict[str, Any]: + return {"characters": []} + + def list_styles(self, *, api_key: str) -> dict[str, Any]: + return {"styles": [], "brands": []} + + def list_references(self, *, api_key: str, category: str | None) -> dict[str, Any]: + return {"references": []} + + def list_loras(self, *, api_key: str) -> dict[str, Any]: + return {"loras": []} + + def enhance_prompt(self, *, api_key: str, prompt: str, level: str) -> dict[str, Any]: + return {"enhanced_prompt": f"Enhanced ({level}): {prompt}"} + + +class FakeModelScanner: + def __init__(self) -> None: + self._models: list[DetectedModel] = [] + + def set_models(self, models: list[DetectedModel]) -> None: + self._models = list(models) + + def scan_video_models(self, folder: Path) -> list[DetectedModel]: # noqa: ARG002 + return list(self._models) + + class FakeModelDownloader: def __init__(self) -> None: self.calls: list[dict[str, Any]] = [] @@ -355,6 +499,9 @@ def __init__(self) -> None: def cleanup(self) -> None: self.cleanup_calls += 1 + def deep_cleanup(self) -> None: + self.cleanup_calls += 1 + class FakeCapture: def __init__( @@ -485,8 +632,10 @@ def create( gemma_root: str | None, upsampler_path: str, device: str | object, + lora_path: str | None = None, + lora_weight: float = 1.0, ) -> "FakeFastVideoPipeline": - del checkpoint_path, gemma_root, upsampler_path, device + del checkpoint_path, gemma_root, upsampler_path, device, lora_path, lora_weight pipeline = FakeFastVideoPipeline._singleton if pipeline is None: raise RuntimeError("FakeFastVideoPipeline singleton is not bound") @@ -545,6 +694,7 @@ def create( def __init__(self) -> None: self.device: str | None = None self.generate_calls: list[dict[str, Any]] = [] + self.img2img_calls: list[dict[str, Any]] = [] self.raise_on_generate: Exception | None = None def generate(self, **kwargs: Any) -> FakeZitOutput: @@ -553,9 +703,21 @@ def generate(self, **kwargs: Any) -> FakeZitOutput: raise self.raise_on_generate return FakeZitOutput(color="blue") + def img2img(self, **kwargs: Any) -> FakeZitOutput: + self.img2img_calls.append(kwargs) + if self.raise_on_generate is not None: + raise self.raise_on_generate + return FakeZitOutput(color="green") + def to(self, device: str) -> None: self.device = device + def load_lora(self, lora_path: str, weight: float = 1.0) -> None: + pass + + def unload_lora(self) -> None: + pass + class FakeIcLoraPipeline: _singleton: ClassVar["FakeIcLoraPipeline | None"] = None @@ -753,13 +915,16 @@ class FakeServices: text_encoder: FakeTextEncoder = field(default_factory=FakeTextEncoder) task_runner: FakeTaskRunner = field(default_factory=FakeTaskRunner) ltx_api_client: FakeLTXAPIClient = field(default_factory=FakeLTXAPIClient) - zit_api_client: FakeZitAPIClient = field(default_factory=FakeZitAPIClient) + image_api_client: FakeImageAPIClient = field(default_factory=FakeImageAPIClient) + video_api_client: FakeVideoAPIClient = field(default_factory=FakeVideoAPIClient) + palette_sync_client: FakePaletteSyncClient = field(default_factory=FakePaletteSyncClient) fast_video_pipeline: FakeFastVideoPipeline = field(default_factory=FakeFastVideoPipeline) image_generation_pipeline: FakeImageGenerationPipeline = field(default_factory=FakeImageGenerationPipeline) ic_lora_pipeline: FakeIcLoraPipeline = field(default_factory=FakeIcLoraPipeline) a2v_pipeline: FakeA2VPipeline = field(default_factory=FakeA2VPipeline) retake_pipeline: FakeRetakePipeline = field(default_factory=FakeRetakePipeline) ic_lora_model_downloader: FakeIcLoraModelDownloader = field(default_factory=FakeIcLoraModelDownloader) + model_scanner: FakeModelScanner = field(default_factory=FakeModelScanner) def __post_init__(self) -> None: FakeFastVideoPipeline.bind_singleton(self.fast_video_pipeline) diff --git a/backend/tests/test_batch.py b/backend/tests/test_batch.py new file mode 100644 index 00000000..d7e2f331 --- /dev/null +++ b/backend/tests/test_batch.py @@ -0,0 +1,348 @@ +"""Tests for batch generation handler and API types.""" + +from __future__ import annotations + +from pathlib import Path + +from starlette.testclient import TestClient + +from api_types import ( + BatchJobItem, + BatchReport, + BatchSubmitRequest, + BatchSubmitResponse, + BatchStatusResponse, + PipelineDefinition, + PipelineStep, + SweepAxis, + SweepDefinition, +) +from handlers.batch_handler import BatchHandler +from state.job_queue import JobQueue + + +# --- Task 3: API types --- + + +def test_batch_submit_request_list_mode() -> None: + req = BatchSubmitRequest( + mode="list", + target="local", + jobs=[ + BatchJobItem(type="image", model="zit", params={"prompt": "a cat"}), + BatchJobItem(type="image", model="zit", params={"prompt": "a dog"}), + ], + ) + assert req.mode == "list" + assert len(req.jobs) == 2 # type: ignore[arg-type] + + +def test_batch_submit_request_sweep_mode() -> None: + req = BatchSubmitRequest( + mode="sweep", + target="cloud", + sweep=SweepDefinition( + base_type="image", + base_model="zit", + base_params={"prompt": "a cat", "width": 1024, "height": 1024}, + axes=[ + SweepAxis(param="loraWeight", values=[0.5, 0.75, 1.0]), + SweepAxis(param="prompt", values=["a cat", "a dog"], mode="search_replace", search="a cat"), + ], + ), + ) + assert req.sweep is not None + assert len(req.sweep.axes) == 2 + + +def test_batch_submit_request_pipeline_mode() -> None: + req = BatchSubmitRequest( + mode="pipeline", + target="local", + pipeline=PipelineDefinition( + steps=[ + PipelineStep(type="image", model="zit", params={"prompt": "a landscape"}), + PipelineStep(type="video", model="fast", params={}, auto_prompt=True), + ], + ), + ) + assert req.pipeline is not None + assert req.pipeline.steps[1].auto_prompt is True + + +def test_batch_report_model() -> None: + report = BatchReport( + batch_id="abc123", + total=10, + succeeded=8, + failed=2, + cancelled=0, + duration_seconds=120.5, + avg_job_seconds=12.05, + result_paths=["/out/1.png", "/out/2.png"], + failed_indices=[3, 7], + sweep_axes=["loraWeight"], + ) + assert report.succeeded + report.failed == report.total + + +# --- Task 4: BatchHandler list mode --- + + +def test_batch_handler_expand_list(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="list", + target="local", + jobs=[ + BatchJobItem(type="image", model="zit", params={"prompt": "a cat"}), + BatchJobItem(type="video", model="fast", params={"prompt": "a dog running"}), + ], + ) + response = handler.submit_batch(request, queue) + + assert response.total_jobs == 2 + assert len(response.job_ids) == 2 + + jobs = queue.jobs_for_batch(response.batch_id) + assert len(jobs) == 2 + assert jobs[0].type == "image" + assert jobs[0].slot == "gpu" + assert jobs[0].batch_index == 0 + assert f"batch:{response.batch_id}" in jobs[0].tags + assert jobs[1].type == "video" + assert jobs[1].batch_index == 1 + + +def test_batch_handler_list_cloud_target(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="list", + target="cloud", + jobs=[BatchJobItem(type="image", model="zit", params={"prompt": "a cat"})], + ) + response = handler.submit_batch(request, queue) + job = queue.get_job(response.job_ids[0]) + assert job is not None + assert job.slot == "api" + + +# --- Task 5: BatchHandler sweep mode --- + + +def test_batch_handler_sweep_single_axis(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="sweep", + target="local", + sweep=SweepDefinition( + base_type="image", + base_model="zit", + base_params={"prompt": "a cat", "width": 1024, "height": 1024}, + axes=[SweepAxis(param="loraWeight", values=[0.5, 0.75, 1.0])], + ), + ) + response = handler.submit_batch(request, queue) + + assert response.total_jobs == 3 + jobs = queue.jobs_for_batch(response.batch_id) + assert jobs[0].params["loraWeight"] == 0.5 + assert jobs[1].params["loraWeight"] == 0.75 + assert jobs[2].params["loraWeight"] == 1.0 + assert all(j.params["prompt"] == "a cat" for j in jobs) + assert "sweep:loraWeight" in jobs[0].tags + + +def test_batch_handler_sweep_two_axes_cartesian(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="sweep", + target="local", + sweep=SweepDefinition( + base_type="image", + base_model="zit", + base_params={"prompt": "a cat", "width": 1024}, + axes=[ + SweepAxis(param="loraWeight", values=[0.5, 1.0]), + SweepAxis(param="numSteps", values=[4, 8]), + ], + ), + ) + response = handler.submit_batch(request, queue) + + assert response.total_jobs == 4 # 2 x 2 cartesian product + jobs = queue.jobs_for_batch(response.batch_id) + combos = [(j.params["loraWeight"], j.params["numSteps"]) for j in jobs] + assert (0.5, 4) in combos + assert (0.5, 8) in combos + assert (1.0, 4) in combos + assert (1.0, 8) in combos + + +def test_batch_handler_sweep_search_replace(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="sweep", + target="local", + sweep=SweepDefinition( + base_type="image", + base_model="zit", + base_params={"prompt": "a cute cat in a garden"}, + axes=[ + SweepAxis( + param="prompt", + values=["cat", "dog", "horse"], + mode="search_replace", + search="cat", + ), + ], + ), + ) + response = handler.submit_batch(request, queue) + + assert response.total_jobs == 3 + jobs = queue.jobs_for_batch(response.batch_id) + assert jobs[0].params["prompt"] == "a cute cat in a garden" + assert jobs[1].params["prompt"] == "a cute dog in a garden" + assert jobs[2].params["prompt"] == "a cute horse in a garden" + + +# --- Task 6: BatchHandler pipeline mode --- + + +def test_batch_handler_pipeline_creates_chained_jobs(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="pipeline", + target="local", + pipeline=PipelineDefinition( + steps=[ + PipelineStep(type="image", model="zit", params={"prompt": "a landscape"}), + PipelineStep(type="video", model="fast", params={"duration": "4"}, auto_prompt=True), + ], + ), + ) + response = handler.submit_batch(request, queue) + + assert response.total_jobs == 2 + jobs = queue.jobs_for_batch(response.batch_id) + img_job = jobs[0] + vid_job = jobs[1] + + assert img_job.type == "image" + assert img_job.depends_on is None + assert vid_job.type == "video" + assert vid_job.depends_on == img_job.id + assert vid_job.auto_params == {"imagePath": "$dep.result_paths[0]", "auto_prompt": "true"} + + +def test_batch_handler_pipeline_three_steps(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="pipeline", + target="local", + pipeline=PipelineDefinition( + steps=[ + PipelineStep(type="image", model="zit", params={"prompt": "frame 1"}), + PipelineStep(type="video", model="fast", params={}, auto_prompt=True), + PipelineStep(type="video", model="pro", params={}, auto_prompt=False), + ], + ), + ) + response = handler.submit_batch(request, queue) + + assert response.total_jobs == 3 + jobs = queue.jobs_for_batch(response.batch_id) + assert jobs[0].depends_on is None + assert jobs[1].depends_on == jobs[0].id + assert jobs[2].depends_on == jobs[1].id + assert jobs[2].auto_params == {"imagePath": "$dep.result_paths[0]"} # No auto_prompt + + +# --- Task 9: Integration tests --- + + +def test_batch_submit_and_status_integration(client: TestClient) -> None: + resp = client.post("/api/queue/submit-batch", json={ + "mode": "list", + "target": "local", + "jobs": [ + {"type": "image", "model": "zit", "params": {"prompt": "cat"}}, + {"type": "image", "model": "zit", "params": {"prompt": "dog"}}, + ], + }) + assert resp.status_code == 200 + data = resp.json() + batch_id = data["batch_id"] + assert data["total_jobs"] == 2 + + resp = client.get(f"/api/queue/batch/{batch_id}/status") + assert resp.status_code == 200 + status = resp.json() + assert status["batch_id"] == batch_id + assert status["total"] == 2 + assert status["queued"] == 2 + assert status["report"] is None + + +def test_batch_cancel_integration(client: TestClient) -> None: + resp = client.post("/api/queue/submit-batch", json={ + "mode": "list", + "target": "local", + "jobs": [ + {"type": "image", "model": "zit", "params": {"prompt": "cat"}}, + {"type": "image", "model": "zit", "params": {"prompt": "dog"}}, + ], + }) + batch_id = resp.json()["batch_id"] + + resp = client.post(f"/api/queue/batch/{batch_id}/cancel") + assert resp.status_code == 200 + + resp = client.get(f"/api/queue/batch/{batch_id}/status") + status = resp.json() + assert status["cancelled"] == 2 + + +def test_batch_retry_failed_integration(client: TestClient) -> None: + resp = client.post("/api/queue/submit-batch", json={ + "mode": "list", + "target": "local", + "jobs": [ + {"type": "image", "model": "zit", "params": {"prompt": "cat"}}, + ], + }) + batch_id = resp.json()["batch_id"] + + resp = client.post(f"/api/queue/batch/{batch_id}/retry-failed") + assert resp.status_code == 200 + + +def test_queue_status_includes_batch_fields(client: TestClient) -> None: + resp = client.post("/api/queue/submit-batch", json={ + "mode": "list", + "target": "local", + "jobs": [{"type": "image", "model": "zit", "params": {"prompt": "cat"}}], + }) + batch_id = resp.json()["batch_id"] + + resp = client.get("/api/queue/status") + data = resp.json() + job = data["jobs"][0] + assert job["batch_id"] == batch_id + assert job["batch_index"] == 0 + assert f"batch:{batch_id}" in job["tags"] diff --git a/backend/tests/test_contact_sheet.py b/backend/tests/test_contact_sheet.py new file mode 100644 index 00000000..31aaae68 --- /dev/null +++ b/backend/tests/test_contact_sheet.py @@ -0,0 +1,81 @@ +"""Tests for contact sheet generation endpoint.""" +from __future__ import annotations + + +class TestContactSheetGenerate: + def test_generates_9_jobs(self, client): + resp = client.post("/api/contact-sheet/generate", json={ + "reference_image_path": "/tmp/ref.png", + "subject_description": "A woman in a red dress", + }) + assert resp.status_code == 200 + data = resp.json() + assert len(data["job_ids"]) == 9 + + # All IDs should be unique + assert len(set(data["job_ids"])) == 9 + + def test_jobs_appear_in_queue(self, client): + resp = client.post("/api/contact-sheet/generate", json={ + "reference_image_path": "/tmp/ref.png", + "subject_description": "A man in a suit", + }) + job_ids = resp.json()["job_ids"] + + queue = client.get("/api/queue/status").json() + queue_ids = {j["id"] for j in queue["jobs"]} + for jid in job_ids: + assert jid in queue_ids + + def test_prompts_include_subject_and_angles(self, client): + resp = client.post("/api/contact-sheet/generate", json={ + "reference_image_path": "/tmp/ref.png", + "subject_description": "A warrior elf", + }) + job_ids = resp.json()["job_ids"] + + queue = client.get("/api/queue/status").json() + prompts = [j["params"]["prompt"] for j in queue["jobs"]] + + # Every prompt should contain the subject + for prompt in prompts: + assert "A warrior elf" in prompt + + # Check that different camera angles are present + angle_keywords = [ + "Close-up", + "Medium shot", + "Full body", + "Over-the-shoulder", + "Low angle", + "High angle", + "Profile", + "Three-quarter", + "Wide establishing", + ] + for keyword in angle_keywords: + assert any(keyword in p for p in prompts), f"Missing angle keyword: {keyword}" + + def test_prompts_include_style(self, client): + resp = client.post("/api/contact-sheet/generate", json={ + "reference_image_path": "/tmp/ref.png", + "subject_description": "A robot", + "style": "cyberpunk neon", + }) + job_ids = resp.json()["job_ids"] + + queue = client.get("/api/queue/status").json() + prompts = [j["params"]["prompt"] for j in queue["jobs"]] + + for prompt in prompts: + assert "cyberpunk neon" in prompt + + def test_returns_9_job_ids(self, client): + resp = client.post("/api/contact-sheet/generate", json={ + "reference_image_path": "/tmp/ref.png", + "subject_description": "A cat", + }) + assert resp.status_code == 200 + data = resp.json() + assert "job_ids" in data + assert len(data["job_ids"]) == 9 diff --git a/backend/tests/test_enhance_prompt.py b/backend/tests/test_enhance_prompt.py new file mode 100644 index 00000000..7400c0ff --- /dev/null +++ b/backend/tests/test_enhance_prompt.py @@ -0,0 +1,146 @@ +"""Tests for prompt enhancement route.""" +from __future__ import annotations + +from tests.fakes.services import FakeResponse + + +# ── Gemini path ──────────────────────────────────────────────── + + +def test_enhance_prompt_returns_enhanced_text(client, test_state, fake_services): + """Enhance prompt route should return an enhanced version of the input via Gemini.""" + test_state.state.app_settings.gemini_api_key = "test-gemini-key" + + fake_services.http.queue("post", FakeResponse( + status_code=200, + json_payload={ + "candidates": [{"content": {"parts": [{"text": "A cinematic shot of a majestic cat walking gracefully across a sun-drenched room, golden hour lighting, shallow depth of field"}]}}] + }, + )) + + resp = client.post("/api/enhance-prompt", json={ + "prompt": "cat walking in room", + "mode": "text-to-video", + }) + assert resp.status_code == 200 + data = resp.json() + assert "enhancedPrompt" in data + assert len(data["enhancedPrompt"]) > len("cat walking in room") + + +def test_enhance_prompt_image_mode(client, test_state, fake_services): + """Should work with image mode via Gemini.""" + test_state.state.app_settings.gemini_api_key = "test-key" + + fake_services.http.queue("post", FakeResponse( + status_code=200, + json_payload={ + "candidates": [{"content": {"parts": [{"text": "A stunning photograph of a cat"}]}}] + }, + )) + + resp = client.post("/api/enhance-prompt", json={ + "prompt": "cat photo", + "mode": "text-to-image", + }) + assert resp.status_code == 200 + assert "enhancedPrompt" in resp.json() + + +# ── Palette proxy path ───────────────────────────────────────── + + +def test_enhance_prompt_via_palette(client, test_state, fake_services): + """When palette_api_key is set, should proxy to Palette prompt-expander.""" + test_state.state.app_settings.palette_api_key = "pal-test-key" + + fake_services.http.queue("post", FakeResponse( + status_code=200, + json_payload={"enhanced_prompt": "A breathtaking cinematic scene of a cat"}, + )) + + resp = client.post("/api/enhance-prompt", json={ + "prompt": "cat scene", + "mode": "text-to-video", + }) + assert resp.status_code == 200 + data = resp.json() + assert data["enhancedPrompt"] == "A breathtaking cinematic scene of a cat" + + # Verify the correct URL and auth header were used + call = fake_services.http.calls[-1] + assert call.method == "post" + assert "/api/prompt-expander" in call.url + assert call.headers is not None + assert call.headers["Authorization"] == "Bearer pal-test-key" + assert call.json_payload is not None + assert call.json_payload["prompt"] == "cat scene" + assert call.json_payload["level"] == "2x" + + +def test_enhance_prompt_palette_takes_priority_over_gemini(client, test_state, fake_services): + """When both keys are set, Palette should be used instead of Gemini.""" + test_state.state.app_settings.palette_api_key = "pal-key" + test_state.state.app_settings.gemini_api_key = "gem-key" + + fake_services.http.queue("post", FakeResponse( + status_code=200, + json_payload={"enhanced_prompt": "Enhanced via palette"}, + )) + + resp = client.post("/api/enhance-prompt", json={ + "prompt": "test prompt", + "mode": "text-to-video", + }) + assert resp.status_code == 200 + assert resp.json()["enhancedPrompt"] == "Enhanced via palette" + + # Should have called Palette, not Gemini + call = fake_services.http.calls[-1] + assert "/api/prompt-expander" in call.url + + +def test_enhance_prompt_palette_error_propagates(client, test_state, fake_services): + """Palette API errors should propagate as HTTP errors.""" + test_state.state.app_settings.palette_api_key = "pal-key" + + fake_services.http.queue("post", FakeResponse( + status_code=500, + text="Internal Server Error", + )) + + resp = client.post("/api/enhance-prompt", json={ + "prompt": "test", + "mode": "text-to-video", + }) + assert resp.status_code == 500 + + +def test_enhance_prompt_palette_expanded_prompt_field(client, test_state, fake_services): + """Should also accept expandedPrompt field from Palette response.""" + test_state.state.app_settings.palette_api_key = "pal-key" + + fake_services.http.queue("post", FakeResponse( + status_code=200, + json_payload={"expandedPrompt": "Expanded prompt text"}, + )) + + resp = client.post("/api/enhance-prompt", json={ + "prompt": "short prompt", + "mode": "text-to-video", + }) + assert resp.status_code == 200 + assert resp.json()["enhancedPrompt"] == "Expanded prompt text" + + +# ── No service configured ────────────────────────────────────── + + +def test_enhance_prompt_no_service_configured(client): + """Should return error when neither Palette nor Gemini is configured.""" + resp = client.post("/api/enhance-prompt", json={ + "prompt": "cat walking", + "mode": "text-to-video", + }) + assert resp.status_code == 400 + assert "NO_AI_SERVICE_CONFIGURED" in resp.json()["error"] diff --git a/backend/tests/test_ffn_chunking.py b/backend/tests/test_ffn_chunking.py new file mode 100644 index 00000000..3a9cd38a --- /dev/null +++ b/backend/tests/test_ffn_chunking.py @@ -0,0 +1,69 @@ +"""Tests for FFN chunked feedforward optimization.""" + +from __future__ import annotations + +import torch + +from services.gpu_optimizations.ffn_chunking import _make_chunked_forward, patch_ffn_chunking + + +class _FakeFeedForward(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.net = torch.nn.Sequential( + torch.nn.Linear(dim, dim * 4), + torch.nn.GELU(), + torch.nn.Linear(dim * 4, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class _FakeTransformerBlock(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.ff = _FakeFeedForward(dim) + self.audio_ff = _FakeFeedForward(dim) + + +class _FakeTransformer(torch.nn.Module): + def __init__(self, dim: int, num_blocks: int) -> None: + super().__init__() + self.blocks = torch.nn.ModuleList([_FakeTransformerBlock(dim) for _ in range(num_blocks)]) + + +def test_chunked_forward_matches_original() -> None: + dim = 32 + ff = _FakeFeedForward(dim) + x = torch.randn(1, 2000, dim) + + original_out = ff(x) + chunked_fn = _make_chunked_forward(ff.forward, num_chunks=4) + chunked_out = chunked_fn(x) + + assert torch.allclose(original_out, chunked_out, atol=1e-5) + + +def test_chunked_forward_skips_short_sequences() -> None: + dim = 16 + ff = _FakeFeedForward(dim) + x = torch.randn(1, 50, dim) # too short to chunk + + original_out = ff(x) + chunked_fn = _make_chunked_forward(ff.forward, num_chunks=4) + chunked_out = chunked_fn(x) + + assert torch.allclose(original_out, chunked_out, atol=1e-5) + + +def test_patch_ffn_chunking_patches_correct_modules() -> None: + model = _FakeTransformer(dim=16, num_blocks=3) + count = patch_ffn_chunking(model, num_chunks=4) + assert count == 6 # 3 blocks x 2 (ff + audio_ff) + + +def test_patch_ffn_chunking_zero_when_no_match() -> None: + model = torch.nn.Linear(16, 16) + count = patch_ffn_chunking(model, num_chunks=4) + assert count == 0 diff --git a/backend/tests/test_flux_klein.py b/backend/tests/test_flux_klein.py new file mode 100644 index 00000000..5634c516 --- /dev/null +++ b/backend/tests/test_flux_klein.py @@ -0,0 +1,284 @@ +"""Integration tests for FLUX.2 Klein 9B image generation pipeline.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from runtime_config.model_download_specs import DEFAULT_MODEL_DOWNLOAD_SPECS + + +def _create_flux_klein_model_files(test_state) -> None: + """Create fake FLUX Klein model directory so pipelines_handler finds it.""" + flux_dir = test_state.config.model_path("flux_klein") + flux_dir.mkdir(parents=True, exist_ok=True) + (flux_dir / "model_index.json").write_bytes(b"{}") + (flux_dir / "model.safetensors").write_bytes(b"\x00" * 1024) + + +def _create_zit_model_files(test_state) -> None: + """Create fake ZIT model directory and set it as the active image model.""" + zit_dir = test_state.config.model_path("zit") + zit_dir.mkdir(parents=True, exist_ok=True) + (zit_dir / "model.safetensors").write_bytes(b"\x00" * 1024) + test_state.state.app_settings.image_model = "z-image-turbo" + + +# ── 1. Model Registration ───────────────────────────────────────────── + + +class TestFluxKleinModelRegistration: + def test_flux_klein_in_download_specs(self): + assert "flux_klein" in DEFAULT_MODEL_DOWNLOAD_SPECS + + def test_flux_klein_spec_repo_id(self): + spec = DEFAULT_MODEL_DOWNLOAD_SPECS["flux_klein"] + assert spec.repo_id == "black-forest-labs/FLUX.2-klein-base-9B" + + def test_flux_klein_spec_is_folder(self): + spec = DEFAULT_MODEL_DOWNLOAD_SPECS["flux_klein"] + assert spec.is_folder is True + + def test_flux_klein_in_available_files(self, test_state): + assert "flux_klein" in test_state.state.available_files + + def test_flux_klein_model_path_resolves(self, test_state): + path = test_state.config.model_path("flux_klein") + assert path.name == "FLUX.2-klein-base-9B" + + +# ── 2. Pipeline Handler: Model Routing ──────────────────────────────── + + +class TestFluxKleinPipelineRouting: + def test_load_image_model_zit_by_default(self, test_state, fake_services): + _create_zit_model_files(test_state) + pipeline = test_state.pipelines.load_image_model_to_gpu("zit") + # Should return the ZIT fake singleton + assert pipeline is fake_services.image_generation_pipeline + + def test_load_image_model_flux_klein_not_configured(self, test_state): + """FLUX Klein pipeline class is None in tests, so loading should raise.""" + _create_flux_klein_model_files(test_state) + with pytest.raises(RuntimeError, match="FLUX.2 Klein pipeline class not configured"): + test_state.pipelines.load_image_model_to_gpu("flux-klein-9b") + + def test_load_image_model_flux_klein_alias(self, test_state): + """Both 'flux-klein-9b' and 'flux_klein' should route to FLUX loader.""" + _create_flux_klein_model_files(test_state) + with pytest.raises(RuntimeError, match="FLUX.2 Klein pipeline class not configured"): + test_state.pipelines.load_image_model_to_gpu("flux_klein") + + +# ── 3. Image Generation: ZIT Default Path ───────────────────────────── + + +class TestFluxKleinImageGeneration: + def test_generate_image_zit_explicit(self, client, fake_services, test_state): + """Explicitly selecting z-image-turbo uses ZIT pipeline.""" + _create_zit_model_files(test_state) + r = client.post( + "/api/generate-image", + json={"prompt": "A cat", "width": 1024, "height": 1024, "numSteps": 4}, + ) + assert r.status_code == 200 + data = r.json() + assert data["status"] == "complete" + assert len(data["image_paths"]) == 1 + assert len(fake_services.image_generation_pipeline.generate_calls) == 1 + + def test_generate_image_flux_klein_setting_uses_flux_path(self, client, test_state): + """Setting image_model=flux-klein-9b routes to FLUX handler (fails without class).""" + _create_flux_klein_model_files(test_state) + test_state.state.app_settings.image_model = "flux-klein-9b" + + r = client.post( + "/api/generate-image", + json={"prompt": "A cat", "width": 1024, "height": 1024, "numSteps": 50}, + ) + # Should fail because flux_klein_pipeline_class is None in tests + assert r.status_code == 500 + + def test_generate_image_guidance_scale_for_flux(self, test_state, fake_services): + """Verify guidance_scale=4.0 is used for FLUX models in generate_image.""" + _create_zit_model_files(test_state) + # Use ZIT but check the image_model routing produces correct guidance + handler = test_state.image_generation + handler._pipelines.load_image_model_to_gpu("zit") + test_state.generation.start_generation("test-id") + paths = handler.generate_image( + prompt="test", + width=512, + height=512, + num_inference_steps=4, + seed=42, + num_images=1, + image_model="z-image-turbo", + ) + # ZIT uses guidance_scale=0.0 + call = fake_services.image_generation_pipeline.generate_calls[0] + assert call["guidance_scale"] == 0.0 + assert len(paths) == 1 + + +# ── 4. Dimension and Clamp Tests ────────────────────────────────────── + + +class TestFluxKleinDimensionClamping: + def test_dimensions_rounded_to_16(self, client, test_state, fake_services): + _create_zit_model_files(test_state) + r = client.post( + "/api/generate-image", + json={"prompt": "test", "width": 1023, "height": 1023}, + ) + assert r.status_code == 200 + call = fake_services.image_generation_pipeline.generate_calls[0] + assert call["width"] % 16 == 0 + assert call["height"] % 16 == 0 + + def test_num_images_clamped_to_12(self, client, test_state, fake_services): + _create_zit_model_files(test_state) + r = client.post( + "/api/generate-image", + json={"prompt": "test", "numImages": 20}, + ) + assert r.status_code == 200 + assert len(fake_services.image_generation_pipeline.generate_calls) == 12 + + +# ── 5. Model Label Routing ──────────────────────────────────────────── + + +class TestFluxKleinModelLabeling: + def test_zit_model_label_in_output_path(self, test_state, fake_services): + _create_zit_model_files(test_state) + handler = test_state.image_generation + handler._pipelines.load_image_model_to_gpu("zit") + test_state.generation.start_generation("label-test") + paths = handler.generate_image( + prompt="a dog", + width=512, + height=512, + num_inference_steps=4, + seed=42, + num_images=1, + image_model="z-image-turbo", + ) + # Output path should contain "zit" in the filename + assert "zit" in Path(paths[0]).name.lower() + + +# ── 6. Error Handling ───────────────────────────────────────────────── + + +class TestFluxKleinErrorHandling: + def test_flux_not_downloaded_returns_error(self, test_state): + """Requesting FLUX Klein when model not downloaded gives clear error.""" + with pytest.raises(RuntimeError, match="FLUX.2 Klein pipeline class not configured"): + test_state.pipelines.load_image_model_to_gpu("flux-klein-9b") + + def test_generation_error_returns_500(self, client, test_state, fake_services): + _create_zit_model_files(test_state) + fake_services.image_generation_pipeline.raise_on_generate = RuntimeError("GPU OOM") + r = client.post("/api/generate-image", json={"prompt": "test"}) + assert r.status_code == 500 + + def test_cancellation_returns_cancelled(self, client, test_state, fake_services): + _create_zit_model_files(test_state) + fake_services.image_generation_pipeline.raise_on_generate = RuntimeError("cancelled") + r = client.post("/api/generate-image", json={"prompt": "test"}) + assert r.status_code == 200 + assert r.json()["status"] == "cancelled" + + +# ── 7. Settings Integration ─────────────────────────────────────────── + + +class TestFluxKleinSettingsIntegration: + def test_image_model_setting_defaults_to_flux_klein(self, test_state): + assert test_state.state.app_settings.image_model == "flux-klein-9b" + + def test_image_model_setting_can_be_changed(self, test_state): + test_state.state.app_settings.image_model = "flux-klein-9b" + assert test_state.state.app_settings.image_model == "flux-klein-9b" + + def test_settings_endpoint_accepts_image_model(self, client, test_state): + r = client.post( + "/api/settings", + json={"imageModel": "flux-klein-9b"}, + ) + assert r.status_code == 200 + assert test_state.state.app_settings.image_model == "flux-klein-9b" + + +# ── 8. Seed Behavior ───────────────────────────────────────────────── + + +class TestFluxKleinSeedBehavior: + def test_locked_seed_used(self, client, test_state, fake_services): + _create_zit_model_files(test_state) + test_state.state.app_settings.seed_locked = True + test_state.state.app_settings.locked_seed = 12345 + r = client.post( + "/api/generate-image", + json={"prompt": "test", "width": 512, "height": 512, "numSteps": 4}, + ) + assert r.status_code == 200 + call = fake_services.image_generation_pipeline.generate_calls[0] + assert call["seed"] == 12345 + + def test_multiple_images_increment_seed(self, client, test_state, fake_services): + _create_zit_model_files(test_state) + test_state.state.app_settings.seed_locked = True + test_state.state.app_settings.locked_seed = 100 + r = client.post( + "/api/generate-image", + json={"prompt": "test", "numImages": 3}, + ) + assert r.status_code == 200 + seeds = [c["seed"] for c in fake_services.image_generation_pipeline.generate_calls] + assert seeds == [100, 101, 102] + + +# ── 9. Img2Img Routing ─────────────────────────────────────────────── + + +class TestFluxKleinImg2Img: + def test_source_image_routes_to_img2img(self, client, test_state, fake_services, make_test_image, tmp_path): + _create_zit_model_files(test_state) + img_buf = make_test_image(64, 64) + img_path = tmp_path / "source.png" + img_path.write_bytes(img_buf.read()) + + r = client.post( + "/api/generate-image", + json={ + "prompt": "A cat", + "width": 512, + "height": 512, + "numSteps": 4, + "sourceImagePath": str(img_path), + "strength": 0.7, + }, + ) + assert r.status_code == 200 + assert len(fake_services.image_generation_pipeline.img2img_calls) == 1 + assert len(fake_services.image_generation_pipeline.generate_calls) == 0 + + +# ── 10. Concurrent Generation Guard ────────────────────────────────── + + +class TestFluxKleinConcurrencyGuard: + def test_rejects_concurrent_generation(self, client, test_state, fake_services): + _create_zit_model_files(test_state) + # Load a pipeline to GPU so we can start a generation + test_state.pipelines.load_image_model_to_gpu("zit") + test_state.generation.start_generation("existing-gen") + + r = client.post( + "/api/generate-image", + json={"prompt": "test", "width": 512, "height": 512}, + ) + assert r.status_code == 409 diff --git a/backend/tests/test_gallery.py b/backend/tests/test_gallery.py new file mode 100644 index 00000000..cf21e35f --- /dev/null +++ b/backend/tests/test_gallery.py @@ -0,0 +1,142 @@ +"""Tests for /api/gallery/local endpoints.""" + +from __future__ import annotations + +from pathlib import Path + + +def _outputs_dir(test_state) -> Path: + return test_state.config.outputs_dir + + +def _create_files(outputs: Path, filenames: list[str]) -> None: + """Create dummy files in the outputs directory.""" + outputs.mkdir(parents=True, exist_ok=True) + for name in filenames: + (outputs / name).write_bytes(b"\x00" * 128) + + +class TestListLocalAssets: + def test_empty_gallery(self, client): + r = client.get("/api/gallery/local") + assert r.status_code == 200 + data = r.json() + assert data["items"] == [] + assert data["total"] == 0 + assert data["page"] == 1 + assert data["total_pages"] == 1 + + def test_list_with_files(self, client, test_state): + outputs = _outputs_dir(test_state) + _create_files(outputs, ["test_image.png", "test_video.mp4", "readme.txt"]) + + r = client.get("/api/gallery/local") + assert r.status_code == 200 + data = r.json() + # txt file should be excluded + assert data["total"] == 2 + filenames = {item["filename"] for item in data["items"]} + assert filenames == {"test_image.png", "test_video.mp4"} + + def test_filter_by_image(self, client, test_state): + outputs = _outputs_dir(test_state) + _create_files(outputs, ["photo.png", "photo.jpg", "clip.mp4"]) + + r = client.get("/api/gallery/local", params={"type": "image"}) + assert r.status_code == 200 + data = r.json() + assert data["total"] == 2 + types = {item["type"] for item in data["items"]} + assert types == {"image"} + + def test_filter_by_video(self, client, test_state): + outputs = _outputs_dir(test_state) + _create_files(outputs, ["photo.png", "clip.mp4", "clip2.webm"]) + + r = client.get("/api/gallery/local", params={"type": "video"}) + assert r.status_code == 200 + data = r.json() + assert data["total"] == 2 + types = {item["type"] for item in data["items"]} + assert types == {"video"} + + def test_pagination(self, client, test_state): + outputs = _outputs_dir(test_state) + _create_files(outputs, [f"img_{i:03d}.png" for i in range(5)]) + + r = client.get("/api/gallery/local", params={"page": 1, "per_page": 2}) + assert r.status_code == 200 + data = r.json() + assert data["total"] == 5 + assert data["page"] == 1 + assert data["per_page"] == 2 + assert data["total_pages"] == 3 + assert len(data["items"]) == 2 + + # Page 3 should have 1 item + r2 = client.get("/api/gallery/local", params={"page": 3, "per_page": 2}) + data2 = r2.json() + assert len(data2["items"]) == 1 + + def test_model_name_parsed_from_prefix(self, client, test_state): + outputs = _outputs_dir(test_state) + _create_files(outputs, ["zit_image_001.png", "api_image_002.jpg", "ltx_fast_003.mp4", "random_004.png"]) + + r = client.get("/api/gallery/local") + assert r.status_code == 200 + data = r.json() + by_filename = {item["filename"]: item["model_name"] for item in data["items"]} + assert by_filename["zit_image_001.png"] == "zit" + assert by_filename["api_image_002.jpg"] == "api" + assert by_filename["ltx_fast_003.mp4"] == "ltx-fast" + assert by_filename["random_004.png"] is None + + def test_asset_has_expected_fields(self, client, test_state): + outputs = _outputs_dir(test_state) + _create_files(outputs, ["sample.png"]) + + r = client.get("/api/gallery/local") + data = r.json() + item = data["items"][0] + assert "id" in item + assert item["filename"] == "sample.png" + assert item["type"] == "image" + assert item["size_bytes"] == 128 + assert "created_at" in item + assert "path" in item + assert "url" in item + + def test_subdirectories_ignored(self, client, test_state): + outputs = _outputs_dir(test_state) + _create_files(outputs, ["top.png"]) + subdir = outputs / "subdir" + subdir.mkdir() + (subdir / "nested.png").write_bytes(b"\x00" * 64) + + r = client.get("/api/gallery/local") + data = r.json() + assert data["total"] == 1 + assert data["items"][0]["filename"] == "top.png" + + +class TestDeleteLocalAsset: + def test_delete_existing_asset(self, client, test_state): + outputs = _outputs_dir(test_state) + _create_files(outputs, ["to_delete.png"]) + + # Get the asset ID first + r = client.get("/api/gallery/local") + asset_id = r.json()["items"][0]["id"] + + # Delete it + r2 = client.delete(f"/api/gallery/local/{asset_id}") + assert r2.status_code == 200 + assert r2.json()["status"] == "ok" + + # Verify it's gone + r3 = client.get("/api/gallery/local") + assert r3.json()["total"] == 0 + + def test_delete_nonexistent_asset(self, client): + r = client.delete("/api/gallery/local/nonexistent_id_1234") + assert r.status_code == 404 diff --git a/backend/tests/test_generation.py b/backend/tests/test_generation.py index c56ccabf..ca7903ee 100644 --- a/backend/tests/test_generation.py +++ b/backend/tests/test_generation.py @@ -1091,9 +1091,9 @@ def test_cancelled(self, client, fake_services, create_fake_model_files): class TestForcedApiGenerateImage: - def test_generate_image_routes_to_zit_api(self, client, test_state, fake_services): + def test_generate_image_routes_to_api(self, client, test_state, fake_services): test_state.config.force_api_generations = True - test_state.state.app_settings.fal_api_key = "fal-key" + test_state.state.app_settings.replicate_api_key = "rep-key" r = client.post( "/api/generate-image", @@ -1104,22 +1104,35 @@ def test_generate_image_routes_to_zit_api(self, client, test_state, fake_service data = r.json() assert data["status"] == "complete" assert len(data["image_paths"]) == 2 - assert len(fake_services.zit_api_client.text_to_image_calls) == 2 + assert len(fake_services.image_api_client.text_to_image_calls) == 2 assert len(fake_services.image_generation_pipeline.generate_calls) == 0 - def test_generate_image_missing_fal_key(self, client, test_state, fake_services): + def test_generate_image_passes_model(self, client, test_state, fake_services): test_state.config.force_api_generations = True - test_state.state.app_settings.fal_api_key = "" + test_state.state.app_settings.replicate_api_key = "rep-key" + test_state.state.app_settings.image_model = "nano-banana-2" + + r = client.post( + "/api/generate-image", + json={"prompt": "A cat", "width": 1024, "height": 1024, "numSteps": 4, "numImages": 1}, + ) + + assert r.status_code == 200 + assert fake_services.image_api_client.text_to_image_calls[0]["model"] == "nano-banana-2" + + def test_generate_image_missing_replicate_key(self, client, test_state, fake_services): + test_state.config.force_api_generations = True + test_state.state.app_settings.replicate_api_key = "" r = client.post("/api/generate-image", json={"prompt": "A cat"}) assert r.status_code == 500 - assert r.json()["error"] == "FAL_API_KEY_NOT_CONFIGURED" + assert r.json()["error"] == "REPLICATE_API_KEY_NOT_CONFIGURED" def test_generate_image_cancelled(self, client, test_state, fake_services): test_state.config.force_api_generations = True - test_state.state.app_settings.fal_api_key = "fal-key" - fake_services.zit_api_client.raise_on_text_to_image = RuntimeError("cancelled") + test_state.state.app_settings.replicate_api_key = "rep-key" + fake_services.image_api_client.raise_on_text_to_image = RuntimeError("cancelled") r = client.post("/api/generate-image", json={"prompt": "A cat"}) @@ -1243,3 +1256,15 @@ def test_local_encoding_skips_api(self, client, test_state, fake_services, creat assert r.status_code == 200 assert len(fake_services.text_encoder.encode_calls) == 0 + + +def test_generate_video_request_accepts_last_frame_path(): + from api_types import GenerateVideoRequest + req = GenerateVideoRequest(prompt="test", lastFramePath="/path/to/last.png") + assert req.lastFramePath == "/path/to/last.png" + + +def test_generate_video_request_last_frame_defaults_none(): + from api_types import GenerateVideoRequest + req = GenerateVideoRequest(prompt="test") + assert req.lastFramePath is None diff --git a/backend/tests/test_job_queue.py b/backend/tests/test_job_queue.py new file mode 100644 index 00000000..4b203bce --- /dev/null +++ b/backend/tests/test_job_queue.py @@ -0,0 +1,157 @@ +"""Tests for the persistent job queue.""" + +from __future__ import annotations + +from pathlib import Path + +from state.job_queue import JobQueue, QueueJob + + +def test_submit_job_assigns_id_and_status(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit( + job_type="video", + model="seedance-1.5-pro", + params={"prompt": "hello"}, + slot="api", + ) + assert job.id + assert job.status == "queued" + assert job.slot == "api" + assert job.progress == 0 + + +def test_get_all_jobs_returns_ordered(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + j1 = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + j2 = queue.submit(job_type="image", model="z-image-turbo", params={}, slot="gpu") + jobs = queue.get_all_jobs() + assert [j.id for j in jobs] == [j1.id, j2.id] + + +def test_next_queued_for_slot(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + queue.submit(job_type="video", model="seedance-1.5-pro", params={}, slot="api") + queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + + gpu_job = queue.next_queued_for_slot("gpu") + assert gpu_job is not None + assert gpu_job.slot == "gpu" + + api_job = queue.next_queued_for_slot("api") + assert api_job is not None + assert api_job.slot == "api" + + +def test_update_job_status(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + queue.update_job(job.id, status="running", progress=50, phase="inference") + updated = queue.get_job(job.id) + assert updated is not None + assert updated.status == "running" + assert updated.progress == 50 + assert updated.phase == "inference" + + +def test_cancel_job(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + queue.cancel_job(job.id) + updated = queue.get_job(job.id) + assert updated is not None + assert updated.status == "cancelled" + + +def test_clear_finished_jobs(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + j1 = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + j2 = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + queue.update_job(j1.id, status="complete") + queue.clear_finished() + remaining = queue.get_all_jobs() + assert len(remaining) == 1 + assert remaining[0].id == j2.id + + +def test_persistence_survives_reload(tmp_path: Path) -> None: + path = tmp_path / "queue.json" + queue1 = JobQueue(persistence_path=path) + job = queue1.submit(job_type="video", model="ltx-fast", params={"prompt": "test"}, slot="gpu") + + queue2 = JobQueue(persistence_path=path) + loaded = queue2.get_job(job.id) + assert loaded is not None + assert loaded.params == {"prompt": "test"} + + +def test_submit_job_with_batch_fields(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit( + job_type="image", + model="zit", + params={"prompt": "a cat"}, + slot="gpu", + batch_id="batch_001", + batch_index=3, + depends_on="job_abc", + tags=["batch:batch_001", "sweep:lora_weight"], + ) + assert job.batch_id == "batch_001" + assert job.batch_index == 3 + assert job.depends_on == "job_abc" + assert job.tags == ["batch:batch_001", "sweep:lora_weight"] + + # Verify persistence round-trip + queue2 = JobQueue(persistence_path=tmp_path / "queue.json") + loaded = queue2.get_job(job.id) + assert loaded is not None + assert loaded.batch_id == "batch_001" + assert loaded.batch_index == 3 + assert loaded.depends_on == "job_abc" + assert loaded.tags == ["batch:batch_001", "sweep:lora_weight"] + + +def test_jobs_for_batch(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b1", batch_index=0) + queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b1", batch_index=1) + queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b2", batch_index=0) + queue.submit(job_type="video", model="fast", params={}, slot="gpu") # No batch + + batch_jobs = queue.jobs_for_batch("b1") + assert len(batch_jobs) == 2 + assert all(j.batch_id == "b1" for j in batch_jobs) + assert [j.batch_index for j in batch_jobs] == [0, 1] + + +def test_active_batch_ids(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b1") + queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b2") + queue.submit(job_type="video", model="fast", params={}, slot="gpu") + + ids = queue.active_batch_ids() + assert set(ids) == {"b1", "b2"} + + +def test_active_batch_ids_excludes_fully_resolved(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b1") + queue.update_job(job.id, status="complete", result_paths=["/out.png"]) + + ids = queue.active_batch_ids() + assert ids == [] # b1 is fully resolved + + +def test_running_jobs_reset_to_queued_on_load(tmp_path: Path) -> None: + path = tmp_path / "queue.json" + queue1 = JobQueue(persistence_path=path) + job = queue1.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + queue1.update_job(job.id, status="running") + + queue2 = JobQueue(persistence_path=path) + loaded = queue2.get_job(job.id) + assert loaded is not None + assert loaded.status == "error" + assert loaded.error == "Interrupted by app restart" diff --git a/backend/tests/test_library.py b/backend/tests/test_library.py new file mode 100644 index 00000000..c919d3ef --- /dev/null +++ b/backend/tests/test_library.py @@ -0,0 +1,209 @@ +"""Tests for /api/library/* endpoints (characters, styles, references).""" + +from __future__ import annotations + + +class TestCharacters: + def test_list_empty(self, client): + r = client.get("/api/library/characters") + assert r.status_code == 200 + assert r.json()["characters"] == [] + + def test_create_character(self, client): + r = client.post("/api/library/characters", json={ + "name": "Alice", + "role": "protagonist", + "description": "A curious adventurer", + }) + assert r.status_code == 200 + data = r.json() + assert data["name"] == "Alice" + assert data["role"] == "protagonist" + assert data["description"] == "A curious adventurer" + assert data["reference_image_paths"] == [] + assert "id" in data + assert "created_at" in data + + def test_create_character_with_images(self, client): + r = client.post("/api/library/characters", json={ + "name": "Bob", + "role": "sidekick", + "description": "Loyal friend", + "reference_image_paths": ["/img/bob1.png", "/img/bob2.png"], + }) + assert r.status_code == 200 + assert r.json()["reference_image_paths"] == ["/img/bob1.png", "/img/bob2.png"] + + def test_list_after_create(self, client): + client.post("/api/library/characters", json={"name": "Alice"}) + client.post("/api/library/characters", json={"name": "Bob"}) + r = client.get("/api/library/characters") + assert r.status_code == 200 + names = [c["name"] for c in r.json()["characters"]] + assert "Alice" in names + assert "Bob" in names + + def test_update_character(self, client): + r = client.post("/api/library/characters", json={"name": "Alice", "role": "hero"}) + cid = r.json()["id"] + + r = client.put(f"/api/library/characters/{cid}", json={"role": "villain", "description": "Turned evil"}) + assert r.status_code == 200 + data = r.json() + assert data["name"] == "Alice" + assert data["role"] == "villain" + assert data["description"] == "Turned evil" + + def test_update_partial_preserves_fields(self, client): + r = client.post("/api/library/characters", json={ + "name": "Alice", + "role": "hero", + "description": "Brave", + }) + cid = r.json()["id"] + + r = client.put(f"/api/library/characters/{cid}", json={"role": "mentor"}) + assert r.status_code == 200 + data = r.json() + assert data["name"] == "Alice" + assert data["role"] == "mentor" + assert data["description"] == "Brave" + + def test_update_nonexistent_returns_404(self, client): + r = client.put("/api/library/characters/doesnotexist", json={"name": "X"}) + assert r.status_code == 404 + + def test_delete_character(self, client): + r = client.post("/api/library/characters", json={"name": "Alice"}) + cid = r.json()["id"] + + r = client.delete(f"/api/library/characters/{cid}") + assert r.status_code == 200 + + r = client.get("/api/library/characters") + assert len(r.json()["characters"]) == 0 + + def test_delete_nonexistent_returns_404(self, client): + r = client.delete("/api/library/characters/doesnotexist") + assert r.status_code == 404 + + def test_create_empty_name_returns_400(self, client): + r = client.post("/api/library/characters", json={"name": " "}) + assert r.status_code == 400 + + +class TestStyles: + def test_list_empty(self, client): + r = client.get("/api/library/styles") + assert r.status_code == 200 + assert r.json()["styles"] == [] + + def test_create_style(self, client): + r = client.post("/api/library/styles", json={ + "name": "Noir", + "description": "Dark and moody", + "reference_image_path": "/img/noir.png", + }) + assert r.status_code == 200 + data = r.json() + assert data["name"] == "Noir" + assert data["description"] == "Dark and moody" + assert data["reference_image_path"] == "/img/noir.png" + assert "id" in data + + def test_list_after_create(self, client): + client.post("/api/library/styles", json={"name": "Noir"}) + client.post("/api/library/styles", json={"name": "Pastel"}) + r = client.get("/api/library/styles") + names = [s["name"] for s in r.json()["styles"]] + assert "Noir" in names + assert "Pastel" in names + + def test_delete_style(self, client): + r = client.post("/api/library/styles", json={"name": "Noir"}) + sid = r.json()["id"] + + r = client.delete(f"/api/library/styles/{sid}") + assert r.status_code == 200 + + r = client.get("/api/library/styles") + assert len(r.json()["styles"]) == 0 + + def test_delete_nonexistent_returns_404(self, client): + r = client.delete("/api/library/styles/doesnotexist") + assert r.status_code == 404 + + def test_create_empty_name_returns_400(self, client): + r = client.post("/api/library/styles", json={"name": " "}) + assert r.status_code == 400 + + +class TestReferences: + def test_list_empty(self, client): + r = client.get("/api/library/references") + assert r.status_code == 200 + assert r.json()["references"] == [] + + def test_create_reference(self, client): + r = client.post("/api/library/references", json={ + "name": "City Park", + "category": "places", + "image_path": "/img/park.png", + }) + assert r.status_code == 200 + data = r.json() + assert data["name"] == "City Park" + assert data["category"] == "places" + assert data["image_path"] == "/img/park.png" + assert "id" in data + + def test_list_all_categories(self, client): + client.post("/api/library/references", json={"name": "Park", "category": "places"}) + client.post("/api/library/references", json={"name": "Sword", "category": "props"}) + client.post("/api/library/references", json={"name": "Hero", "category": "people"}) + + r = client.get("/api/library/references") + assert len(r.json()["references"]) == 3 + + def test_filter_by_category(self, client): + client.post("/api/library/references", json={"name": "Park", "category": "places"}) + client.post("/api/library/references", json={"name": "Sword", "category": "props"}) + client.post("/api/library/references", json={"name": "Beach", "category": "places"}) + + r = client.get("/api/library/references?category=places") + refs = r.json()["references"] + assert len(refs) == 2 + assert all(ref["category"] == "places" for ref in refs) + + r = client.get("/api/library/references?category=props") + refs = r.json()["references"] + assert len(refs) == 1 + assert refs[0]["name"] == "Sword" + + def test_filter_empty_category(self, client): + client.post("/api/library/references", json={"name": "Park", "category": "places"}) + + r = client.get("/api/library/references?category=people") + assert r.json()["references"] == [] + + def test_delete_reference(self, client): + r = client.post("/api/library/references", json={"name": "Park", "category": "places"}) + rid = r.json()["id"] + + r = client.delete(f"/api/library/references/{rid}") + assert r.status_code == 200 + + r = client.get("/api/library/references") + assert len(r.json()["references"]) == 0 + + def test_delete_nonexistent_returns_404(self, client): + r = client.delete("/api/library/references/doesnotexist") + assert r.status_code == 404 + + def test_create_empty_name_returns_400(self, client): + r = client.post("/api/library/references", json={"name": " ", "category": "other"}) + assert r.status_code == 400 + + def test_invalid_category_rejected(self, client): + r = client.post("/api/library/references", json={"name": "X", "category": "invalid"}) + assert r.status_code == 422 diff --git a/backend/tests/test_long_video.py b/backend/tests/test_long_video.py new file mode 100644 index 00000000..ec844e57 --- /dev/null +++ b/backend/tests/test_long_video.py @@ -0,0 +1,485 @@ +"""Tests for the long video pipeline (chain-extend + concat).""" + +from __future__ import annotations + +import os +import tempfile +from pathlib import Path + +from state.job_queue import JobQueue, QueueJob +from handlers.queue_worker import QueueWorker +from handlers.video_generation_handler import VideoGenerationHandler + + +# --------------------------------------------------------------------------- +# Category 1: Route registration — POST /api/generate/long +# --------------------------------------------------------------------------- + + +class TestLongVideoRoute: + def test_route_exists_and_validates_request(self, client): + """POST /api/generate/long should exist and validate body.""" + resp = client.post("/api/generate/long", json={ + "prompt": "A cinematic landscape", + "imagePath": "/nonexistent/image.png", + "targetDuration": 12, + }) + # Route exists (not 404/405). It may fail because models aren't loaded, + # but it should at least parse the request (not 422 validation error). + assert resp.status_code != 404 + assert resp.status_code != 405 + + def test_route_rejects_empty_prompt(self, client): + """Empty prompt should be rejected by validation.""" + resp = client.post("/api/generate/long", json={ + "prompt": "", + "imagePath": "/some/image.png", + "targetDuration": 12, + }) + assert resp.status_code == 422 + + def test_route_rejects_missing_image_path(self, client): + """imagePath is required.""" + resp = client.post("/api/generate/long", json={ + "prompt": "test", + "targetDuration": 12, + }) + assert resp.status_code == 422 + + +# --------------------------------------------------------------------------- +# Category 2: Queue submission — type: "long_video" +# --------------------------------------------------------------------------- + + +class TestLongVideoQueueSubmission: + def test_submit_long_video_job(self, client): + """long_video type should be accepted and routed to GPU slot.""" + resp = client.post("/api/queue/submit", json={ + "type": "long_video", + "model": "fast", + "params": { + "prompt": "Epic landscape", + "imagePath": "/fake/image.png", + "targetDuration": 20, + "segmentDuration": 4, + "resolution": "540p", + }, + }) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "queued" + assert "id" in data + + def test_long_video_job_appears_in_status(self, client): + """Submitted long_video job should appear in queue status.""" + client.post("/api/queue/submit", json={ + "type": "long_video", + "model": "fast", + "params": {"prompt": "test", "imagePath": "/img.png", "targetDuration": 16}, + }) + status = client.get("/api/queue/status") + jobs = status.json()["jobs"] + assert len(jobs) == 1 + assert jobs[0]["type"] == "long_video" + + def test_long_video_routes_to_gpu_slot(self, client): + """long_video should always route to GPU slot (local generation).""" + client.post("/api/queue/submit", json={ + "type": "long_video", + "model": "fast", + "params": {"prompt": "test", "imagePath": "/img.png", "targetDuration": 12}, + }) + status = client.get("/api/queue/status") + jobs = status.json()["jobs"] + assert jobs[0]["slot"] == "gpu" + + +# --------------------------------------------------------------------------- +# Category 3: Job executor dispatch — long_video type +# --------------------------------------------------------------------------- + + +class TestLongVideoExecutorDispatch: + def test_worker_dispatches_long_video_to_gpu_executor(self, tmp_path: Path): + """QueueWorker should route long_video jobs to GPU executor.""" + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit( + job_type="long_video", + model="fast", + params={ + "prompt": "test", + "imagePath": "/fake/img.png", + "targetDuration": 16, + "segmentDuration": 4, + }, + slot="gpu", + ) + + class FakeExecutor: + def __init__(self) -> None: + self.executed_jobs: list[QueueJob] = [] + + def execute(self, job: QueueJob) -> list[str]: + self.executed_jobs.append(job) + return ["/fake/output_long.mp4"] + + gpu_exec = FakeExecutor() + api_exec = FakeExecutor() + worker = QueueWorker(queue=queue, gpu_executor=gpu_exec, api_executor=api_exec) + worker.tick() + + assert len(gpu_exec.executed_jobs) == 1 + assert gpu_exec.executed_jobs[0].id == job.id + assert gpu_exec.executed_jobs[0].type == "long_video" + assert len(api_exec.executed_jobs) == 0 + + def test_executor_passes_long_video_params(self, tmp_path: Path): + """Executor should forward all long_video params correctly.""" + queue = JobQueue(persistence_path=tmp_path / "queue.json") + params = { + "prompt": "cinematic shot", + "imagePath": "/image.png", + "targetDuration": 20, + "segmentDuration": 4, + "resolution": "720p", + "aspectRatio": "16:9", + "fps": 24, + "cameraMotion": "dolly_in", + } + job = queue.submit(job_type="long_video", model="fast", params=params, slot="gpu") + + class CapturingExecutor: + def __init__(self) -> None: + self.captured_params: dict[str, object] = {} + + def execute(self, job: QueueJob) -> list[str]: + self.captured_params = dict(job.params) + return ["/out.mp4"] + + executor = CapturingExecutor() + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=executor) + worker.tick() + + assert executor.captured_params["prompt"] == "cinematic shot" + assert executor.captured_params["targetDuration"] == 20 + assert executor.captured_params["segmentDuration"] == 4 + assert executor.captured_params["cameraMotion"] == "dolly_in" + + +# --------------------------------------------------------------------------- +# Category 4: Segment calculation +# --------------------------------------------------------------------------- + + +class TestSegmentCalculation: + """Verify num_segments = ceil(target_duration / segment_duration).""" + + def test_exact_multiple(self): + """20s target / 4s segments = exactly 5.""" + n = max(1, (20 + 4 - 1) // 4) + assert n == 5 + + def test_remainder_rounds_up(self): + """10s target / 4s segments = 3 (not 2.5).""" + n = max(1, (10 + 4 - 1) // 4) + assert n == 3 + + def test_single_segment(self): + """4s target / 4s segments = 1.""" + n = max(1, (4 + 4 - 1) // 4) + assert n == 1 + + def test_very_long(self): + """60s target / 4s segments = 15.""" + n = max(1, (60 + 4 - 1) // 4) + assert n == 15 + + def test_short_target_below_segment(self): + """2s target / 4s segments = 1 (minimum 1).""" + n = max(1, (2 + 4 - 1) // 4) + assert n == 1 + + def test_12s_target(self): + """12s / 4s = exactly 3.""" + n = max(1, (12 + 4 - 1) // 4) + assert n == 3 + + +# --------------------------------------------------------------------------- +# Category 5: Concat logic — _concatenate_segments +# --------------------------------------------------------------------------- + + +class TestConcatenateSegments: + """Test the static _concatenate_segments method.""" + + def test_single_segment_copies_file(self, tmp_path: Path): + """Single segment should just copy the file.""" + seg = tmp_path / "segment_001.mp4" + seg.write_bytes(b"fake_video_content") + out = tmp_path / "final.mp4" + + VideoGenerationHandler._concatenate_segments( + segment_paths=[str(seg)], + output_path=str(out), + ffmpeg_path="unused", + fps=24, + ) + + assert out.exists() + assert out.read_bytes() == b"fake_video_content" + + def test_concat_file_format(self, tmp_path: Path): + """Multi-segment should write proper concat file with forward slashes.""" + # We can't run real ffmpeg in tests, but we can verify the concat file + # is written correctly by monkeypatching subprocess.run + import subprocess + calls: list[list[str]] = [] + + original_run = subprocess.run + + def fake_run(cmd: list[str], **kwargs: object) -> subprocess.CompletedProcess[bytes]: + calls.append(cmd) + # Read the concat file before it gets cleaned up + for i, arg in enumerate(cmd): + if arg == "-i" and i + 1 < len(cmd): + concat_path = cmd[i + 1] + content = Path(concat_path).read_text() + # Verify forward slashes + for line in content.strip().split("\n"): + assert "\\" not in line, f"Backslash found in concat file: {line}" + assert line.startswith("file '") + # Create the output file so the method succeeds + for i, arg in enumerate(cmd): + if arg == "-y" and len(cmd) > 1: + # The output is the last arg + Path(cmd[-1]).write_bytes(b"concatenated") + break + return subprocess.CompletedProcess(cmd, 0) + + subprocess.run = fake_run # type: ignore[assignment] + try: + seg1 = tmp_path / "seg1.mp4" + seg2 = tmp_path / "seg2.mp4" + seg1.write_bytes(b"video1") + seg2.write_bytes(b"video2") + out = tmp_path / "out.mp4" + + VideoGenerationHandler._concatenate_segments( + segment_paths=[str(seg1), str(seg2)], + output_path=str(out), + ffmpeg_path="ffmpeg", + fps=24, + ) + + assert len(calls) == 1 + assert "-f" in calls[0] + assert "concat" in calls[0] + assert "-c" in calls[0] + assert "copy" in calls[0] + finally: + subprocess.run = original_run # type: ignore[assignment] + + def test_concat_cleans_up_temp_file(self, tmp_path: Path): + """Concat file should be cleaned up even on success.""" + import subprocess + + created_files: list[str] = [] + + original_run = subprocess.run + + def fake_run(cmd: list[str], **kwargs: object) -> subprocess.CompletedProcess[bytes]: + for i, arg in enumerate(cmd): + if arg == "-i" and i + 1 < len(cmd): + created_files.append(cmd[i + 1]) + Path(cmd[-1]).write_bytes(b"result") + return subprocess.CompletedProcess(cmd, 0) + + subprocess.run = fake_run # type: ignore[assignment] + try: + seg1 = tmp_path / "a.mp4" + seg2 = tmp_path / "b.mp4" + seg1.write_bytes(b"x") + seg2.write_bytes(b"y") + + VideoGenerationHandler._concatenate_segments( + [str(seg1), str(seg2)], str(tmp_path / "out.mp4"), "ffmpeg", 24, + ) + + # The temp concat file should have been cleaned up + assert len(created_files) == 1 + assert not os.path.exists(created_files[0]) + finally: + subprocess.run = original_run # type: ignore[assignment] + + def test_windows_paths_converted_to_forward_slashes(self, tmp_path: Path): + """Backslashes in Windows paths should be converted to forward slashes.""" + import subprocess + + concat_content: list[str] = [] + + original_run = subprocess.run + + def fake_run(cmd: list[str], **kwargs: object) -> subprocess.CompletedProcess[bytes]: + for i, arg in enumerate(cmd): + if arg == "-i" and i + 1 < len(cmd): + concat_content.append(Path(cmd[i + 1]).read_text()) + Path(cmd[-1]).write_bytes(b"result") + return subprocess.CompletedProcess(cmd, 0) + + subprocess.run = fake_run # type: ignore[assignment] + try: + seg1 = tmp_path / "s1.mp4" + seg2 = tmp_path / "s2.mp4" + seg1.write_bytes(b"x") + seg2.write_bytes(b"y") + + # Simulate Windows-style paths + win_paths = [ + str(seg1).replace("/", "\\"), + str(seg2).replace("/", "\\"), + ] + + VideoGenerationHandler._concatenate_segments( + win_paths, str(tmp_path / "out.mp4"), "ffmpeg", 24, + ) + + assert len(concat_content) == 1 + for line in concat_content[0].strip().split("\n"): + assert "\\" not in line + finally: + subprocess.run = original_run # type: ignore[assignment] + + +# --------------------------------------------------------------------------- +# Category 6: Phase reporting — generating_segment, concatenating +# --------------------------------------------------------------------------- + + +class TestLongVideoPhaseReporting: + """Verify that phase names match what the frontend expects.""" + + def test_phase_names_are_valid_strings(self): + """The phase names used in generate_long_video must match frontend expectations.""" + # These are the phases reported by generate_long_video + expected_phases = {"generating_segment", "concatenating"} + # We can't run generate_long_video in tests (no GPU), but we can + # verify the constants by inspecting the source + import inspect + source = inspect.getsource(VideoGenerationHandler.generate_long_video) + for phase in expected_phases: + assert f'"{phase}"' in source, f"Phase '{phase}' not found in generate_long_video source" + + +# --------------------------------------------------------------------------- +# Category 7: Frontend submission — long_video type detection +# (These test the logic extracted from use-generation.ts) +# --------------------------------------------------------------------------- + + +class TestFrontendLongVideoLogic: + """Test the frontend submission logic (mirrored in Python for validation).""" + + @staticmethod + def _should_use_long_video( + duration: int, has_image: bool, has_audio: bool, has_last_frame: bool, + ) -> bool: + """Mirror of the frontend logic: useLongVideo = duration > 8 && imagePath && !audioPath && !lastFramePath""" + return duration > 8 and has_image and not has_audio and not has_last_frame + + def test_long_video_with_image_over_8s(self): + assert self._should_use_long_video(12, has_image=True, has_audio=False, has_last_frame=False) + + def test_regular_video_at_8s(self): + """8s should NOT trigger long video (must be > 8).""" + assert not self._should_use_long_video(8, has_image=True, has_audio=False, has_last_frame=False) + + def test_regular_video_at_4s(self): + assert not self._should_use_long_video(4, has_image=True, has_audio=False, has_last_frame=False) + + def test_no_image_stays_regular(self): + """Without an image, can't use long_video (needs I2V seed segment).""" + assert not self._should_use_long_video(20, has_image=False, has_audio=False, has_last_frame=False) + + def test_with_audio_stays_regular(self): + """Audio path present means A2V, not long video.""" + assert not self._should_use_long_video(20, has_image=True, has_audio=True, has_last_frame=False) + + def test_with_last_frame_stays_regular(self): + """Last frame means extend, not long video.""" + assert not self._should_use_long_video(20, has_image=True, has_audio=False, has_last_frame=True) + + def test_30s_long_video(self): + assert self._should_use_long_video(30, has_image=True, has_audio=False, has_last_frame=False) + + def test_60s_long_video(self): + assert self._should_use_long_video(60, has_image=True, has_audio=False, has_last_frame=False) + + +# --------------------------------------------------------------------------- +# Category 8: Phase message mapping +# --------------------------------------------------------------------------- + + +class TestPhaseMessageMapping: + """Verify frontend getPhaseMessage covers all long video phases. + + We can't import TS, but we validate the source file contains the mappings. + """ + + def test_phase_messages_exist_in_frontend(self): + """use-generation.ts must contain phase messages for long video phases.""" + frontend_path = Path(__file__).parent.parent.parent / "frontend" / "hooks" / "use-generation.ts" + if not frontend_path.exists(): + # Skip if frontend not available (CI without frontend) + return + + source = frontend_path.read_text() + required_phases = [ + "generating_segment", + "extracting_frame", + "concatenating", + ] + for phase in required_phases: + assert f"'{phase}'" in source, ( + f"Phase '{phase}' not found in use-generation.ts getPhaseMessage()" + ) + + +# --------------------------------------------------------------------------- +# Category 9: Duration options consistency +# --------------------------------------------------------------------------- + + +class TestDurationOptionsConsistency: + """Verify GenSpace.tsx and SettingsPanel.tsx have matching duration configs.""" + + def test_duration_options_match(self): + """Both files must have the same duration options and LOCAL_MAX_DURATION.""" + frontend_dir = Path(__file__).parent.parent.parent / "frontend" + genspace = frontend_dir / "views" / "GenSpace.tsx" + settings_panel = frontend_dir / "components" / "SettingsPanel.tsx" + + if not genspace.exists() or not settings_panel.exists(): + return # Skip in CI + + gs_src = genspace.read_text() + sp_src = settings_panel.read_text() + + # Both should have the same duration array + duration_pattern = "[4, 5, 6, 8, 10, 12, 16, 20, 30, 60]" + assert duration_pattern in gs_src, "GenSpace.tsx missing updated duration options" + assert duration_pattern in sp_src, "SettingsPanel.tsx missing updated duration options" + + # Both should have 540p: 60 + assert "'540p': 60" in gs_src, "GenSpace.tsx LOCAL_MAX_DURATION for 540p should be 60" + assert "'540p': 60" in sp_src, "SettingsPanel.tsx LOCAL_MAX_DURATION for 540p should be 60" + + # Both should have 720p: 10 + assert "'720p': 10" in gs_src + assert "'720p': 10" in sp_src + + # Both should have 1080p: 5 + assert "'1080p': 5" in gs_src + assert "'1080p': 5" in sp_src diff --git a/backend/tests/test_model_guide.py b/backend/tests/test_model_guide.py new file mode 100644 index 00000000..b54933a2 --- /dev/null +++ b/backend/tests/test_model_guide.py @@ -0,0 +1,47 @@ +"""Tests for model guide recommendation logic.""" + +from __future__ import annotations + +from services.model_scanner.model_guide_data import MODEL_FORMATS, DISTILLED_LORA_INFO, recommend_format + + +class TestRecommendFormat: + def test_48gb_recommends_bf16(self) -> None: + assert recommend_format(48) == "bf16" + + def test_32gb_recommends_bf16(self) -> None: + assert recommend_format(32) == "bf16" + + def test_24gb_recommends_fp8(self) -> None: + assert recommend_format(24) == "fp8" + + def test_20gb_recommends_fp8(self) -> None: + assert recommend_format(20) == "fp8" + + def test_16gb_recommends_gguf_q5k(self) -> None: + assert recommend_format(16) == "gguf_q5k" + + def test_12gb_recommends_gguf_q4k(self) -> None: + assert recommend_format(12) == "gguf_q4k" + + def test_10gb_recommends_gguf_q4k(self) -> None: + assert recommend_format(10) == "gguf_q4k" + + def test_8gb_recommends_api_only(self) -> None: + assert recommend_format(8) == "api_only" + + def test_none_vram_defaults_to_bf16(self) -> None: + assert recommend_format(None) == "bf16" + + +class TestModelFormatsData: + def test_all_formats_have_required_fields(self) -> None: + for fmt in MODEL_FORMATS: + assert fmt.id + assert fmt.name + assert fmt.min_vram_gb > 0 + assert fmt.download_url.startswith("https://") + + def test_distilled_lora_info_has_url(self) -> None: + assert DISTILLED_LORA_INFO.download_url.startswith("https://") + assert DISTILLED_LORA_INFO.size_gb > 0 diff --git a/backend/tests/test_model_scanner.py b/backend/tests/test_model_scanner.py new file mode 100644 index 00000000..e9d996dd --- /dev/null +++ b/backend/tests/test_model_scanner.py @@ -0,0 +1,145 @@ +"""Tests for ModelScannerImpl.""" + +from __future__ import annotations + +import json +import struct +from pathlib import Path + +import pytest + +from services.model_scanner.model_scanner_impl import ModelScannerImpl + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _write_minimal_gguf(path: Path, version: int = 3) -> None: + """Write a minimal valid GGUF file (magic + version only).""" + path.write_bytes(b"GGUF" + struct.pack(" None: + """Write a minimal valid safetensors file (8-byte header length + empty JSON header).""" + header = b"{}" + header_len = struct.pack(" ModelScannerImpl: + return ModelScannerImpl() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_empty_folder_returns_empty_list(scanner: ModelScannerImpl, tmp_path: Path) -> None: + result = scanner.scan_video_models(tmp_path) + assert result == [] + + +def test_nonexistent_folder_returns_empty_list(scanner: ModelScannerImpl, tmp_path: Path) -> None: + missing = tmp_path / "does_not_exist" + result = scanner.scan_video_models(missing) + assert result == [] + + +def test_detects_gguf_file(scanner: ModelScannerImpl, tmp_path: Path) -> None: + gguf_path = tmp_path / "ltx-video-Q5_K_M.gguf" + _write_minimal_gguf(gguf_path) + + result = scanner.scan_video_models(tmp_path) + + assert len(result) == 1 + model = result[0] + assert model.model_format == "gguf" + assert model.quant_type == "Q5_K_M" + assert model.filename == "ltx-video-Q5_K_M.gguf" + assert model.path == str(gguf_path) + assert model.size_bytes == gguf_path.stat().st_size + assert model.is_distilled is False + + +def test_detects_safetensors_file(scanner: ModelScannerImpl, tmp_path: Path) -> None: + st_path = tmp_path / "ltx-video.safetensors" + _write_minimal_safetensors(st_path) + + result = scanner.scan_video_models(tmp_path) + + assert len(result) == 1 + model = result[0] + assert model.model_format == "bf16" + assert model.filename == "ltx-video.safetensors" + assert model.path == str(st_path) + assert model.size_bytes == st_path.stat().st_size + assert model.is_distilled is False + + +def test_detects_nf4_folder(scanner: ModelScannerImpl, tmp_path: Path) -> None: + nf4_dir = tmp_path / "ltx-video-nf4" + nf4_dir.mkdir() + (nf4_dir / "quantize_config.json").write_text( + json.dumps({"quant_type": "nf4", "bits": 4}), encoding="utf-8" + ) + (nf4_dir / "model.safetensors").write_bytes(b"\x00" * 1024) + + result = scanner.scan_video_models(tmp_path) + + assert len(result) == 1 + model = result[0] + assert model.model_format == "nf4" + assert model.quant_type == "nf4" + assert model.filename == "ltx-video-nf4" + assert model.path == str(nf4_dir) + assert model.size_bytes > 0 + assert model.is_distilled is False + + +def test_skips_corrupt_gguf_file(scanner: ModelScannerImpl, tmp_path: Path) -> None: + corrupt = tmp_path / "corrupt.gguf" + corrupt.write_bytes(b"NOTGGUF") + + result = scanner.scan_video_models(tmp_path) + assert result == [] + + +def test_skips_non_model_files(scanner: ModelScannerImpl, tmp_path: Path) -> None: + (tmp_path / "readme.txt").write_text("hello", encoding="utf-8") + (tmp_path / "config.json").write_text("{}", encoding="utf-8") + (tmp_path / "image.png").write_bytes(b"\x89PNG\r\n") + + result = scanner.scan_video_models(tmp_path) + assert result == [] + + +def test_multiple_models_in_folder(scanner: ModelScannerImpl, tmp_path: Path) -> None: + # GGUF + gguf_path = tmp_path / "ltx-Q8_0.gguf" + _write_minimal_gguf(gguf_path) + + # Safetensors + st_path = tmp_path / "ltx-bf16.safetensors" + _write_minimal_safetensors(st_path) + + # NF4 folder + nf4_dir = tmp_path / "ltx-nf4" + nf4_dir.mkdir() + (nf4_dir / "quantize_config.json").write_text( + json.dumps({"quant_type": "nf4"}), encoding="utf-8" + ) + + result = scanner.scan_video_models(tmp_path) + + assert len(result) == 3 + formats = {m.model_format for m in result} + assert formats == {"gguf", "bf16", "nf4"} diff --git a/backend/tests/test_model_selection.py b/backend/tests/test_model_selection.py new file mode 100644 index 00000000..276f1632 --- /dev/null +++ b/backend/tests/test_model_selection.py @@ -0,0 +1,54 @@ +"""Integration tests for video model scan, select, and guide endpoints.""" + +from __future__ import annotations + +import struct +from pathlib import Path + +from app_handler import AppHandler + + +def _write_minimal_gguf(path: Path) -> None: + with open(path, "wb") as f: + f.write(b"GGUF") + f.write(struct.pack(" None: + resp = client.get("/api/models/video/scan") + assert resp.status_code == 200 + data = resp.json() + assert data["models"] == [] + assert data["distilled_lora_found"] is False + + +class TestVideoModelSelect: + def test_select_nonexistent_model_returns_400(self, client) -> None: + resp = client.post("/api/models/video/select", json={"model": "nonexistent.gguf"}) + assert resp.status_code == 400 + + def test_select_valid_model_updates_settings(self, client, test_state: AppHandler) -> None: + models_dir = test_state.config.models_dir + gguf_path = models_dir / "test-model.gguf" + _write_minimal_gguf(gguf_path) + + client.post("/api/settings", json={"customVideoModelPath": str(models_dir)}) + + resp = client.post("/api/models/video/select", json={"model": "test-model.gguf"}) + assert resp.status_code == 200 + + assert test_state.state.app_settings.selected_video_model == "test-model.gguf" + + +class TestVideoModelGuide: + def test_guide_returns_formats_and_recommendation(self, client) -> None: + resp = client.get("/api/models/video/guide") + assert resp.status_code == 200 + data = resp.json() + assert "formats" in data + assert "recommended_format" in data + assert "distilled_lora" in data + assert len(data["formats"]) > 0 diff --git a/backend/tests/test_models.py b/backend/tests/test_models.py index 392c3e5e..2d380307 100644 --- a/backend/tests/test_models.py +++ b/backend/tests/test_models.py @@ -31,7 +31,7 @@ def test_nothing_downloaded(self, client): assert r.json()["all_downloaded"] is False def test_all_downloaded(self, client, create_fake_model_files): - create_fake_model_files(include_zit=True) + create_fake_model_files(include_flux_klein=True) r = client.get("/api/models/status") assert r.json()["all_downloaded"] is True @@ -68,7 +68,7 @@ def test_active(self, client, test_state): progress=0.5, downloaded_bytes=5_000_000_000, total_bytes=10_000_000_000, - speed_mbps=50, + speed_bytes_per_sec=50_000_000.0, ) } r = client.get("/api/models/download/progress") @@ -249,9 +249,10 @@ def test_failed_download_cleans_up_downloading_dir(self, test_state): class TestHuggingFaceInternals: """Guard tests for huggingface_hub internals we rely on. - We monkey-patch ``file_download.http_get`` to inject a custom tqdm bar - for progress tracking during ``hf_hub_download`` (which has no public - ``tqdm_class`` parameter, unlike ``snapshot_download``). + We monkey-patch ``file_download.http_get`` and ``file_download.xet_get`` + to inject a custom tqdm bar for progress tracking during + ``hf_hub_download`` (which has no public ``tqdm_class`` parameter, + unlike ``snapshot_download``). If these tests break after a huggingface_hub upgrade, the internal API has changed. Find an alternative approach and raise to a developer. @@ -268,3 +269,12 @@ def test_http_get_accepts_tqdm_bar(self): assert "_tqdm_bar" in sig.parameters, ( "file_download.http_get no longer accepts _tqdm_bar — progress patch for hf_hub_download is broken" ) + + def test_xet_get_accepts_tqdm_bar(self): + xet_get = getattr(file_download, "xet_get", None) + if xet_get is None: + return # xet_get not present in this version; patch skips it gracefully + sig = inspect.signature(xet_get) + assert "_tqdm_bar" in sig.parameters, ( + "file_download.xet_get no longer accepts _tqdm_bar — progress patch for xet downloads is broken" + ) diff --git a/backend/tests/test_palette_sync.py b/backend/tests/test_palette_sync.py new file mode 100644 index 00000000..83afec8b --- /dev/null +++ b/backend/tests/test_palette_sync.py @@ -0,0 +1,106 @@ +"""Integration tests for Palette LoRA sync.""" + +from __future__ import annotations + +from starlette.testclient import TestClient +from tests.fakes.services import FakeResponse + + +class TestSyncLorasNoConnection: + def test_sync_loras_without_api_key_returns_not_connected( + self, client: TestClient, test_state, + ) -> None: + # Ensure no API key is set + test_state.state.app_settings.palette_api_key = "" + resp = client.post("/api/sync/library/sync-loras") + assert resp.status_code == 200 + data = resp.json() + assert data["connected"] is False + + +class TestSyncLorasWithConnection: + def test_sync_loras_empty_catalog( + self, client: TestClient, test_state, fake_services, + ) -> None: + test_state.state.app_settings.palette_api_key = "test-jwt-token" + # FakePaletteSyncClient.list_loras returns {"loras": []} by default + resp = client.post("/api/sync/library/sync-loras") + assert resp.status_code == 200 + data = resp.json() + assert data["connected"] is True + assert data["synced"] == 0 + assert data["skipped"] == 0 + + def test_sync_loras_downloads_new_lora( + self, client: TestClient, test_state, fake_services, + ) -> None: + test_state.state.app_settings.palette_api_key = "test-jwt-token" + + # Override list_loras to return a LoRA with a known download URL + original_list = fake_services.palette_sync_client.list_loras + def mock_list_loras(*, api_key: str) -> dict: + return {"loras": [{ + "id": "dcau-k9b", + "name": "DCAU", + "type": "style", + "trigger_word": "DC animation style", + "thumbnail_url": "", + "compatible_models": ["flux-2-klein-9b"], + }]} + fake_services.palette_sync_client.list_loras = mock_list_loras + + # Queue a fake HTTP response for the LoRA download + fake_services.http.queue("get", FakeResponse( + status_code=200, + content=b"\x00" * 1024, # Fake safetensors content + )) + + resp = client.post("/api/sync/library/sync-loras") + assert resp.status_code == 200 + data = resp.json() + assert data["connected"] is True + assert data["synced"] == 1 + + # Verify the LoRA was registered in the catalog + entries = test_state.lora._store.list_all() + palette_entries = [e for e in entries if e.id.startswith("palette:")] + assert len(palette_entries) == 1 + assert palette_entries[0].name == "[Palette] DCAU" + assert palette_entries[0].trigger_phrase == "DC animation style" + + # Restore + fake_services.palette_sync_client.list_loras = original_list + + def test_sync_loras_skips_already_synced( + self, client: TestClient, test_state, fake_services, + ) -> None: + test_state.state.app_settings.palette_api_key = "test-jwt-token" + + # Pre-register a Palette LoRA in the catalog + from state.lora_library import LoraEntry + test_state.lora._store.add(LoraEntry( + id="palette:dcau-k9b", + name="[Palette] DCAU", + file_path="/fake/path.safetensors", + file_size_bytes=1024, + )) + + # Override list_loras to return the same LoRA + original_list = fake_services.palette_sync_client.list_loras + def mock_list_loras(*, api_key: str) -> dict: + return {"loras": [{ + "id": "dcau-k9b", + "name": "DCAU", + "type": "style", + "trigger_word": "DC animation style", + "thumbnail_url": "", + }]} + fake_services.palette_sync_client.list_loras = mock_list_loras + + resp = client.post("/api/sync/library/sync-loras") + assert resp.status_code == 200 + data = resp.json() + assert data["skipped"] == 1 + assert data["synced"] == 0 + + fake_services.palette_sync_client.list_loras = original_list diff --git a/backend/tests/test_prompts.py b/backend/tests/test_prompts.py new file mode 100644 index 00000000..19bb2686 --- /dev/null +++ b/backend/tests/test_prompts.py @@ -0,0 +1,296 @@ +"""Tests for prompt library and wildcard endpoints.""" + +from __future__ import annotations + +import random + +from services.wildcard_parser import WildcardDef, expand_prompt, expand_random + + +# ============================================================ +# Unit tests for wildcard_parser +# ============================================================ + + +class TestExpandPrompt: + def test_single_wildcard(self): + result = expand_prompt( + "A _color_ car", + [WildcardDef("color", ["red", "blue"])], + ) + assert sorted(result) == ["A blue car", "A red car"] + + def test_two_wildcards_cartesian_product(self): + result = expand_prompt( + "A _color_ _animal_", + [ + WildcardDef("color", ["red", "blue"]), + WildcardDef("animal", ["cat", "dog"]), + ], + ) + assert sorted(result) == [ + "A blue cat", + "A blue dog", + "A red cat", + "A red dog", + ] + + def test_no_wildcards_returns_original(self): + result = expand_prompt("plain prompt", []) + assert result == ["plain prompt"] + + def test_undefined_wildcard_kept_literal(self): + result = expand_prompt("A _missing_ thing", []) + assert result == ["A _missing_ thing"] + + def test_nested_wildcards(self): + """A wildcard value itself contains another wildcard reference.""" + result = expand_prompt( + "I see a _thing_", + [ + WildcardDef("thing", ["_color_ ball"]), + WildcardDef("color", ["red", "green"]), + ], + ) + assert sorted(result) == ["I see a green ball", "I see a red ball"] + + def test_repeated_wildcard(self): + """Same wildcard used twice in a prompt expands to same value in each slot.""" + result = expand_prompt( + "_color_ and _color_", + [WildcardDef("color", ["red", "blue"])], + ) + # Both slots get replaced by same value since replace is global + assert sorted(result) == ["blue and blue", "red and red"] + + +class TestExpandRandom: + def test_returns_requested_count(self): + result = expand_random( + "A _color_ car", + [WildcardDef("color", ["red", "blue", "green"])], + count=5, + rng=random.Random(42), + ) + assert len(result) == 5 + for r in result: + assert "_color_" not in r + + def test_single_result_default(self): + result = expand_random( + "A _color_ _animal_", + [ + WildcardDef("color", ["red"]), + WildcardDef("animal", ["cat"]), + ], + rng=random.Random(0), + ) + assert result == ["A red cat"] + + def test_no_wildcards_returns_original(self): + result = expand_random("no wildcards here", [], count=3) + assert result == ["no wildcards here"] * 3 + + def test_nested_random(self): + result = expand_random( + "A _thing_", + [ + WildcardDef("thing", ["_color_ ball"]), + WildcardDef("color", ["red"]), + ], + count=1, + rng=random.Random(0), + ) + assert result == ["A red ball"] + + +# ============================================================ +# Integration tests for routes +# ============================================================ + + +class TestPromptCRUD: + def test_save_and_list(self, client): + r = client.post("/api/prompts", json={"text": "a sunset over mountains", "tags": ["nature"], "category": "landscape"}) + assert r.status_code == 200 + data = r.json() + assert data["text"] == "a sunset over mountains" + assert data["tags"] == ["nature"] + assert data["category"] == "landscape" + assert data["used_count"] == 0 + prompt_id = data["id"] + + r = client.get("/api/prompts") + assert r.status_code == 200 + prompts = r.json()["prompts"] + assert len(prompts) == 1 + assert prompts[0]["id"] == prompt_id + + def test_delete_prompt(self, client): + r = client.post("/api/prompts", json={"text": "to delete"}) + prompt_id = r.json()["id"] + + r = client.delete(f"/api/prompts/{prompt_id}") + assert r.status_code == 200 + + r = client.get("/api/prompts") + assert len(r.json()["prompts"]) == 0 + + def test_delete_nonexistent_returns_404(self, client): + r = client.delete("/api/prompts/nonexistent") + assert r.status_code == 404 + + def test_save_prompt_minimal(self, client): + """Save with only required field (text).""" + r = client.post("/api/prompts", json={"text": "minimal prompt"}) + assert r.status_code == 200 + data = r.json() + assert data["tags"] == [] + assert data["category"] == "" + + +class TestPromptSearch: + def test_search_by_text(self, client): + client.post("/api/prompts", json={"text": "a red fox in snow"}) + client.post("/api/prompts", json={"text": "blue ocean waves"}) + + r = client.get("/api/prompts", params={"search": "fox"}) + prompts = r.json()["prompts"] + assert len(prompts) == 1 + assert "fox" in prompts[0]["text"] + + def test_filter_by_tag(self, client): + client.post("/api/prompts", json={"text": "prompt1", "tags": ["nature"]}) + client.post("/api/prompts", json={"text": "prompt2", "tags": ["urban"]}) + + r = client.get("/api/prompts", params={"tag": "nature"}) + prompts = r.json()["prompts"] + assert len(prompts) == 1 + assert prompts[0]["text"] == "prompt1" + + def test_sort_by_used_count(self, client): + r1 = client.post("/api/prompts", json={"text": "less used"}) + r2 = client.post("/api/prompts", json={"text": "more used"}) + id2 = r2.json()["id"] + + # Bump usage on second prompt + client.post(f"/api/prompts/{id2}/usage") + client.post(f"/api/prompts/{id2}/usage") + + r = client.get("/api/prompts", params={"sort_by": "used_count"}) + prompts = r.json()["prompts"] + assert prompts[0]["text"] == "more used" + assert prompts[0]["used_count"] == 2 + + +class TestUsageTracking: + def test_increment_usage(self, client): + r = client.post("/api/prompts", json={"text": "track me"}) + prompt_id = r.json()["id"] + + r = client.post(f"/api/prompts/{prompt_id}/usage") + assert r.status_code == 200 + assert r.json()["used_count"] == 1 + + r = client.post(f"/api/prompts/{prompt_id}/usage") + assert r.json()["used_count"] == 2 + + def test_increment_nonexistent_returns_404(self, client): + r = client.post("/api/prompts/nonexistent/usage") + assert r.status_code == 404 + + +class TestWildcardCRUD: + def test_create_and_list(self, client): + r = client.post("/api/wildcards", json={"name": "color", "values": ["red", "blue", "green"]}) + assert r.status_code == 200 + data = r.json() + assert data["name"] == "color" + assert data["values"] == ["red", "blue", "green"] + wc_id = data["id"] + + r = client.get("/api/wildcards") + assert r.status_code == 200 + wildcards = r.json()["wildcards"] + assert len(wildcards) == 1 + assert wildcards[0]["id"] == wc_id + + def test_update_wildcard(self, client): + r = client.post("/api/wildcards", json={"name": "animal", "values": ["cat"]}) + wc_id = r.json()["id"] + + r = client.put(f"/api/wildcards/{wc_id}", json={"values": ["cat", "dog", "bird"]}) + assert r.status_code == 200 + assert r.json()["values"] == ["cat", "dog", "bird"] + + def test_update_nonexistent_returns_404(self, client): + r = client.put("/api/wildcards/nonexistent", json={"values": ["x"]}) + assert r.status_code == 404 + + def test_delete_wildcard(self, client): + r = client.post("/api/wildcards", json={"name": "size", "values": ["big", "small"]}) + wc_id = r.json()["id"] + + r = client.delete(f"/api/wildcards/{wc_id}") + assert r.status_code == 200 + + r = client.get("/api/wildcards") + assert len(r.json()["wildcards"]) == 0 + + def test_delete_nonexistent_returns_404(self, client): + r = client.delete("/api/wildcards/nonexistent") + assert r.status_code == 404 + + +class TestWildcardExpansion: + def test_expand_all(self, client): + client.post("/api/wildcards", json={"name": "color", "values": ["red", "blue"]}) + client.post("/api/wildcards", json={"name": "animal", "values": ["cat", "dog"]}) + + r = client.post("/api/wildcards/expand", json={ + "prompt": "A _color_ _animal_", + "mode": "all", + }) + assert r.status_code == 200 + expanded = r.json()["expanded"] + assert sorted(expanded) == [ + "A blue cat", + "A blue dog", + "A red cat", + "A red dog", + ] + + def test_expand_random(self, client): + client.post("/api/wildcards", json={"name": "color", "values": ["red", "blue", "green"]}) + + r = client.post("/api/wildcards/expand", json={ + "prompt": "A _color_ thing", + "mode": "random", + "count": 3, + }) + assert r.status_code == 200 + expanded = r.json()["expanded"] + assert len(expanded) == 3 + for e in expanded: + assert "_color_" not in e + + def test_expand_no_wildcards(self, client): + r = client.post("/api/wildcards/expand", json={ + "prompt": "plain prompt", + "mode": "all", + }) + assert r.status_code == 200 + assert r.json()["expanded"] == ["plain prompt"] + + def test_expand_nested_wildcards(self, client): + """Wildcard value contains another wildcard reference.""" + client.post("/api/wildcards", json={"name": "scene", "values": ["_color_ sunset"]}) + client.post("/api/wildcards", json={"name": "color", "values": ["golden", "crimson"]}) + + r = client.post("/api/wildcards/expand", json={ + "prompt": "A _scene_", + "mode": "all", + }) + assert r.status_code == 200 + expanded = sorted(r.json()["expanded"]) + assert expanded == ["A crimson sunset", "A golden sunset"] diff --git a/backend/tests/test_queue_routes.py b/backend/tests/test_queue_routes.py new file mode 100644 index 00000000..3048de6a --- /dev/null +++ b/backend/tests/test_queue_routes.py @@ -0,0 +1,87 @@ +"""Tests for queue API routes.""" + +from __future__ import annotations + + +def test_submit_video_job(client): + resp = client.post("/api/queue/submit", json={ + "type": "video", + "model": "ltx-fast", + "params": {"prompt": "a cat", "duration": "6", "resolution": "720p", "aspectRatio": "16:9"}, + }) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "queued" + assert "id" in data + + +def test_submit_image_job(client): + resp = client.post("/api/queue/submit", json={ + "type": "image", + "model": "z-image-turbo", + "params": {"prompt": "a dog", "width": 1024, "height": 1024}, + }) + assert resp.status_code == 200 + assert resp.json()["status"] == "queued" + + +def test_get_queue_status(client): + client.post("/api/queue/submit", json={ + "type": "video", + "model": "ltx-fast", + "params": {"prompt": "test"}, + }) + resp = client.get("/api/queue/status") + assert resp.status_code == 200 + jobs = resp.json()["jobs"] + assert len(jobs) == 1 + assert jobs[0]["status"] == "queued" + + +def test_cancel_job(client): + submit_resp = client.post("/api/queue/submit", json={ + "type": "video", + "model": "ltx-fast", + "params": {"prompt": "test"}, + }) + job_id = submit_resp.json()["id"] + cancel_resp = client.post(f"/api/queue/cancel/{job_id}") + assert cancel_resp.status_code == 200 + assert cancel_resp.json()["status"] == "cancelled" + + +def test_clear_finished_jobs(client): + submit_resp = client.post("/api/queue/submit", json={ + "type": "video", + "model": "ltx-fast", + "params": {"prompt": "test"}, + }) + job_id = submit_resp.json()["id"] + client.post(f"/api/queue/cancel/{job_id}") + client.post("/api/queue/clear") + status_resp = client.get("/api/queue/status") + assert len(status_resp.json()["jobs"]) == 0 + + +def test_seedance_routes_to_api_slot(client): + resp = client.post("/api/queue/submit", json={ + "type": "video", + "model": "seedance-1.5-pro", + "params": {"prompt": "test"}, + }) + assert resp.status_code == 200 + status = client.get("/api/queue/status") + jobs = status.json()["jobs"] + assert len(jobs) == 1 + assert jobs[0]["slot"] == "api" + + +def test_nano_banana_routes_to_api_slot(client): + resp = client.post("/api/queue/submit", json={ + "type": "image", + "model": "nano-banana-2", + "params": {"prompt": "test"}, + }) + assert resp.status_code == 200 + status = client.get("/api/queue/status") + assert status.json()["jobs"][0]["slot"] == "api" diff --git a/backend/tests/test_queue_worker.py b/backend/tests/test_queue_worker.py new file mode 100644 index 00000000..08576169 --- /dev/null +++ b/backend/tests/test_queue_worker.py @@ -0,0 +1,238 @@ +"""Tests for the queue worker.""" + +from __future__ import annotations + +from pathlib import Path + +from state.job_queue import JobQueue, QueueJob +from handlers.queue_worker import QueueWorker + + +class FakeJobExecutor: + def __init__(self, result_paths: list[str] | None = None) -> None: + self.executed_jobs: list[QueueJob] = [] + self.raise_on_execute: Exception | None = None + self._result_paths = result_paths if result_paths is not None else ["/fake/output.mp4"] + + def execute(self, job: QueueJob) -> list[str]: + self.executed_jobs.append(job) + if self.raise_on_execute is not None: + raise self.raise_on_execute + return list(self._result_paths) + + +class FakeEnhanceHandler: + def __init__(self, result: str = "") -> None: + self._result = result + + def enhance_i2v_motion(self, image_path: str) -> str: + return self._result + + +def test_worker_processes_gpu_job(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="video", model="ltx-fast", params={"prompt": "test"}, slot="gpu") + + executor = FakeJobExecutor() + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=FakeJobExecutor()) + worker.tick() + + assert len(executor.executed_jobs) == 1 + assert executor.executed_jobs[0].id == job.id + updated = queue.get_job(job.id) + assert updated is not None + assert updated.status == "complete" + assert updated.result_paths == ["/fake/output.mp4"] + + +def test_worker_processes_api_job(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="video", model="seedance-1.5-pro", params={"prompt": "test"}, slot="api") + + api_executor = FakeJobExecutor() + worker = QueueWorker(queue=queue, gpu_executor=FakeJobExecutor(), api_executor=api_executor) + worker.tick() + + assert len(api_executor.executed_jobs) == 1 + updated = queue.get_job(job.id) + assert updated is not None + assert updated.status == "complete" + + +def test_worker_handles_execution_error(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + + executor = FakeJobExecutor() + executor.raise_on_execute = RuntimeError("GPU exploded") + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=FakeJobExecutor()) + worker.tick() + + # tick() dispatches to a thread — wait for it to complete + import time + for _ in range(50): + updated = queue.get_job(job.id) + if updated is not None and updated.status == "error": + break + time.sleep(0.05) + + updated = queue.get_job(job.id) + assert updated is not None + assert updated.status == "error" + assert updated.error == "GPU exploded" + + +def test_worker_runs_gpu_and_api_in_parallel(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + gpu_job = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + api_job = queue.submit(job_type="video", model="seedance-1.5-pro", params={}, slot="api") + + gpu_executor = FakeJobExecutor() + api_executor = FakeJobExecutor() + worker = QueueWorker(queue=queue, gpu_executor=gpu_executor, api_executor=api_executor) + worker.tick() + + assert len(gpu_executor.executed_jobs) == 1 + assert len(api_executor.executed_jobs) == 1 + + +def test_worker_skips_cancelled_job(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + queue.cancel_job(job.id) + + executor = FakeJobExecutor() + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=FakeJobExecutor()) + worker.tick() + + assert len(executor.executed_jobs) == 0 + + +# --- Task 7: Dependency checking --- + + +def test_worker_skips_job_with_unresolved_dependency(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + parent = queue.submit(job_type="image", model="zit", params={"prompt": "a cat"}, slot="gpu") + child = queue.submit( + job_type="video", model="fast", params={}, slot="gpu", + depends_on=parent.id, + auto_params={"imagePath": "$dep.result_paths[0]"}, + ) + + executor = FakeJobExecutor(result_paths=[]) + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=executor) + worker.tick() + + import time + time.sleep(0.1) # Let thread finish + + # Parent was picked up and executed, child should still be queued + # (parent completes but child wasn't dispatched in same tick) + assert len(executor.executed_jobs) == 1 + assert executor.executed_jobs[0].id == parent.id + c = queue.get_job(child.id) + assert c is not None and c.status == "queued" + + +def test_worker_dispatches_dependent_job_after_parent_completes(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + parent = queue.submit(job_type="image", model="zit", params={"prompt": "a cat"}, slot="gpu") + child = queue.submit( + job_type="video", model="fast", params={}, slot="gpu", + depends_on=parent.id, + auto_params={"imagePath": "$dep.result_paths[0]"}, + ) + + # Simulate parent completing + queue.update_job(parent.id, status="complete", result_paths=["/out/cat.png"]) + + executor = FakeJobExecutor(result_paths=["/out/video.mp4"]) + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=executor) + worker.tick() + + # Child should now be dispatched + c = queue.get_job(child.id) + assert c is not None and c.status in ("running", "complete") + + +def test_worker_fails_dependent_job_when_parent_errors(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + parent = queue.submit(job_type="image", model="zit", params={"prompt": "a cat"}, slot="gpu") + child = queue.submit( + job_type="video", model="fast", params={}, slot="gpu", + depends_on=parent.id, + ) + + # Simulate parent failing + queue.update_job(parent.id, status="error", error="GPU OOM") + + executor = FakeJobExecutor(result_paths=[]) + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=executor) + worker.tick() + + child_job = queue.get_job(child.id) + assert child_job is not None + assert child_job.status == "error" + assert "Upstream job" in (child_job.error or "") + + +# --- Task 8: Batch completion detection --- + + +def test_worker_detects_batch_completion(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + j1 = queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b1", batch_index=0) + j2 = queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b1", batch_index=1) + + completed_batches: list[str] = [] + + def on_batch_complete(batch_id: str, jobs: list[QueueJob]) -> None: + completed_batches.append(batch_id) + + executor = FakeJobExecutor(result_paths=["/out.png"]) + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=executor, on_batch_complete=on_batch_complete) + + # Complete both jobs + queue.update_job(j1.id, status="complete", result_paths=["/out/1.png"]) + queue.update_job(j2.id, status="complete", result_paths=["/out/2.png"]) + + worker.tick() + + assert completed_batches == ["b1"] + + # Second tick should NOT re-notify + worker.tick() + assert completed_batches == ["b1"] + + +# --- Task 11: I2V auto-prompt --- + + +def test_worker_generates_i2v_prompt_for_auto_prompt_job(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + + # Parent image job already complete + parent = queue.submit(job_type="image", model="zit", params={"prompt": "a landscape"}, slot="gpu", batch_id="b1") + queue.update_job(parent.id, status="complete", result_paths=["/out/landscape.png"]) + + # Child video job with auto_prompt + child = queue.submit( + job_type="video", model="fast", params={"duration": "4"}, slot="gpu", + batch_id="b1", depends_on=parent.id, + auto_params={"imagePath": "$dep.result_paths[0]", "auto_prompt": "true"}, + ) + + fake_enhance = FakeEnhanceHandler(result="The camera pans across mountains.") + executor = FakeJobExecutor(result_paths=["/out/video.mp4"]) + worker = QueueWorker( + queue=queue, gpu_executor=executor, api_executor=executor, + enhance_handler=fake_enhance, + ) + worker.tick() + + # Verify the child job got the auto-generated prompt + child_job = queue.get_job(child.id) + assert child_job is not None + assert child_job.params.get("prompt") == "The camera pans across mountains." + assert child_job.params.get("imagePath") == "/out/landscape.png" diff --git a/backend/tests/test_r2_upload.py b/backend/tests/test_r2_upload.py new file mode 100644 index 00000000..f2fd2077 --- /dev/null +++ b/backend/tests/test_r2_upload.py @@ -0,0 +1,29 @@ +"""Tests for R2 upload integration.""" + +from __future__ import annotations + + +def test_r2_client_is_configured_when_credentials_present() -> None: + from services.r2_client.r2_client_impl import R2ClientImpl + + client = R2ClientImpl( + access_key_id="test", + secret_access_key="test", + endpoint="https://example.com", + bucket="test-bucket", + public_url="https://pub.example.com", + ) + assert client.is_configured() is True + + +def test_r2_client_not_configured_when_empty() -> None: + from services.r2_client.r2_client_impl import R2ClientImpl + + client = R2ClientImpl( + access_key_id="", + secret_access_key="", + endpoint="", + bucket="", + public_url="", + ) + assert client.is_configured() is False diff --git a/backend/tests/test_receive_job.py b/backend/tests/test_receive_job.py new file mode 100644 index 00000000..68e434f4 --- /dev/null +++ b/backend/tests/test_receive_job.py @@ -0,0 +1,83 @@ +"""Tests for receive-job endpoint (jobs from Director's Palette).""" +from __future__ import annotations + + +class TestReceiveJobConnected: + def test_receive_job_creates_queue_entry(self, client): + client.post("/api/settings", json={"paletteApiKey": "dp_valid_key"}) + resp = client.post("/api/sync/receive-job", json={ + "prompt": "A cinematic sunset over the ocean", + "model": "ltx-fast", + "settings": {"resolution": "720p", "duration": "4", "fps": "24", "aspect_ratio": "16:9"}, + }) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "queued" + assert "id" in data + + # Verify it appeared in the queue + queue = client.get("/api/queue/status").json() + assert len(queue["jobs"]) == 1 + job = queue["jobs"][0] + assert job["id"] == data["id"] + assert job["params"]["prompt"] == "A cinematic sunset over the ocean" + + def test_receive_job_with_character_reference(self, client): + client.post("/api/settings", json={"paletteApiKey": "dp_valid_key"}) + resp = client.post("/api/sync/receive-job", json={ + "prompt": "Character walking in park", + "model": "ltx-fast", + "character_id": "char-abc-123", + }) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "queued" + + queue = client.get("/api/queue/status").json() + job = queue["jobs"][0] + assert job["params"]["character_id"] == "char-abc-123" + + def test_receive_job_with_first_frame_url(self, client, fake_services): + client.post("/api/settings", json={"paletteApiKey": "dp_valid_key"}) + from tests.fakes.services import FakeResponse + fake_services.http.queue("get", FakeResponse(status_code=200, content=b"fake-image-data")) + resp = client.post("/api/sync/receive-job", json={ + "prompt": "Animate this scene", + "model": "ltx-fast", + "first_frame_url": "https://example.com/frame.png", + }) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "queued" + + queue = client.get("/api/queue/status").json() + job = queue["jobs"][0] + assert "imagePath" in job["params"] + + +class TestReceiveJobDisconnected: + def test_receive_job_returns_403_without_api_key(self, client): + resp = client.post("/api/sync/receive-job", json={ + "prompt": "A cinematic sunset over the ocean", + "model": "ltx-fast", + }) + assert resp.status_code == 403 + data = resp.json() + assert "error" in data + + +class TestReceiveJobValidation: + def test_receive_job_rejects_empty_prompt(self, client): + client.post("/api/settings", json={"paletteApiKey": "dp_valid_key"}) + resp = client.post("/api/sync/receive-job", json={ + "prompt": " ", + "model": "ltx-fast", + }) + assert resp.status_code == 422 + + def test_receive_job_rejects_missing_prompt(self, client): + client.post("/api/settings", json={"paletteApiKey": "dp_valid_key"}) + resp = client.post("/api/sync/receive-job", json={ + "model": "ltx-fast", + }) + assert resp.status_code == 422 diff --git a/backend/tests/test_response_models.py b/backend/tests/test_response_models.py index 3b2d9b6a..b633f2ca 100644 --- a/backend/tests/test_response_models.py +++ b/backend/tests/test_response_models.py @@ -39,7 +39,7 @@ def test_camelcase_keys(self, client, test_state): progress=0.45, downloaded_bytes=5_000_000_000, total_bytes=19_000_000_000, - speed_mbps=50, + speed_bytes_per_sec=50_000_000.0, ) } @@ -57,7 +57,7 @@ def test_camelcase_keys(self, client, test_state): "filesCompleted", "totalFiles", "error", - "speedMbps", + "speedBytesPerSec", } assert set(data.keys()) == expected_keys @@ -74,7 +74,7 @@ def test_camelcase_keys(self, client): assert "fast_model" not in data assert "seedLocked" in data assert "seed_locked" not in data - assert "hasFalApiKey" in data + assert "hasReplicateApiKey" in data class TestGenerateSnakeCaseKeys: diff --git a/backend/tests/test_runtime_policy_decision.py b/backend/tests/test_runtime_policy_decision.py index 0999f709..2cfd971f 100644 --- a/backend/tests/test_runtime_policy_decision.py +++ b/backend/tests/test_runtime_policy_decision.py @@ -15,7 +15,7 @@ def test_windows_without_cuda_forces_api() -> None: def test_windows_with_low_vram_forces_api() -> None: - assert decide_force_api_generations(system="Windows", cuda_available=True, vram_gb=30) is True + assert decide_force_api_generations(system="Windows", cuda_available=True, vram_gb=16) is True def test_windows_with_unknown_vram_forces_api() -> None: @@ -23,7 +23,8 @@ def test_windows_with_unknown_vram_forces_api() -> None: def test_windows_with_required_vram_allows_local_mode() -> None: - assert decide_force_api_generations(system="Windows", cuda_available=True, vram_gb=31) is False + assert decide_force_api_generations(system="Windows", cuda_available=True, vram_gb=20) is False + assert decide_force_api_generations(system="Windows", cuda_available=True, vram_gb=24) is False def test_other_systems_fail_closed() -> None: diff --git a/backend/tests/test_settings.py b/backend/tests/test_settings.py index 28b34c5c..eb5166b3 100644 --- a/backend/tests/test_settings.py +++ b/backend/tests/test_settings.py @@ -19,7 +19,9 @@ def test_default_settings(self, client, default_app_settings): assert data["loadOnStartup"] is False assert data["hasLtxApiKey"] is False assert data["userPrefersLtxApiVideoGenerations"] is False - assert data["hasFalApiKey"] is False + assert data["hasReplicateApiKey"] is False + assert data["imageModel"] == "flux-klein-9b" + assert data["videoModel"] == "ltx-fast" assert data["useLocalTextEncoder"] is False assert data["fastModel"] == {"useUpscaler": True} assert data["proModel"] == {"steps": 20, "useUpscaler": True} @@ -29,8 +31,9 @@ def test_default_settings(self, client, default_app_settings): assert data["hasGeminiApiKey"] is False assert data["seedLocked"] is False assert data["lockedSeed"] == 42 + assert data["batchSoundEnabled"] is True assert "ltxApiKey" not in data - assert "falApiKey" not in data + assert "replicateApiKey" not in data assert "geminiApiKey" not in data def test_reflects_changed_settings(self, client, test_state): @@ -107,13 +110,13 @@ def test_update_api_keys(self, client, test_state): json={ "ltxApiKey": "ltx-key-abc", "geminiApiKey": "gemini-key-xyz", - "falApiKey": "fal-key-123", + "replicateApiKey": "rep-key-123", }, ) assert r.status_code == 200 assert test_state.state.app_settings.ltx_api_key == "ltx-key-abc" assert test_state.state.app_settings.gemini_api_key == "gemini-key-xyz" - assert test_state.state.app_settings.fal_api_key == "fal-key-123" + assert test_state.state.app_settings.replicate_api_key == "rep-key-123" def test_update_user_prefers_api_video_generations(self, client, test_state): r = client.post("/api/settings", json={"userPrefersLtxApiVideoGenerations": True}) @@ -122,11 +125,11 @@ def test_update_user_prefers_api_video_generations(self, client, test_state): def test_empty_string_does_not_erase_key(self, client, test_state): test_state.state.app_settings.ltx_api_key = "real-key" - test_state.state.app_settings.fal_api_key = "fal-key" - r = client.post("/api/settings", json={"ltxApiKey": "", "falApiKey": ""}) + test_state.state.app_settings.replicate_api_key = "rep-key" + r = client.post("/api/settings", json={"ltxApiKey": "", "replicateApiKey": ""}) assert r.status_code == 200 assert test_state.state.app_settings.ltx_api_key == "real-key" - assert test_state.state.app_settings.fal_api_key == "fal-key" + assert test_state.state.app_settings.replicate_api_key == "rep-key" def test_omitted_key_does_not_erase_key(self, client, test_state): test_state.state.app_settings.ltx_api_key = "real-key" @@ -139,6 +142,16 @@ def test_unknown_field_rejected(self, client): assert r.status_code == 422 +class TestVideoModel: + def test_video_model_roundtrips(self, client, test_state): + resp = client.post("/api/settings", json={"videoModel": "seedance-1.5-pro"}) + assert resp.status_code == 200 + assert test_state.state.app_settings.video_model == "seedance-1.5-pro" + + get_resp = client.get("/api/settings") + assert get_resp.json()["videoModel"] == "seedance-1.5-pro" + + class TestSettingsPersistence: def _new_state(self, test_state, default_app_settings): fake_services = FakeServices() @@ -151,13 +164,19 @@ def _new_state(self, test_state, default_app_settings): text_encoder=fake_services.text_encoder, task_runner=fake_services.task_runner, ltx_api_client=fake_services.ltx_api_client, - zit_api_client=fake_services.zit_api_client, + image_api_client=fake_services.image_api_client, + video_api_client=fake_services.video_api_client, + palette_sync_client=fake_services.palette_sync_client, fast_video_pipeline_class=type(fake_services.fast_video_pipeline), + gguf_video_pipeline_class=None, + nf4_video_pipeline_class=None, image_generation_pipeline_class=type(fake_services.image_generation_pipeline), + flux_klein_pipeline_class=None, ic_lora_pipeline_class=type(fake_services.ic_lora_pipeline), a2v_pipeline_class=type(fake_services.a2v_pipeline), retake_pipeline_class=type(fake_services.retake_pipeline), ic_lora_model_downloader=fake_services.ic_lora_model_downloader, + model_scanner=fake_services.model_scanner, ) return build_initial_state(test_state.config, default_app_settings.model_copy(deep=True), service_bundle=bundle) @@ -197,6 +216,61 @@ def test_user_prefers_api_video_generations_persists(self, client, test_state, d assert loaded.state.app_settings.user_prefers_ltx_api_video_generations is True +class TestPaletteApiKey: + def test_palette_api_key_roundtrip(self, client, default_app_settings): + """Palette API key can be saved and is masked in responses.""" + resp = client.post("/api/settings", json={"paletteApiKey": "dp_test_key_123"}) + assert resp.status_code == 200 + resp = client.get("/api/settings") + data = resp.json() + assert data["hasPaletteApiKey"] is True + assert "dp_test_key_123" not in resp.text + + +class TestAbliteratedTextEncoder: + def test_setting_roundtrips(self, client, test_state): + resp = client.post("/api/settings", json={"useAbliteratedTextEncoder": True}) + assert resp.status_code == 200 + assert test_state.state.app_settings.use_abliterated_text_encoder is True + + get_resp = client.get("/api/settings") + assert get_resp.json()["useAbliteratedTextEncoder"] is True + + def test_default_is_false(self, client): + resp = client.get("/api/settings") + assert resp.json()["useAbliteratedTextEncoder"] is False + + def test_resolve_gemma_root_uses_abliterated_when_enabled(self, test_state, create_fake_model_files): + create_fake_model_files() + test_state.state.app_settings.use_local_text_encoder = True + test_state.state.app_settings.use_abliterated_text_encoder = True + + # Create abliterated encoder directory + abliterated_dir = test_state.config.model_path("text_encoder_abliterated") + abliterated_dir.mkdir(parents=True, exist_ok=True) + (abliterated_dir / "model.safetensors").write_bytes(b"\x00" * 1024) + + gemma_root = test_state.text.resolve_gemma_root() + assert gemma_root is not None + assert "abliterated" in gemma_root + + def test_resolve_gemma_root_falls_back_when_abliterated_missing(self, test_state, create_fake_model_files): + create_fake_model_files() + test_state.state.app_settings.use_local_text_encoder = True + test_state.state.app_settings.use_abliterated_text_encoder = True + + # No abliterated directory — should fall back to standard + gemma_root = test_state.text.resolve_gemma_root() + assert gemma_root is not None + assert "abliterated" not in gemma_root + + def test_abliterated_not_required_for_download(self, client, test_state): + resp = client.get("/api/models/status") + models = resp.json()["models"] + abliterated = next(m for m in models if "abliterated" in m["name"].lower()) + assert abliterated["required"] is False + + class TestSettingsSchemaDrift: def test_update_request_tracks_app_settings_fields(self): assert set(AppSettings.model_fields) == set(UpdateSettingsRequest.model_fields) diff --git a/backend/tests/test_state_actions.py b/backend/tests/test_state_actions.py index 749df8b4..a9bcdde3 100644 --- a/backend/tests/test_state_actions.py +++ b/backend/tests/test_state_actions.py @@ -139,7 +139,7 @@ def test_mps_skips_torch_compile(test_state, fake_services): def test_startup_warmup_keeps_fast_on_gpu_and_preloads_zit_on_cpu(test_state, fake_services, create_fake_model_files): - create_fake_model_files(include_zit=True) + create_fake_model_files(include_zit=True, include_flux_klein=True) test_state.state.app_settings.load_on_startup = True test_state.health.default_warmup() diff --git a/backend/tests/test_style_guide.py b/backend/tests/test_style_guide.py new file mode 100644 index 00000000..864100a5 --- /dev/null +++ b/backend/tests/test_style_guide.py @@ -0,0 +1,93 @@ +"""Tests for style guide grid generation endpoint.""" +from __future__ import annotations + + +class TestStyleGuideGenerate: + def test_generates_9_jobs(self, client): + resp = client.post("/api/style-guide/generate", json={ + "style_name": "Impressionism", + "style_description": "soft brushstrokes, vibrant light", + }) + assert resp.status_code == 200 + data = resp.json() + assert len(data["job_ids"]) == 9 + assert len(set(data["job_ids"])) == 9 + + def test_jobs_appear_in_queue(self, client): + resp = client.post("/api/style-guide/generate", json={ + "style_name": "Art Deco", + }) + job_ids = resp.json()["job_ids"] + + queue = client.get("/api/queue/status").json() + queue_ids = {j["id"] for j in queue["jobs"]} + for jid in job_ids: + assert jid in queue_ids + + def test_prompts_include_style_name_and_description(self, client): + resp = client.post("/api/style-guide/generate", json={ + "style_name": "Film Noir", + "style_description": "high contrast, dramatic shadows", + }) + job_ids = resp.json()["job_ids"] + + queue = client.get("/api/queue/status").json() + prompts = [j["params"]["prompt"] for j in queue["jobs"]] + + for prompt in prompts: + assert "Film Noir" in prompt + assert "high contrast, dramatic shadows" in prompt + + def test_prompts_include_diverse_subjects(self, client): + resp = client.post("/api/style-guide/generate", json={ + "style_name": "Watercolor", + }) + job_ids = resp.json()["job_ids"] + + queue = client.get("/api/queue/status").json() + prompts = [j["params"]["prompt"] for j in queue["jobs"]] + + subject_keywords = [ + "Portrait", + "Cityscape", + "Nature landscape", + "Interior room", + "Food", + "Vehicle", + "Animal", + "Architecture", + "Abstract", + ] + for keyword in subject_keywords: + assert any(keyword in p for p in prompts), f"Missing subject keyword: {keyword}" + + def test_returns_9_job_ids(self, client): + resp = client.post("/api/style-guide/generate", json={ + "style_name": "Pop Art", + "style_description": "bold colors, comic style", + }) + assert resp.status_code == 200 + data = resp.json() + assert "job_ids" in data + assert len(data["job_ids"]) == 9 + + def test_reference_image_path_passed_to_jobs(self, client): + resp = client.post("/api/style-guide/generate", json={ + "style_name": "Minimalist", + "reference_image_path": "/tmp/style_ref.png", + }) + job_ids = resp.json()["job_ids"] + + queue = client.get("/api/queue/status").json() + for job in queue["jobs"]: + assert job["params"]["reference_image_path"] == "/tmp/style_ref.png" + + def test_no_reference_image_when_not_provided(self, client): + resp = client.post("/api/style-guide/generate", json={ + "style_name": "Gothic", + }) + job_ids = resp.json()["job_ids"] + + queue = client.get("/api/queue/status").json() + for job in queue["jobs"]: + assert "reference_image_path" not in job["params"] diff --git a/backend/tests/test_sync.py b/backend/tests/test_sync.py new file mode 100644 index 00000000..7f8dfbf0 --- /dev/null +++ b/backend/tests/test_sync.py @@ -0,0 +1,43 @@ +"""Tests for Palette sync routes.""" +from __future__ import annotations + + +class TestSyncStatus: + def test_disconnected_by_default(self, client): + resp = client.get("/api/sync/status") + assert resp.status_code == 200 + data = resp.json() + assert data["connected"] is False + assert data["user"] is None + + def test_connected_after_setting_api_key(self, client): + client.post("/api/settings", json={"paletteApiKey": "dp_valid_key"}) + resp = client.get("/api/sync/status") + assert resp.status_code == 200 + data = resp.json() + assert data["connected"] is True + assert data["user"]["email"] == "test@example.com" + + def test_connection_fails_with_invalid_key(self, client, fake_services): + fake_services.palette_sync_client.raise_on_validate = RuntimeError("Invalid API key") + client.post("/api/settings", json={"paletteApiKey": "dp_bad_key"}) + resp = client.get("/api/sync/status") + data = resp.json() + assert data["connected"] is False + + +class TestSyncCredits: + def test_credits_when_connected(self, client): + client.post("/api/settings", json={"paletteApiKey": "dp_valid_key"}) + resp = client.get("/api/sync/credits") + assert resp.status_code == 200 + data = resp.json() + assert data["balance_cents"] == 5000 + assert "pricing" in data + + def test_credits_when_disconnected(self, client): + resp = client.get("/api/sync/credits") + assert resp.status_code == 200 + data = resp.json() + assert data["balance_cents"] is None + assert data["connected"] is False diff --git a/backend/tests/test_tea_cache.py b/backend/tests/test_tea_cache.py new file mode 100644 index 00000000..42164b6e --- /dev/null +++ b/backend/tests/test_tea_cache.py @@ -0,0 +1,67 @@ +"""Tests for TeaCache denoising loop caching.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from services.gpu_optimizations.tea_cache import ( + TeaCacheState, + wrap_denoise_fn_with_tea_cache, +) + + +@dataclass +class FakeLatentState: + latent: torch.Tensor + denoise_mask: torch.Tensor + clean_latent: torch.Tensor + + +def _make_fake_denoise(): # type: ignore[no-untyped-def] + call_count = [0] + + def denoise_fn(video_state, audio_state, sigmas, step_index): # type: ignore[no-untyped-def] + call_count[0] += 1 + return video_state.latent * 0.9, audio_state.latent * 0.9 + + return denoise_fn, call_count + + +def test_tea_cache_disabled_when_threshold_zero() -> None: + original, call_count = _make_fake_denoise() + wrapped = wrap_denoise_fn_with_tea_cache(original, num_steps=10, threshold=0.0) + assert wrapped is original # no wrapping + + +def test_tea_cache_always_computes_first_step() -> None: + original, call_count = _make_fake_denoise() + wrapped = wrap_denoise_fn_with_tea_cache(original, num_steps=10, threshold=0.05) + latent = torch.randn(1, 100, 64) + vs = FakeLatentState(latent=latent, denoise_mask=torch.ones(1), clean_latent=latent) + _v, _a = wrapped(vs, vs, torch.linspace(1, 0, 11), 0) + assert call_count[0] == 1 + + +def test_tea_cache_always_computes_last_step() -> None: + original, call_count = _make_fake_denoise() + wrapped = wrap_denoise_fn_with_tea_cache(original, num_steps=10, threshold=0.05) + latent = torch.randn(1, 100, 64) + vs = FakeLatentState(latent=latent, denoise_mask=torch.ones(1), clean_latent=latent) + _v, _a = wrapped(vs, vs, torch.linspace(1, 0, 11), 0) + _v, _a = wrapped(vs, vs, torch.linspace(1, 0, 11), 9) + assert call_count[0] == 2 # both first and last always computed + + +def test_tea_cache_skips_similar_steps() -> None: + original, call_count = _make_fake_denoise() + wrapped = wrap_denoise_fn_with_tea_cache(original, num_steps=10, threshold=100.0) + latent = torch.ones(1, 100, 64) + vs = FakeLatentState(latent=latent, denoise_mask=torch.ones(1), clean_latent=latent) + # Step 0: always computed + wrapped(vs, vs, torch.linspace(1, 0, 11), 0) + # Step 1: very high threshold means this should be skipped + wrapped(vs, vs, torch.linspace(1, 0, 11), 1) + tea_state: TeaCacheState = wrapped._tea_cache_state + assert tea_state.skipped >= 1 diff --git a/backend/tests/test_video_api_client.py b/backend/tests/test_video_api_client.py new file mode 100644 index 00000000..18e1fcc2 --- /dev/null +++ b/backend/tests/test_video_api_client.py @@ -0,0 +1,150 @@ +"""Tests for the ReplicateVideoClientImpl video API client.""" + +from __future__ import annotations + +import pytest + +from services.video_api_client.replicate_video_client_impl import ReplicateVideoClientImpl +from tests.fakes.services import FakeHTTPClient, FakeResponse + + +API_KEY = "test-replicate-key" +SEEDANCE_MODEL = "seedance-1.5-pro" +BASE_URL = "https://api.replicate.com/v1" + + +def _make_client(http: FakeHTTPClient) -> ReplicateVideoClientImpl: + return ReplicateVideoClientImpl(http=http, api_base_url=BASE_URL) + + +def _default_kwargs() -> dict[str, object]: + return { + "api_key": API_KEY, + "model": SEEDANCE_MODEL, + "prompt": "A cat walking on a beach", + "duration": 5, + "resolution": "720p", + "aspect_ratio": "16:9", + "generate_audio": False, + } + + +def test_seedance_text_to_video_sync_success() -> None: + http = FakeHTTPClient() + + # POST returns succeeded immediately + http.queue( + "post", + FakeResponse( + status_code=201, + json_payload={ + "id": "pred123", + "status": "succeeded", + "output": "https://replicate.delivery/video.mp4", + }, + ), + ) + + # GET downloads video bytes + video_bytes = b"fake-mp4-video-content" + http.queue( + "get", + FakeResponse(status_code=200, content=video_bytes), + ) + + client = _make_client(http) + result = client.generate_text_to_video(**_default_kwargs()) + + assert result == video_bytes + assert len(http.calls) == 2 + assert http.calls[0].method == "post" + assert "bytedance/seedance-1.5-pro" in http.calls[0].url + assert http.calls[1].method == "get" + + +def test_seedance_text_to_video_polling_success() -> None: + http = FakeHTTPClient() + + # POST returns processing + http.queue( + "post", + FakeResponse( + status_code=201, + json_payload={ + "id": "pred456", + "status": "processing", + "urls": {"get": f"{BASE_URL}/predictions/pred456"}, + }, + ), + ) + + # First poll: still processing + http.queue( + "get", + FakeResponse( + status_code=200, + json_payload={ + "id": "pred456", + "status": "processing", + }, + ), + ) + + # Second poll: succeeded + http.queue( + "get", + FakeResponse( + status_code=200, + json_payload={ + "id": "pred456", + "status": "succeeded", + "output": "https://replicate.delivery/video2.mp4", + }, + ), + ) + + # Download + video_bytes = b"polled-video-content" + http.queue( + "get", + FakeResponse(status_code=200, content=video_bytes), + ) + + client = _make_client(http) + result = client.generate_text_to_video(**_default_kwargs()) + + assert result == video_bytes + # POST + 2 polls + 1 download = 4 calls + assert len(http.calls) == 4 + + +def test_unknown_model_raises() -> None: + http = FakeHTTPClient() + client = _make_client(http) + + kwargs = _default_kwargs() + kwargs["model"] = "nonexistent-model" + + with pytest.raises(RuntimeError, match="Unknown video model"): + client.generate_text_to_video(**kwargs) + + +def test_prediction_failure_raises() -> None: + http = FakeHTTPClient() + + http.queue( + "post", + FakeResponse( + status_code=201, + json_payload={ + "id": "pred789", + "status": "failed", + "error": "GPU out of memory", + }, + ), + ) + + client = _make_client(http) + + with pytest.raises(RuntimeError, match="Replicate prediction failed"): + client.generate_text_to_video(**_default_kwargs()) diff --git a/backend/uv.lock b/backend/uv.lock index 07a19c05..4bc85b9a 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -306,6 +306,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821, upload-time = "2024-10-26T00:50:33.425Z" }, ] +[[package]] +name = "gguf" +version = "0.18.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3f/26/7622a41c39db9d7090225a4bf8368550e59694dcf7313b44f9a82b501209/gguf-0.18.0.tar.gz", hash = "sha256:b4659093d5d0dccdb5902a904d54b327f4052879fe5e90946ad5fce9f8018c2e", size = 107170, upload-time = "2026-02-27T15:05:39.254Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/0c/e0f1eae7535a97476fb903f65301e35da2a66182b8161066b7eb312b2cb8/gguf-0.18.0-py3-none-any.whl", hash = "sha256:af93f7ef198a265cbde5fa6a6b3101528bca285903949ab0a3e591cd993a1864", size = 114244, upload-time = "2026-02-27T15:05:37.991Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -513,6 +528,7 @@ dependencies = [ { name = "diffusers" }, { name = "fastapi" }, { name = "ftfy" }, + { name = "gguf" }, { name = "huggingface-hub" }, { name = "imageio" }, { name = "imageio-ffmpeg" }, @@ -552,6 +568,7 @@ requires-dist = [ { name = "diffusers", git = "https://github.com/huggingface/diffusers.git?rev=01de02e8b4f2cc91df4f3e91cb6535ebcbeb490c" }, { name = "fastapi", specifier = ">=0.115.0" }, { name = "ftfy", specifier = ">=6.0.0" }, + { name = "gguf", specifier = ">=0.10.0" }, { name = "httpx", marker = "extra == 'test'", specifier = ">=0.27" }, { name = "huggingface-hub", specifier = ">=0.23.0" }, { name = "imageio", specifier = ">=2.37.2" }, diff --git a/docs/gpu-optimization-results.md b/docs/gpu-optimization-results.md new file mode 100644 index 00000000..3a27bf78 --- /dev/null +++ b/docs/gpu-optimization-results.md @@ -0,0 +1,195 @@ +# GPU Optimization Results — March 9, 2026 + +**Hardware:** NVIDIA RTX 4090 24GB VRAM, Windows 11, CUDA 12.9 +**Model:** LTX-Video 2.3 (ltx-2.3-22b-distilled), 22B parameter distilled model + +--- + +## What We Did + +### 1. FFN Chunked Feedforward (Peak VRAM Reduction) + +**Problem:** The LTX transformer has 48 transformer blocks. Each block's FeedForward layer expands the hidden dimension by 4× (e.g., 3072 → 12288), creating enormous intermediate tensors. When processing long sequences (>200 frames), these intermediate tensors exceed available VRAM, forcing the GPU into a slow fallback path that causes the "nonlinear scaling cliff" — where 10s took 6.5× longer than 8s for only 25% more frames. + +**Solution:** We monkey-patch every FeedForward module in the transformer to split its computation along the sequence dimension into 8 chunks. Instead of computing all 241 frames at once through the 4× expansion, it processes ~30 frames at a time. The output is mathematically identical (FeedForward is pointwise, so chunking is lossless). This reduces peak VRAM usage by up to 8×. + +**Setting:** `ffnChunkCount` (default: 8, set to 0 to disable) +**File:** `backend/services/gpu_optimizations/ffn_chunking.py` + +### 2. TeaCache (Timestep-Aware Caching) + +**Problem:** During denoising, the transformer runs a full forward pass at every timestep. Many adjacent timesteps produce very similar outputs, wasting computation. + +**Solution:** TeaCache monitors how much the input changes between denoising steps using a relative L1 distance metric, rescaled by a polynomial fitted to the LTX-Video noise schedule. When the change is below a threshold, it reuses the previous step's residual (the difference between input and output) instead of running the full transformer. First and last steps are always fully computed. + +**Setting:** `teaCacheThreshold` (default: 0.0 = off, 0.03 = balanced quality/speed, 0.05 = aggressive) +**File:** `backend/services/gpu_optimizations/tea_cache.py` + +### 3. VRAM Deep Cleanup + +**Problem:** After heavy GPU workloads (especially long generations), VRAM fragmentation caused subsequent generations to stall at 15% progress indefinitely. The GPU showed 100% utilization but made no progress. + +**Solution:** After every GPU job completes, we now run an aggressive cleanup: two rounds of garbage collection + CUDA cache clearing + CUDA synchronize. This ensures VRAM is fully reclaimed before the next job starts. + +**File:** `backend/services/gpu_cleaner/torch_cleaner.py` (deep_cleanup method) + +### 4. R2 Cloud Storage Upload + +**Problem:** Generated videos/images only lived on the local machine. + +**Solution:** After each generation, results can be automatically uploaded to Cloudflare R2 (S3-compatible) storage. Configure via settings: `r2AccessKeyId`, `r2SecretAccessKey`, `r2Endpoint`, `r2Bucket`, `r2PublicUrl`, and toggle with `autoUploadToR2`. + +**File:** `backend/services/r2_client/r2_client_impl.py` + +--- + +## Benchmark Results + +All tests at 512p, 16:9 aspect ratio, 24fps, FFN chunking=8, TeaCache threshold=0.03. + +### Comparison Table + +| Duration | Frames | Baseline | Optimized | Speedup | Time Saved | +|----------|--------|----------|-----------|---------|------------| +| 2s | 49 | 37s | 42s (cold) / 59s* | — | Session warmth variance | +| 5s | 121 | 84s | **65s** | **1.29×** | 19s (23%) | +| 8s | 193 | 100s | **65s** | **1.54×** | 35s (35%) | +| 10s | 241 | 651s | **275s** | **2.37×** | 376s (58%) | +| 20s | 481 | ~11,820s (3.3 hrs) | Still slow** | ~2× est. | See notes | + +*The 2s "optimized" run was the first warm generation after cold start. The baseline 37s was also from a warm session. Session-dependent variance of ±20s is normal for short clips. + +**The 20s test reached 15% inference after ~50 minutes and was cancelled. Extrapolating: ~5-6 hours total, which is ~2× faster than baseline but still impractical. For 20s+ content, use extend chains (5× 4s clips ≈ 5 minutes). + +### Key Takeaway: The Nonlinear Scaling Cliff is Dramatically Reduced + +Before optimizations: +``` +8s (193 frames) = 100s +10s (241 frames) = 651s ← 6.5× jump for 25% more frames +``` + +After optimizations: +``` +8s (193 frames) = 65s +10s (241 frames) = 275s ← 4.2× jump for 25% more frames (was 6.5×) +``` + +The cliff is still there (attention is still quadratic in sequence length) but FFN chunking prevents the worst of the VRAM thrashing. The 10s generation went from "impractical" (11 min) to "tolerable" (4.5 min). + +### Scaling Curve (512p, Optimized) + +``` +Frames: 49 121 193 241 481 +Time: 42s 65s 65s 275s ~5hrs (est) +Per-frame: 0.86s 0.54s 0.34s 1.14s ~37s +``` + +The sweet spot is **5-8 seconds** (121-193 frames). Beyond 193 frames, the per-frame cost increases dramatically due to quadratic attention scaling. + +--- + +## Practical User Guidelines (Updated) + +| Use Case | Setting | Expected Time | +|----------|---------|---------------| +| Quick preview | 512p, 2s | ~40s | +| Standard clip | 512p, 5s | **~65s** (was 84s) | +| Longer clip | 512p, 8s | **~65s** (was 100s) | +| Extended clip | 512p, 10s | **~4.5 min** (was 11 min) | +| Long scene | 512p, 2s × 5 extend chain | ~5 min | +| High quality short | 720p, 2s | ~40s | + +**Avoid:** 512p ≥20s (hours), 720p ≥8s (hours), 1080p ≥5s (OOM crash) + +**For longer scenes:** Use the extend feature to chain 2-5s clips. Five 2s clips = ~3-5 min total, vs a single 10s clip = ~4.5 min — similar time but with better control. + +--- + +## Generated Samples + +All benchmark outputs are in `D:\git\directors-desktop\backend\outputs\`: + +### Today's Optimized Benchmark Outputs +| File | Duration | Time | Prompt | +|------|----------|------|--------| +| `ltx2_video_20260309_113236_01996534.mp4` | 2s (warmup) | 42s | Test warmup scene | +| `ltx2_video_20260309_113318_a6b79593.mp4` | 2s | 59s | Ocean waves crashing on rocky shore, sunset, cinematic | +| `ltx2_video_20260309_113417_9832bd5d.mp4` | 5s | 65s | Jellyfish glowing neon in dark ocean, bioluminescent | +| `ltx2_video_20260309_113522_68e618a6.mp4` | 8s | 65s | Samurai in bamboo forest, rain, cinematic | +| `ltx2_video_20260309_113627_5ab81499.mp4` | 10s | 275s | Rocket launch at dawn, smoke plume, slow motion | + +### Earlier Baseline Benchmark Outputs +| File | Duration | Time | Prompt | +|------|----------|------|--------| +| `ltx2_video_20260309_052523_29fd7111.mp4` | 5s | 84s | Ocean waves crashing on rocky shore, sunset, cinematic | +| `ltx2_video_20260309_052647_869c45d0.mp4` | 8s | 100s | Jellyfish glowing neon in dark ocean, bioluminescent | +| `ltx2_video_20260309_052826_071807c1.mp4` | 10s | 651s | Samurai in bamboo forest, rain, cinematic | +| `ltx2_video_20260309_053918_5b3bb04d.mp4` | 10s | 650s | Rocket launch at dawn, smoke plume, slow motion | +| `ltx2_video_20260309_055008_3f713c6c.mp4` | 20s | ~3.3hrs | Flower blooming timelapse, golden hour | + +### Image Outputs (ZIT model) +| File | Size | Notes | +|------|------|-------| +| `zit_image_20260309_095030_6f9f6d95.png` | 1.2MB | Robot in flower garden (API, nano-banana-2) | +| `zit_image_20260309_051257_40f12d39.png` | 1.2MB | Local ZIT generation | +| Plus 7 more ZIT images from earlier sessions | | | + +--- + +## Model & Feature Status + +### LTX-Video Version +**Yes, we're on LTX-2.3** — the latest. The checkpoint is `ltx-2.3-22b-distilled.safetensors` from `Lightricks/LTX-2.3` on HuggingFace. This is the 22-billion parameter distilled model, which is the most capable version available. + +### Prompt Enhancement (Magic Wand) +**Yes, prompt enhancement is implemented.** There's a magic wand button in the GenSpace UI that enhances prompts via: +1. **Directors Palette API** (priority) — if `paletteApiKey` is set, uses `/api/prompt-expander` endpoint with "2x" expansion level +2. **Gemini 2.0 Flash** (fallback) — if `geminiApiKey` is set, uses Gemini to expand prompts with cinematic details + +The enhancement adds lighting, camera angles, mood, and atmosphere to vague prompts. It works for both text-to-video and text-to-image modes. + +**Current state:** It works if you have either a Palette API key or Gemini API key configured. The Palette API key is already set based on your settings. + +### Aspect Ratios +Supported aspect ratios for local generation: +- **16:9** (default) — standard widescreen +- **9:16** — vertical/portrait +- **1:1** — square +- **4:3** — classic TV +- **3:4** — portrait + +For API-forced generation (Seedance), only 16:9 and 9:16 are allowed. + +The resolution mapping by aspect ratio is handled in `video_generation_handler.py`. Each resolution tier (512p, 720p, 1080p) has width/height values per aspect ratio. + +--- + +## What's Missing: Time Estimation UI + +Currently, the frontend shows these status phases during generation: +- "Queued — waiting..." +- "Starting up..." +- "Preparing GPU..." +- "Loading model..." / "Loading video model..." +- "Encoding prompt..." +- "Generating..." (inference phase) +- "Decoding video..." +- "Complete!" + +**What Directors Palette has that we don't:** +- **Elapsed time counter** (MM:SS since generation started) +- **Estimated time remaining** (based on benchmark data for the resolution/duration combo) +- **Time-based progress bar** (progress = elapsed / estimated × 100%) +- **Stage indicators** (Analysis → Generation → Complete) + +**Recommendation:** Add estimated durations based on our benchmark data. For example, when a user submits a 512p 8s job, show "Estimated: ~1:05" and count up elapsed time. The data: + +``` +512p 2s = ~40s estimated +512p 5s = ~65s estimated +512p 8s = ~65s estimated +512p 10s = ~275s estimated (4:35) +``` + +This would be a frontend change in `use-generation.ts` and the progress display component. diff --git a/docs/palette-api-spec.md b/docs/palette-api-spec.md new file mode 100644 index 00000000..ce694859 --- /dev/null +++ b/docs/palette-api-spec.md @@ -0,0 +1,762 @@ +# Director's Palette API Specification for Desktop Integration + +**Version:** 1.0 +**Date:** 2026-03-06 +**Audience:** Director's Palette engineering team +**Status:** Draft + +--- + +## Table of Contents + +1. [Overview](#overview) +2. [Priority Order](#priority-order) +3. [Infrastructure Requirements](#infrastructure-requirements) +4. [Authentication Middleware](#authentication-middleware) +5. [Endpoints: Auth](#endpoints-auth) +6. [Endpoints: Gallery](#endpoints-gallery) +7. [Endpoints: Library](#endpoints-library) +8. [Endpoints: Credits](#endpoints-credits) +9. [Endpoints: Prompts](#endpoints-prompts) +10. [Error Format](#error-format) +11. [Open Questions](#open-questions) + +--- + +## Overview + +Director's Desktop is an Electron app for local AI video generation. It needs to integrate with Director's Palette (the Next.js web app) so users can: + +- **Authenticate** from the desktop app using their existing Palette account +- **Browse** their cloud gallery, characters, styles, references, and brands +- **Upload** locally generated assets to their cloud gallery +- **Check credits** before running API-backed generations +- **Enhance prompts** using the existing prompt expander + +All new endpoints live under the `/api/desktop/*` prefix. This isolates Desktop-specific concerns (CORS, API key auth, pairing) from the existing web app routes. + +**Integration topology:** + +``` +Director's Desktop (Electron) + | + | HTTP requests to https://directors-palette-v2.vercel.app/api/desktop/* + | Auth: Bearer token (JWT or dp_ API key) + | +Director's Palette (Next.js on Vercel) + | + | Supabase (auth, DB, storage) +``` + +--- + +## Priority Order + +Build in this order. Each phase is independently useful. + +| Phase | What | Why | +|-------|------|-----| +| **Phase 1** | Infrastructure (CORS, auth middleware, API key expansion) | Nothing works without this | +| **Phase 2** | `GET /api/desktop/me` | Desktop can verify a token is valid | +| **Phase 3** | Auth pairing flow (pair, poll, complete) + deep link redirect | Users can log in from Desktop | +| **Phase 4** | `GET /api/desktop/gallery` + `DELETE` | Users can browse/manage cloud gallery | +| **Phase 5** | `POST /api/desktop/gallery/upload` | Users can push local generations to cloud | +| **Phase 6** | Library endpoints (characters, styles, references, brands) | Users can pull creative assets into Desktop | +| **Phase 7** | Credits endpoints | Desktop can gate API-backed generations | +| **Phase 8** | `POST /api/desktop/prompt/enhance` | Prompt enhancement from Desktop | + +--- + +## Infrastructure Requirements + +### 1. CORS on `/api/desktop/*` + +All routes under `/api/desktop/*` must return CORS headers allowing requests from the Desktop backend. + +**Required headers on every response (including preflight):** + +``` +Access-Control-Allow-Origin: http://localhost:8000 +Access-Control-Allow-Methods: GET, POST, PUT, DELETE, OPTIONS +Access-Control-Allow-Headers: Content-Type, Authorization +Access-Control-Max-Age: 86400 +``` + +For `OPTIONS` preflight requests, return `204 No Content` with the above headers. + +**Implementation note:** In Next.js App Router, this is best done with a middleware function that matches `/api/desktop/:path*` and injects headers. Alternatively, each route handler can call a shared `withCors()` helper. Either pattern is fine as long as every route and every OPTIONS response is covered. + +### 2. API Key Expansion + +Currently, `api_keys` table entries are admin-only. This must change: + +- **Any authenticated user** can create API keys for their own account. +- Existing table schema (`api_keys` with `dp_` + 32-hex format, SHA-256 hash lookup) is fine. No schema changes needed. +- Add two internal endpoints (used by the Palette web UI, not by Desktop): + - Create key: inserts a row with `user_id`, `hashed_key`, `name`, `created_at` + - Revoke key: soft-deletes or hard-deletes by key ID for the authenticated user +- Desktop does not call these endpoints directly. The user creates/manages keys in the Palette web UI and pastes the key into Desktop. + +### 3. Deep Link Redirect from OAuth Callback + +Modify the existing `/auth/callback` route to support a `redirect_to` query parameter: + +1. User clicks "Login with Palette" in Desktop. +2. Desktop opens the system browser to: `https://directors-palette-v2.vercel.app/auth/login?redirect_to=directorsdesktop://auth/callback` +3. User completes OAuth (Google/Apple) or email/password login as normal. +4. `/auth/callback` receives the Supabase auth code, exchanges it for a session. +5. If `redirect_to` starts with `directorsdesktop://`, redirect to: `directorsdesktop://auth/callback?access_token={JWT}&refresh_token={JWT}` +6. If `redirect_to` is missing or does not start with `directorsdesktop://`, proceed with existing behavior. + +**Security:** Only allow `redirect_to` values starting with `directorsdesktop://`. Reject all other custom schemes. + +--- + +## Authentication Middleware + +Every `/api/desktop/*` endpoint requires authentication. The middleware must support two token formats in the `Authorization` header: + +### Token Format 1: Supabase JWT + +``` +Authorization: Bearer eyJhbGciOiJIUzI1NiIs... +``` + +- Verify using Supabase's `createClient` / `getUser()` server-side. +- Extract `user.id` from the verified token. + +### Token Format 2: API Key + +``` +Authorization: Bearer dp_a1b2c3d4e5f6... +``` + +- Detect by the `dp_` prefix. +- SHA-256 hash the key, look up in `api_keys` table. +- Extract `user_id` from the matched row. +- Reject if no match or if the key is revoked/expired. + +### Middleware Behavior + +``` +function authenticateDesktop(request): + token = request.headers["Authorization"]?.replace("Bearer ", "") + + if not token: + return 401 { error: "missing_token", message: "Authorization header required" } + + if token.startsWith("dp_"): + user_id = lookupApiKey(sha256(token)) + if not user_id: + return 401 { error: "invalid_api_key", message: "API key not found or revoked" } + else: + user = supabase.auth.getUser(token) + if error: + return 401 { error: "invalid_token", message: "JWT expired or invalid" } + user_id = user.id + + // Attach user_id to request context for handler use +``` + +### Authenticated Context + +Every handler receives a resolved `user_id: string` (UUID). Handlers never deal with token parsing directly. + +--- + +## Endpoints: Auth + +### `GET /api/desktop/me` + +Validate the token and return normalized user information. + +**Auth:** Required + +**Request:** No parameters. + +**Response `200`:** + +```json +{ + "id": "uuid", + "email": "user@example.com", + "display_name": "Jane Smith", + "avatar_url": "https://lh3.googleusercontent.com/...", + "created_at": "2025-06-15T10:30:00Z" +} +``` + +**Field sources:** +- `id`: `auth.users.id` +- `email`: `auth.users.email` +- `display_name`: `auth.users.raw_user_meta_data->>'full_name'` (fall back to email prefix if null) +- `avatar_url`: `auth.users.raw_user_meta_data->>'avatar_url'` (nullable) +- `created_at`: `auth.users.created_at` + +**Errors:** +- `401` if token is invalid (see middleware) + +--- + +### `POST /api/desktop/auth/pair` + +Create a new pairing session. Desktop displays the resulting QR code. The user scans it with their phone or clicks the URL in their browser (where they are already logged in to Palette). + +**Auth:** Not required (this is how unauthenticated Desktop instances initiate login). + +**Request body:** None. + +**Response `201`:** + +```json +{ + "pairing_code": "A1B2C3", + "qr_url": "https://directors-palette-v2.vercel.app/pair/A1B2C3", + "expires_at": "2026-03-06T12:15:00Z" +} +``` + +**Implementation notes:** +- `pairing_code`: 6 alphanumeric characters, uppercase. Unique, not guessable. +- Store in a `pairing_sessions` table (or KV/cache): `code`, `status` (pending/completed/expired), `access_token` (null until completed), `created_at`, `expires_at`. +- Expires after 10 minutes. +- Consider using Supabase table or Vercel KV. Either works. + +**Errors:** +- `429` if rate limited (max 5 pairing requests per IP per minute) + +--- + +### `GET /api/desktop/auth/pair/[code]` + +Poll the status of a pairing session. Desktop calls this every 2-3 seconds until status is `completed` or `expired`. + +**Auth:** Not required. + +**Request params:** +- `code` (path parameter): The 6-character pairing code. + +**Response `200` (pending):** + +```json +{ + "status": "pending", + "expires_at": "2026-03-06T12:15:00Z" +} +``` + +**Response `200` (completed):** + +```json +{ + "status": "completed", + "access_token": "eyJhbGciOiJIUzI1NiIs...", + "refresh_token": "v1.MjQ1..." +} +``` + +**Response `200` (expired):** + +```json +{ + "status": "expired" +} +``` + +**Implementation notes:** +- Once status is `completed` and the tokens have been returned, immediately delete or invalidate the pairing session. Tokens must only be retrievable once. +- Return `expired` for codes past their `expires_at` or that have already been consumed. + +**Errors:** +- `404` if code does not exist: `{ "error": "not_found", "message": "Pairing code not found" }` + +--- + +### `POST /api/desktop/auth/pair/[code]/complete` + +Called from the Palette web app (or mobile) by an authenticated user to complete a pairing session. This is what "approves" the Desktop login. + +**Auth:** Required (the user completing the pairing must be logged in to Palette). + +**Request params:** +- `code` (path parameter): The 6-character pairing code. + +**Request body:** None. The user identity comes from the authenticated session. + +**Response `200`:** + +```json +{ + "status": "completed" +} +``` + +**Implementation notes:** +- Look up the pairing session by `code`. Verify it is still `pending` and not expired. +- Generate a new Supabase session/token pair for the authenticated user (or reuse the current session tokens). +- Store `access_token` and `refresh_token` in the pairing session row. Set status to `completed`. +- The next poll from Desktop on `GET /api/desktop/auth/pair/[code]` will pick up the tokens. + +**Errors:** +- `404` if code does not exist or is expired +- `409` if code is already completed: `{ "error": "already_completed", "message": "This pairing code has already been used" }` + +--- + +## Endpoints: Gallery + +### `GET /api/desktop/gallery` + +Paginated list of the user's gallery items. + +**Auth:** Required + +**Query parameters:** + +| Param | Type | Default | Description | +|-------|------|---------|-------------| +| `page` | integer | `1` | Page number (1-indexed) | +| `per_page` | integer | `50` | Items per page. Max `100`. | +| `type` | string | (all) | Filter by `generation_type`. Values: `image`, `video`, `audio`, or omit for all. | +| `folder_id` | string | (none) | Filter by folder. Omit for all folders. | +| `sort` | string | `created_at_desc` | Sort order. Values: `created_at_desc`, `created_at_asc`. | + +**Response `200`:** + +```json +{ + "items": [ + { + "id": "uuid", + "generation_type": "image", + "status": "completed", + "public_url": "https://xyz.supabase.co/storage/v1/object/public/directors-palette/generations/uid/image_abc.png", + "file_size": 2048576, + "mime_type": "image/png", + "metadata": { "prompt": "a sunset over mountains", "model": "flux-schnell" }, + "folder_id": "uuid-or-null", + "created_at": "2026-03-01T14:22:00Z", + "updated_at": "2026-03-01T14:22:00Z" + } + ], + "pagination": { + "page": 1, + "per_page": 50, + "total_items": 237, + "total_pages": 5 + } +} +``` + +**Notes:** +- Only return items where `user_id` matches the authenticated user. +- Exclude items with `status = 'error'` unless explicitly requested (future filter). +- `public_url` is the Supabase public bucket URL. No signed URLs needed. + +**Errors:** +- `400` if `per_page` > 100 or `type` is invalid + +--- + +### `POST /api/desktop/gallery/upload` + +Upload a locally generated asset to the user's cloud gallery. + +**Auth:** Required + +**Request:** `multipart/form-data` + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `file` | file | Yes | The file to upload. Max 50 MB. | +| `generation_type` | string | Yes | `image`, `video`, or `audio` | +| `metadata` | string (JSON) | No | Stringified JSON with prompt, model, parameters, etc. | + +**Response `201`:** + +```json +{ + "id": "uuid", + "generation_type": "video", + "status": "completed", + "public_url": "https://xyz.supabase.co/storage/v1/object/public/directors-palette/generations/uid/video_abc.mp4", + "file_size": 15728640, + "mime_type": "video/mp4", + "metadata": { "prompt": "a cat playing piano", "model": "ltx-video" }, + "created_at": "2026-03-06T09:00:00Z" +} +``` + +**Implementation notes:** +- Upload to Supabase storage bucket `directors-palette` at path `generations/{userId}/{type}_{uniqueId}.{ext}` +- Insert a row into the `gallery` table with `status = 'completed'`, `user_id`, `public_url`, `file_size`, `mime_type`, `generation_type`, and `metadata`. +- Enforce the 500-image cap. If the user is at the limit, return `409`. +- For video uploads, set `expires_at` to 7 days from now (matching existing policy). + +**Errors:** +- `400` if file is missing or exceeds 50 MB +- `400` if `generation_type` is invalid +- `409` if user has reached the 500-item gallery cap: `{ "error": "gallery_limit_reached", "message": "Gallery limit of 500 items reached. Delete items to upload more." }` +- `413` if file exceeds 50 MB (may be caught at infra level before handler) + +--- + +### `DELETE /api/desktop/gallery/[id]` + +Delete a gallery item. + +**Auth:** Required + +**Request params:** +- `id` (path parameter): Gallery item UUID. + +**Response `200`:** + +```json +{ + "deleted": true +} +``` + +**Implementation notes:** +- Verify `gallery.user_id` matches the authenticated user before deleting. +- Delete the file from Supabase storage as well as the database row. +- If the gallery item is referenced by a `storyboard_characters.reference_gallery_id` or `style_guides.reference_gallery_id` or `reference.gallery_id`, either cascade-null those FKs or return an error. Recommendation: cascade-null and delete. + +**Errors:** +- `404` if item does not exist or does not belong to user +- `404` (not `403`) for items belonging to other users -- do not reveal existence + +--- + +## Endpoints: Library + +### `GET /api/desktop/library/characters` + +List all characters across all of the user's storyboards. + +**Auth:** Required + +**Query parameters:** + +| Param | Type | Default | Description | +|-------|------|---------|-------------| +| `storyboard_id` | string | (all) | Filter to a specific storyboard | + +**Response `200`:** + +```json +{ + "characters": [ + { + "id": "uuid", + "storyboard_id": "uuid", + "name": "Detective Harlow", + "description": "Tall, mid-40s, weathered trench coat, sharp eyes", + "has_reference": true, + "reference_image_url": "https://xyz.supabase.co/storage/v1/object/public/directors-palette/generations/uid/image_ref.png", + "mentions": 12, + "metadata": {} + } + ] +} +``` + +**Implementation notes:** +- Join `storyboard_characters` through their parent storyboards to filter by `user_id`. +- For `reference_image_url`: if `reference_gallery_id` is not null, join to `gallery` and return `public_url`. Otherwise null. +- Sort by `mentions` descending (most-used characters first). + +**Errors:** +- `400` if `storyboard_id` is provided but does not belong to the user + +--- + +### `GET /api/desktop/library/styles` + +List user's custom style guides plus the 9 preset styles. + +**Auth:** Required + +**Response `200`:** + +```json +{ + "styles": [ + { + "id": "uuid", + "name": "Noir Cinematic", + "description": "High contrast black and white...", + "style_prompt": "noir cinematic style, high contrast, deep shadows...", + "reference_image_url": "https://...", + "is_preset": false, + "metadata": {} + }, + { + "id": "preset_photorealistic", + "name": "Photorealistic", + "description": "...", + "style_prompt": "photorealistic, highly detailed...", + "reference_image_url": null, + "is_preset": true, + "metadata": {} + } + ] +} +``` + +**Implementation notes:** +- Query `style_guides` where `user_id` matches. +- Append the 9 hardcoded preset styles with `is_preset: true` and synthetic IDs prefixed `preset_`. +- For `reference_image_url`: join to `gallery` via `reference_gallery_id` if present. + +--- + +### `GET /api/desktop/library/references` + +List reference images with optional category filter. + +**Auth:** Required + +**Query parameters:** + +| Param | Type | Default | Description | +|-------|------|---------|-------------| +| `category` | string | (all) | Filter by category: `people`, `places`, `props`, `layouts` | + +**Response `200`:** + +```json +{ + "references": [ + { + "id": "uuid", + "category": "people", + "tags": ["protagonist", "female", "young"], + "gallery_item": { + "id": "uuid", + "public_url": "https://...", + "mime_type": "image/png", + "file_size": 1024000 + } + } + ] +} +``` + +**Implementation notes:** +- Join `reference` to `gallery` via `gallery_id` FK. +- Filter to items where the gallery item's `user_id` matches the authenticated user. +- If `category` is provided, filter by `reference.category`. +- Sort by `reference.id` descending (newest first). + +**Errors:** +- `400` if `category` is not one of the four valid values + +--- + +### `GET /api/desktop/library/brands` + +List user's brands. + +**Auth:** Required + +**Response `200`:** + +```json +{ + "brands": [ + { + "id": "uuid", + "name": "Acme Corp", + "slug": "acme-corp", + "logo_url": "https://...", + "tagline": "Innovation for everyone", + "industry": "Technology", + "audience": { "age_range": "25-45", "interests": ["tech", "productivity"] }, + "voice": { "tone": "professional", "personality": "authoritative" }, + "visual_identity": { "primary_color": "#2563EB", "font": "Inter" }, + "visual_style": { "mood": "clean", "photography": "lifestyle" }, + "music": { "genre": "electronic", "tempo": "moderate" }, + "brand_guide_image_url": "https://..." + } + ] +} +``` + +**Implementation notes:** +- Query `brands` table. Filter mechanism depends on how brands are associated with users. If there is a `user_id` column, filter by it. If brands are shared/org-level, return all brands the user has access to. +- Map `audience_json` to `audience`, `voice_json` to `voice`, etc. (drop the `_json` suffix in the API response for cleaner naming). + +--- + +## Endpoints: Credits + +### `GET /api/desktop/credits` + +Return the user's credit balance and model pricing. + +**Auth:** Required + +**Response `200`:** + +```json +{ + "balance_cents": 1500, + "lifetime_purchased_cents": 5000, + "lifetime_used_cents": 3500, + "pricing": { + "image": 20, + "video": 40, + "audio": 15, + "text": 3 + } +} +``` + +**Implementation notes:** +- `balance_cents`, `lifetime_purchased_cents`, `lifetime_used_cents` from `user_credits` table and aggregated from `credit_transactions`. +- `pricing` from `model_pricing` table. Map to the four generation types. Values are in cents. +- This is a superset of the existing `GET /api/credits` response, repackaged for clarity. + +--- + +### `POST /api/desktop/credits/check` + +Check whether the user can afford a specific generation type. Does not deduct. + +**Auth:** Required + +**Request body:** + +```json +{ + "generation_type": "video" +} +``` + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `generation_type` | string | Yes | `image`, `video`, `audio`, or `text` | + +**Response `200`:** + +```json +{ + "can_afford": true, + "cost_cents": 40, + "balance_cents": 1500, + "balance_after_cents": 1460 +} +``` + +**Response `200` (insufficient balance):** + +```json +{ + "can_afford": false, + "cost_cents": 40, + "balance_cents": 30, + "balance_after_cents": -10 +} +``` + +**Notes:** +- This is read-only. It does not deduct credits. +- Desktop uses this before starting a generation to show the user a confirmation or a "top up credits" prompt. + +**Errors:** +- `400` if `generation_type` is invalid + +--- + +## Endpoints: Prompts + +### `POST /api/desktop/prompt/enhance` + +Proxy to the existing prompt expander. + +**Auth:** Required + +**Request body:** + +```json +{ + "prompt": "a cat sitting on a windowsill", + "level": "2x", + "style": "cinematic" +} +``` + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `prompt` | string | Yes | The original prompt to enhance. Max 2000 characters. | +| `level` | string | No | Enhancement level: `2x` (default) or `3x` | +| `style` | string | No | Director style to apply (e.g., `cinematic`, `anime`, `documentary`). Omit for general enhancement. | + +**Response `200`:** + +```json +{ + "original_prompt": "a cat sitting on a windowsill", + "enhanced_prompt": "A fluffy orange tabby cat perched gracefully on a sun-drenched windowsill, warm golden hour light streaming through sheer curtains, dust motes floating in the air, shallow depth of field, intimate documentary photography style", + "level": "2x", + "style": "cinematic" +} +``` + +**Implementation notes:** +- This is a thin proxy over the existing `POST /api/prompt-expander` logic. Extract the shared logic into a function callable from both routes, or have this route call the existing one internally. +- Deduct text-generation credits (3 cents) per call. + +**Errors:** +- `400` if `prompt` is empty or exceeds 2000 characters +- `400` if `level` is not `2x` or `3x` +- `402` if user has insufficient credits: `{ "error": "insufficient_credits", "message": "Not enough credits for prompt enhancement", "cost_cents": 3, "balance_cents": 1 }` + +--- + +## Error Format + +All errors follow a consistent shape: + +```json +{ + "error": "error_code_snake_case", + "message": "Human-readable description of what went wrong" +} +``` + +### Standard Error Codes + +| HTTP Status | Error Code | When | +|-------------|-----------|------| +| `400` | `bad_request` | Malformed input, invalid parameters | +| `401` | `missing_token` | No Authorization header | +| `401` | `invalid_token` | JWT expired or invalid | +| `401` | `invalid_api_key` | API key not found or revoked | +| `402` | `insufficient_credits` | User cannot afford the operation | +| `404` | `not_found` | Resource does not exist (or user lacks access) | +| `409` | `already_completed` | Pairing code already used | +| `409` | `gallery_limit_reached` | 500-item gallery cap hit | +| `413` | `file_too_large` | Upload exceeds 50 MB | +| `429` | `rate_limited` | Too many requests | +| `500` | `internal_error` | Unexpected server error | + +--- + +## Open Questions + +These are decisions for the Palette team. Desktop will adapt to whatever you choose. + +1. **Pairing session storage:** Supabase table vs. Vercel KV? KV is simpler for ephemeral data with TTL. A Supabase table is fine too. + +2. **Brands ownership:** The `brands` table schema provided does not include a `user_id` column. How are brands associated with users? Is there a join table? Desktop needs to know which brands to return. + +3. **API key management UI:** Desktop assumes users create/revoke API keys in the Palette web settings page. Does the Palette team want to add this to an existing settings view, or create a new one? + +4. **Token refresh:** When Desktop receives a JWT via pairing or deep link, the JWT will eventually expire. Should Desktop call a `POST /api/desktop/auth/refresh` endpoint with the refresh token, or should it just re-trigger the pairing/login flow? Recommendation: add a refresh endpoint, but it is not blocking for Phase 1. + +5. **Prompt enhancement styles:** What are the valid `style` values for the prompt expander? Desktop needs an enum or a list endpoint. For now, Desktop will send freeform strings and let the expander handle unknowns gracefully. + +6. **Video expiry enforcement:** Desktop-uploaded videos get `expires_at` set to 7 days. Is there an existing cron/scheduled function that cleans up expired items, or does Desktop need to handle this awareness client-side only? diff --git a/docs/palette-credits-handoff.md b/docs/palette-credits-handoff.md new file mode 100644 index 00000000..64be3f0c --- /dev/null +++ b/docs/palette-credits-handoff.md @@ -0,0 +1,223 @@ +# Director's Desktop — Credits & Cost Display Integration + +**Date:** 2026-03-09 +**From:** Desktop team +**To:** Palette engineering team +**Status:** Needed for v1 launch — users need to see costs before generating + +--- + +## The Problem + +Desktop users generate videos and images using cloud APIs that cost credits. Right now they have **no visibility into how much anything costs** before they hit Generate, and **no running balance** visible while they work. + +We need three things from Palette to fix this: + +1. A way to fetch the user's current credit balance +2. A pricing table so we can show "this will cost X credits" before generation +3. A pre-check endpoint so we can block generation when credits are insufficient + +--- + +## What Desktop Needs + +### 1. `GET /api/desktop/credits` — Balance + Pricing + +We already have this in our spec (`docs/palette-api-spec.md` lines 595-620). Here's the contract: + +**Auth:** Bearer token (JWT or `dp_` API key) + +**Response `200`:** + +```json +{ + "balance_cents": 1500, + "lifetime_purchased_cents": 5000, + "lifetime_used_cents": 3500, + "pricing": { + "video_t2v": 40, + "video_i2v": 40, + "video_seedance": 80, + "image": 20, + "image_edit": 20, + "audio": 15, + "text_enhance": 3 + } +} +``` + +**Why we need granular pricing keys:** Desktop supports multiple generation types and models. A flat "video: 40" doesn't work when Seedance costs more than LTX. We need to show the user "This Seedance generation will cost 80 credits" vs "This LTX generation will cost 40 credits." + +**Pricing keys we need (at minimum):** + +| Key | What it covers | Desktop UI context | +|-----|---------------|-------------------| +| `video_t2v` | Text-to-video (LTX local/API) | User types prompt, hits Generate | +| `video_i2v` | Image-to-video (LTX) | User uploads image + prompt | +| `video_seedance` | Seedance cloud video | User selects Seedance model | +| `image` | Text-to-image (ZIT) | User generates image | +| `image_edit` | Image editing/img2img | User edits existing image | +| `text_enhance` | Prompt enhancement | User clicks magic wand | + +If your pricing is simpler (same cost for all video types), just return the same value for all video keys. We'd rather have too many keys than not enough. + +**How Desktop uses this:** +- We poll this endpoint when the app starts and after each generation +- We display balance in the header bar: `Credits: $15.00` +- We show cost before generation: `This will cost $0.40 (balance: $15.00)` + +--- + +### 2. `POST /api/desktop/credits/check` — Pre-Generation Check + +Already in our spec (lines 624-670). Before every cloud generation, Desktop calls this to verify the user can afford it. + +**Request:** + +```json +{ + "generation_type": "video_seedance", + "count": 1 +} +``` + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `generation_type` | string | Yes | One of the pricing keys above | +| `count` | int | No | Number of generations (default 1, for bulk) | + +**Response `200` (can afford):** + +```json +{ + "can_afford": true, + "cost_cents": 80, + "balance_cents": 1500, + "balance_after_cents": 1420 +} +``` + +**Response `200` (insufficient):** + +```json +{ + "can_afford": false, + "cost_cents": 80, + "balance_cents": 30, + "balance_after_cents": -50, + "top_up_url": "https://directorspalette.com/settings/billing" +} +``` + +**Why `count`?** Desktop supports bulk generation — a user might queue 5 Seedance videos at once. We need to check if they can afford all 5 before submitting, not fail on video #3. + +**Why `top_up_url`?** When the user can't afford it, we want to show a "Top up credits" button that opens their browser to Palette's billing page. If this URL is static, we can hardcode it. If it varies per user, return it here. + +--- + +### 3. `POST /api/desktop/credits/deduct` — After Generation Completes + +When a generation finishes successfully, Desktop tells Palette to deduct the credits. + +**Request:** + +```json +{ + "generation_type": "video_seedance", + "count": 1, + "metadata": { + "model": "seedance-1.5-pro", + "duration_seconds": 5, + "resolution": "720p", + "job_id": "abc123" + } +} +``` + +**Response `200`:** + +```json +{ + "deducted_cents": 80, + "balance_cents": 1420 +} +``` + +**Response `402` (insufficient — race condition):** + +```json +{ + "error": "insufficient_credits", + "balance_cents": 30 +} +``` + +**Why deduct after, not before?** Generations can fail or be cancelled. We don't want to charge for failed jobs. The check endpoint gates the UI, the deduct endpoint charges after success. + +**Why metadata?** So you have an audit trail of what was generated. Optional fields, include whatever is useful for your analytics. + +--- + +## What Desktop Has Already Built + +| Component | Status | Notes | +|-----------|--------|-------| +| `GET /api/sync/credits` backend route | Built | Calls Palette, returns balance to frontend | +| `PaletteSyncClient.get_credits()` | Built | HTTP client method, handles errors gracefully | +| Credit display on Home screen | Built | Shows `Credits: $X.XX` when connected | +| Credit display in Settings | Built | Shows balance in Palette connection section | +| Fake test implementation | Built | Returns `{"balance": 5000, "currency": "credits"}` | + +**What we'll build once you deploy these endpoints:** + +1. Credit balance in the generation header bar (always visible while working) +2. Cost estimate shown on the Generate button: `Generate ($0.40)` +3. Pre-generation check — "Insufficient credits" dialog with "Top up" button +4. Post-generation deduction +5. Bulk generation cost calculation: `5 videos x $0.80 = $4.00` + +--- + +## Questions We Need Answered + +1. **Are the credit endpoints deployed?** The spec has been written since Phase 7 but we don't know if they're live. + +2. **Is pricing per-model or flat?** Does a Seedance video cost the same as an LTX video? Does resolution/duration affect price? + +3. **Who deducts credits?** Two options: + - **Option A (we prefer):** Desktop calls `/credits/deduct` after successful generation. Desktop controls when charges happen. + - **Option B:** Palette's prompt-expander and other proxied endpoints auto-deduct. Desktop only needs to read balance. + + Tell us which model you're using so we build accordingly. + +4. **What's the `top_up_url`?** Where do users go to buy more credits? Is it always `https://directorspalette.com/settings/billing` or something else? + +5. **Free tier / trial credits?** Do new users get any free credits? How many? We need to know for onboarding messaging. + +6. **Local generations cost credits?** When a user generates video on their own GPU (LTX local), does that cost Palette credits? Or only cloud API generations? + +--- + +## How to Test + +Once endpoints are deployed: + +```bash +# Get balance + pricing +curl -H "Authorization: Bearer dp_..." \ + https://directorspalette.com/api/desktop/credits + +# Check if user can afford a Seedance video +curl -X POST -H "Authorization: Bearer dp_..." \ + -H "Content-Type: application/json" \ + -d '{"generation_type": "video_seedance", "count": 1}' \ + https://directorspalette.com/api/desktop/credits/check + +# Deduct after successful generation +curl -X POST -H "Authorization: Bearer dp_..." \ + -H "Content-Type: application/json" \ + -d '{"generation_type": "video_seedance", "count": 1}' \ + https://directorspalette.com/api/desktop/credits/deduct +``` + +Let us know when it's live and we'll wire it up on our side. diff --git a/docs/palette-team-handoff.md b/docs/palette-team-handoff.md new file mode 100644 index 00000000..971faee6 --- /dev/null +++ b/docs/palette-team-handoff.md @@ -0,0 +1,169 @@ +# Director's Desktop — Palette Team Handoff + +**Date:** 2026-03-08 +**From:** Desktop team +**To:** Palette engineering team +**Status:** Blocking — Desktop auth is broken without this + +--- + +## What's happening + +Director's Desktop has a Settings screen where users connect their Palette account. Right now, **it doesn't work**. Users enter their `dp_` API key, hit Save, and nothing happens — it silently fails and shows "Not connected." + +We've built a workaround (email/password login via Supabase directly), but `dp_` API keys — which is what users actually have — won't work until you deploy one endpoint. + +--- + +## What we need from you (priority order) + +### 1. `GET /api/desktop/me` (BLOCKING — do this first) + +This is the only endpoint that is currently blocking us. Everything else can wait. + +**What it does:** Validates a token and returns user info. + +**Auth:** Required — `Authorization: Bearer {token}` + +The token will be either: +- A Supabase JWT (from `supabase.auth.getSession()`) +- A `dp_` API key (from the `api_keys` table) + +**Your auth middleware needs to handle both formats:** + +``` +if token starts with "dp_": + hash = sha256(token) + lookup hashed_key in api_keys table + get user_id from matched row + reject if no match or revoked +else: + user = supabase.auth.getUser(token) + reject if invalid/expired + user_id = user.id +``` + +**Response `200`:** + +```json +{ + "id": "uuid", + "email": "user@example.com", + "display_name": "Jane Smith", + "avatar_url": "https://lh3.googleusercontent.com/..." +} +``` + +**Field sources:** +- `id`: `auth.users.id` +- `email`: `auth.users.email` +- `display_name`: `auth.users.raw_user_meta_data->>'full_name'` (fall back to email prefix) +- `avatar_url`: `auth.users.raw_user_meta_data->>'avatar_url'` (nullable) + +**Error responses:** +- `401` if token missing: `{ "error": "missing_token" }` +- `401` if token invalid: `{ "error": "invalid_token" }` +- `401` if API key not found: `{ "error": "invalid_api_key" }` + +--- + +### 2. CORS on `/api/desktop/*` + +Desktop's Python backend (running on `localhost:8000`) makes HTTP requests to Palette. You need CORS headers on every `/api/desktop/*` response, including OPTIONS preflight: + +``` +Access-Control-Allow-Origin: * +Access-Control-Allow-Methods: GET, POST, PUT, DELETE, OPTIONS +Access-Control-Allow-Headers: Content-Type, Authorization +Access-Control-Max-Age: 86400 +``` + +We're using `*` for origin because the request comes from a local Python server, not a browser. If you want to lock it down, `http://localhost:8000` works, but `*` is simpler and fine for server-to-server. + +For OPTIONS requests, return `204 No Content` with the above headers. + +**In Next.js App Router**, the simplest approach is a middleware that matches `/api/desktop/:path*`. + +--- + +### 3. API key creation UI (nice-to-have, not blocking) + +Users need a way to create `dp_` API keys from the Palette web UI. If this doesn't exist yet, users can still use email/password login from Desktop (we built that). But eventually they'll want API keys for a set-it-and-forget-it connection. + +A simple settings page section: +- "Create API Key" button → generates `dp_` + 32 hex chars +- Show the key once, store SHA-256 hash in `api_keys` table +- List existing keys with revoke option + +--- + +## What Desktop has already built + +Here's the full integration infrastructure on our side, ready to go: + +| Feature | Status | What it does | +|---------|--------|-------------| +| Email/password login | Working now | Calls Supabase `/auth/v1/token` directly, stores JWT + refresh token | +| Token auto-refresh | Working now | Refreshes expired JWTs using stored refresh token | +| `dp_` API key auth | Built, waiting on you | Calls `GET /api/desktop/me` — currently gets 404 | +| Settings UI | Working now | Tabbed auth: "Login with Email" + "API Key" tabs, error feedback, disconnect | +| Gallery sync client | Built, waiting on you | Calls `GET /api/desktop/gallery`, `DELETE /api/desktop/gallery/{id}` | +| Library sync client | Built, waiting on you | Calls characters, styles, references endpoints | +| Credits check | Built, waiting on you | Calls `GET /api/desktop/credits` | +| Prompt enhancement | Built, waiting on you | Calls `POST /api/desktop/prompt/enhance` | + +**The only thing blocking a user-facing release is #1 (`GET /api/desktop/me`) and #2 (CORS).** Everything else can be phased in. + +--- + +## How Desktop calls you + +All requests go from Desktop's Python backend (NOT from the browser/Electron renderer): + +``` +User clicks "Connect" in Desktop + → Electron frontend calls localhost:8000/api/sync/connect + → Python backend calls https://directorspal.com/api/desktop/me + → With header: Authorization: Bearer dp_a1b2c3d4... + → Palette validates, returns user info + → Desktop stores the key and shows "Connected as user@example.com" +``` + +**Base URL we're hitting:** `https://directorspal.com` + +If this isn't right (staging URL, different domain, etc.), let us know and we'll update. + +--- + +## The full API spec + +The complete endpoint spec is already written: **`docs/palette-api-spec.md`** in this repo. It covers all 8 phases (auth, gallery, library, credits, prompts) with exact request/response shapes, error codes, and implementation notes. + +But **you don't need to read all of that right now.** Just build `/api/desktop/me` + CORS and we're unblocked. + +--- + +## How to test it + +Once you deploy `/api/desktop/me`: + +```bash +# Test with a Supabase JWT +curl -H "Authorization: Bearer eyJhbG..." \ + https://directorspal.com/api/desktop/me + +# Test with a dp_ API key +curl -H "Authorization: Bearer dp_a1b2c3d4e5f6..." \ + https://directorspal.com/api/desktop/me + +# Both should return: +# { "id": "...", "email": "...", "display_name": "...", "avatar_url": "..." } +``` + +Then tell us it's live and we'll test from Desktop. + +--- + +## Questions? + +If anything is unclear or you need to change the contract (different URL, different response shape, etc.), reach out before building. We can adapt on our side quickly. diff --git a/docs/palette-team-questionnaire.md b/docs/palette-team-questionnaire.md new file mode 100644 index 00000000..673fe40a --- /dev/null +++ b/docs/palette-team-questionnaire.md @@ -0,0 +1,178 @@ +# Director's Palette — Integration Questionnaire + +**Context:** We're building integration between Director's Desktop (Electron app for local AI video/image generation) and Director's Palette (web/mobile). Desktop needs a set of API endpoints from Palette to enable auth, gallery sync, library sync, and credits. This questionnaire helps us write the exact API spec for your team. + +Please answer each section. If something doesn't exist yet, just say "not built yet" — that's useful info too. + +--- + +## 1. Authentication & Users + +1a. **What Supabase Auth providers are enabled?** (check all that apply) +- [ ] Email/password +- [ ] Google OAuth +- [ ] Apple OAuth +- [ ] Other: ___________ + +1b. **Is there a `profiles` or `users` table** beyond Supabase's built-in `auth.users`? If yes, what columns does it have? (We need: display name, email, avatar URL at minimum) + +1c. **Do you have an API keys system?** Can users generate long-lived API keys (like "Developer API Keys" in settings)? If yes: +- What table stores them? +- What columns? (key hash, user_id, label, created_at, etc.) +- How are they validated? (lookup by hash? prefix matching?) + +1d. **Is there any existing endpoint that validates a token and returns user info?** (e.g., `GET /api/me` or similar) If yes, what's the route and response shape? + +1e. **What does a Supabase access token look like in your system?** When a user logs in via browser, do you use: +- Supabase JWT (from `supabase.auth.getSession()`) +- A custom session token +- Something else + +--- + +## 2. Gallery / Generated Assets + +2a. **What Supabase Storage bucket(s) store generated images/videos?** (bucket name(s)) + +2b. **Is there a database table that indexes gallery items?** If yes: +- Table name? +- Key columns? (id, user_id, filename, file_path/storage_key, type, prompt, model_name, created_at, etc.) +- Any RLS policies? (user can only see their own items?) + +2c. **What's the current gallery cap?** (The design doc mentions 500 images — is that enforced? Where?) + +2d. **Are there thumbnails?** Auto-generated, or same file served at different sizes? + +2e. **How are storage URLs generated?** +- Public bucket with direct URLs? +- Private bucket with signed URLs? (if so, what expiry?) +- Through an API proxy? + +2f. **What's the max file size for uploads?** Any format restrictions? + +--- + +## 3. Library — Characters + +3a. **Is there a `characters` table?** If yes: +- Table name and key columns? +- How are reference images stored? (array of storage paths? separate join table?) +- Is there RLS? (user sees only their characters?) + +3b. **What fields does a character have?** (name, role, description, reference_images, etc.) + +3c. **Are characters tied to a specific project/brand, or global to the user?** + +--- + +## 4. Library — Styles & Brands + +4a. **Is there a `styles` or `brands` table?** If yes: +- Table name and key columns? +- What fields? (name, description, reference_image, color palette, fonts, etc.) + +4b. **Are "style guides" a thing in Palette?** (the 3x3 grid generation feature) If yes, where are they stored? + +4c. **Is there a brand identity system?** (logos, color palettes, font selections tied to a brand) + +--- + +## 5. Library — References + +5a. **Is there a `references` or `reference_images` table?** If yes: +- Table name and key columns? +- Categories? (people, places, props, other — or different?) + +5b. **Are references shared across projects or per-project?** + +--- + +## 6. Prompts + +6a. **Is there a saved prompts / prompt library feature?** If yes: +- Table name and key columns? +- Fields? (text, tags, category, use_count, etc.) + +6b. **Is there a prompt enhancement/expansion feature?** (rough prompt -> detailed cinematic prompt) If yes, what model/API does it use? + +--- + +## 7. Credits + +7a. **How does the credit system work?** +- Table/column that stores balance? (e.g., `profiles.credits_balance` or separate `credits` table?) +- Are credits purchased? Subscription-based? Free tier? +- What costs credits? (image generation, video generation, prompt enhancement?) + +7b. **Is there an endpoint to check credit balance?** Route and response shape? + +7c. **Is there an endpoint to deduct credits?** Or do credits deduct automatically when a generation job runs? + +7d. **Credit costs per action:** +- Image generation: ___ credits +- Video generation: ___ credits +- Prompt enhancement: ___ credits +- Other: ___________ + +--- + +## 8. Existing API Routes + +8a. **List any existing API routes that might be relevant** to this integration. For each, provide: +- Route (method + path) +- What it does +- Auth required? (how?) +- Response shape (or link to code) + +Common ones we'd want to know about: +- User profile / me +- Gallery list +- Gallery upload +- Character CRUD +- Style CRUD +- Credits balance +- Prompt library + +8b. **What's your API auth pattern?** +- Bearer token in Authorization header? +- Cookie-based sessions? +- Supabase anon key + user JWT? +- Something else? + +8c. **Is there CORS configured?** Desktop will call from `http://localhost:8000` (backend proxy) — not from a browser directly. + +--- + +## 9. Technical Details + +9a. **What's the production URL for Palette?** (e.g., `https://directorspalette.com`, `https://app.directorspalette.com`, etc.) + +9b. **What framework/stack is the API built on?** (Next.js API routes? Separate backend? Supabase Edge Functions?) + +9c. **Is there a staging/dev environment** we can test against? + +9d. **Any rate limiting on API routes?** If yes, what are the limits? + +--- + +## 10. Desktop-Specific Needs + +These are features we want to build. Tell us if they conflict with anything or if you have preferences: + +10a. **Browser login redirect:** After login on Palette web, we want to redirect to `directorsdesktop://auth/callback?token=XXX`. Is there an existing OAuth callback flow we should hook into, or do you need to add a redirect? + +10b. **QR code pairing:** Desktop shows a QR code, user scans from Palette mobile app. We need a short-lived pairing endpoint. Do you have anything like this, or would it be net-new? + +10c. **"Send to Desktop" from Palette:** A button in Palette web that sends a generation job to the user's Desktop app. Any existing job/queue system we should integrate with, or is this net-new? + +10d. **Gallery upload from Desktop:** When a user generates locally and clicks "Push to Cloud," we upload the file + metadata. Any preferences on how this should work? (multipart form? presigned URL? base64?) + +--- + +## 11. Anything Else + +Anything we should know about the Palette architecture, conventions, or constraints that would affect this integration? Anything planned that might overlap? + +--- + +**Return this completed questionnaire to the Desktop team and we'll produce a detailed API spec for the endpoints we need you to build.** diff --git a/docs/performance-report.md b/docs/performance-report.md new file mode 100644 index 00000000..6cbc0c87 --- /dev/null +++ b/docs/performance-report.md @@ -0,0 +1,228 @@ +# Directors Desktop — Performance Report + +**Hardware:** NVIDIA RTX 4090 24GB VRAM, Windows 11, CUDA 12.9, Driver 576.80 +**Date:** March 9, 2026 +**Backend:** LTX-Video 0.9.7 (ltx-fast model), ZIT image model + +--- + +## Benchmark Results + +### Image Generation (ZIT, Local GPU, Warm) + +| Resolution | Time | +|-------------|-------| +| 1024×1024 | 10s | +| 768×1344 | 18s | + +### Video Generation (ltx-fast, Local GPU, Warm) + +| Resolution | Duration | Frames | Time | Notes | +|------------|----------|--------|------|-------| +| 512p | 2s | 49 | 37s | Baseline | +| 512p | 5s | 121 | 44–84s | Session-dependent (44s prior, 84s this session) | +| 512p | 8s | 193 | 86–100s | Session-dependent (86s prior, 100s this session) | +| 512p | 10s (run A) | 241 | 651.6s (~10.9 min) | Consistent | +| 512p | 10s (run B) | 241 | 650.4s (~10.8 min) | Consistent — within 0.2% of run A | +| 512p | 20s | 481 | ~11,820s (~3.3 hrs) | Timed out at 3hrs; completed ~17min later | +| 720p | 2s | 49 | 39s | | +| 720p | 5s | 121 | 83s | | +| 720p | 8s | 193 | CANCELLED | ~36 min/step, estimated 4+ hours total | +| 1080p | 2s | 49 | 499s (~8.3 min) | | +| 1080p | 5s | 121 | CRASH | OOM during VAE decode after ~40 min | + +### Video Extend (512p, 2s segments) + +Extend was tested with last-frame extraction via ffmpeg, then submitting with `lastFramePath`. +After the 20s generation, the extend base job stalled at 15% inference despite the GPU running at 100%. This indicates post-heavy-load degradation requiring a backend restart. + +### API Generation (Cloud) + +| Model | Type | Params | Time | Notes | +|-------|------|--------|------|-------| +| nano-banana-2 | Image | 1024×1024, 4 steps | ~40s | Cloud API | +| seedance-1.5-pro | Video | 720p 5s | ~2-3 min | Cloud API; rejects 512p (only 480p/720p/1080p) | + +### Cold Start + +| Scenario | Time | +|----------|------| +| First generation after app launch (512p 2s) | ~66s | +| Warm generation (512p 2s, model already loaded) | ~37s | +| Cold start overhead | ~29s (model loading + warmup) | + +--- + +## Scaling Analysis + +### Frame Count vs Time (512p) + +``` +Frames: 49 121 193 241 481 +Time: 37s 84s 100s 651s 11,820s +``` + +The relationship is **highly non-linear**. Key observations: + +1. **49 → 193 frames (4× more):** Time scales ~2.7× (37s → 100s) — roughly linear +2. **193 → 241 frames (1.25× more):** Time scales ~6.5× (100s → 651s) — massive jump +3. **241 → 481 frames (2× more):** Time scales ~18× (651s → 11,820s) — exponential blowup + +The dramatic scaling cliff at ~200 frames suggests the model hits a VRAM or attention computation threshold. Beyond 193 frames, inference shifts from being compute-bound to memory-bound, requiring tiling or sequential processing that dramatically slows throughput. + +### Resolution Scaling (2s / 49 frames) + +``` +512p (960×544): 37s +720p (1280×704): 39s +1080p (1920×1088): 499s +``` + +Similar cliff between 720p and 1080p — the attention computation quadruples but VRAM constraints force a much slower execution path. + +--- + +## Critical Issues Found + +### 1. Non-Linear Scaling Beyond 8s at 512p +- The jump from 8s (100s) to 10s (651s) is a **6.5× increase for only 25% more frames** +- This means anything beyond ~8s at 512p enters an extremely slow regime +- **Impact:** Users will experience unexpectedly long waits for clips >8s + +### 2. 20s Generation is Impractical (~3.3 hours) +- 20s at 512p takes 3+ hours on a high-end RTX 4090 +- **Recommendation:** Either cap duration at 10s with clear warning, or implement frame-chunked generation + +### 3. Post-Heavy-Load GPU Degradation +- After the 3+ hour 20s generation, a simple 2s job stalled at 15% indefinitely +- GPU showed 100% utilization but no progress — likely VRAM fragmentation +- **Impact:** Users may need to restart the app after long generations +- **Fix needed:** Explicit VRAM cleanup (torch.cuda.empty_cache(), gc.collect()) between generations + +### 4. 1080p 5s Crashes with OOM +- 1080p at 5s (121 frames) crashes during VAE decode +- VAE decode requires loading the full frame buffer into VRAM +- **Impact:** 1080p is limited to ≤2s clips (and those take ~8 min) + +### 5. Cancel Doesn't Stop GPU Inference +- `POST /api/queue/cancel/{id}` only marks the job status — the GPU continues working +- A cancelled long generation (e.g., 20s) will still consume GPU for hours +- **Fix needed:** Implement cooperative cancellation with a check in the inference loop + +### 6. Warmup Race Conditions +- Jobs submitted before model warmup completes can fail with shape mismatch errors +- Backend should either queue jobs until warmup is done, or block submission + +--- + +## Optimization Opportunities + +Based on benchmark analysis + research of LTX-Video forks and community projects. + +### Already Implemented + +The codebase already has: SageAttention v1, FP8 quantization (cast), torch.compile, VAE tiling, text encoder CPU offloading, API-based text encoding. + +### HIGH IMPACT — Speed + +1. **TeaCache (Timestep-Aware Caching)** + - Source: [ali-vilab/TeaCache](https://github.com/ali-vilab/TeaCache) + - Skips redundant transformer forward passes by caching outputs at timesteps where output changes minimally + - **Expected speedup: 1.6-2.1×** (training-free, drop-in) + - Could cut 512p 2s from 37s → ~18-23s + +2. **Frame-Chunked Generation for Long Clips** + - Instead of generating all 481 frames at once for 20s, generate in 2-5s chunks and stitch + - Already have the extend/lastFramePath mechanism — automate this internally + - Could reduce 20s from ~3.3 hours to ~5-10 minutes (5× 2s = ~185s) + +### HIGH IMPACT — VRAM / Stability + +3. **FFN Chunked Feedforward** + - Source: [RandomInternetPreson/ComfyUI_LTX-2_VRAM_Memory_Management](https://github.com/RandomInternetPreson/ComfyUI_LTX-2_VRAM_Memory_Management) + - LTX-2 transformer FFN layers expand hidden dim 4×, creating enormous intermediate tensors + - Chunking into 8-16 pieces reduces peak VRAM by up to 8× with zero quality loss + - **Benchmarks on RTX 4090**: 800 frames at 1920×1088 in ~16.5 GB, 900 frames in ~18.5 GB + - **This is likely the fix for 1080p OOM crashes and the 10s nonlinear scaling cliff** + +4. **Aggressive VRAM Cleanup Between Generations** + - Add `torch.cuda.empty_cache()` + `gc.collect()` after each generation completes + - Clear any intermediate tensors held in the pipeline + - This would prevent the post-heavy-load stall issue + +### MEDIUM IMPACT — Speed + +5. **SageAttention 2++ Upgrade** + - Source: [thu-ml/SageAttention](https://github.com/thu-ml/SageAttention) + - Current codebase pins v1.0.6; SageAttention 2++ provides 3.9× speedup over FlashAttention (vs 2.1× for v1) + +6. **FP8 Scaled MM (TensorRT-LLM)** + - Source: upstream ltx-core + - Switch from `QuantizationPolicy.fp8_cast()` to `fp8_scaled_mm()` — uses native FP8 matrix multiplication without upcasting + - Both faster and less memory; available on RTX 40xx + +7. **Guidance Skip Steps** + - `skip_step` param in `MultiModalGuiderParams` skips CFG computation every N steps + - Since guidance requires 2-3× forward passes per step, skipping alternating steps cuts total passes ~30-40% + +8. **Pre-Quantized FP8 Checkpoint** + - Use `Lightricks/LTX-2.3-fp8` from HuggingFace instead of runtime FP8 casting + - Faster load time (no conversion), potentially better quality (calibrated offline) + +### MEDIUM IMPACT — UX + +9. **Cooperative Cancellation** + - Thread a cancellation callback through the denoising loop + - Upstream `ltx-pipelines` denoising functions accept `on_step` callbacks + - Check cancel flag after each timestep — enables immediate cancel vs waiting hours + +10. **Duration Warnings in UI** + - Show estimated time before submission based on resolution + duration + - Warn users when estimated time exceeds 5 minutes + +11. **Resolution Cap Enforcement** + - Prevent 1080p ≥5s (will crash) + - Prevent 720p ≥8s (impractical — would take hours) + +### LOWER PRIORITY + +12. **Memory Profiles (a la Wan2GP)** + - Source: [deepbeepmeep/Wan2GP](https://github.com/deepbeepmeep/Wan2GP) + - Let users choose speed/memory tradeoff (tight → full-VRAM preload) + +13. **Multi-GPU Tensor Parallelism** + - Ring Attention + sequence parallelism for multi-GPU setups + - Niche but enables 600K+ token sequences + +### Priority Matrix for Our Specific Issues + +| Issue | Best Fix | Expected Impact | +|---|---|---| +| 1080p 5s OOM crash | FFN chunking (#3) | Eliminates crash | +| 512p 10s = 651s (nonlinear jump) | FFN chunking (#3) + TeaCache (#1) | 5-10× faster | +| 37s baseline for 512p 2s | TeaCache (#1) + SageAttention 2++ (#5) | ~15-20s | +| 20s = 3.3 hours | Auto-chunking (#2) | ~5-10 min | +| Post-heavy-load stall | VRAM cleanup (#4) | Eliminates stall | +| Cancel doesn't stop GPU | Step-level callback (#9) | Immediate cancel | +| Cold start overhead | FP8 pre-quantized checkpoint (#8) | ~30-50% faster load | + +--- + +## Practical User Guidelines + +Based on these benchmarks, here are the recommended settings for the RTX 4090: + +| Use Case | Recommended Setting | Expected Time | +|----------|-------------------|---------------| +| Quick preview | 512p, 2s | ~37s | +| Standard clip | 512p, 5s | ~1.5 min | +| Longer clip | 512p, 8s | ~1.5-2 min | +| Extended scene | 512p, 2s × 5 (extend chain) | ~3 min | +| High quality short | 720p, 2s | ~39s | +| High quality standard | 720p, 5s | ~1.5 min | +| Maximum quality | 1080p, 2s | ~8 min | +| Quick image | 1024×1024 | ~10s | + +**Avoid:** 512p ≥10s (10+ min), 720p ≥8s (hours), 1080p ≥5s (crash) + +**For longer scenes:** Use the extend feature to chain 2-5s clips instead of generating long durations in one shot. diff --git a/docs/plans/2026-03-06-desktop-palette-integration-v2.md b/docs/plans/2026-03-06-desktop-palette-integration-v2.md new file mode 100644 index 00000000..c8b9aabf --- /dev/null +++ b/docs/plans/2026-03-06-desktop-palette-integration-v2.md @@ -0,0 +1,903 @@ +# Desktop + Palette Integration Design v2 (Final) + +**Date:** 2026-03-06 +**Status:** Final -- incorporates all answers from the Palette team questionnaire +**Replaces:** `2026-03-06-palette-desktop-integration-design.md` + +--- + +## 1. Goal + +Director's Desktop is the local creative powerhouse (GPU rendering, NLE editor, unlimited +local gallery). Director's Palette is the cloud creative hub (storyboards, library, +characters, credits, web/mobile access). This design connects them so a user signed +in to Palette can browse their cloud library from Desktop, push local work to the cloud, +pull cloud assets locally, and use Palette's prompt expander -- all while Desktop remains +fully functional offline. + +--- + +## 2. System Context + +``` +Director's Palette (Next.js 15, Vercel) + Production: directors-palette-v2.vercel.app + Auth: Supabase cookie sessions (web) + API keys for Desktop (dp_xxx, SHA-256 hashed, admin-only today) + DB: Supabase PostgreSQL (gallery, storyboard_characters, + style_guides, reference, user_credits, brands) + Storage: Supabase Storage bucket "directors-palette" (public) + Limits: 500 image cap, videos expire 7 days, 50 MB upload max + No CORS headers today (must be added for /api/desktop/* routes) + No staging environment + | + | HTTPS /api/desktop/* (new routes, to be built on Palette side) + | Authorization: Bearer dp_xxx + v +Director's Desktop (Electron + FastAPI) + Electron main process + - Registers directorsdesktop:// protocol handler + - Stores token in safeStorage + - Passes token to backend via POST /api/sync/connect on startup + FastAPI backend (localhost:8000) + - Proxies all Palette API calls (frontend never calls Palette directly) + - Token held in memory: AppState.app_settings.palette_api_key + - Persisted to settings.json (encrypted at rest via Electron safeStorage) + React frontend + - Calls localhost:8000 only + - Cloud features appear conditionally when connected +``` + +**Key principle:** Desktop works fully offline. Cloud features are additive. + +--- + +## 3. Authentication + +### 3.1 Three Auth Methods + +``` +Method 1: Browser Login (default, best UX) + Desktop click "Sign In" + -> Electron opens system browser to: + directors-palette-v2.vercel.app/auth/desktop-login + -> User logs in via Supabase (email/password, Google OAuth) + -> Palette redirects to: directorsdesktop://auth/callback?token=dp_xxx + -> Electron intercepts protocol, extracts token + -> Electron calls POST localhost:8000/api/sync/connect {token: "dp_xxx"} + -> Backend validates via Palette /api/desktop/auth/validate + -> Settings saved to disk + +Method 2: API Key Paste (power users) + User generates API key in Palette admin panel + -> Copies dp_xxx key + -> Pastes into Desktop Settings > Palette Connection + -> Frontend calls POST /api/sync/connect {token: "dp_xxx"} + -> Same validation flow as above + +Method 3: QR Pairing (future -- mobile) + Desktop generates short-lived pairing code + QR + -> Displays QR in modal + -> Desktop polls Palette: GET /api/desktop/pair/poll?code=XXXX + -> User scans QR from Palette mobile (already logged in) + -> Palette associates code with user, returns token on next poll + -> Desktop receives token, connects +``` + +### 3.2 Token Format and Validation + +- Format: `dp_` prefix + 40 hex chars (e.g., `dp_a1b2c3...`) +- Storage on Palette: SHA-256 hash in `api_keys` table +- No `/api/me` endpoint exists; user info comes from `auth.users.user_metadata` + (display_name, avatar_url) returned by the validate endpoint + +### 3.3 Token Lifecycle + +``` +Electron startup + -> Read encrypted token from safeStorage + -> If token exists, POST /api/sync/connect to backend + -> Backend validates with Palette (caches user info in SyncHandler._cached_user) + -> If validation fails (expired/revoked), set connected=false, clear cache + -> Frontend polls GET /api/sync/status to show auth state + +Sign out + -> POST /api/sync/disconnect + -> Backend clears palette_api_key from AppState + -> Settings saved (key removed from disk) + -> Electron clears safeStorage +``` + +### 3.4 What Palette Must Build + +| Item | Description | +|------|-------------| +| `/api/desktop/auth/validate` | POST, accepts `Authorization: Bearer dp_xxx`, returns user_metadata | +| `/auth/desktop-login` page | Login flow that redirects to `directorsdesktop://auth/callback?token=dp_xxx` | +| Non-admin API key generation | Allow regular users to create API keys (currently admin-only) | +| CORS headers on `/api/desktop/*` | Not strictly needed (Desktop backend proxies), but good practice | + +--- + +## 4. Desktop Backend Routes + +All cloud features are proxied through the Desktop backend. The frontend never calls +Palette directly. This keeps auth tokens server-side and allows offline graceful degradation. + +### 4.1 Existing Routes (enhance) + +| Route | Method | Current | Enhancement | +|-------|--------|---------|-------------| +| `GET /api/sync/status` | GET | Returns `{connected, user}` | Add `credits_balance`, `gallery_count`, `video_expiry_warning` | +| `POST /api/sync/connect` | POST | Stores token, validates | No changes needed | +| `POST /api/sync/disconnect` | POST | Clears token | No changes needed | +| `GET /api/sync/credits` | GET | Returns `{connected, balance}` | Add `pricing` object with per-model costs | + +### 4.2 New Routes + +| Route | Method | Purpose | +|-------|--------|---------| +| `GET /api/sync/gallery` | GET | Proxy paginated gallery list from Palette | +| `POST /api/sync/gallery/upload` | POST | Proxy multipart upload to Palette (50 MB max) | +| `GET /api/sync/gallery/{id}/download` | GET | Download a cloud asset to local outputs dir | +| `GET /api/sync/library/characters` | GET | Proxy character list (flattened from per-storyboard) | +| `GET /api/sync/library/styles` | GET | Proxy style guides + brands | +| `GET /api/sync/library/references` | GET | Proxy references (people/places/props/layouts) | +| `POST /api/sync/prompt/enhance` | POST | Proxy to Palette prompt expander | + +### 4.3 Route Details + +#### GET /api/sync/gallery + +Proxies to Palette `GET /api/desktop/gallery?page=1&per_page=50&type=all`. + +Response: +```json +{ + "items": [ + { + "id": "uuid", + "filename": "gen_001.png", + "url": "https://...supabase.co/storage/v1/object/public/directors-palette/...", + "type": "image", + "size_bytes": 1234567, + "created_at": "2026-03-06T...", + "expires_at": null, + "is_video": false + } + ], + "total": 142, + "page": 1, + "per_page": 50, + "total_pages": 3 +} +``` + +Notes: +- Videos have `expires_at` set (7-day expiry from Palette). +- No thumbnails exist on Palette; Desktop must show full images or generate local thumbnails. +- 500 image cap is enforced server-side on Palette; Desktop shows remaining quota in gallery header. + +#### POST /api/sync/gallery/upload + +Accepts multipart form data with a file field. Proxies to Palette `POST /api/upload-file`. +50 MB max enforced on both Desktop (FastAPI request size limit) and Palette. + +Request: `multipart/form-data` with `file` field. + +Response: +```json +{ + "status": "ok", + "gallery_id": "uuid", + "url": "https://...supabase.co/storage/..." +} +``` + +Error when at cap: +```json +{ + "error": "Gallery full (500/500). Delete items in Palette to make room." +} +``` + +#### GET /api/sync/gallery/{id}/download + +Downloads the cloud asset and saves it to Desktop's local outputs directory. +Returns the local file path so the frontend can display it immediately. + +Response: +```json +{ + "status": "ok", + "local_path": "/path/to/outputs/cloud_abc123.png", + "type": "image" +} +``` + +#### GET /api/sync/library/characters + +Proxies to Palette `GET /api/desktop/library/characters`. + +Palette stores characters per-storyboard (`storyboard_characters` table). The Desktop +API endpoint flattens these into a single list, deduplicating by name. Each character +has a single reference image via a gallery FK. + +Response: +```json +{ + "characters": [ + { + "id": "uuid", + "name": "Maya", + "role": "protagonist", + "description": "A young filmmaker...", + "reference_image_url": "https://...supabase.co/storage/...", + "storyboard_name": "Episode 1", + "source": "cloud" + } + ] +} +``` + +#### GET /api/sync/library/styles + +Proxies to Palette `GET /api/desktop/library/styles`. + +Returns style guides (global to user, table `style_guides`) and brands (table `brands`). +Also includes the 9 hardcoded presets from Palette. + +Response: +```json +{ + "styles": [...], + "brands": [...], + "presets": ["cinematic", "anime", "noir", ...] +} +``` + +#### GET /api/sync/library/references + +Proxies to Palette `GET /api/desktop/library/references`. + +Categories: people, places, props, layouts. Tags searchable via GIN index on Palette. + +Query params: `?category=people&search=sunset` + +Response: +```json +{ + "references": [ + { + "id": "uuid", + "name": "Beach sunset", + "category": "places", + "tags": ["sunset", "ocean", "warm"], + "image_url": "https://...supabase.co/storage/...", + "source": "cloud" + } + ] +} +``` + +#### POST /api/sync/prompt/enhance + +Proxies to Palette `POST /api/prompt-expander`. + +When Desktop is connected to Palette, prompt enhancement uses Palette's GPT-4o-mini +expander (richer, director-style output). When disconnected, falls back to the existing +local Gemini-based `EnhancePromptHandler`. + +Request: +```json +{ + "prompt": "a woman walking in rain", + "level": "2x", + "director_style": "spielberg" +} +``` + +Palette expander supports: +- `level`: "2x" (moderate) or "3x" (maximum expansion) +- `director_style`: optional style influence (e.g., "spielberg", "kubrick", "nolan") + +Response: +```json +{ + "enhanced_prompt": "A determined young woman in a dark trenchcoat...", + "source": "palette" +} +``` + +Fallback (disconnected): +```json +{ + "enhanced_prompt": "...", + "source": "gemini" +} +``` + +### 4.4 Credits Enhancement + +`GET /api/sync/credits` enhanced response: + +```json +{ + "connected": true, + "balance_cents": 4250, + "balance_display": "$42.50", + "pricing": { + "image": 20, + "video": 40 + }, + "unit": "cents" +} +``` + +Credits are only deducted for cloud-based generation (Replicate API calls through Palette). +Local GPU generation is always free. + +--- + +## 5. Data Flow Diagrams + +### 5.1 Auth Flow (Browser Login) + +``` +User Desktop Frontend Desktop Backend Electron Main Palette Web + | | | | | + |--click "Sign In"----->| | | | + | |--IPC: openPaletteLoginPage---------->| | + | | | |--open browser->| + | | | | | + | | | (user logs in via Supabase) + | | | | | + | | | |<--redirect-----| + | | | | directorsdesktop:// + | | | | auth/callback?token=dp_xxx + | | | | | + | | |<--POST /api/sync/connect--------| + | | | {token: "dp_xxx"} | + | | | | | + | | |----validate-----|-------------->| + | | | GET /api/desktop/auth/validate + | | |<---user_metadata-|--------------| + | | | | | + | | |--save settings->| | + | | | |--safeStorage | + | | | | | + | |<--{connected:true}--| | | + |<--show user avatar----| | | | +``` + +### 5.2 Cloud Gallery Browse + Download + +``` +User Desktop Frontend Desktop Backend Palette API + | | | | + |--click Cloud tab----->| | | + | |--GET /api/sync/gallery?page=1----------->| + | | |--GET /api/desktop/gallery------>| + | | |<--{items, total, ...}-----------| + | |<--gallery items----| | + |<--render grid---------| | | + | | | | + |--click download------>| | | + | |--GET /api/sync/gallery/{id}/download---->| + | | |--download file from Supabase--->| + | | |<--binary file data-------------| + | | |--save to outputs/ | + | |<--{local_path}-----| | + |<--show in Local tab---| | | +``` + +### 5.3 Push Local Asset to Cloud + +``` +User Desktop Frontend Desktop Backend Palette API + | | | | + |--click "Push to Cloud" on local asset----->| | + | |--POST /api/sync/gallery/upload----------->| + | | (multipart: file from local path) | + | | |--POST /api/upload-file--------->| + | | | (multipart, 50MB max) | + | | |<--{gallery_id, url}-------------| + | |<--{status: "ok"}----| | + |<--show cloud badge----| | | +``` + +### 5.4 Prompt Enhancement (Dual Path) + +``` +User Desktop Frontend Desktop Backend Palette API + | | | | + |--click sparkle btn--->| | | + | |--POST /api/sync/prompt/enhance----------->| + | | {prompt, level, director_style} | + | | | | + | | [connected to Palette?] | + | | | | + | | YES: |--POST /api/prompt-expander---->| + | | | (GPT-4o-mini, director style)| + | | |<--{enhanced}-------------------| + | | | | + | | NO: |--Gemini API (local key) | + | | | (existing EnhancePromptHandler) + | | |<--{enhanced} | + | | | | + | |<--{enhanced_prompt, source}---------------| + |<--replace prompt------| | | +``` + +--- + +## 6. Backend Implementation Details + +### 6.1 PaletteSyncClient Protocol Extension + +The existing `PaletteSyncClient` protocol (`services/palette_sync_client/palette_sync_client.py`) +must be extended with new methods: + +```python +class PaletteSyncClient(Protocol): + # Existing + def validate_connection(self, *, api_key: str) -> dict[str, Any]: ... + def get_credits(self, *, api_key: str) -> dict[str, Any]: ... + + # New + def list_gallery(self, *, api_key: str, page: int, per_page: int, + asset_type: str) -> dict[str, Any]: ... + def download_asset(self, *, api_key: str, asset_id: str) -> bytes: ... + def upload_asset(self, *, api_key: str, file_data: bytes, + filename: str, content_type: str) -> dict[str, Any]: ... + def list_characters(self, *, api_key: str) -> dict[str, Any]: ... + def list_styles(self, *, api_key: str) -> dict[str, Any]: ... + def list_references(self, *, api_key: str, category: str | None, + search: str | None) -> dict[str, Any]: ... + def enhance_prompt(self, *, api_key: str, prompt: str, level: str, + director_style: str | None) -> dict[str, Any]: ... +``` + +The real implementation (`PaletteSyncClientImpl`) calls Palette's `/api/desktop/*` routes. +A `FakePaletteSyncClient` in `tests/fakes/` provides canned responses for testing. + +### 6.2 SyncHandler Extension + +The existing `SyncHandler` (`handlers/sync_handler.py`) gains new methods mapping 1:1 to +the new routes. Each method: + +1. Reads `api_key` from `self._state.app_settings.palette_api_key` +2. Returns `{"connected": false, ...}` if no key +3. Delegates to `self._client.(...)` +4. Returns the response dict + +For `gallery/download`, the handler also writes the downloaded bytes to the local outputs +directory, following the same naming convention as `GalleryHandler`. + +For `prompt/enhance`, the handler checks connection status. If connected, calls +`self._client.enhance_prompt(...)`. If disconnected, delegates to the existing +`EnhancePromptHandler.enhance()` method (Gemini path). + +### 6.3 New Route File + +`_routes/sync.py` already exists. The new gallery/library/prompt routes are added +to this same file, keeping the `/api/sync` prefix. + +### 6.4 AppState Changes + +No new fields on `AppState`. The `palette_api_key` field on `AppSettings` is sufficient. +Cached user info stays in `SyncHandler._cached_user` (already implemented). + +### 6.5 Concurrency + +Palette API calls are network I/O. They should NOT hold the shared lock. The existing +SyncHandler methods already follow this pattern (no lock usage). New methods follow +the same approach. + +--- + +## 7. Palette Backend Routes (To Be Built) + +These are the Next.js API routes that Palette must implement to support Desktop integration. +All routes require `Authorization: Bearer dp_xxx` header. All routes must add CORS +headers allowing `localhost:8000` origin. + +| Route | Method | Description | +|-------|--------|-------------| +| `/api/desktop/auth/validate` | POST | Validate API key, return user_metadata | +| `/api/desktop/gallery` | GET | Paginated gallery list (query: page, per_page, type) | +| `/api/desktop/gallery/download/{id}` | GET | Return signed URL or stream asset bytes | +| `/api/desktop/library/characters` | GET | All characters across storyboards (flattened) | +| `/api/desktop/library/styles` | GET | Style guides + brands + preset names | +| `/api/desktop/library/references` | GET | References with category/tag filtering | +| `/api/desktop/credits` | GET | Balance in cents + pricing table | +| `/api/desktop/pair/create` | POST | Create pairing code (for QR flow, future) | +| `/api/desktop/pair/poll` | GET | Poll for completed pairing (future) | + +Existing Palette routes that Desktop proxies directly (no `/api/desktop/` wrapper needed): + +| Route | Method | Notes | +|-------|--------|-------| +| `/api/upload-file` | POST | Multipart upload, 50 MB max. Needs CORS + API key auth added. | +| `/api/prompt-expander` | POST | GPT-4o-mini expander. Needs CORS + API key auth added. | + +--- + +## 8. Electron Changes + +### 8.1 Protocol Handler Registration + +In `electron/main.ts` (or equivalent), register the custom protocol on app startup: + +```typescript +app.setAsDefaultProtocolClient('directorsdesktop'); + +// Handle the protocol URL +app.on('open-url', (event, url) => { + event.preventDefault(); + handleDeepLink(url); +}); + +// Windows: protocol URL comes via second-instance +app.on('second-instance', (event, argv) => { + const deepLink = argv.find(arg => arg.startsWith('directorsdesktop://')); + if (deepLink) handleDeepLink(deepLink); + // Focus existing window + if (mainWindow) { + if (mainWindow.isMinimized()) mainWindow.restore(); + mainWindow.focus(); + } +}); +``` + +### 8.2 Deep Link Handler + +```typescript +async function handleDeepLink(url: string) { + const parsed = new URL(url); + if (parsed.hostname === 'auth' && parsed.pathname === '/callback') { + const token = parsed.searchParams.get('token'); + if (token) { + // Store in safeStorage for persistence across restarts + safeStorage.encryptString(token); // save to disk + + // Send to backend + await fetch('http://localhost:8000/api/sync/connect', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ token }), + }); + + // Notify renderer to refresh auth state + mainWindow?.webContents.send('auth-state-changed'); + } + } +} +``` + +### 8.3 Startup Token Restoration + +On app launch, after the backend is ready: + +```typescript +const encryptedToken = readFromDisk(); // safeStorage file +if (encryptedToken) { + const token = safeStorage.decryptString(encryptedToken); + await fetch('http://localhost:8000/api/sync/connect', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ token }), + }); +} +``` + +### 8.4 New IPC Handler + +Add to `electron/preload.ts`: + +```typescript +openPaletteLoginPage: () => ipcRenderer.invoke('open-palette-login-page'), +// Already exists in preload.ts -- just needs the main process handler +// to open: directors-palette-v2.vercel.app/auth/desktop-login +``` + +### 8.5 Electron Builder Config + +Add protocol association to `electron-builder.yml`: + +```yaml +protocols: + - name: "Directors Desktop" + schemes: + - directorsdesktop +``` + +--- + +## 9. Frontend Changes + +### 9.1 Gallery View (Gallery.tsx) + +Add tabs: **Local** | **Cloud** + +- **Local tab** (existing): Shows local gallery from `GET /api/gallery/local` +- **Cloud tab** (new): Shows cloud gallery from `GET /api/sync/gallery` + - Paginated grid with lazy loading + - Each item shows a cloud badge icon + - Click to preview (full-size, no thumbnails from Palette) + - "Download" button saves to local and switches to Local tab + - Shows video expiry warning: "Expires in X days" for video items + - Shows quota: "142 / 500 images" + - Disabled with "Sign in to access cloud gallery" when not connected + +**Push to Cloud** button on local items: +- Available on Local tab items +- Grayed out when not connected or at 500 cap +- Triggers `POST /api/sync/gallery/upload` + +### 9.2 Characters View (Characters.tsx) + +Show two sections or a toggle: **Local** | **Cloud** + +- **Local characters** (existing): From `GET /api/library/characters` +- **Cloud characters** (new): From `GET /api/sync/library/characters` + - Shows storyboard origin as subtitle: "from Episode 1" + - Cloud badge on items + - "Pull to Local" downloads reference image and creates a local character + - Palette characters are per-storyboard; Desktop flattens into a single list, + deduplicating by name (keep most recently modified) + +### 9.3 Styles View (Styles.tsx) + +Similar Local/Cloud split: +- Cloud styles from `GET /api/sync/library/styles` +- Includes brands (separate from style guides on Palette) +- Shows 9 hardcoded presets (always available, no pull needed) +- "Pull to Local" creates a local style entry with downloaded reference image + +### 9.4 References View (References.tsx) + +- Cloud references from `GET /api/sync/library/references` +- Category filter: people / places / props / layouts (matches Palette categories) +- Tag search (GIN-indexed on Palette, passed as query param) +- "Pull to Local" downloads image, creates local reference + +### 9.5 Home.tsx Sidebar + +Enhance the account section at the bottom of the sidebar: + +``` +Connected state: + [Avatar] display_name + Credits: $42.50 + [Sign Out] + +Disconnected state: + [Sign In to Director's Palette] + (or paste API key in Settings) +``` + +### 9.6 Settings View + +Add "Director's Palette" section: + +``` +Director's Palette + Status: Connected as "John Doe" [Disconnect] + --or-- + Status: Not connected [Sign In] [Paste API Key] + + API Key: dp_xxxx...xxxx [Clear] + + Prompt Enhancement: [Palette (GPT-4o-mini)] / [Local (Gemini)] + When connected, defaults to Palette. Toggle to override. +``` + +### 9.7 Prompt Enhancement UI + +The sparkle button next to the prompt field in Playground.tsx / GenSpace.tsx: + +- Calls `POST /api/sync/prompt/enhance` (unified endpoint) +- Shows source badge: "Enhanced by Palette" or "Enhanced by Gemini" +- Level selector dropdown: "2x" or "3x" +- Optional director style selector (only when Palette connected) + +--- + +## 10. Generation Pipeline: Last Frame Wiring + +The `lastFramePath` field exists on `GenerateVideoRequest` but is not yet wired through +the generation pipeline. This section specifies how to complete the wiring. + +### 10.1 Queue Pipeline + +`QueueSubmitRequest.params` already passes arbitrary params. The `lastFramePath` key +needs to flow through: + +``` +Frontend: POST /api/queue/submit + params: { prompt: "...", imagePath: "first.png", lastFramePath: "last.png", ... } + | + v +JobQueue.submit() -> Job record with params dict + | + v +Job runner picks job, reads params["lastFramePath"] + | + v +VideoGenerationHandler or API client +``` + +### 10.2 LTX Local (Fast Pipeline) + +LTX supports multiple `ImageConditioningInput` entries. Currently, first frame is: + +```python +ImageConditioningInput(path=image_path, frame_idx=0, strength=1.0) +``` + +Last frame adds a second entry: + +```python +ImageConditioningInput(path=last_frame_path, frame_idx=num_frames - 1, strength=1.0) +``` + +Both can be provided simultaneously for start-to-end guided generation. + +The `num_frames` value comes from the duration and fps settings: +`num_frames = int(float(duration) * float(fps)) + 1` + +### 10.3 Seedance (Replicate API) + +Seedance supports `last_frame` as a separate parameter (not an image conditioning list). +The `ReplicateVideoClientImpl` must accept and pass `last_frame` to the Replicate API +input params. + +### 10.4 Receive Job from Palette + +The existing `ReceiveJobHandler` already handles `last_frame_url` on `ReceiveJobRequest`. +It downloads the URL to a temp file and passes it as `lastFramePath` in the job params. +This path just needs to be wired through to the generation handler as described above. + +--- + +## 11. Key Decisions + +| # | Decision | Rationale | +|---|----------|-----------| +| 1 | Desktop works fully offline | Cloud is additive, not required | +| 2 | Delete local does not delete cloud | One-way protection, avoids accidental data loss | +| 3 | Push to cloud is explicit | User chooses what uploads; respects 500 cap | +| 4 | Credits only for cloud generation | Local GPU is free; credits = Palette API costs | +| 5 | Videos expire after 7 days on Palette | Desktop warns with countdown badge on video items | +| 6 | Characters are per-storyboard in Palette, global in Desktop | Desktop flattens + deduplicates for simpler browsing | +| 7 | Prompt enhancement: Palette when connected, Gemini when not | Palette uses GPT-4o-mini with director styles; Gemini is the offline fallback | +| 8 | Frontend never calls Palette directly | All proxied through Desktop backend; keeps tokens server-side | +| 9 | Token stored in Electron safeStorage | OS-level encryption; backend holds in memory only | +| 10 | No thumbnails from Palette | Display full images; consider generating local thumbnails for performance | +| 11 | Palette base URL: `directors-palette-v2.vercel.app` | No staging env; single production target | +| 12 | Upload limit: 50 MB | Matches Palette's existing `POST /api/upload-file` limit | + +--- + +## 12. Error Handling + +### 12.1 Network Failures + +All sync routes return graceful disconnected responses on network failure: + +```json +{ + "connected": false, + "error": "Could not reach Director's Palette. Working offline." +} +``` + +Frontend shows a subtle toast, not a blocking error. Local features remain fully functional. + +### 12.2 Token Expiry/Revocation + +If Palette returns 401 on any proxy call: + +1. Backend clears `palette_api_key` from AppState +2. Clears `_cached_user` +3. Returns `{"connected": false, "error": "Session expired. Please sign in again."}` +4. Frontend transitions to disconnected state + +### 12.3 Gallery Quota + +Upload attempts when at 500 cap return: + +```json +{ + "error": "gallery_full", + "message": "Gallery full (500/500). Delete items in Director's Palette to make room.", + "current": 500, + "limit": 500 +} +``` + +### 12.4 Video Expiry + +When listing cloud gallery items, videos with `expires_at` within 24 hours get flagged: + +```json +{ + "expiry_warning": "This video expires in 6 hours. Download locally to keep it." +} +``` + +--- + +## 13. Testing Strategy + +Following the existing backend testing patterns (integration-first, no mocks): + +### 13.1 FakePaletteSyncClient + +New file: `tests/fakes/fake_palette_sync_client.py` + +Provides canned responses for all `PaletteSyncClient` protocol methods. Supports +configurable scenarios: + +- `FakePaletteSyncClient(connected=True)` -- returns valid user, gallery, etc. +- `FakePaletteSyncClient(connected=False)` -- raises on all calls +- `FakePaletteSyncClient(gallery_full=True)` -- upload returns 409 +- `FakePaletteSyncClient(token_expired=True)` -- returns 401 + +### 13.2 Test Cases + +**Auth tests:** +- Connect with valid token -> connected=true, user info cached +- Connect with invalid token -> connected=false, error returned +- Disconnect -> clears key, status shows disconnected +- Token expiry on subsequent call -> auto-disconnect + +**Gallery tests:** +- List cloud gallery (paginated) +- Download cloud asset to local +- Upload local asset to cloud +- Upload when at 500 cap -> error +- Gallery operations when disconnected -> graceful error + +**Library tests:** +- List cloud characters (flattened from per-storyboard) +- List cloud styles + brands + presets +- List cloud references with category filter + +**Prompt enhancement tests:** +- Enhance with Palette connected -> uses Palette expander, source="palette" +- Enhance with Palette disconnected -> uses Gemini, source="gemini" +- Enhance with no Gemini key and disconnected -> error + +**Credits tests:** +- Get credits when connected -> balance + pricing +- Get credits when disconnected -> graceful error + +--- + +## 14. Migration from v1 Design + +This document replaces `2026-03-06-palette-desktop-integration-design.md`. Key changes: + +1. **Removed speculative features** not informed by the questionnaire (bidirectional job sync, inpainting, prompt library sync) +2. **Added concrete Palette schema details** (table names, column names, storage bucket, caps) +3. **Added prompt expander integration** as a first-class feature (GPT-4o-mini, 2x/3x levels, director styles) +4. **Clarified character model mismatch** (per-storyboard vs global) with flatten strategy +5. **Added video expiry handling** (7-day expiry with warnings) +6. **Specified all Palette routes to be built** (previously vague) +7. **Added upload details** (multipart, 50 MB, `/api/upload-file` endpoint) +8. **Clarified credits format** (balance in cents, per-model pricing: image 20c, video 40c) +9. **Removed assumptions about existing endpoints** (`/api/me` does not exist; user info from `auth.users.user_metadata`) +10. **Added CORS requirement** for Palette `/api/desktop/*` routes + +--- + +## 15. Implementation Order + +1. **Palette side first:** Build `/api/desktop/*` routes + CORS + API key auth for non-admin users +2. **Desktop backend:** Extend `PaletteSyncClient` protocol + impl + sync handler + routes +3. **Electron:** Register protocol handler + deep link + safeStorage flow +4. **Frontend:** Settings connection UI -> Gallery Cloud tab -> Library cloud sections -> Prompt enhance +5. **Last frame wiring:** Queue pipeline -> LTX handler -> Seedance handler +6. **Tests:** FakePaletteSyncClient + integration tests for all sync routes diff --git a/docs/plans/2026-03-06-palette-desktop-integration-design.md b/docs/plans/2026-03-06-palette-desktop-integration-design.md new file mode 100644 index 00000000..0b17fd23 --- /dev/null +++ b/docs/plans/2026-03-06-palette-desktop-integration-design.md @@ -0,0 +1,321 @@ +# Director's Desktop + Director's Palette Integration Design + +## Goal + +Transform Directors Desktop from a standalone local AI generation app into the **local creative powerhouse** paired with Director's Palette (web/mobile). Palette is where you plan, organize, and manage your creative library in the cloud. Desktop is where you render, edit, and experiment with unlimited local power. They share auth, gallery, library, and characters. + +## Architecture + +``` +Director's Palette (Web/Mobile) + Supabase Auth - PostgreSQL - Supabase Storage + Gallery - Library - Characters - Brands - Credits + Generation (Replicate) - Storyboards + | + | REST API (/api/desktop/*) + v +Director's Desktop (Electron) + Palette Sync Layer Local Engine + - Auth (3 methods) - GPU Generation + - Gallery browser - Local gallery (unlimited) + - Library/Characters - NLE Editor + - Credits display - Job Queue + - Push results back - Settings + + FastAPI Backend (localhost:8000) + + /api/sync/* routes for Palette connection +``` + +Key principle: Desktop works fully offline/standalone. Palette connection is additive. + +--- + +## Phase 1: Foundation (Auth + Sync Infrastructure) + +### Authentication — Three Methods + +1. **Browser login (default)** — Click "Sign In," Desktop opens browser to Palette's Supabase auth page. After login, token returns to Electron via deep link (`directorsdesktop://auth/callback`). Same UI as Palette (email/password + Google OAuth + sign up). + +2. **QR code pairing** — Desktop generates short-lived pairing code + QR. Scan from Palette mobile app (already logged in). Desktop receives auth token via polling. + +3. **API key** — Generate in Palette Settings > Developer > API Keys. Paste into Desktop settings. + +**Token storage:** Electron `safeStorage` API. Backend receives via `Authorization` header on sync routes. + +**Offline behavior:** Local generation, gallery, NLE all functional without auth. Cloud features appear only when connected. + +### Sync Protocol + +**New Palette API routes** (`/api/desktop/*`): + +| Route | Method | Purpose | +|-------|--------|---------| +| `/api/desktop/auth/validate` | POST | Validate Desktop auth token | +| `/api/desktop/gallery` | GET | Paginated gallery browse | +| `/api/desktop/gallery/download/:id` | GET | Download single asset | +| `/api/desktop/gallery/upload` | POST | Push local asset to cloud | +| `/api/desktop/library/characters` | GET | List characters + references | +| `/api/desktop/library/styles` | GET | List style guides + brands | +| `/api/desktop/library/references` | GET | List reference images | +| `/api/desktop/credits` | GET | Credit balance | +| `/api/desktop/prompts` | GET | Synced prompt library | +| `/api/desktop/send-job` | POST | Send generation job from Palette to Desktop | + +**New Desktop API routes** (`/api/sync/*`): + +| Route | Method | Purpose | +|-------|--------|---------| +| `/api/sync/connect` | POST | Store Palette auth token | +| `/api/sync/status` | GET | Connection status + user info | +| `/api/sync/receive-job` | POST | Accept generation job from Palette | + +--- + +## Phase 2: Gallery + Library + +### Gallery — Local + Cloud Unified View + +Two tabs: +- **Local** — Unlimited images/videos from local generations. Stored on disk. +- **Cloud** — Paginated browse of Palette gallery (500 image cap server-side). Pull on demand. + +Key behaviors: +- Cloud items show cloud badge. Click to download locally. +- Local items have "Push to Cloud" button (counts against 500 cap). +- Delete local ≠ delete cloud. Independent. +- Search/filter across both tabs. +- Model badge on each asset (already implemented). + +### Library — Characters, Styles, References + +**Characters panel:** +- Pull character list from Palette (name, role, reference images). +- Create local-only characters (name + reference images). +- Push individual local characters to Palette when ready. +- Pick a character when generating → references auto-attach. + +**Styles panel:** +- Browse Palette's style guides and brand visual identities. +- Create local style presets (reference image + description). +- Generate Style Guide Grid (3x3). + +**References panel:** +- Categorized reference images (people, places, props). +- `@mention` autocomplete in prompt field. +- Pull from Palette or upload locally. + +### Credits Display +- Live balance visible in sidebar. +- Deducted when using cloud generation (Replicate API). +- Local GPU generation is free (no credits needed). + +--- + +## Phase 3: Generation Upgrades + +### First Frame + Last Frame + +Two image slots in Playground: + +``` +[First Frame] [Last Frame] +Paste / Drop Paste / Drop +or Browse or Browse +``` + +**Input methods:** Ctrl+V paste from clipboard, drag & drop, file browse, or "Extract Frame" from existing video. + +**Backend wiring:** +- LTX local: First → `ImageConditioningInput(frame_idx=0)`, Last → `ImageConditioningInput(frame_idx=num_frames-1)`. Both can be set simultaneously. +- Seedance (Replicate): First → `image` param, Last → `last_frame` param. + +**New API fields:** +- `GenerateVideoRequest.lastFramePath: str | None = None` +- `QueueSubmitRequest` updated to pass last frame path through params. + +### Image Variations +- Add "Variations: 1-12" slider to image generation settings. +- Backend already supports it (clamped 1-12), UI just needs to expose it. +- Results displayed in a grid. + +### Social Media Aspect Ratio Presets +- 16:9 → "YouTube / Landscape" +- 9:16 → "TikTok / Reels / Shorts" +- 1:1 → "Instagram Square" +- 4:3 → "Standard" +- 4:5 → "Instagram Post" (new) + +### Prompt Enhancement Button +- Sparkle button next to prompt field. +- Uses Gemini (key already in settings) to enhance rough prompt into detailed cinematic prompt. +- New backend route: `POST /api/enhance-prompt` (generalize existing gap-prompt logic). + +### Frame Extraction from Player +- Right-click video in player → "Extract Frame" → saves to local gallery. +- "Use as First Frame" / "Use as Last Frame" context menu options. +- Uses existing `extract-video-frame` IPC handler. + +--- + +## Phase 4: Power Tools + +### Wildcards +- Define variation lists: `_outfit_` = ["red dress", "blue suit", "leather jacket"] +- Use in prompts: "A woman in _outfit_ walking through _location_" +- Generate all combinations or random selection. +- Parser logic from Palette's `wildcard/parser.ts`. + +### Prompt Library with Autocomplete +- Save and recall prompts with `@tag` autocomplete. +- Sort by most-used, recent, category. +- Sync with Palette prompt library when connected. + +### Contact Sheet Generation +- Select reference image → generate 3x3 grid of cinematic angles in one API call. +- Slice into 9 separate images. +- Add any to project timeline or gallery. + +### Style Guide Grids +- Reference image + style name → 3x3 grid showing style across diverse subjects. +- Store in project asset library. + +--- + +## Phase 5: Advanced Features + +### Send-to-Desktop from Palette Web +- Button in Palette: "Send to Desktop" on any generation prompt or storyboard shot. +- Desktop receives via `/api/sync/receive-job`. +- Job appears in Desktop queue with prompt + references pre-loaded. + +### Inpainting in Editor +- Annotation canvas overlay on video frames. +- Draw masks/arrows/labels → send to inpaint API. +- Non-destructive editing of generated shots. + +### Bidirectional Prompt/Job Sync +- Generation jobs started in Palette can be monitored in Desktop and vice versa. +- Shared job history when connected. + +### QR Code Pairing +- Full implementation with WebSocket handshake. +- One-time scan, persistent session. + +--- + +## Phase 6: Testing & Quality Assurance + +### Backend Integration Tests (Python — pytest) + +**Auth & sync tests:** +- Token validation (valid, expired, malformed) +- Connection status (connected, disconnected, token refresh) +- Auth token storage and retrieval + +**Gallery sync tests:** +- Cloud gallery pagination and filtering +- Asset download (success, not found, unauthorized) +- Asset upload/push (success, quota exceeded, duplicate) +- Local delete does not affect cloud +- Cloud browse while offline (graceful failure) + +**Library sync tests:** +- Character list fetch and caching +- Character reference image download +- Style guide fetch +- Reference image categorization +- Push local character to cloud + +**Generation pipeline tests:** +- First frame only generation (LTX + Seedance) +- Last frame only generation (LTX + Seedance) +- First + last frame simultaneous generation +- Last frame index calculation from num_frames +- Paste/clipboard image handling +- Image variations (1, 4, 12 count) +- Prompt enhancement route + +**Queue tests:** +- Receive job from Palette (valid, malformed) +- Job with character references attached +- Job with first/last frame images +- Credits check before cloud generation + +### Frontend Tests (TypeScript) + +**Component tests:** +- First/Last frame image slots (paste, drop, browse, clear) +- Gallery tab switching (local vs cloud) +- Character picker with reference previews +- Credits display (connected vs disconnected) +- Sidebar navigation states (auth vs no-auth) +- Image variation grid rendering +- Social media preset labels +- Prompt autocomplete dropdown + +**Integration tests:** +- Auth flow: browser login → token stored → sync routes work +- Gallery: browse cloud → download → appears in local +- Push: local asset → push to cloud → appears in cloud tab +- Generation: set first+last frame → generate → correct params sent + +### End-to-End Tests + +- Full auth → browse gallery → download → use as first frame → generate → push result workflow +- Offline mode: all local features work without Palette connection +- Reconnection: lose connection → queue local work → reconnect → sync + +### Palette API Tests (Next.js — in Director's Palette repo) + +- Desktop auth validation endpoint +- Gallery pagination with RLS (user sees only their data) +- Asset download with signed URLs +- Upload with quota enforcement +- Character/style/reference list endpoints +- Credits balance endpoint +- Send-job endpoint + +--- + +## Sidebar Layout + +``` +[Logo] Director's Desktop + +── CREATE ────────── + Playground (generate images/video) + Queue (job queue status) + +── EDIT ──────────── + Projects (NLE timeline projects) + +── LIBRARY ───────── + Gallery (local + cloud) + Characters (Palette + local) + Styles (style guides/brands) + References (categorized refs) + +── TOOLS ─────────── + Wildcards (prompt variations) + Prompt Library (saved prompts) + Contact Sheets (3x3 angle grids) + +── ACCOUNT ───────── + Credits: 4,250 (live balance) + Settings + Sign In / User +``` + +When not signed in: LIBRARY shows local-only items + "Sign in to sync" CTA. ACCOUNT shows Sign In button. Credits hidden. + +--- + +## Key Decisions + +1. **Desktop works standalone** — No auth required for local generation/editing. +2. **Cloud is additive** — Palette connection unlocks gallery, library, credits, sync. +3. **Delete local ≠ delete cloud** — One-way protection. +4. **Push is explicit** — User chooses what to upload to cloud. +5. **Credits only for cloud generation** — Local GPU is free. +6. **LTX supports last frame** — via `frame_idx` parameter on `ImageConditioningInput`. +7. **Three auth methods** — Browser login (default), QR pairing (mobile), API key (power users). diff --git a/docs/plans/2026-03-06-palette-integration-phase1-plan.md b/docs/plans/2026-03-06-palette-integration-phase1-plan.md new file mode 100644 index 00000000..19a4a0cd --- /dev/null +++ b/docs/plans/2026-03-06-palette-integration-phase1-plan.md @@ -0,0 +1,814 @@ +# Director's Desktop — Palette Integration Phase 1 Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Add Palette auth connection (API key to start), sync status route, and first/last frame video generation — the two highest-impact features that make Desktop feel integrated and powerful. + +**Architecture:** Desktop's FastAPI backend gets new `/api/sync/*` routes for Palette connection and a `PaletteSyncClient` service for calling Palette's API. The video generation pipeline gets `lastFramePath` support for both LTX local (via `frame_idx`) and Seedance (via Replicate `last_frame` param). Frontend gets first/last frame image slots with paste/drop/browse support. + +**Tech Stack:** Python 3.12+ (FastAPI, Pydantic, pytest), TypeScript (React 18, Electron), Supabase (Palette backend) + +--- + +## Part A: Palette Sync — API Key Auth + Status + +### Task 1: Palette API Key Setting + +**Files:** +- Modify: `backend/state/app_settings.py` +- Modify: `backend/handlers/settings_handler.py:66` +- Modify: `frontend/contexts/AppSettingsContext.tsx` +- Modify: `settings.json` +- Test: `backend/tests/test_settings.py` + +**Step 1: Write the failing test** + +Add to `backend/tests/test_settings.py`: + +```python +def test_palette_api_key_roundtrip(client, default_app_settings): + """Palette API key can be saved and is masked in responses.""" + resp = client.patch("/api/settings", json={"paletteApiKey": "dp_test_key_123"}) + assert resp.status_code == 200 + + resp = client.get("/api/settings") + data = resp.json() + assert data["hasPaletteApiKey"] is True + assert "dp_test_key_123" not in resp.text +``` + +**Step 2: Run test to verify it fails** + +Run: `cd backend && uv run pytest tests/test_settings.py::test_palette_api_key_roundtrip -v --tb=short` +Expected: FAIL — `hasPaletteApiKey` not in response + +**Step 3: Implement** + +In `backend/state/app_settings.py`, add to `AppSettings` class: +```python +palette_api_key: str = "" +``` + +In `backend/state/app_settings.py`, add to `SettingsResponse`: +```python +has_palette_api_key: bool = False +``` + +In `backend/state/app_settings.py`, in `to_settings_response()` method, add `palette_api_key` to the popped keys list and set `has_palette_api_key`. + +In `backend/handlers/settings_handler.py`, add `"palette_api_key"` to the key fields list (around line 66). + +In `frontend/contexts/AppSettingsContext.tsx`, add: +```typescript +hasPaletteApiKey: boolean // in interface +paletteApiKey: '' // in defaults +``` + +In `settings.json`, add: +```json +"palette_api_key": "" +``` + +**Step 4: Run test to verify it passes** + +Run: `cd backend && uv run pytest tests/test_settings.py::test_palette_api_key_roundtrip -v --tb=short` +Expected: PASS + +**Step 5: Commit** + +```bash +git add backend/state/app_settings.py backend/handlers/settings_handler.py frontend/contexts/AppSettingsContext.tsx settings.json backend/tests/test_settings.py +git commit -m "feat: add palette_api_key to app settings" +``` + +--- + +### Task 2: PaletteSyncClient Service Protocol + +**Files:** +- Create: `backend/services/palette_sync_client/palette_sync_client.py` +- Create: `backend/services/palette_sync_client/__init__.py` +- Modify: `backend/services/interfaces.py` + +**Step 1: Create the Protocol** + +```python +"""Protocol for communicating with Director's Palette cloud API.""" + +from __future__ import annotations + +from typing import Any, Protocol + + +class PaletteSyncClient(Protocol): + def validate_connection(self, *, api_key: str) -> dict[str, Any]: + """Validate API key and return user info. Raises on failure.""" + ... + + def get_credits(self, *, api_key: str) -> dict[str, Any]: + """Return credit balance for the authenticated user.""" + ... +``` + +**Step 2: Create `__init__.py`** + +```python +from services.palette_sync_client.palette_sync_client import PaletteSyncClient + +__all__ = ["PaletteSyncClient"] +``` + +**Step 3: Add to interfaces.py** + +Add import and `__all__` entry for `PaletteSyncClient`. + +**Step 4: Commit** + +```bash +git add backend/services/palette_sync_client/ backend/services/interfaces.py +git commit -m "feat: add PaletteSyncClient protocol" +``` + +--- + +### Task 3: Fake PaletteSyncClient + Test Infrastructure + +**Files:** +- Modify: `backend/tests/fakes/services.py` +- Modify: `backend/tests/conftest.py` + +**Step 1: Add FakePaletteSyncClient to fakes** + +```python +class FakePaletteSyncClient: + def __init__(self) -> None: + self.validate_calls: list[str] = [] + self.credits_calls: list[str] = [] + self.raise_on_validate: Exception | None = None + self.user_info: dict[str, Any] = {"id": "user-123", "email": "test@example.com", "name": "Test User"} + self.credits_info: dict[str, Any] = {"balance": 5000, "currency": "credits"} + + def validate_connection(self, *, api_key: str) -> dict[str, Any]: + self.validate_calls.append(api_key) + if self.raise_on_validate is not None: + raise self.raise_on_validate + return self.user_info + + def get_credits(self, *, api_key: str) -> dict[str, Any]: + self.credits_calls.append(api_key) + return self.credits_info +``` + +Add to `FakeServices` dataclass: +```python +palette_sync_client: FakePaletteSyncClient = field(default_factory=FakePaletteSyncClient) +``` + +Update `conftest.py` ServiceBundle construction to include `palette_sync_client=fake_services.palette_sync_client`. + +**Step 2: Commit** + +```bash +git add backend/tests/fakes/services.py backend/tests/conftest.py +git commit -m "feat: add FakePaletteSyncClient test double" +``` + +--- + +### Task 4: Sync Routes — Connect + Status + +**Files:** +- Create: `backend/_routes/sync.py` +- Create: `backend/handlers/sync_handler.py` +- Modify: `backend/app_factory.py` +- Modify: `backend/app_handler.py` +- Create: `backend/tests/test_sync.py` + +**Step 1: Write the failing tests** + +Create `backend/tests/test_sync.py`: + +```python +"""Tests for Palette sync routes.""" +from __future__ import annotations + +import pytest + + +class TestSyncStatus: + def test_disconnected_by_default(self, client): + resp = client.get("/api/sync/status") + assert resp.status_code == 200 + data = resp.json() + assert data["connected"] is False + assert data["user"] is None + + def test_connected_after_setting_api_key(self, client): + client.patch("/api/settings", json={"paletteApiKey": "dp_valid_key"}) + resp = client.get("/api/sync/status") + assert resp.status_code == 200 + data = resp.json() + assert data["connected"] is True + assert data["user"]["email"] == "test@example.com" + + def test_connection_fails_with_invalid_key(self, client, fake_services): + fake_services.palette_sync_client.raise_on_validate = RuntimeError("Invalid API key") + client.patch("/api/settings", json={"paletteApiKey": "dp_bad_key"}) + resp = client.get("/api/sync/status") + data = resp.json() + assert data["connected"] is False + + +class TestSyncCredits: + def test_credits_when_connected(self, client): + client.patch("/api/settings", json={"paletteApiKey": "dp_valid_key"}) + resp = client.get("/api/sync/credits") + assert resp.status_code == 200 + data = resp.json() + assert data["balance"] == 5000 + + def test_credits_when_disconnected(self, client): + resp = client.get("/api/sync/credits") + assert resp.status_code == 200 + data = resp.json() + assert data["balance"] is None + assert data["connected"] is False +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && uv run pytest tests/test_sync.py -v --tb=short` +Expected: FAIL — module not found + +**Step 3: Implement sync_handler.py** + +```python +"""Handler for Palette sync operations.""" +from __future__ import annotations + +from typing import Any + +from services.palette_sync_client.palette_sync_client import PaletteSyncClient +from state.app_state_types import AppState + + +class SyncHandler: + def __init__(self, state: AppState, palette_sync_client: PaletteSyncClient) -> None: + self._state = state + self._client = palette_sync_client + self._cached_user: dict[str, Any] | None = None + + def get_status(self) -> dict[str, Any]: + api_key = self._state.app_settings.palette_api_key + if not api_key: + return {"connected": False, "user": None} + try: + user = self._client.validate_connection(api_key=api_key) + self._cached_user = user + return {"connected": True, "user": user} + except Exception: + self._cached_user = None + return {"connected": False, "user": None} + + def get_credits(self) -> dict[str, Any]: + api_key = self._state.app_settings.palette_api_key + if not api_key: + return {"connected": False, "balance": None} + try: + credits = self._client.get_credits(api_key=api_key) + return {"connected": True, **credits} + except Exception: + return {"connected": False, "balance": None} +``` + +**Step 4: Implement sync routes** + +Create `backend/_routes/sync.py`: +```python +"""Palette sync routes.""" +from __future__ import annotations + +from fastapi import APIRouter +from state.deps import get_handler + +router = APIRouter(prefix="/api/sync", tags=["sync"]) + + +@router.get("/status") +def sync_status(): + return get_handler().sync.get_status() + + +@router.get("/credits") +def sync_credits(): + return get_handler().sync.get_credits() +``` + +**Step 5: Wire into AppHandler and app_factory** + +In `app_handler.py`: Add `PaletteSyncClient` to imports, `__init__` params, and create `self.sync = SyncHandler(...)`. + +In `app_factory.py`: Import and register `sync_router`. + +In `ServiceBundle`: Add `palette_sync_client` field. + +**Step 6: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_sync.py -v --tb=short` +Expected: ALL PASS + +**Step 7: Run full test suite** + +Run: `cd backend && uv run pytest -v --tb=short` +Expected: ALL PASS (no regressions) + +**Step 8: Commit** + +```bash +git add backend/_routes/sync.py backend/handlers/sync_handler.py backend/app_factory.py backend/app_handler.py backend/tests/test_sync.py +git commit -m "feat: add /api/sync/status and /api/sync/credits routes" +``` + +--- + +## Part B: First Frame + Last Frame Video Generation + +### Task 5: Add lastFramePath to API Types + +**Files:** +- Modify: `backend/api_types.py` +- Test: `backend/tests/test_generation.py` + +**Step 1: Write the failing test** + +Add to `backend/tests/test_generation.py`: + +```python +def test_generate_video_request_accepts_last_frame_path(): + from api_types import GenerateVideoRequest + req = GenerateVideoRequest(prompt="test", lastFramePath="/path/to/last.png") + assert req.lastFramePath == "/path/to/last.png" + +def test_generate_video_request_last_frame_defaults_none(): + from api_types import GenerateVideoRequest + req = GenerateVideoRequest(prompt="test") + assert req.lastFramePath is None +``` + +**Step 2: Run to verify failure** + +Run: `cd backend && uv run pytest tests/test_generation.py::test_generate_video_request_accepts_last_frame_path -v --tb=short` +Expected: FAIL + +**Step 3: Add field to GenerateVideoRequest** + +In `backend/api_types.py`, add to `GenerateVideoRequest`: +```python +lastFramePath: str | None = None +``` + +**Step 4: Run to verify pass** + +Run: `cd backend && uv run pytest tests/test_generation.py::test_generate_video_request_accepts_last_frame_path tests/test_generation.py::test_generate_video_request_last_frame_defaults_none -v --tb=short` +Expected: PASS + +**Step 5: Commit** + +```bash +git add backend/api_types.py backend/tests/test_generation.py +git commit -m "feat: add lastFramePath to GenerateVideoRequest" +``` + +--- + +### Task 6: Wire Last Frame into LTX Local Pipeline + +**Files:** +- Modify: `backend/handlers/video_generation_handler.py` +- Test: `backend/tests/test_generation.py` + +**Step 1: Write the failing test** + +```python +def test_local_generation_with_last_frame(client, test_state, create_fake_model_files, make_test_image, tmp_path): + """Last frame image should be passed as ImageConditioningInput with frame_idx=last.""" + create_fake_model_files() + # Force local generation + test_state.state.app_settings.ltx_api_key = "" + + # Create a fake last-frame image + img_buf = make_test_image(512, 512) + img_path = tmp_path / "last_frame.png" + img_path.write_bytes(img_buf.read()) + + resp = client.post("/api/generate-video", json={ + "prompt": "A cat walks across the room", + "resolution": "540p", + "model": "fast", + "duration": "2", + "fps": "24", + "lastFramePath": str(img_path), + }) + assert resp.status_code == 200 + + # Verify the pipeline received last-frame conditioning + calls = test_state.fast_video_pipeline_class._singleton.generate_calls + assert len(calls) == 1 + images = calls[0]["images"] + # Should have at least one image with frame_idx > 0 + last_frame_images = [img for img in images if img.frame_idx > 0] + assert len(last_frame_images) == 1 +``` + +**Step 2: Run to verify failure** + +Run: `cd backend && uv run pytest tests/test_generation.py::test_local_generation_with_last_frame -v --tb=short` +Expected: FAIL — lastFramePath not handled + +**Step 3: Implement in video_generation_handler.py** + +In the `_generate_local()` method, after the existing first-frame image conditioning block, add: + +```python +if request.lastFramePath: + last_frame_image = Image.open(request.lastFramePath).convert("RGB") + # num_frames - 1 is the last frame index + images.append(ImageConditioningInput( + path=request.lastFramePath, + frame_idx=num_frames - 1, + strength=1.0, + )) +``` + +**Step 4: Run to verify pass** + +Run: `cd backend && uv run pytest tests/test_generation.py::test_local_generation_with_last_frame -v --tb=short` +Expected: PASS + +**Step 5: Run full suite** + +Run: `cd backend && uv run pytest -v --tb=short` +Expected: ALL PASS + +**Step 6: Commit** + +```bash +git add backend/handlers/video_generation_handler.py backend/tests/test_generation.py +git commit -m "feat: wire lastFramePath into LTX local pipeline via frame_idx" +``` + +--- + +### Task 7: Wire Last Frame into Seedance (Replicate) + +**Files:** +- Modify: `backend/services/video_api_client/video_api_client.py` +- Modify: `backend/services/video_api_client/replicate_video_client_impl.py` +- Modify: `backend/tests/fakes/services.py` +- Test: `backend/tests/test_video_api_client.py` + +**Step 1: Write the failing test** + +Add to `backend/tests/test_video_api_client.py`: + +```python +def test_seedance_with_last_frame(): + """Seedance should pass last_frame in input payload.""" + http = FakeHTTPClient() + # Queue prediction response (sync success) + http.queue("post", FakeResponse( + status_code=200, + json_payload={"status": "succeeded", "output": "https://example.com/video.mp4"}, + )) + # Queue video download + http.queue("get", FakeResponse(status_code=200, content=b"fake-video")) + + client = ReplicateVideoClientImpl(http=http) + result = client.generate_text_to_video( + api_key="test-key", + model="seedance-1.5-pro", + prompt="A cat", + duration=5, + resolution="720p", + aspect_ratio="16:9", + generate_audio=False, + last_frame_path="/tmp/last.png", + ) + + # Verify the POST payload included last_frame + post_call = http.calls[0] + payload = post_call.json_payload + assert "last_frame" in payload["input"] +``` + +**Step 2: Update Protocol to accept last_frame_path** + +In `video_api_client.py`: +```python +def generate_text_to_video( + self, *, api_key: str, model: str, prompt: str, + duration: int, resolution: str, aspect_ratio: str, + generate_audio: bool, last_frame_path: str | None = None, +) -> bytes: ... +``` + +**Step 3: Update ReplicateVideoClientImpl** + +Add `last_frame_path` parameter. For seedance, if provided, read the image, base64-encode it, and include as `last_frame` in the input payload (or as a data URI). + +**Step 4: Update FakeVideoAPIClient** + +Add `last_frame_path` parameter to `generate_text_to_video`. + +**Step 5: Run tests** + +Run: `cd backend && uv run pytest tests/test_video_api_client.py -v --tb=short` +Expected: ALL PASS + +**Step 6: Commit** + +```bash +git add backend/services/video_api_client/ backend/tests/fakes/services.py backend/tests/test_video_api_client.py +git commit -m "feat: add last_frame_path support to Seedance via Replicate" +``` + +--- + +### Task 8: Queue Submit with Last Frame + +**Files:** +- Modify: `backend/app_handler.py` (determine_slot and queue submit logic) +- Modify: `backend/_routes/queue.py` +- Test: `backend/tests/test_queue_routes.py` + +**Step 1: Write the failing test** + +```python +def test_queue_submit_video_with_last_frame(client): + resp = client.post("/api/queue/submit", json={ + "type": "video", + "prompt": "A cat walks", + "model": "ltx-fast", + "params": {"lastFramePath": "/tmp/last.png", "resolution": "540p"}, + }) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "queued" + + # Verify params contain lastFramePath + status = client.get("/api/queue/status").json() + job = status["jobs"][0] + assert job["params"]["lastFramePath"] == "/tmp/last.png" +``` + +**Step 2: Implement** — The queue already passes arbitrary `params` through, so this should mostly work. Verify the queue worker passes `lastFramePath` to the video generation handler. + +**Step 3: Run tests** + +Run: `cd backend && uv run pytest tests/test_queue_routes.py -v --tb=short` +Expected: ALL PASS + +**Step 4: Commit** + +```bash +git add backend/ backend/tests/test_queue_routes.py +git commit -m "feat: queue submit supports lastFramePath param" +``` + +--- + +### Task 9: Frontend — First/Last Frame Image Slots + +**Files:** +- Create: `frontend/components/FrameSlot.tsx` +- Modify: `frontend/views/Playground.tsx` +- Modify: `frontend/hooks/use-generation.ts` + +**Step 1: Create FrameSlot component** + +A reusable image slot that supports: +- Paste (Ctrl+V / Cmd+V) +- Drag & drop +- Click to browse +- Thumbnail preview with X to clear +- Label ("First Frame" / "Last Frame") + +```typescript +interface FrameSlotProps { + label: string + imageUrl: string | null + onImageSet: (url: string | null, path: string | null) => void +} +``` + +**Step 2: Add to Playground** + +Add two FrameSlot components above the prompt area: +```tsx +
+ + +
+``` + +**Step 3: Wire into generation hook** + +Update `use-generation.ts` `generate()` to accept `lastFramePath` and pass it in the queue submit request params. + +**Step 4: Verify TypeScript compiles** + +Run: `cd D:/git/directors-desktop && npx tsc --noEmit` +Expected: clean + +**Step 5: Commit** + +```bash +git add frontend/components/FrameSlot.tsx frontend/views/Playground.tsx frontend/hooks/use-generation.ts +git commit -m "feat: add first/last frame image slots to Playground UI" +``` + +--- + +### Task 10: Frontend — Palette Connection UI in Settings + +**Files:** +- Modify: `frontend/components/SettingsModal.tsx` +- Modify: `frontend/contexts/AppSettingsContext.tsx` + +**Step 1: Add Palette API Key input to Settings** + +In SettingsModal, add a new "Director's Palette" section: +- API Key input field (password type, with show/hide toggle) +- Connection status indicator (green dot = connected, red = disconnected) +- Credits display when connected +- "Get API Key" link button + +**Step 2: Add sync status polling** + +In AppSettingsContext, add: +- `paletteConnected: boolean` +- `paletteUser: { email: string; name: string } | null` +- `paletteCredits: number | null` +- Poll `/api/sync/status` and `/api/sync/credits` every 60 seconds when key is set + +**Step 3: Verify TypeScript compiles** + +Run: `cd D:/git/directors-desktop && npx tsc --noEmit` +Expected: clean + +**Step 4: Commit** + +```bash +git add frontend/components/SettingsModal.tsx frontend/contexts/AppSettingsContext.tsx +git commit -m "feat: add Palette connection UI to Settings modal" +``` + +--- + +### Task 11: Image Variations Slider + +**Files:** +- Modify: `frontend/components/SettingsPanel.tsx` +- No backend changes needed (already supports 1-12 variations) + +**Step 1: Add variations slider** + +In SettingsPanel, when mode is `text-to-image`, add: +```tsx +
+ + onSettingsChange({...settings, variations: parseInt(e.target.value)})} /> + {settings.variations || 1} +
+``` + +**Step 2: Verify TypeScript compiles + commit** + +```bash +git add frontend/components/SettingsPanel.tsx +git commit -m "feat: expose image variations slider (1-12)" +``` + +--- + +### Task 12: Social Media Aspect Ratio Labels + +**Files:** +- Modify: `frontend/components/SettingsPanel.tsx` + +**Step 1: Update aspect ratio labels** + +Change the video aspect ratio options: +```tsx + + +``` + +For image aspect ratios, add labels and the new 4:5 option: +```tsx + + + + + + + +``` + +**Step 2: Commit** + +```bash +git add frontend/components/SettingsPanel.tsx +git commit -m "feat: add social media labels to aspect ratio presets" +``` + +--- + +### Task 13: Prompt Enhancement Button + +**Files:** +- Create: `backend/_routes/enhance_prompt.py` +- Create: `backend/handlers/enhance_prompt_handler.py` +- Modify: `backend/app_factory.py` +- Modify: `backend/app_handler.py` +- Create: `backend/tests/test_enhance_prompt.py` +- Modify: `frontend/views/Playground.tsx` + +**Step 1: Write the failing test** + +```python +def test_enhance_prompt_returns_enhanced_text(client, fake_services): + """Enhance prompt route should return an enhanced version of the input.""" + # Configure Gemini key + client.patch("/api/settings", json={"geminiApiKey": "test-gemini-key"}) + + # Queue a fake Gemini response + from tests.fakes.services import FakeResponse + fake_services.http.queue("post", FakeResponse( + status_code=200, + json_payload={ + "candidates": [{"content": {"parts": [{"text": "A cinematic shot of a majestic cat walking gracefully across a sun-drenched room, golden hour lighting, shallow depth of field"}]}}] + }, + )) + + resp = client.post("/api/enhance-prompt", json={ + "prompt": "cat walking in room", + "mode": "text-to-video", + }) + assert resp.status_code == 200 + data = resp.json() + assert "enhancedPrompt" in data + assert len(data["enhancedPrompt"]) > len("cat walking in room") +``` + +**Step 2: Implement handler + route** + +The handler calls Gemini API (same pattern as `suggest_gap_prompt_handler.py`) with a system prompt asking it to enhance the user's rough prompt into a detailed cinematic description. + +**Step 3: Add sparkle button to Playground prompt area** + +Next to the prompt textarea, add a small button with a Sparkles icon that calls `/api/enhance-prompt` and replaces the prompt text. + +**Step 4: Run tests + commit** + +```bash +git add backend/_routes/enhance_prompt.py backend/handlers/enhance_prompt_handler.py backend/app_factory.py backend/app_handler.py backend/tests/test_enhance_prompt.py frontend/views/Playground.tsx +git commit -m "feat: add prompt enhancement button (Gemini-powered)" +``` + +--- + +### Task 14: Full Test Suite Verification + +**Step 1: Run all backend tests** + +Run: `cd backend && uv run pytest -v --tb=short` +Expected: ALL PASS + +**Step 2: Run TypeScript type check** + +Run: `cd D:/git/directors-desktop && npx tsc --noEmit` +Expected: clean + +**Step 3: Run Python type check** + +Run: `cd backend && uv run pyright` +Expected: 0 errors + +**Step 4: Manual smoke test** + +Start the app: `npx pnpm dev` +- Verify Playground shows First Frame / Last Frame slots +- Verify paste (Ctrl+V) works in frame slots +- Verify Settings shows Palette API Key section +- Verify image variations slider appears in text-to-image mode +- Verify aspect ratio labels show platform names +- Verify prompt enhancement button appears + +**Step 5: Final commit** + +```bash +git add -A +git commit -m "feat: Phase 1 complete — Palette auth + first/last frame + generation upgrades" +``` diff --git a/docs/plans/2026-03-06-queue-and-seedance-design.md b/docs/plans/2026-03-06-queue-and-seedance-design.md new file mode 100644 index 00000000..f1af1354 --- /dev/null +++ b/docs/plans/2026-03-06-queue-and-seedance-design.md @@ -0,0 +1,181 @@ +# Design: Generation Queue + Seedance 1.5 Pro + +## Summary + +Add a persistent job queue to LTX Desktop so users can submit multiple generation requests that process sequentially. Add Seedance 1.5 Pro as a video model option via Replicate API. Allow local GPU jobs and API jobs to run in parallel. + +## Decisions + +| Decision | Choice | +|----------|--------| +| Queue type | Simple sequential, fire-and-forget | +| Persistence | Disk-persisted, survives restarts | +| Seedance UI placement | Same video generation view, model dropdown | +| Model selection scope | Global setting | +| Queue UI location | Existing generation results area | +| Parallelism | GPU slot + API slot can run concurrently | +| Seedance routing | Always API (Replicate), regardless of force_api | + +## Architecture + +``` +Frontend queue UI (existing results area) + | POST /api/queue/submit +Backend JobQueue (persistent JSON file) + | QueueWorker thread (picks next job) + |-- GPU slot (LTX video, ZIT images) + +-- API slot (Seedance, Replicate images) <-- parallel + | results written to outputs/ +Frontend polls GET /api/queue/status +``` + +## Backend Changes + +### Job Queue State (`backend/state/job_queue.py`) + +Jobs persisted to JSON file in app data directory alongside settings.json. + +```python +@dataclass +class QueueJob: + id: str + type: Literal["video", "image"] + model: str + params: dict[str, Any] + status: Literal["queued", "running", "complete", "error", "cancelled"] + slot: Literal["gpu", "api"] + progress: int # 0-100 + phase: str # "queued", "loading_model", "inference", etc. + result_paths: list[str] + error: str | None + created_at: str # ISO 8601 +``` + +On startup, any `running` jobs reset to `queued` (crash recovery). + +### Queue Worker (`backend/handlers/queue_worker.py`) + +Background thread started on app boot. Two concurrent slots: + +- **GPU slot**: local LTX video, local ZIT image generation +- **API slot**: Seedance (Replicate), Replicate image models, LTX API video + +Worker loop: +1. Check if GPU slot is free -> pick next queued GPU-type job +2. Check if API slot is free -> pick next queued API-type job +3. Both can run simultaneously +4. Sleep 500ms between checks + +Reuses existing handler logic internally (VideoGenerationHandler, ImageGenerationHandler). + +### Slot Assignment + +| Model | Slot | +|-------|------| +| ltx-fast (local GPU available) | gpu | +| ltx-fast (force_api or no GPU) | api | +| seedance-1.5-pro | api (always) | +| z-image-turbo (local GPU available) | gpu | +| z-image-turbo (force_api or no GPU) | api | +| nano-banana-2 | api (always) | + +### New Routes (`backend/_routes/queue.py`) + +- `POST /api/queue/submit` - Add job, returns `{id, status}` +- `GET /api/queue/status` - Returns all jobs with progress +- `POST /api/queue/cancel/{job_id}` - Cancel specific job +- `POST /api/queue/clear` - Remove completed/errored jobs + +### Seedance Video Client (`backend/services/video_api_client/`) + +New service directory following existing patterns: + +**Protocol** (`video_api_client.py`): +```python +class VideoAPIClient(Protocol): + def generate_text_to_video( + self, *, api_key: str, model: str, + prompt: str, duration: int, resolution: str, + aspect_ratio: str, generate_audio: bool, + ) -> bytes: ... +``` + +**Implementation** (`replicate_video_client_impl.py`): +- Model routing: `seedance-1.5-pro` -> `bytedance/seedance-1.5-pro` +- Input: `{prompt, duration, resolution, ratio, generate_audio}` +- Duration: 4-12 seconds +- Resolution: 480p, 720p +- Aspect ratios: 16:9, 9:16, 1:1, 4:3, 3:4, 21:9 +- Same Prefer: wait + polling pattern as image client +- Output: video bytes (mp4) + +### Settings Changes (`backend/state/app_settings.py`) + +Add: +- `video_model: str = "ltx-fast"` (choices: `"ltx-fast"`, `"seedance-1.5-pro"`) + +SettingsResponse: +- Add `video_model: str` + +### AppHandler Wiring + +- Add `video_api_client: VideoAPIClient` to ServiceBundle +- Wire `ReplicateVideoClientImpl` in `build_default_service_bundle` +- Pass to QueueWorker + +## Frontend Changes + +### Queue in Results Area + +The existing generation results area becomes a job list: +- Each job: status badge, prompt preview, progress bar (if running), result thumbnail (if complete) +- Generate button submits to queue and immediately re-enables +- Two progress bars can show simultaneously (GPU + API) +- Completed jobs show clickable results (same as current behavior) + +### Settings Modal + +- Add **Video Model** dropdown: LTX Fast | Seedance 1.5 Pro +- When Seedance selected, resolution options change to 480p/720p +- Duration range changes to 4-12s for Seedance + +### AppSettingsContext + +- Add `videoModel: string` (default: `"ltx-fast"`) + +### use-generation.ts + +- POST to `/api/queue/submit` instead of `/api/generate` or `/api/generate-image` +- Poll `/api/queue/status` for all job progress +- Handle multiple concurrent results + +## Persistence Format + +```json +{ + "jobs": [ + { + "id": "a1b2c3", + "type": "video", + "model": "seedance-1.5-pro", + "params": {"prompt": "...", "duration": 8, "resolution": "720p", "aspect_ratio": "16:9", "generate_audio": true}, + "status": "complete", + "slot": "api", + "progress": 100, + "phase": "complete", + "result_paths": ["/path/to/output.mp4"], + "error": null, + "created_at": "2026-03-06T02:30:00Z" + } + ] +} +``` + +## Testing Strategy + +- Fake queue worker for unit tests (no real GPU/API calls) +- Integration tests: submit job, verify status transitions +- Test parallel slot execution (GPU + API simultaneously) +- Test persistence: write queue, reload, verify recovery +- Test Seedance client with FakeHTTPClient (same pattern as image client tests) +- Test cancel mid-queue, cancel running job diff --git a/docs/plans/2026-03-06-queue-and-seedance-plan.md b/docs/plans/2026-03-06-queue-and-seedance-plan.md new file mode 100644 index 00000000..1aa52aa9 --- /dev/null +++ b/docs/plans/2026-03-06-queue-and-seedance-plan.md @@ -0,0 +1,1266 @@ +# Queue + Seedance 1.5 Pro Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Add a persistent job queue with dual GPU/API parallel slots, and Seedance 1.5 Pro video model via Replicate. + +**Architecture:** Jobs are submitted to a queue (persisted as JSON), processed by a background worker with two concurrent slots (GPU for local, API for Replicate). Seedance 1.5 Pro always routes through the API slot via Replicate. Frontend polls queue status and shows all jobs in the existing results area. + +**Tech Stack:** Python FastAPI backend, React TypeScript frontend, Replicate API, Pydantic models + +--- + +## Task 1: Video API Client Protocol + Replicate Implementation + +Create the VideoAPIClient service following the same pattern as ImageAPIClient. + +**Files:** +- Create: `backend/services/video_api_client/__init__.py` +- Create: `backend/services/video_api_client/video_api_client.py` +- Create: `backend/services/video_api_client/replicate_video_client_impl.py` +- Test: `backend/tests/test_video_api_client.py` + +**Step 1: Write the protocol** + +`backend/services/video_api_client/video_api_client.py`: +```python +"""Video API client protocol for cloud video generation.""" + +from __future__ import annotations + +from typing import Protocol + + +class VideoAPIClient(Protocol): + def generate_text_to_video( + self, + *, + api_key: str, + model: str, + prompt: str, + duration: int, + resolution: str, + aspect_ratio: str, + generate_audio: bool, + ) -> bytes: + ... +``` + +`backend/services/video_api_client/__init__.py`: +```python +from services.video_api_client.video_api_client import VideoAPIClient +from services.video_api_client.replicate_video_client_impl import ReplicateVideoClientImpl + +__all__ = ["VideoAPIClient", "ReplicateVideoClientImpl"] +``` + +**Step 2: Write failing tests for the Replicate video client** + +`backend/tests/test_video_api_client.py`: +```python +"""Tests for the Replicate video API client.""" + +from __future__ import annotations + +import pytest + +from tests.fakes.services import FakeHTTPClient, FakeResponse +from services.video_api_client.replicate_video_client_impl import ReplicateVideoClientImpl + + +def _make_client(http: FakeHTTPClient) -> ReplicateVideoClientImpl: + return ReplicateVideoClientImpl(http=http, api_base_url="https://test.replicate.com/v1") + + +def test_seedance_text_to_video_sync_success() -> None: + http = FakeHTTPClient() + client = _make_client(http) + + # Prediction returns succeeded immediately (Prefer: wait) + http.queue("post", FakeResponse( + status_code=201, + json_payload={ + "id": "pred-1", + "status": "succeeded", + "output": "https://example.com/video.mp4", + }, + )) + # Download the video + http.queue("get", FakeResponse(status_code=200, content=b"fake-mp4-bytes")) + + result = client.generate_text_to_video( + api_key="test-key", + model="seedance-1.5-pro", + prompt="a cat dancing", + duration=8, + resolution="720p", + aspect_ratio="16:9", + generate_audio=True, + ) + + assert result == b"fake-mp4-bytes" + # Verify the POST was sent to the correct model endpoint + assert "bytedance/seedance-1.5-pro" in http.calls[0].url + # Verify input payload + payload = http.calls[0].json_payload + assert payload is not None + assert payload["input"]["prompt"] == "a cat dancing" + assert payload["input"]["duration"] == 8 + assert payload["input"]["seed"] is not None + + +def test_seedance_text_to_video_polling_success() -> None: + http = FakeHTTPClient() + client = _make_client(http) + + # Prediction returns processing + http.queue("post", FakeResponse( + status_code=201, + json_payload={ + "id": "pred-2", + "status": "processing", + "urls": {"get": "https://test.replicate.com/v1/predictions/pred-2"}, + }, + )) + # First poll: still processing + http.queue("get", FakeResponse( + status_code=200, + json_payload={"id": "pred-2", "status": "processing"}, + )) + # Second poll: succeeded + http.queue("get", FakeResponse( + status_code=200, + json_payload={ + "id": "pred-2", + "status": "succeeded", + "output": "https://example.com/video2.mp4", + }, + )) + # Download + http.queue("get", FakeResponse(status_code=200, content=b"polled-video")) + + result = client.generate_text_to_video( + api_key="test-key", + model="seedance-1.5-pro", + prompt="a dog running", + duration=4, + resolution="480p", + aspect_ratio="9:16", + generate_audio=False, + ) + + assert result == b"polled-video" + + +def test_unknown_model_raises() -> None: + http = FakeHTTPClient() + client = _make_client(http) + + with pytest.raises(RuntimeError, match="Unknown video model"): + client.generate_text_to_video( + api_key="test-key", + model="nonexistent-model", + prompt="test", + duration=4, + resolution="720p", + aspect_ratio="16:9", + generate_audio=False, + ) + + +def test_prediction_failure_raises() -> None: + http = FakeHTTPClient() + client = _make_client(http) + + http.queue("post", FakeResponse( + status_code=201, + json_payload={ + "id": "pred-fail", + "status": "failed", + "error": "GPU OOM", + }, + )) + + with pytest.raises(RuntimeError, match="failed"): + client.generate_text_to_video( + api_key="test-key", + model="seedance-1.5-pro", + prompt="test", + duration=4, + resolution="720p", + aspect_ratio="16:9", + generate_audio=False, + ) +``` + +**Step 3: Run tests to verify they fail** + +Run: `cd backend && uv run pytest tests/test_video_api_client.py -v --tb=short` +Expected: FAIL (module not found) + +**Step 4: Implement ReplicateVideoClientImpl** + +`backend/services/video_api_client/replicate_video_client_impl.py`: +```python +"""Replicate API client implementation for cloud video generation (Seedance).""" + +from __future__ import annotations + +import time +from typing import Any, cast + +from services.http_client.http_client import HTTPClient +from services.services_utils import JSONValue + +REPLICATE_API_BASE_URL = "https://api.replicate.com/v1" + +_MODEL_ROUTES: dict[str, str] = { + "seedance-1.5-pro": "bytedance/seedance-1.5-pro", +} + +_POLL_INTERVAL_SECONDS = 2 +_POLL_TIMEOUT_SECONDS = 300 + + +class ReplicateVideoClientImpl: + def __init__(self, http: HTTPClient, *, api_base_url: str = REPLICATE_API_BASE_URL) -> None: + self._http = http + self._base_url = api_base_url.rstrip("/") + + def generate_text_to_video( + self, + *, + api_key: str, + model: str, + prompt: str, + duration: int, + resolution: str, + aspect_ratio: str, + generate_audio: bool, + ) -> bytes: + replicate_model = _MODEL_ROUTES.get(model) + if replicate_model is None: + raise RuntimeError(f"Unknown video model: {model}") + + input_payload = self._build_input( + prompt=prompt, + duration=duration, + resolution=resolution, + aspect_ratio=aspect_ratio, + generate_audio=generate_audio, + ) + + prediction = self._create_prediction( + api_key=api_key, + replicate_model=replicate_model, + input_payload=input_payload, + ) + + output_url = self._wait_for_output(api_key, prediction) + return self._download_video(api_key, output_url) + + @staticmethod + def _build_input( + *, + prompt: str, + duration: int, + resolution: str, + aspect_ratio: str, + generate_audio: bool, + ) -> dict[str, JSONValue]: + seed = int(time.time()) % 2_147_483_647 + return { + "prompt": prompt, + "duration": duration, + "resolution": resolution, + "aspect_ratio": aspect_ratio, + "generate_audio": generate_audio, + "seed": seed, + } + + def _create_prediction( + self, + *, + api_key: str, + replicate_model: str, + input_payload: dict[str, JSONValue], + ) -> dict[str, Any]: + url = f"{self._base_url}/models/{replicate_model}/predictions" + payload: dict[str, JSONValue] = {"input": input_payload} + + response = self._http.post( + url, + headers=self._headers(api_key, prefer_wait=True), + json_payload=payload, + timeout=300, + ) + if response.status_code not in (200, 201): + detail = response.text[:500] if response.text else "Unknown error" + raise RuntimeError(f"Replicate prediction failed ({response.status_code}): {detail}") + + return self._json_object(response.json(), context="create prediction") + + def _wait_for_output(self, api_key: str, prediction: dict[str, Any]) -> str: + status = prediction.get("status", "") + if status == "succeeded": + return self._extract_output_url(prediction) + + if status in ("failed", "canceled"): + error = prediction.get("error", "Unknown error") + raise RuntimeError(f"Replicate prediction {status}: {error}") + + poll_url = prediction.get("urls", {}).get("get") + if not isinstance(poll_url, str) or not poll_url: + prediction_id = prediction.get("id", "") + poll_url = f"{self._base_url}/predictions/{prediction_id}" + + deadline = time.monotonic() + _POLL_TIMEOUT_SECONDS + while time.monotonic() < deadline: + time.sleep(_POLL_INTERVAL_SECONDS) + resp = self._http.get(poll_url, headers=self._headers(api_key), timeout=30) + if resp.status_code != 200: + detail = resp.text[:500] if resp.text else "Unknown error" + raise RuntimeError(f"Replicate poll failed ({resp.status_code}): {detail}") + + data = self._json_object(resp.json(), context="poll") + poll_status = data.get("status", "") + if poll_status == "succeeded": + return self._extract_output_url(data) + if poll_status in ("failed", "canceled"): + error = data.get("error", "Unknown error") + raise RuntimeError(f"Replicate prediction {poll_status}: {error}") + + raise RuntimeError("Replicate video prediction timed out") + + def _download_video(self, api_key: str, url: str) -> bytes: + download = self._http.get(url, headers=self._headers(api_key), timeout=300) + if download.status_code != 200: + detail = download.text[:500] if download.text else "Unknown error" + raise RuntimeError(f"Replicate video download failed ({download.status_code}): {detail}") + if not download.content: + raise RuntimeError("Replicate video download returned empty body") + return download.content + + @staticmethod + def _headers(api_key: str, *, prefer_wait: bool = False) -> dict[str, str]: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + if prefer_wait: + headers["Prefer"] = "wait" + return headers + + @staticmethod + def _extract_output_url(prediction: dict[str, Any]) -> str: + output = prediction.get("output") + if isinstance(output, str) and output: + return output + if isinstance(output, list) and output: + output_list = cast(list[object], output) + first = output_list[0] + if isinstance(first, str) and first: + return first + raise RuntimeError("Replicate response missing output URL") + + @staticmethod + def _json_object(payload: object, *, context: str) -> dict[str, Any]: + if isinstance(payload, dict): + return cast(dict[str, Any], payload) + raise RuntimeError(f"Unexpected Replicate {context} response format") +``` + +**Step 5: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_video_api_client.py -v --tb=short` +Expected: All 4 tests PASS + +**Step 6: Wire into services/interfaces.py** + +Add to `backend/services/interfaces.py`: +- Import: `from services.video_api_client.video_api_client import VideoAPIClient` +- Add `"VideoAPIClient"` to `__all__` + +**Step 7: Add FakeVideoAPIClient to test fakes** + +Add to `backend/tests/fakes/services.py`: +```python +class FakeVideoAPIClient: + def __init__(self) -> None: + self.text_to_video_calls: list[dict[str, Any]] = [] + self.raise_on_text_to_video: Exception | None = None + self.text_to_video_result = b"fake-seedance-video" + + def generate_text_to_video( + self, + *, + api_key: str, + model: str, + prompt: str, + duration: int, + resolution: str, + aspect_ratio: str, + generate_audio: bool, + ) -> bytes: + self.text_to_video_calls.append({ + "api_key": api_key, + "model": model, + "prompt": prompt, + "duration": duration, + "resolution": resolution, + "aspect_ratio": aspect_ratio, + "generate_audio": generate_audio, + }) + if self.raise_on_text_to_video is not None: + raise self.raise_on_text_to_video + return self.text_to_video_result +``` + +Add field to `FakeServices`: +```python +video_api_client: FakeVideoAPIClient = field(default_factory=FakeVideoAPIClient) +``` + +**Step 8: Wire VideoAPIClient into AppHandler and ServiceBundle** + +Modify `backend/app_handler.py`: +- Add `VideoAPIClient` to imports from `services.interfaces` +- Add `video_api_client: VideoAPIClient` param to `AppHandler.__init__` and store as `self.video_api_client` +- Add `video_api_client: VideoAPIClient` field to `ServiceBundle` +- In `build_default_service_bundle`: import `ReplicateVideoClientImpl`, instantiate `ReplicateVideoClientImpl(http=http)`, add to bundle +- In `build_initial_state`: pass `video_api_client=bundle.video_api_client` + +Modify `backend/tests/conftest.py`: +- Add `video_api_client=fake_services.video_api_client` to `ServiceBundle(...)` constructor + +**Step 9: Run all tests** + +Run: `cd backend && uv run pytest -v --tb=short` +Expected: All tests PASS + +**Step 10: Commit** + +```bash +git add backend/services/video_api_client/ backend/tests/test_video_api_client.py +git add backend/services/interfaces.py backend/tests/fakes/services.py +git add backend/app_handler.py backend/tests/conftest.py +git commit -m "feat: add VideoAPIClient protocol + Replicate Seedance implementation" +``` + +--- + +## Task 2: Settings — Add video_model + +**Files:** +- Modify: `backend/state/app_settings.py` +- Modify: `backend/tests/test_settings.py` +- Modify: `settings.json` + +**Step 1: Add video_model to AppSettings** + +In `backend/state/app_settings.py`: +- Add to `AppSettings`: `video_model: str = "ltx-fast"` +- Add to `SettingsResponse`: `video_model: str = "ltx-fast"` + +In `settings.json`: add `"videoModel": "ltx-fast"` + +**Step 2: Add test** + +In `backend/tests/test_settings.py` add: +```python +def test_video_model_roundtrips(client, test_state): + resp = client.post("/api/settings", json={"videoModel": "seedance-1.5-pro"}) + assert resp.status_code == 200 + assert test_state.state.app_settings.video_model == "seedance-1.5-pro" + + get_resp = client.get("/api/settings") + assert get_resp.json()["videoModel"] == "seedance-1.5-pro" +``` + +**Step 3: Run tests** + +Run: `cd backend && uv run pytest tests/test_settings.py -v --tb=short` +Expected: PASS + +**Step 4: Commit** + +```bash +git add backend/state/app_settings.py backend/tests/test_settings.py settings.json +git commit -m "feat: add video_model setting (ltx-fast | seedance-1.5-pro)" +``` + +--- + +## Task 3: Job Queue State — QueueJob + JobQueue + +**Files:** +- Create: `backend/state/job_queue.py` +- Test: `backend/tests/test_job_queue.py` + +**Step 1: Write the test** + +`backend/tests/test_job_queue.py`: +```python +"""Tests for the persistent job queue.""" + +from __future__ import annotations + +import json +from pathlib import Path + +from state.job_queue import JobQueue, QueueJob + + +def test_submit_job_assigns_id_and_status(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit( + job_type="video", + model="seedance-1.5-pro", + params={"prompt": "hello"}, + slot="api", + ) + assert job.id + assert job.status == "queued" + assert job.slot == "api" + assert job.progress == 0 + + +def test_get_all_jobs_returns_ordered(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + j1 = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + j2 = queue.submit(job_type="image", model="z-image-turbo", params={}, slot="gpu") + jobs = queue.get_all_jobs() + assert [j.id for j in jobs] == [j1.id, j2.id] + + +def test_next_queued_for_slot(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + queue.submit(job_type="video", model="seedance-1.5-pro", params={}, slot="api") + queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + + gpu_job = queue.next_queued_for_slot("gpu") + assert gpu_job is not None + assert gpu_job.slot == "gpu" + + api_job = queue.next_queued_for_slot("api") + assert api_job is not None + assert api_job.slot == "api" + + +def test_update_job_status(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + queue.update_job(job.id, status="running", progress=50, phase="inference") + updated = queue.get_job(job.id) + assert updated is not None + assert updated.status == "running" + assert updated.progress == 50 + assert updated.phase == "inference" + + +def test_cancel_job(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + queue.cancel_job(job.id) + updated = queue.get_job(job.id) + assert updated is not None + assert updated.status == "cancelled" + + +def test_clear_finished_jobs(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + j1 = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + j2 = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + queue.update_job(j1.id, status="complete") + queue.clear_finished() + remaining = queue.get_all_jobs() + assert len(remaining) == 1 + assert remaining[0].id == j2.id + + +def test_persistence_survives_reload(tmp_path: Path) -> None: + path = tmp_path / "queue.json" + queue1 = JobQueue(persistence_path=path) + job = queue1.submit(job_type="video", model="ltx-fast", params={"prompt": "test"}, slot="gpu") + + queue2 = JobQueue(persistence_path=path) + loaded = queue2.get_job(job.id) + assert loaded is not None + assert loaded.params == {"prompt": "test"} + + +def test_running_jobs_reset_to_queued_on_load(tmp_path: Path) -> None: + path = tmp_path / "queue.json" + queue1 = JobQueue(persistence_path=path) + job = queue1.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + queue1.update_job(job.id, status="running") + + queue2 = JobQueue(persistence_path=path) + loaded = queue2.get_job(job.id) + assert loaded is not None + assert loaded.status == "queued" +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && uv run pytest tests/test_job_queue.py -v --tb=short` +Expected: FAIL (import error) + +**Step 3: Implement JobQueue** + +`backend/state/job_queue.py`: +```python +"""Persistent job queue for sequential generation processing.""" + +from __future__ import annotations + +import json +import uuid +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Literal + + +@dataclass +class QueueJob: + id: str + type: Literal["video", "image"] + model: str + params: dict[str, Any] + status: Literal["queued", "running", "complete", "error", "cancelled"] + slot: Literal["gpu", "api"] + progress: int = 0 + phase: str = "queued" + result_paths: list[str] = field(default_factory=list) + error: str | None = None + created_at: str = "" + + +class JobQueue: + def __init__(self, persistence_path: Path) -> None: + self._path = persistence_path + self._jobs: list[QueueJob] = [] + self._load() + + def submit( + self, + *, + job_type: str, + model: str, + params: dict[str, Any], + slot: str, + ) -> QueueJob: + job = QueueJob( + id=uuid.uuid4().hex[:8], + type=job_type, # type: ignore[arg-type] + model=model, + params=params, + status="queued", + slot=slot, # type: ignore[arg-type] + progress=0, + phase="queued", + result_paths=[], + error=None, + created_at=datetime.now(timezone.utc).isoformat(), + ) + self._jobs.append(job) + self._save() + return job + + def get_all_jobs(self) -> list[QueueJob]: + return list(self._jobs) + + def get_job(self, job_id: str) -> QueueJob | None: + for job in self._jobs: + if job.id == job_id: + return job + return None + + def next_queued_for_slot(self, slot: str) -> QueueJob | None: + for job in self._jobs: + if job.status == "queued" and job.slot == slot: + return job + return None + + def update_job( + self, + job_id: str, + *, + status: str | None = None, + progress: int | None = None, + phase: str | None = None, + result_paths: list[str] | None = None, + error: str | None = None, + ) -> None: + job = self.get_job(job_id) + if job is None: + return + if status is not None: + job.status = status # type: ignore[assignment] + if progress is not None: + job.progress = progress + if phase is not None: + job.phase = phase + if result_paths is not None: + job.result_paths = result_paths + if error is not None: + job.error = error + self._save() + + def cancel_job(self, job_id: str) -> None: + self.update_job(job_id, status="cancelled") + + def clear_finished(self) -> None: + self._jobs = [j for j in self._jobs if j.status not in ("complete", "error", "cancelled")] + self._save() + + def _save(self) -> None: + data = {"jobs": [asdict(j) for j in self._jobs]} + self._path.parent.mkdir(parents=True, exist_ok=True) + self._path.write_text(json.dumps(data, indent=2), encoding="utf-8") + + def _load(self) -> None: + if not self._path.exists(): + return + try: + raw = json.loads(self._path.read_text(encoding="utf-8")) + for item in raw.get("jobs", []): + job = QueueJob(**item) + if job.status == "running": + job.status = "queued" + job.progress = 0 + job.phase = "queued" + self._jobs.append(job) + except (json.JSONDecodeError, TypeError, KeyError): + self._jobs = [] +``` + +**Step 4: Run tests** + +Run: `cd backend && uv run pytest tests/test_job_queue.py -v --tb=short` +Expected: All 8 tests PASS + +**Step 5: Commit** + +```bash +git add backend/state/job_queue.py backend/tests/test_job_queue.py +git commit -m "feat: add persistent job queue with disk persistence and crash recovery" +``` + +--- + +## Task 4: Queue Worker + +**Files:** +- Create: `backend/handlers/queue_worker.py` +- Test: `backend/tests/test_queue_worker.py` + +**Step 1: Write the test** + +`backend/tests/test_queue_worker.py`: +```python +"""Tests for the queue worker.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock + +from state.job_queue import JobQueue, QueueJob +from handlers.queue_worker import QueueWorker + + +class FakeJobExecutor: + def __init__(self) -> None: + self.executed_jobs: list[QueueJob] = [] + self.raise_on_execute: Exception | None = None + + def execute(self, job: QueueJob) -> list[str]: + self.executed_jobs.append(job) + if self.raise_on_execute is not None: + raise self.raise_on_execute + return ["/fake/output.mp4"] + + +def test_worker_processes_gpu_job(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="video", model="ltx-fast", params={"prompt": "test"}, slot="gpu") + + executor = FakeJobExecutor() + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=FakeJobExecutor()) + worker.tick() + + assert len(executor.executed_jobs) == 1 + assert executor.executed_jobs[0].id == job.id + updated = queue.get_job(job.id) + assert updated is not None + assert updated.status == "complete" + assert updated.result_paths == ["/fake/output.mp4"] + + +def test_worker_processes_api_job(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="video", model="seedance-1.5-pro", params={"prompt": "test"}, slot="api") + + api_executor = FakeJobExecutor() + worker = QueueWorker(queue=queue, gpu_executor=FakeJobExecutor(), api_executor=api_executor) + worker.tick() + + assert len(api_executor.executed_jobs) == 1 + updated = queue.get_job(job.id) + assert updated is not None + assert updated.status == "complete" + + +def test_worker_handles_execution_error(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + + executor = FakeJobExecutor() + executor.raise_on_execute = RuntimeError("GPU exploded") + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=FakeJobExecutor()) + worker.tick() + + updated = queue.get_job(job.id) + assert updated is not None + assert updated.status == "error" + assert updated.error == "GPU exploded" + + +def test_worker_runs_gpu_and_api_in_parallel(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + gpu_job = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + api_job = queue.submit(job_type="video", model="seedance-1.5-pro", params={}, slot="api") + + gpu_executor = FakeJobExecutor() + api_executor = FakeJobExecutor() + worker = QueueWorker(queue=queue, gpu_executor=gpu_executor, api_executor=api_executor) + worker.tick() + + # Both should have been picked up in one tick + assert len(gpu_executor.executed_jobs) == 1 + assert len(api_executor.executed_jobs) == 1 + + +def test_worker_skips_cancelled_job(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="video", model="ltx-fast", params={}, slot="gpu") + queue.cancel_job(job.id) + + executor = FakeJobExecutor() + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=FakeJobExecutor()) + worker.tick() + + assert len(executor.executed_jobs) == 0 +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && uv run pytest tests/test_queue_worker.py -v --tb=short` +Expected: FAIL + +**Step 3: Implement QueueWorker** + +`backend/handlers/queue_worker.py`: +```python +"""Background queue worker that processes jobs from the job queue.""" + +from __future__ import annotations + +import logging +import threading +from typing import Protocol + +from state.job_queue import JobQueue, QueueJob + +logger = logging.getLogger(__name__) + + +class JobExecutor(Protocol): + def execute(self, job: QueueJob) -> list[str]: + ... + + +class QueueWorker: + def __init__( + self, + *, + queue: JobQueue, + gpu_executor: JobExecutor, + api_executor: JobExecutor, + ) -> None: + self._queue = queue + self._gpu_executor = gpu_executor + self._api_executor = api_executor + self._gpu_busy = False + self._api_busy = False + self._lock = threading.Lock() + + def tick(self) -> None: + """Process one round: pick up available jobs for each free slot.""" + gpu_job: QueueJob | None = None + api_job: QueueJob | None = None + + with self._lock: + if not self._gpu_busy: + gpu_job = self._queue.next_queued_for_slot("gpu") + if gpu_job is not None: + self._gpu_busy = True + self._queue.update_job(gpu_job.id, status="running", phase="starting") + + if not self._api_busy: + api_job = self._queue.next_queued_for_slot("api") + if api_job is not None: + self._api_busy = True + self._queue.update_job(api_job.id, status="running", phase="starting") + + threads: list[threading.Thread] = [] + + if gpu_job is not None: + t = threading.Thread(target=self._run_job, args=(gpu_job, self._gpu_executor, "gpu"), daemon=True) + threads.append(t) + t.start() + + if api_job is not None: + t = threading.Thread(target=self._run_job, args=(api_job, self._api_executor, "api"), daemon=True) + threads.append(t) + t.start() + + for t in threads: + t.join() + + def _run_job(self, job: QueueJob, executor: JobExecutor, slot: str) -> None: + try: + result_paths = executor.execute(job) + self._queue.update_job(job.id, status="complete", progress=100, phase="complete", result_paths=result_paths) + except Exception as exc: + logger.error("Job %s failed: %s", job.id, exc) + self._queue.update_job(job.id, status="error", error=str(exc)) + finally: + with self._lock: + if slot == "gpu": + self._gpu_busy = False + else: + self._api_busy = False +``` + +**Step 4: Run tests** + +Run: `cd backend && uv run pytest tests/test_queue_worker.py -v --tb=short` +Expected: All 5 tests PASS + +**Step 5: Commit** + +```bash +git add backend/handlers/queue_worker.py backend/tests/test_queue_worker.py +git commit -m "feat: add queue worker with dual GPU/API slot parallelism" +``` + +--- + +## Task 5: Queue API Routes + +**Files:** +- Create: `backend/_routes/queue.py` +- Modify: `backend/app_factory.py` +- Modify: `backend/app_handler.py` +- Test: `backend/tests/test_queue_routes.py` + +**Step 1: Write the tests** + +`backend/tests/test_queue_routes.py`: +```python +"""Tests for queue API routes.""" + +from __future__ import annotations + + +def test_submit_video_job(client): + resp = client.post("/api/queue/submit", json={ + "type": "video", + "model": "ltx-fast", + "params": {"prompt": "a cat", "duration": "6", "resolution": "720p", "aspectRatio": "16:9"}, + }) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "queued" + assert "id" in data + + +def test_submit_image_job(client): + resp = client.post("/api/queue/submit", json={ + "type": "image", + "model": "z-image-turbo", + "params": {"prompt": "a dog", "width": 1024, "height": 1024}, + }) + assert resp.status_code == 200 + assert resp.json()["status"] == "queued" + + +def test_get_queue_status(client): + client.post("/api/queue/submit", json={ + "type": "video", + "model": "ltx-fast", + "params": {"prompt": "test"}, + }) + resp = client.get("/api/queue/status") + assert resp.status_code == 200 + jobs = resp.json()["jobs"] + assert len(jobs) == 1 + assert jobs[0]["status"] == "queued" + + +def test_cancel_job(client): + submit_resp = client.post("/api/queue/submit", json={ + "type": "video", + "model": "ltx-fast", + "params": {"prompt": "test"}, + }) + job_id = submit_resp.json()["id"] + cancel_resp = client.post(f"/api/queue/cancel/{job_id}") + assert cancel_resp.status_code == 200 + assert cancel_resp.json()["status"] == "cancelled" + + +def test_clear_finished_jobs(client): + submit_resp = client.post("/api/queue/submit", json={ + "type": "video", + "model": "ltx-fast", + "params": {"prompt": "test"}, + }) + job_id = submit_resp.json()["id"] + client.post(f"/api/queue/cancel/{job_id}") + client.post("/api/queue/clear") + status_resp = client.get("/api/queue/status") + assert len(status_resp.json()["jobs"]) == 0 +``` + +**Step 2: Add request/response types to api_types.py** + +Add to `backend/api_types.py`: +```python +class QueueSubmitRequest(BaseModel): + type: Literal["video", "image"] + model: str + params: dict[str, object] = {} + +class QueueJobResponse(BaseModel): + id: str + type: str + model: str + params: dict[str, object] + status: str + slot: str + progress: int + phase: str + result_paths: list[str] + error: str | None + created_at: str + +class QueueStatusResponse(BaseModel): + jobs: list[QueueJobResponse] + +class QueueSubmitResponse(BaseModel): + id: str + status: str +``` + +**Step 3: Implement slot assignment logic and wire queue into AppHandler** + +Add to `backend/app_handler.py`: +- Import `JobQueue` from `state.job_queue` +- In `AppHandler.__init__`: create `self.job_queue = JobQueue(persistence_path=config.settings_file.parent / "job_queue.json")` +- Add a method `determine_slot(model: str) -> str` that returns: + - `"api"` if model is `"seedance-1.5-pro"` or `"nano-banana-2"` + - `"api"` if `self.config.force_api_generations` is True + - `"gpu"` otherwise + +**Step 4: Create the route file** + +`backend/_routes/queue.py`: +```python +"""Route handlers for /api/queue/*.""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends + +from api_types import QueueSubmitRequest, QueueSubmitResponse, QueueStatusResponse, QueueJobResponse +from state import get_state_service +from app_handler import AppHandler + +router = APIRouter(prefix="/api/queue", tags=["queue"]) + + +@router.post("/submit", response_model=QueueSubmitResponse) +def route_queue_submit( + req: QueueSubmitRequest, + handler: AppHandler = Depends(get_state_service), +) -> QueueSubmitResponse: + slot = handler.determine_slot(req.model) + job = handler.job_queue.submit( + job_type=req.type, + model=req.model, + params=dict(req.params), + slot=slot, + ) + return QueueSubmitResponse(id=job.id, status=job.status) + + +@router.get("/status", response_model=QueueStatusResponse) +def route_queue_status( + handler: AppHandler = Depends(get_state_service), +) -> QueueStatusResponse: + jobs = handler.job_queue.get_all_jobs() + return QueueStatusResponse(jobs=[ + QueueJobResponse( + id=j.id, type=j.type, model=j.model, params=dict(j.params), + status=j.status, slot=j.slot, progress=j.progress, phase=j.phase, + result_paths=j.result_paths, error=j.error, created_at=j.created_at, + ) + for j in jobs + ]) + + +@router.post("/cancel/{job_id}") +def route_queue_cancel( + job_id: str, + handler: AppHandler = Depends(get_state_service), +) -> QueueSubmitResponse: + handler.job_queue.cancel_job(job_id) + job = handler.job_queue.get_job(job_id) + status = job.status if job else "not_found" + return QueueSubmitResponse(id=job_id, status=status) + + +@router.post("/clear") +def route_queue_clear( + handler: AppHandler = Depends(get_state_service), +) -> QueueStatusResponse: + handler.job_queue.clear_finished() + return route_queue_status(handler) +``` + +**Step 5: Register router in app_factory.py** + +Add to `backend/app_factory.py`: +- Import: `from _routes.queue import router as queue_router` +- Add: `app.include_router(queue_router)` + +**Step 6: Run tests** + +Run: `cd backend && uv run pytest tests/test_queue_routes.py -v --tb=short` +Expected: All 5 tests PASS + +Run: `cd backend && uv run pytest -v --tb=short` +Expected: All tests PASS + +**Step 7: Commit** + +```bash +git add backend/_routes/queue.py backend/api_types.py backend/app_handler.py +git add backend/app_factory.py backend/tests/test_queue_routes.py +git commit -m "feat: add queue API routes (submit, status, cancel, clear)" +``` + +--- + +## Task 6: Frontend — Settings (video_model) + +**Files:** +- Modify: `frontend/contexts/AppSettingsContext.tsx` +- Modify: `frontend/components/SettingsModal.tsx` + +**Step 1: Add videoModel to AppSettingsContext** + +In `frontend/contexts/AppSettingsContext.tsx`: +- Add to `AppSettings` interface: `videoModel: string` +- Add to `DEFAULT_APP_SETTINGS`: `videoModel: 'ltx-fast'` +- Add to `normalizeAppSettings`: `videoModel: data.videoModel ?? DEFAULT_APP_SETTINGS.videoModel` + +**Step 2: Add Video Model dropdown to SettingsModal** + +In `frontend/components/SettingsModal.tsx`, add a "Video Model" select with options: +- `ltx-fast` — "LTX Fast" +- `seedance-1.5-pro` — "Seedance 1.5 Pro" + +Bound to `settings.videoModel`, save via `updateSettings({ videoModel: value })`. + +**Step 3: Commit** + +```bash +git add frontend/contexts/AppSettingsContext.tsx frontend/components/SettingsModal.tsx +git commit -m "feat: add video model selector (LTX Fast | Seedance 1.5 Pro) to settings" +``` + +--- + +## Task 7: Frontend — Queue UI + use-generation Refactor + +**Files:** +- Modify: `frontend/hooks/use-generation.ts` +- Modify: relevant view component that shows generation results + +**Step 1: Update use-generation.ts to submit to queue** + +Replace the direct `POST /api/generate` and `POST /api/generate-image` calls with `POST /api/queue/submit`. Replace progress polling from `/api/generation/progress` to `/api/queue/status`. + +The hook should: +1. Submit job to `/api/queue/submit` with `{type, model, params}` +2. Poll `/api/queue/status` every 500ms +3. Track all jobs (not just the latest) — expose `jobs` array +4. Keep the existing `generate()` / `generateImage()` API signatures for now, just route through queue + +**Step 2: Show queue status in results area** + +Update the results area component to render `jobs` from use-generation: +- Each job: status badge, prompt preview, progress bar (running), thumbnail/link (complete), error message (error) +- Multiple jobs can show simultaneously +- Generate button re-enables immediately after submission + +**Step 3: Commit** + +```bash +git add frontend/hooks/use-generation.ts frontend/components/ frontend/views/ +git commit -m "feat: frontend queue UI — submit to queue, poll status, show all jobs" +``` + +--- + +## Task 8: Run Full Test Suite + Type Checks + +**Step 1: Run backend tests** + +Run: `cd backend && uv run pytest -v --tb=short` +Expected: All tests PASS + +**Step 2: Run Python type checks** + +Run: `cd backend && uv run pyright` +Expected: No errors + +**Step 3: Run TypeScript type checks** + +Run: `npx pnpm typecheck:ts` +Expected: No errors + +**Step 4: Final commit** + +```bash +git add -A +git commit -m "chore: all tests and type checks passing for queue + seedance" +``` + +--- + +## Slot Assignment Summary + +| Model | Condition | Slot | +|-------|-----------|------| +| `ltx-fast` | local GPU available (`force_api=false`) | `gpu` | +| `ltx-fast` | `force_api=true` or no GPU | `api` | +| `seedance-1.5-pro` | always | `api` | +| `z-image-turbo` | local GPU available (`force_api=false`) | `gpu` | +| `z-image-turbo` | `force_api=true` or no GPU | `api` | +| `nano-banana-2` | always | `api` | diff --git a/docs/plans/2026-03-08-bulk-generation-design.md b/docs/plans/2026-03-08-bulk-generation-design.md new file mode 100644 index 00000000..0d120a87 --- /dev/null +++ b/docs/plans/2026-03-08-bulk-generation-design.md @@ -0,0 +1,270 @@ +# Bulk Generation Design + +**Date:** 2026-03-08 +**Status:** Approved + +## Overview + +Add bulk image generation (with/without LoRAs), bulk video generation, and image-to-video pipeline chaining to LTX Desktop. Supports local GPU and cloud API execution. Three input modes: manual list, CSV/JSON import, and grid sweep builder. + +## Architecture + +Extends the existing job queue system. A batch is a group of `QueueJob` entries sharing a `batch_id`. Server-side expansion keeps frontend thin. + +``` +Frontend BatchBuilder → POST /api/queue/submit-batch → BatchHandler.submit_batch() + → expands to N QueueJobs (list | sweep cartesian product | pipeline chain) + → QueueWorker.tick() dispatches per slot with dependency checking + → BatchReport generated on completion → sound + toast notification +``` + +## Data Model + +### QueueJob Extensions + +New fields on existing `QueueJob` dataclass: + +| Field | Type | Purpose | +|-------|------|---------| +| `batch_id` | `str \| None` | Groups jobs from same batch | +| `batch_index` | `int` | Position within batch (grid ordering) | +| `depends_on` | `str \| None` | Job ID that must complete first | +| `auto_params` | `dict[str, str]` | Template refs resolved at dispatch, e.g. `{"imagePath": "$dep.result_paths[0]"}` | +| `tags` | `list[str]` | Gallery filtering, e.g. `["batch:abc123", "sweep:lora_weight"]` | + +### API Types + +```python +class BatchSubmitRequest(BaseModel): + mode: Literal["list", "sweep", "pipeline"] + target: Literal["local", "cloud"] + jobs: list[BatchJobItem] | None = None # mode: list + sweep: SweepDefinition | None = None # mode: sweep + pipeline: PipelineDefinition | None = None # mode: pipeline + +class BatchJobItem(BaseModel): + type: Literal["video", "image"] + model: str + params: dict[str, object] = {} + +class SweepDefinition(BaseModel): + base_type: Literal["video", "image"] + base_model: str + base_params: dict[str, object] = {} + axes: list[SweepAxis] # 1-3 axes + +class SweepAxis(BaseModel): + param: str # e.g. "loraWeight", "prompt", "loraPath" + values: list[object] # e.g. [0.5, 0.75, 1.0] + mode: Literal["replace", "search_replace"] = "replace" + search: str | None = None # For search_replace mode + +class PipelineDefinition(BaseModel): + steps: list[PipelineStep] + +class PipelineStep(BaseModel): + type: Literal["video", "image"] + model: str + params: dict[str, object] = {} + auto_prompt: bool = False # Generate i2v motion prompt from previous step's image + +class BatchSubmitResponse(BaseModel): + batch_id: str + job_ids: list[str] + total_jobs: int + +class BatchStatusResponse(BaseModel): + batch_id: str + total: int + completed: int + failed: int + running: int + queued: int + jobs: list[QueueJobResponse] + report: BatchReport | None = None # Populated when batch fully resolved + +class BatchReport(BaseModel): + batch_id: str + total: int + succeeded: int + failed: int + cancelled: int + duration_seconds: float + avg_job_seconds: float + result_paths: list[str] # Ordered by batch_index + failed_indices: list[int] # Grid gaps + sweep_axes: list[str] | None = None +``` + +## Backend: BatchHandler + +New handler at `backend/handlers/batch_handler.py`. + +```python +class BatchHandler: + def submit_batch(self, request, queue) -> BatchSubmitResponse: + batch_id = generate_id() + slot = "api" if request.target == "cloud" else "gpu" + match request.mode: + case "list": jobs = self._expand_list(request.jobs, batch_id, slot) + case "sweep": jobs = self._expand_sweep(request.sweep, batch_id, slot) + case "pipeline": jobs = self._expand_pipeline(request.pipeline, batch_id, slot) + for job in jobs: + queue.submit(job) + return BatchSubmitResponse(batch_id=batch_id, job_ids=[j.id for j in jobs], total_jobs=len(jobs)) +``` + +- `_expand_list`: Maps each `BatchJobItem` to a `QueueJob` with shared `batch_id`. +- `_expand_sweep`: Computes cartesian product of axes. Each combination becomes a job with `batch_index` encoding grid position. +- `_expand_pipeline`: Creates chained jobs where step N has `depends_on` pointing to step N-1. For multi-image pipelines, each image spawns its own chain. + +## Backend: QueueWorker Changes + +### Dependency Checking + +In `_next_job_for_slot()`, before dispatching a job: + +```python +if job.depends_on: + dep = queue.get(job.depends_on) + if dep.status == "complete": + resolve_auto_params(job, dep) # Substitute $dep.result_paths[0] etc. + return job # Ready to run + elif dep.status == "error": + job.status = "error" + job.error = f"Upstream job {dep.id} failed" + continue # Skip, check next + else: + continue # Not ready yet, skip +else: + return job # No dependency, dispatch normally +``` + +### Batch Completion Detection + +After each tick, check if any batch just became fully resolved: + +```python +def _check_batch_completions(self): + for batch_id in queue.active_batch_ids(): + jobs = queue.jobs_for_batch(batch_id) + if all(j.status in ("complete", "error", "cancelled") for j in jobs): + if batch_id not in self._notified_batches: + self._notified_batches.add(batch_id) + self._emit_batch_complete(batch_id, jobs) +``` + +### Failure Handling + +- Individual job failure does NOT cancel the batch. +- Remaining jobs continue processing. +- `POST /api/queue/batch/{batch_id}/retry-failed` re-queues errored jobs with same params + fresh IDs. +- `POST /api/queue/batch/{batch_id}/cancel` cancels all `queued` jobs, aborts `running` if possible. + +## Backend: I2V Auto-Prompt Generation + +When a pipeline step has `auto_prompt: true`, the system generates a motion-focused video prompt from the previous step's image output. + +### Enhancement Mode + +New `i2v_motion` level in `EnhancePromptHandler`: + +```python +class EnhancePromptRequest(BaseModel): + prompt: str + level: Literal["standard", "creative", "i2v_motion"] = "standard" + image_path: str | None = None # Required for i2v_motion +``` + +### I2V Motion System Prompt + +> You are an expert cinematographer. Given an image, generate a video motion prompt. Rules: +> - Do NOT describe what the image already shows — describe what HAPPENS NEXT +> - Structure: subject motion, then camera movement, then environmental dynamics +> - Use specific cinematography terms: dolly, pan, tilt, orbit, truck, rack focus +> - Add motion intensity qualifiers: subtle, steady, dramatic +> - Animate empty regions (sky, water, foliage) to prevent frozen areas +> - Use present tense, single flowing paragraph, 4-6 sentences + +### Two-Stage Flow + +1. Caption the image via Gemini vision (or Palette enhance endpoint). +2. Feed caption + i2v_motion system prompt to LLM. +3. Inject generated motion prompt into the video job's params. + +Executed in `QueueWorker._resolve_auto_params()` when `auto_params["auto_prompt"]` is set. + +## Backend: Routes + +New route file `backend/_routes/batch.py`: + +| Method | Path | Purpose | +|--------|------|---------| +| `POST` | `/api/queue/submit-batch` | Submit a batch definition | +| `GET` | `/api/queue/batch/{batch_id}/status` | Batch aggregate status + report | +| `POST` | `/api/queue/batch/{batch_id}/retry-failed` | Re-queue failed jobs | +| `POST` | `/api/queue/batch/{batch_id}/cancel` | Cancel remaining jobs | + +## Batch Completion: Sound + Report + +### Sound + +Frontend plays a short completion chime when polling detects a batch fully resolved. Respects `batchSoundEnabled` in app settings (default: on). Also triggers an Electron native notification if the app is in background. + +### Report + +`BatchReport` is populated on the `BatchStatusResponse` when all jobs are resolved. Contains success/fail counts, wall-clock duration, average job time, all result paths ordered by `batch_index`, and failed indices for grid gap display. + +### Toast + +Frontend shows: **"Batch complete — 47/50 succeeded (3 failed) in 12m 34s"** with a "View Results" button that filters gallery to `batch:{batch_id}`. + +## Frontend: Batch Builder UI + +New modal or panel accessible from GenSpace. Three tabs: + +### Tab 1: Manual List + +Table with rows: prompt, type (image/video), model, LoRA, LoRA weight. "Add Row" button, duplicate, delete, drag-to-reorder. Each row inherits current GenSpace settings as defaults. + +### Tab 2: CSV/JSON Import + +File picker or paste textarea. Preview table with validation (highlight errors). Supported CSV columns: `prompt, type, model, lora_path, lora_weight, width, height, duration, fps, camera_motion, seed`. JSON supports `defaults` block + `jobs` array. + +### Tab 3: Grid Builder (Sweeps) + +Base prompt + settings at top. Up to 3 axis selectors: pick param, enter values (comma-separated or range syntax `0.3-1.0:8`). Live preview showing grid dimensions and total job count. + +### Pipeline Toggle + +Checkbox on any batch: "Also generate video from each image." When checked, each image job gets a chained video job with `auto_prompt: true`. User sets video model, duration, fps for the chained step. + +### Batch Queue Panel + +Existing queue panel gains batch grouping: collapsible headers showing "Batch: 12/50 complete." Per-batch actions: cancel, retry failed. + +## Gallery Integration + +Minimal changes. Jobs in a batch get tags (`batch:{batch_id}`, `sweep:{param}`). Gallery gains a filter dropdown for batch tags. Results sorted by `batch_index` to preserve grid order. + +## Per-Batch Execution Target + +User picks "Run locally" or "Run on cloud" per batch. Maps to slot assignment: `target: "local"` → `slot: "gpu"`, `target: "cloud"` → `slot: "api"`. Cloud requires Palette API key or Replicate API key. + +## Testing Strategy + +- Unit tests for sweep expansion (cartesian product correctness). +- Unit tests for dependency resolution in QueueWorker. +- Integration tests for batch submit → poll → completion flow using fake services. +- Test partial failure: job 3/10 fails, jobs 4-10 continue, report shows 9/10 with 1 failed. +- Test pipeline: image job completes → video job auto-dispatches with resolved params. +- Test i2v auto-prompt: mock enhance handler, verify motion prompt injected. +- CSV/JSON parsing tests with edge cases (empty fields, invalid values, missing headers). + +## Key Decisions + +1. **Server-side expansion** over client-side — testable, atomic, consistent. +2. **`depends_on` single-parent** over full DAG — sufficient for i2v chains without engine complexity. +3. **Tags for gallery** over separate batch gallery view — minimal frontend changes, uses existing gallery. +4. **Per-batch target** over per-job target — simpler UX, avoids confusing mixed-slot batches. +5. **Sound + toast + report** on completion — easy wins for UX polish. diff --git a/docs/plans/2026-03-08-bulk-generation-plan.md b/docs/plans/2026-03-08-bulk-generation-plan.md new file mode 100644 index 00000000..1124cd16 --- /dev/null +++ b/docs/plans/2026-03-08-bulk-generation-plan.md @@ -0,0 +1,1982 @@ +# Bulk Generation Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Add bulk image generation (with/without LoRAs), bulk video generation, LoRA weight sweeps, CSV/JSON import, grid sweep builder, image-to-video pipeline chaining with auto-prompt generation, and batch completion notifications. + +**Architecture:** Extends the existing job queue with `batch_id`, `depends_on`, and `auto_params` fields. New `BatchHandler` expands batch definitions server-side into individual `QueueJob` entries. `QueueWorker` gains dependency checking. Frontend adds a batch builder modal with three tabs (list, import, grid). Existing gallery gains batch tag filtering. + +**Tech Stack:** Python/FastAPI (backend), React/TypeScript (frontend), existing JobQueue persistence, Gemini API for i2v prompt generation. + +**Design doc:** `docs/plans/2026-03-08-bulk-generation-design.md` + +--- + +## Task 1: Extend QueueJob Dataclass with Batch Fields + +**Files:** +- Modify: `backend/state/job_queue.py:13-26` (QueueJob dataclass) +- Modify: `backend/state/job_queue.py:106-123` (persistence _save/_load) +- Test: `backend/tests/test_job_queue.py` + +**Step 1: Write the failing test** + +```python +# backend/tests/test_job_queue.py — add new test + +def test_submit_job_with_batch_fields(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit( + job_type="image", + model="zit", + params={"prompt": "a cat"}, + slot="gpu", + batch_id="batch_001", + batch_index=3, + depends_on="job_abc", + tags=["batch:batch_001", "sweep:lora_weight"], + ) + assert job.batch_id == "batch_001" + assert job.batch_index == 3 + assert job.depends_on == "job_abc" + assert job.tags == ["batch:batch_001", "sweep:lora_weight"] + + # Verify persistence round-trip + queue2 = JobQueue(persistence_path=tmp_path / "queue.json") + loaded = queue2.get(job.id) + assert loaded is not None + assert loaded.batch_id == "batch_001" + assert loaded.batch_index == 3 + assert loaded.depends_on == "job_abc" + assert loaded.tags == ["batch:batch_001", "sweep:lora_weight"] +``` + +**Step 2: Run test to verify it fails** + +Run: `cd backend && uv run pytest tests/test_job_queue.py::test_submit_job_with_batch_fields -v --tb=short` +Expected: FAIL — `submit()` doesn't accept `batch_id`, `batch_index`, `depends_on`, `tags` params + +**Step 3: Add batch fields to QueueJob dataclass** + +In `backend/state/job_queue.py`, extend the `QueueJob` dataclass (line 13-26): + +```python +@dataclass +class QueueJob: + id: str + type: Literal["video", "image"] + model: str + params: dict[str, Any] + status: Literal["queued", "running", "complete", "error", "cancelled"] + slot: Literal["gpu", "api"] + progress: int = 0 + phase: str = "queued" + result_paths: list[str] = field(default_factory=list) + error: str | None = None + created_at: str = "" + # Batch fields + batch_id: str | None = None + batch_index: int = 0 + depends_on: str | None = None + auto_params: dict[str, str] = field(default_factory=dict) + tags: list[str] = field(default_factory=list) +``` + +**Step 4: Update `submit()` to accept new fields** + +In `backend/state/job_queue.py`, update the `submit()` method signature to accept optional batch fields and pass them through to the QueueJob constructor. + +**Step 5: Update `_load()` for backwards compatibility** + +In the `_load()` method, when deserializing old jobs that lack batch fields, provide defaults: + +```python +batch_id=d.get("batch_id"), +batch_index=d.get("batch_index", 0), +depends_on=d.get("depends_on"), +auto_params=d.get("auto_params", {}), +tags=d.get("tags", []), +``` + +**Step 6: Run test to verify it passes** + +Run: `cd backend && uv run pytest tests/test_job_queue.py::test_submit_job_with_batch_fields -v --tb=short` +Expected: PASS + +**Step 7: Run all existing queue tests to verify no regressions** + +Run: `cd backend && uv run pytest tests/test_job_queue.py -v --tb=short` +Expected: All PASS + +**Step 8: Commit** + +```bash +git add backend/state/job_queue.py backend/tests/test_job_queue.py +git commit -m "feat: add batch_id, depends_on, tags fields to QueueJob" +``` + +--- + +## Task 2: Add JobQueue Helper Methods for Batches + +**Files:** +- Modify: `backend/state/job_queue.py` +- Test: `backend/tests/test_job_queue.py` + +**Step 1: Write the failing tests** + +```python +def test_jobs_for_batch(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b1", batch_index=0) + queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b1", batch_index=1) + queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b2", batch_index=0) + queue.submit(job_type="video", model="fast", params={}, slot="gpu") # No batch + + batch_jobs = queue.jobs_for_batch("b1") + assert len(batch_jobs) == 2 + assert all(j.batch_id == "b1" for j in batch_jobs) + assert [j.batch_index for j in batch_jobs] == [0, 1] + + +def test_active_batch_ids(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b1") + queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b2") + queue.submit(job_type="video", model="fast", params={}, slot="gpu") + + ids = queue.active_batch_ids() + assert set(ids) == {"b1", "b2"} + + +def test_active_batch_ids_excludes_fully_resolved(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + job = queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b1") + queue.update_job(job.id, status="complete", result_paths=["/out.png"]) + + ids = queue.active_batch_ids() + assert ids == [] # b1 is fully resolved +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && uv run pytest tests/test_job_queue.py::test_jobs_for_batch tests/test_job_queue.py::test_active_batch_ids tests/test_job_queue.py::test_active_batch_ids_excludes_fully_resolved -v --tb=short` +Expected: FAIL — methods don't exist + +**Step 3: Implement helper methods** + +Add to `JobQueue` class: + +```python +def jobs_for_batch(self, batch_id: str) -> list[QueueJob]: + return sorted( + [j for j in self._jobs if j.batch_id == batch_id], + key=lambda j: j.batch_index, + ) + +def active_batch_ids(self) -> list[str]: + batch_ids: set[str] = set() + for job in self._jobs: + if job.batch_id and job.status in ("queued", "running"): + batch_ids.add(job.batch_id) + return sorted(batch_ids) + +def get(self, job_id: str) -> QueueJob | None: + for job in self._jobs: + if job.id == job_id: + return job + return None +``` + +Note: Check if `get()` already exists. If so, skip adding it. + +**Step 4: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_job_queue.py -v --tb=short` +Expected: All PASS + +**Step 5: Commit** + +```bash +git add backend/state/job_queue.py backend/tests/test_job_queue.py +git commit -m "feat: add jobs_for_batch, active_batch_ids, get helpers to JobQueue" +``` + +--- + +## Task 3: Add Batch API Types + +**Files:** +- Modify: `backend/api_types.py:321-324` (after existing queue types) +- Test: `backend/tests/test_batch.py` (NEW) + +**Step 1: Write the failing test** + +```python +# backend/tests/test_batch.py + +from backend.api_types import ( + BatchJobItem, + BatchSubmitRequest, + BatchSubmitResponse, + BatchStatusResponse, + BatchReport, + SweepAxis, + SweepDefinition, + PipelineStep, + PipelineDefinition, +) + + +def test_batch_submit_request_list_mode() -> None: + req = BatchSubmitRequest( + mode="list", + target="local", + jobs=[ + BatchJobItem(type="image", model="zit", params={"prompt": "a cat"}), + BatchJobItem(type="image", model="zit", params={"prompt": "a dog"}), + ], + ) + assert req.mode == "list" + assert len(req.jobs) == 2 + + +def test_batch_submit_request_sweep_mode() -> None: + req = BatchSubmitRequest( + mode="sweep", + target="cloud", + sweep=SweepDefinition( + base_type="image", + base_model="zit", + base_params={"prompt": "a cat", "width": 1024, "height": 1024}, + axes=[ + SweepAxis(param="loraWeight", values=[0.5, 0.75, 1.0]), + SweepAxis(param="prompt", values=["a cat", "a dog"], mode="search_replace", search="a cat"), + ], + ), + ) + assert req.sweep is not None + assert len(req.sweep.axes) == 2 + + +def test_batch_submit_request_pipeline_mode() -> None: + req = BatchSubmitRequest( + mode="pipeline", + target="local", + pipeline=PipelineDefinition( + steps=[ + PipelineStep(type="image", model="zit", params={"prompt": "a landscape"}), + PipelineStep(type="video", model="fast", params={}, auto_prompt=True), + ], + ), + ) + assert req.pipeline is not None + assert req.pipeline.steps[1].auto_prompt is True + + +def test_batch_report_model() -> None: + report = BatchReport( + batch_id="abc123", + total=10, + succeeded=8, + failed=2, + cancelled=0, + duration_seconds=120.5, + avg_job_seconds=12.05, + result_paths=["/out/1.png", "/out/2.png"], + failed_indices=[3, 7], + sweep_axes=["loraWeight"], + ) + assert report.succeeded + report.failed == report.total +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && uv run pytest tests/test_batch.py -v --tb=short` +Expected: FAIL — imports don't exist + +**Step 3: Add batch types to api_types.py** + +Add after the existing `QueueSubmitRequest` (around line 324): + +```python +# --- Batch Generation Types --- + +class BatchJobItem(BaseModel): + type: Literal["video", "image"] + model: str + params: dict[str, object] = {} + +class SweepAxis(BaseModel): + param: str + values: list[object] + mode: Literal["replace", "search_replace"] = "replace" + search: str | None = None + +class SweepDefinition(BaseModel): + base_type: Literal["video", "image"] + base_model: str + base_params: dict[str, object] = {} + axes: list[SweepAxis] + +class PipelineStep(BaseModel): + type: Literal["video", "image"] + model: str + params: dict[str, object] = {} + auto_prompt: bool = False + +class PipelineDefinition(BaseModel): + steps: list[PipelineStep] + +class BatchSubmitRequest(BaseModel): + mode: Literal["list", "sweep", "pipeline"] + target: Literal["local", "cloud"] + jobs: list[BatchJobItem] | None = None + sweep: SweepDefinition | None = None + pipeline: PipelineDefinition | None = None + +class BatchSubmitResponse(BaseModel): + batch_id: str + job_ids: list[str] + total_jobs: int + +class BatchReport(BaseModel): + batch_id: str + total: int + succeeded: int + failed: int + cancelled: int + duration_seconds: float + avg_job_seconds: float + result_paths: list[str] + failed_indices: list[int] + sweep_axes: list[str] | None = None + +class BatchStatusResponse(BaseModel): + batch_id: str + total: int + completed: int + failed: int + running: int + queued: int + jobs: list[QueueJobResponse] + report: BatchReport | None = None +``` + +**Step 4: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_batch.py -v --tb=short` +Expected: All PASS + +**Step 5: Commit** + +```bash +git add backend/api_types.py backend/tests/test_batch.py +git commit -m "feat: add batch generation API types" +``` + +--- + +## Task 4: Implement BatchHandler — List Mode + +**Files:** +- Create: `backend/handlers/batch_handler.py` +- Test: `backend/tests/test_batch.py` + +**Step 1: Write the failing test** + +```python +# backend/tests/test_batch.py — add + +from backend.handlers.batch_handler import BatchHandler +from backend.state.job_queue import JobQueue + + +def test_batch_handler_expand_list(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="list", + target="local", + jobs=[ + BatchJobItem(type="image", model="zit", params={"prompt": "a cat"}), + BatchJobItem(type="video", model="fast", params={"prompt": "a dog running"}), + ], + ) + response = handler.submit_batch(request, queue) + + assert response.total_jobs == 2 + assert len(response.job_ids) == 2 + + jobs = queue.jobs_for_batch(response.batch_id) + assert len(jobs) == 2 + assert jobs[0].type == "image" + assert jobs[0].slot == "gpu" + assert jobs[0].batch_index == 0 + assert jobs[0].tags == [f"batch:{response.batch_id}"] + assert jobs[1].type == "video" + assert jobs[1].batch_index == 1 + + +def test_batch_handler_list_cloud_target(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="list", + target="cloud", + jobs=[BatchJobItem(type="image", model="zit", params={"prompt": "a cat"})], + ) + response = handler.submit_batch(request, queue) + job = queue.get(response.job_ids[0]) + assert job is not None + assert job.slot == "api" +``` + +**Step 2: Run test to verify it fails** + +Run: `cd backend && uv run pytest tests/test_batch.py::test_batch_handler_expand_list tests/test_batch.py::test_batch_handler_list_cloud_target -v --tb=short` +Expected: FAIL — module doesn't exist + +**Step 3: Implement BatchHandler with list mode** + +Create `backend/handlers/batch_handler.py`: + +```python +from __future__ import annotations + +import uuid + +from backend.api_types import ( + BatchJobItem, + BatchSubmitRequest, + BatchSubmitResponse, +) +from backend.state.job_queue import JobQueue + + +class BatchHandler: + def submit_batch(self, request: BatchSubmitRequest, queue: JobQueue) -> BatchSubmitResponse: + batch_id = uuid.uuid4().hex[:8] + slot = "api" if request.target == "cloud" else "gpu" + + if request.mode == "list": + jobs = self._expand_list(request.jobs or [], batch_id, slot) + elif request.mode == "sweep": + raise NotImplementedError("sweep mode not yet implemented") + elif request.mode == "pipeline": + raise NotImplementedError("pipeline mode not yet implemented") + else: + raise ValueError(f"Unknown batch mode: {request.mode}") + + job_ids: list[str] = [] + for job_def in jobs: + job = queue.submit( + job_type=job_def["type"], + model=job_def["model"], + params=job_def["params"], + slot=slot, + batch_id=batch_id, + batch_index=job_def["batch_index"], + tags=[f"batch:{batch_id}"], + ) + job_ids.append(job.id) + + return BatchSubmitResponse(batch_id=batch_id, job_ids=job_ids, total_jobs=len(job_ids)) + + def _expand_list( + self, items: list[BatchJobItem], batch_id: str, slot: str + ) -> list[dict[str, object]]: + return [ + { + "type": item.type, + "model": item.model, + "params": item.params, + "batch_index": i, + } + for i, item in enumerate(items) + ] +``` + +**Step 4: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_batch.py -v --tb=short` +Expected: All PASS + +**Step 5: Commit** + +```bash +git add backend/handlers/batch_handler.py backend/tests/test_batch.py +git commit -m "feat: implement BatchHandler with list mode expansion" +``` + +--- + +## Task 5: Implement BatchHandler — Sweep Mode + +**Files:** +- Modify: `backend/handlers/batch_handler.py` +- Test: `backend/tests/test_batch.py` + +**Step 1: Write the failing tests** + +```python +def test_batch_handler_sweep_single_axis(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="sweep", + target="local", + sweep=SweepDefinition( + base_type="image", + base_model="zit", + base_params={"prompt": "a cat", "width": 1024, "height": 1024}, + axes=[SweepAxis(param="loraWeight", values=[0.5, 0.75, 1.0])], + ), + ) + response = handler.submit_batch(request, queue) + + assert response.total_jobs == 3 + jobs = queue.jobs_for_batch(response.batch_id) + assert jobs[0].params["loraWeight"] == 0.5 + assert jobs[1].params["loraWeight"] == 0.75 + assert jobs[2].params["loraWeight"] == 1.0 + # All share base params + assert all(j.params["prompt"] == "a cat" for j in jobs) + assert "sweep:loraWeight" in jobs[0].tags + + +def test_batch_handler_sweep_two_axes_cartesian(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="sweep", + target="local", + sweep=SweepDefinition( + base_type="image", + base_model="zit", + base_params={"prompt": "a cat", "width": 1024}, + axes=[ + SweepAxis(param="loraWeight", values=[0.5, 1.0]), + SweepAxis(param="numSteps", values=[4, 8]), + ], + ), + ) + response = handler.submit_batch(request, queue) + + assert response.total_jobs == 4 # 2 x 2 cartesian product + jobs = queue.jobs_for_batch(response.batch_id) + combos = [(j.params["loraWeight"], j.params["numSteps"]) for j in jobs] + assert (0.5, 4) in combos + assert (0.5, 8) in combos + assert (1.0, 4) in combos + assert (1.0, 8) in combos + + +def test_batch_handler_sweep_search_replace(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="sweep", + target="local", + sweep=SweepDefinition( + base_type="image", + base_model="zit", + base_params={"prompt": "a cute cat in a garden"}, + axes=[ + SweepAxis( + param="prompt", + values=["cat", "dog", "horse"], + mode="search_replace", + search="cat", + ), + ], + ), + ) + response = handler.submit_batch(request, queue) + + assert response.total_jobs == 3 + jobs = queue.jobs_for_batch(response.batch_id) + assert jobs[0].params["prompt"] == "a cute cat in a garden" + assert jobs[1].params["prompt"] == "a cute dog in a garden" + assert jobs[2].params["prompt"] == "a cute horse in a garden" +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && uv run pytest tests/test_batch.py::test_batch_handler_sweep_single_axis tests/test_batch.py::test_batch_handler_sweep_two_axes_cartesian tests/test_batch.py::test_batch_handler_sweep_search_replace -v --tb=short` +Expected: FAIL — NotImplementedError + +**Step 3: Implement sweep expansion** + +Add to `BatchHandler`: + +```python +import itertools + +def _expand_sweep( + self, sweep: SweepDefinition, batch_id: str, slot: str +) -> list[dict[str, object]]: + # Build value lists per axis + axis_values: list[list[tuple[str, object]]] = [] + for axis in sweep.axes: + pairs: list[tuple[str, object]] = [] + for val in axis.values: + pairs.append((axis.param, val)) + axis_values.append(pairs) + + # Cartesian product + combos = list(itertools.product(*axis_values)) + results: list[dict[str, object]] = [] + for i, combo in enumerate(combos): + params = dict(sweep.base_params) + for param_name, value in combo: + axis_def = next(a for a in sweep.axes if a.param == param_name) + if axis_def.mode == "search_replace" and axis_def.search and param_name in params: + current = str(params[param_name]) + params[param_name] = current.replace(axis_def.search, str(value)) + else: + params[param_name] = value + results.append({ + "type": sweep.base_type, + "model": sweep.base_model, + "params": params, + "batch_index": i, + }) + + return results +``` + +Update `submit_batch()` to pass sweep tags and call `_expand_sweep`. Add sweep axis names to tags: `[f"batch:{batch_id}"] + [f"sweep:{a.param}" for a in sweep.axes]`. + +**Step 4: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_batch.py -v --tb=short` +Expected: All PASS + +**Step 5: Commit** + +```bash +git add backend/handlers/batch_handler.py backend/tests/test_batch.py +git commit -m "feat: implement sweep mode with cartesian product and search/replace" +``` + +--- + +## Task 6: Implement BatchHandler — Pipeline Mode + +**Files:** +- Modify: `backend/handlers/batch_handler.py` +- Test: `backend/tests/test_batch.py` + +**Step 1: Write the failing tests** + +```python +def test_batch_handler_pipeline_creates_chained_jobs(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="pipeline", + target="local", + pipeline=PipelineDefinition( + steps=[ + PipelineStep(type="image", model="zit", params={"prompt": "a landscape"}), + PipelineStep(type="video", model="fast", params={"duration": "4"}, auto_prompt=True), + ], + ), + ) + response = handler.submit_batch(request, queue) + + assert response.total_jobs == 2 + jobs = queue.jobs_for_batch(response.batch_id) + img_job = jobs[0] + vid_job = jobs[1] + + assert img_job.type == "image" + assert img_job.depends_on is None + assert vid_job.type == "video" + assert vid_job.depends_on == img_job.id + assert vid_job.auto_params == {"imagePath": "$dep.result_paths[0]", "auto_prompt": "true"} + + +def test_batch_handler_pipeline_three_steps(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + handler = BatchHandler() + + request = BatchSubmitRequest( + mode="pipeline", + target="local", + pipeline=PipelineDefinition( + steps=[ + PipelineStep(type="image", model="zit", params={"prompt": "frame 1"}), + PipelineStep(type="video", model="fast", params={}, auto_prompt=True), + PipelineStep(type="video", model="pro", params={}, auto_prompt=False), + ], + ), + ) + response = handler.submit_batch(request, queue) + + assert response.total_jobs == 3 + jobs = queue.jobs_for_batch(response.batch_id) + assert jobs[0].depends_on is None + assert jobs[1].depends_on == jobs[0].id + assert jobs[2].depends_on == jobs[1].id + assert jobs[2].auto_params == {"imagePath": "$dep.result_paths[0]"} # No auto_prompt +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && uv run pytest tests/test_batch.py::test_batch_handler_pipeline_creates_chained_jobs tests/test_batch.py::test_batch_handler_pipeline_three_steps -v --tb=short` +Expected: FAIL — NotImplementedError + +**Step 3: Implement pipeline expansion** + +Add to `BatchHandler`: + +```python +def _expand_pipeline( + self, pipeline: PipelineDefinition, batch_id: str, slot: str +) -> list[dict[str, object]]: + """Create chained jobs. Each step depends_on the previous step's job ID.""" + results: list[dict[str, object]] = [] + prev_job_id: str | None = None + + for i, step in enumerate(pipeline.steps): + job_id = uuid.uuid4().hex[:8] + auto_params: dict[str, str] = {} + if prev_job_id is not None: + auto_params["imagePath"] = "$dep.result_paths[0]" + if step.auto_prompt: + auto_params["auto_prompt"] = "true" + + results.append({ + "type": step.type, + "model": step.model, + "params": dict(step.params), + "batch_index": i, + "job_id": job_id, + "depends_on": prev_job_id, + "auto_params": auto_params, + }) + prev_job_id = job_id + + return results +``` + +Update `submit_batch()` to pass `depends_on`, `auto_params`, and pre-generated `job_id` through to `queue.submit()`. This requires `queue.submit()` to accept an optional `job_id` override (modify `JobQueue.submit()` to accept `job_id: str | None = None` and use it if provided instead of generating a new one). + +**Step 4: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_batch.py -v --tb=short` +Expected: All PASS + +**Step 5: Commit** + +```bash +git add backend/handlers/batch_handler.py backend/state/job_queue.py backend/tests/test_batch.py +git commit -m "feat: implement pipeline mode with depends_on chaining" +``` + +--- + +## Task 7: QueueWorker Dependency Checking + +**Files:** +- Modify: `backend/handlers/queue_worker.py:34-62` (tick method) +- Test: `backend/tests/test_queue_worker.py` + +**Step 1: Write the failing tests** + +```python +# backend/tests/test_queue_worker.py — add + +def test_worker_skips_job_with_unresolved_dependency(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + parent = queue.submit(job_type="image", model="zit", params={"prompt": "a cat"}, slot="gpu") + child = queue.submit( + job_type="video", model="fast", params={}, slot="gpu", + depends_on=parent.id, + auto_params={"imagePath": "$dep.result_paths[0]"}, + ) + + executor = FakeExecutor(result_paths=[]) + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=executor) + worker.tick() + + # Parent should be running, child still queued (dependency not met) + assert queue.get(parent.id).status == "running" + assert queue.get(child.id).status == "queued" + + +def test_worker_dispatches_dependent_job_after_parent_completes(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + parent = queue.submit(job_type="image", model="zit", params={"prompt": "a cat"}, slot="gpu") + child = queue.submit( + job_type="video", model="fast", params={}, slot="gpu", + depends_on=parent.id, + auto_params={"imagePath": "$dep.result_paths[0]"}, + ) + + # Simulate parent completing + queue.update_job(parent.id, status="complete", result_paths=["/out/cat.png"]) + + executor = FakeExecutor(result_paths=["/out/video.mp4"]) + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=executor) + worker.tick() + + # Child should now be dispatched with resolved params + assert queue.get(child.id).status in ("running", "complete") + + +def test_worker_fails_dependent_job_when_parent_errors(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + parent = queue.submit(job_type="image", model="zit", params={"prompt": "a cat"}, slot="gpu") + child = queue.submit( + job_type="video", model="fast", params={}, slot="gpu", + depends_on=parent.id, + ) + + # Simulate parent failing + queue.update_job(parent.id, status="error", error="GPU OOM") + + executor = FakeExecutor(result_paths=[]) + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=executor) + worker.tick() + + child_job = queue.get(child.id) + assert child_job.status == "error" + assert "Upstream job" in child_job.error +``` + +Note: You may need to create a `FakeExecutor` class in the test file or in `tests/fakes/` that implements the `JobExecutor` protocol. It should store result_paths and return them from `execute()`. + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && uv run pytest tests/test_queue_worker.py::test_worker_skips_job_with_unresolved_dependency tests/test_queue_worker.py::test_worker_dispatches_dependent_job_after_parent_completes tests/test_queue_worker.py::test_worker_fails_dependent_job_when_parent_errors -v --tb=short` +Expected: FAIL + +**Step 3: Modify QueueWorker to check dependencies** + +In `backend/handlers/queue_worker.py`, modify the job selection logic in `tick()`. Replace direct `queue.next_queued_for_slot(slot)` with a new method that checks `depends_on`: + +```python +def _next_ready_job(self, slot: str) -> QueueJob | None: + for job in self._queue.queued_jobs_for_slot(slot): + if job.depends_on is None: + return job + dep = self._queue.get(job.depends_on) + if dep is None: + return job # Dependency missing, run anyway + if dep.status == "complete": + self._resolve_auto_params(job, dep) + return job + if dep.status in ("error", "cancelled"): + self._queue.update_job( + job.id, + status="error", + error=f"Upstream job {dep.id} failed: {dep.error or dep.status}", + ) + continue + # dep still queued/running — skip this job for now + continue + return None + +def _resolve_auto_params(self, job: QueueJob, dep: QueueJob) -> None: + for key, template in job.auto_params.items(): + if key == "auto_prompt": + continue # Handled by i2v prompt generation later + if template == "$dep.result_paths[0]" and dep.result_paths: + job.params[key] = dep.result_paths[0] +``` + +Note: `queued_jobs_for_slot(slot)` may need to be added to `JobQueue` — it returns all jobs with `status == "queued"` and matching `slot`, ordered by creation. + +**Step 4: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_queue_worker.py -v --tb=short` +Expected: All PASS + +**Step 5: Run all backend tests** + +Run: `cd backend && uv run pytest -v --tb=short` +Expected: All PASS + +**Step 6: Commit** + +```bash +git add backend/handlers/queue_worker.py backend/state/job_queue.py backend/tests/test_queue_worker.py +git commit -m "feat: add dependency checking to QueueWorker dispatch" +``` + +--- + +## Task 8: Batch Completion Detection + +**Files:** +- Modify: `backend/handlers/queue_worker.py` +- Test: `backend/tests/test_queue_worker.py` + +**Step 1: Write the failing test** + +```python +def test_worker_detects_batch_completion(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + j1 = queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b1", batch_index=0) + j2 = queue.submit(job_type="image", model="zit", params={}, slot="gpu", batch_id="b1", batch_index=1) + + completed_batches: list[str] = [] + def on_batch_complete(batch_id: str, jobs: list[QueueJob]) -> None: + completed_batches.append(batch_id) + + executor = FakeExecutor(result_paths=["/out.png"]) + worker = QueueWorker(queue=queue, gpu_executor=executor, api_executor=executor, on_batch_complete=on_batch_complete) + + # Complete both jobs + queue.update_job(j1.id, status="complete", result_paths=["/out/1.png"]) + queue.update_job(j2.id, status="complete", result_paths=["/out/2.png"]) + + worker.tick() + + assert completed_batches == ["b1"] + + # Second tick should NOT re-notify + worker.tick() + assert completed_batches == ["b1"] +``` + +**Step 2: Run test to verify it fails** + +Run: `cd backend && uv run pytest tests/test_queue_worker.py::test_worker_detects_batch_completion -v --tb=short` +Expected: FAIL — on_batch_complete param doesn't exist + +**Step 3: Add batch completion checking to QueueWorker** + +Add `on_batch_complete` callback to `__init__`, track `_notified_batches: set[str]`, and call `_check_batch_completions()` at end of `tick()`: + +```python +def _check_batch_completions(self) -> None: + for batch_id in self._queue.active_batch_ids(): + # active_batch_ids only returns batches with queued/running jobs, so skip + pass + # Instead, check ALL batch ids that have jobs + seen: set[str] = set() + for job in self._queue.all_jobs(): + if job.batch_id and job.batch_id not in self._notified_batches: + seen.add(job.batch_id) + for batch_id in seen: + jobs = self._queue.jobs_for_batch(batch_id) + if all(j.status in ("complete", "error", "cancelled") for j in jobs): + self._notified_batches.add(batch_id) + if self._on_batch_complete: + self._on_batch_complete(batch_id, jobs) +``` + +Note: Need to add `all_jobs()` method to JobQueue if not present (returns `list(self._jobs)`). + +**Step 4: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_queue_worker.py -v --tb=short` +Expected: All PASS + +**Step 5: Commit** + +```bash +git add backend/handlers/queue_worker.py backend/state/job_queue.py backend/tests/test_queue_worker.py +git commit -m "feat: add batch completion detection to QueueWorker" +``` + +--- + +## Task 9: Batch Status & Report Endpoint + +**Files:** +- Modify: `backend/handlers/batch_handler.py` +- Create: `backend/_routes/batch.py` +- Modify: `backend/app_factory.py:79-96` (router registration) +- Modify: `backend/app_handler.py:247-278` (add BatchHandler) +- Test: `backend/tests/test_batch.py` + +**Step 1: Write the failing integration test** + +```python +# backend/tests/test_batch.py — add integration tests using TestClient + +import pytest +from starlette.testclient import TestClient + + +def test_batch_submit_and_status_integration(client: TestClient) -> None: + # Submit a batch + resp = client.post("/api/queue/submit-batch", json={ + "mode": "list", + "target": "local", + "jobs": [ + {"type": "image", "model": "zit", "params": {"prompt": "cat"}}, + {"type": "image", "model": "zit", "params": {"prompt": "dog"}}, + ], + }) + assert resp.status_code == 200 + data = resp.json() + batch_id = data["batch_id"] + assert data["total_jobs"] == 2 + + # Check batch status + resp = client.get(f"/api/queue/batch/{batch_id}/status") + assert resp.status_code == 200 + status = resp.json() + assert status["batch_id"] == batch_id + assert status["total"] == 2 + assert status["queued"] == 2 + assert status["report"] is None # Not yet complete + + +def test_batch_cancel_integration(client: TestClient) -> None: + resp = client.post("/api/queue/submit-batch", json={ + "mode": "list", + "target": "local", + "jobs": [ + {"type": "image", "model": "zit", "params": {"prompt": "cat"}}, + {"type": "image", "model": "zit", "params": {"prompt": "dog"}}, + ], + }) + batch_id = resp.json()["batch_id"] + + resp = client.post(f"/api/queue/batch/{batch_id}/cancel") + assert resp.status_code == 200 + + resp = client.get(f"/api/queue/batch/{batch_id}/status") + status = resp.json() + assert status["cancelled"] == 2 + + +def test_batch_retry_failed_integration(client: TestClient) -> None: + resp = client.post("/api/queue/submit-batch", json={ + "mode": "list", + "target": "local", + "jobs": [ + {"type": "image", "model": "zit", "params": {"prompt": "cat"}}, + ], + }) + batch_id = resp.json()["batch_id"] + job_id = resp.json()["job_ids"][0] + + # Simulate failure via direct queue manipulation (in real test, use handler) + # This needs the test_state fixture to access the queue + # For now, test the route exists and returns 200 + resp = client.post(f"/api/queue/batch/{batch_id}/retry-failed") + assert resp.status_code == 200 +``` + +Note: These tests require the `client` fixture from `conftest.py`. You may need to add `batch_handler` to `AppHandler.__init__()` and register the batch router in `app_factory.py`. + +**Step 2: Run tests to verify they fail** + +Run: `cd backend && uv run pytest tests/test_batch.py::test_batch_submit_and_status_integration -v --tb=short` +Expected: FAIL — route doesn't exist + +**Step 3: Add batch_status and batch report methods to BatchHandler** + +```python +def get_batch_status(self, batch_id: str, queue: JobQueue) -> BatchStatusResponse: + jobs = queue.jobs_for_batch(batch_id) + if not jobs: + raise HTTPError(status_code=404, detail=f"Batch {batch_id} not found") + + completed = sum(1 for j in jobs if j.status == "complete") + failed = sum(1 for j in jobs if j.status == "error") + running = sum(1 for j in jobs if j.status == "running") + cancelled = sum(1 for j in jobs if j.status == "cancelled") + queued = sum(1 for j in jobs if j.status == "queued") + + report = None + if queued == 0 and running == 0: + report = self._build_report(batch_id, jobs) + + return BatchStatusResponse( + batch_id=batch_id, + total=len(jobs), + completed=completed, + failed=failed, + running=running, + queued=queued, + jobs=[self._job_to_response(j) for j in jobs], + report=report, + ) + +def cancel_batch(self, batch_id: str, queue: JobQueue) -> None: + for job in queue.jobs_for_batch(batch_id): + if job.status == "queued": + queue.update_job(job.id, status="cancelled") + +def retry_failed(self, batch_id: str, queue: JobQueue) -> BatchSubmitResponse: + failed_jobs = [j for j in queue.jobs_for_batch(batch_id) if j.status == "error"] + new_ids: list[str] = [] + for job in failed_jobs: + new_job = queue.submit( + job_type=job.type, model=job.model, params=job.params, + slot=job.slot, batch_id=batch_id, batch_index=job.batch_index, + tags=job.tags, + ) + new_ids.append(new_job.id) + return BatchSubmitResponse(batch_id=batch_id, job_ids=new_ids, total_jobs=len(new_ids)) +``` + +**Step 4: Create batch routes** + +Create `backend/_routes/batch.py`: + +```python +from fastapi import APIRouter, Depends +from backend.api_types import BatchSubmitRequest, BatchSubmitResponse, BatchStatusResponse +from backend.state.deps import get_state_service + +batch_router = APIRouter(prefix="/api/queue", tags=["batch"]) + +@batch_router.post("/submit-batch", response_model=BatchSubmitResponse) +def submit_batch(request: BatchSubmitRequest, handler=Depends(get_state_service)): + return handler.batch.submit_batch(request, handler.job_queue) + +@batch_router.get("/batch/{batch_id}/status", response_model=BatchStatusResponse) +def batch_status(batch_id: str, handler=Depends(get_state_service)): + return handler.batch.get_batch_status(batch_id, handler.job_queue) + +@batch_router.post("/batch/{batch_id}/cancel") +def batch_cancel(batch_id: str, handler=Depends(get_state_service)): + handler.batch.cancel_batch(batch_id, handler.job_queue) + return {"status": "cancelled"} + +@batch_router.post("/batch/{batch_id}/retry-failed", response_model=BatchSubmitResponse) +def batch_retry(batch_id: str, handler=Depends(get_state_service)): + return handler.batch.retry_failed(batch_id, handler.job_queue) +``` + +**Step 5: Wire BatchHandler into AppHandler** + +In `backend/app_handler.py`, add `from backend.handlers.batch_handler import BatchHandler` and `self.batch = BatchHandler()` after the other handler instantiations (around line 267). + +**Step 6: Register batch router in app_factory.py** + +In `backend/app_factory.py`, add `from backend._routes.batch import batch_router` and `app.include_router(batch_router)` alongside the other router registrations (around line 89). + +**Step 7: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_batch.py -v --tb=short` +Expected: All PASS + +**Step 8: Run all backend tests** + +Run: `cd backend && uv run pytest -v --tb=short` +Expected: All PASS + +**Step 9: Commit** + +```bash +git add backend/_routes/batch.py backend/handlers/batch_handler.py backend/app_handler.py backend/app_factory.py backend/tests/test_batch.py +git commit -m "feat: add batch routes — submit, status, cancel, retry-failed" +``` + +--- + +## Task 10: I2V Motion Prompt Generation + +**Files:** +- Modify: `backend/handlers/enhance_prompt_handler.py:102-141` +- Test: `backend/tests/test_enhance_prompt.py` + +**Step 1: Write the failing test** + +```python +# backend/tests/test_enhance_prompt.py — add + +def test_enhance_prompt_i2v_motion_mode(client: TestClient, test_state: AppHandler) -> None: + test_state.app_state.app_settings.gemini_api_key = "test-key" + + # Queue a fake Gemini response for image caption + motion prompt + fake_http = test_state.services.http_client # FakeHTTPClient + # First call: image caption response + fake_http.queue_post_response(FakeResponse( + status_code=200, + json_payload={ + "candidates": [{"content": {"parts": [{"text": "A serene mountain landscape at golden hour with a lake in the foreground."}]}}] + }, + )) + # Second call: motion prompt generation + fake_http.queue_post_response(FakeResponse( + status_code=200, + json_payload={ + "candidates": [{"content": {"parts": [{"text": "The camera slowly pans across the mountain range as golden light shifts. Gentle ripples spread across the lake surface while distant birds glide overhead."}]}}] + }, + )) + + resp = client.post("/api/prompt/enhance", json={ + "prompt": "", + "level": "i2v_motion", + "image_path": "/path/to/landscape.png", + }) + assert resp.status_code == 200 + result = resp.json() + assert "camera" in result["enhanced_prompt"].lower() or "pan" in result["enhanced_prompt"].lower() +``` + +Note: Adjust based on actual enhance prompt route path and response schema. Check `backend/_routes/` for the enhance prompt route. + +**Step 2: Run test to verify it fails** + +Run: `cd backend && uv run pytest tests/test_enhance_prompt.py::test_enhance_prompt_i2v_motion_mode -v --tb=short` +Expected: FAIL — `i2v_motion` level not recognized + +**Step 3: Add i2v_motion mode to EnhancePromptHandler** + +In `backend/handlers/enhance_prompt_handler.py`, add a new branch in the enhance method for `level == "i2v_motion"`. Add a dedicated method `_enhance_i2v_motion()`: + +```python +I2V_MOTION_SYSTEM_PROMPT = """You are an expert cinematographer. Given a description of a still image, generate a video motion prompt that describes what happens next. + +Rules: +- Do NOT describe what the image already shows — describe what HAPPENS NEXT +- Structure: subject motion → camera movement → environmental dynamics +- Use specific cinematography terms: dolly, pan, tilt, orbit, truck, rack focus, crane +- Add motion intensity qualifiers: subtle, steady, dramatic, gentle +- Animate empty regions (sky, water, foliage) to prevent frozen areas +- Use present tense, single flowing paragraph, 4-6 sentences +- Focus on plausible, physics-grounded motion""" + +async def _enhance_i2v_motion(self, image_path: str) -> str: + # Step 1: Caption the image + caption = await self._caption_image(image_path) + # Step 2: Generate motion prompt from caption + motion_prompt = await self._generate_motion_prompt(caption) + return motion_prompt +``` + +The implementation depends on the existing Gemini calling pattern (lines 102-141). Follow the same HTTP client pattern but with the i2v system prompt. + +**Step 4: Update EnhancePromptRequest to accept image_path** + +In `api_types.py`, add `image_path: str | None = None` to the enhance prompt request model. + +**Step 5: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_enhance_prompt.py -v --tb=short` +Expected: All PASS + +**Step 6: Commit** + +```bash +git add backend/handlers/enhance_prompt_handler.py backend/api_types.py backend/tests/test_enhance_prompt.py +git commit -m "feat: add i2v_motion prompt generation mode" +``` + +--- + +## Task 11: Wire I2V Auto-Prompt into QueueWorker + +**Files:** +- Modify: `backend/handlers/queue_worker.py` +- Test: `backend/tests/test_queue_worker.py` + +**Step 1: Write the failing test** + +```python +def test_worker_generates_i2v_prompt_for_auto_prompt_job(tmp_path: Path) -> None: + queue = JobQueue(persistence_path=tmp_path / "queue.json") + + # Parent image job already complete + parent = queue.submit(job_type="image", model="zit", params={"prompt": "a landscape"}, slot="gpu", batch_id="b1") + queue.update_job(parent.id, status="complete", result_paths=["/out/landscape.png"]) + + # Child video job with auto_prompt + child = queue.submit( + job_type="video", model="fast", params={"duration": "4"}, slot="gpu", + batch_id="b1", depends_on=parent.id, + auto_params={"imagePath": "$dep.result_paths[0]", "auto_prompt": "true"}, + ) + + # Create a fake enhance handler that returns a canned motion prompt + fake_enhance = FakeEnhanceHandler(result="The camera pans across mountains.") + executor = FakeExecutor(result_paths=["/out/video.mp4"]) + worker = QueueWorker( + queue=queue, gpu_executor=executor, api_executor=executor, + enhance_handler=fake_enhance, + ) + worker.tick() + + # Verify the child job got the auto-generated prompt + child_job = queue.get(child.id) + assert child_job.params.get("prompt") == "The camera pans across mountains." + assert child_job.params.get("imagePath") == "/out/landscape.png" +``` + +Note: `FakeEnhanceHandler` is a simple class with an `enhance_i2v_motion(image_path)` method returning a canned string. + +**Step 2: Run test to verify it fails** + +Run: `cd backend && uv run pytest tests/test_queue_worker.py::test_worker_generates_i2v_prompt_for_auto_prompt_job -v --tb=short` +Expected: FAIL + +**Step 3: Add auto-prompt resolution to QueueWorker** + +In `_resolve_auto_params()`, add handling for `auto_prompt`: + +```python +def _resolve_auto_params(self, job: QueueJob, dep: QueueJob) -> None: + for key, template in list(job.auto_params.items()): + if template == "$dep.result_paths[0]" and dep.result_paths: + job.params[key] = dep.result_paths[0] + + if job.auto_params.get("auto_prompt") == "true" and self._enhance_handler: + image_path = job.params.get("imagePath", dep.result_paths[0] if dep.result_paths else "") + if image_path: + motion_prompt = self._enhance_handler.enhance_i2v_motion(str(image_path)) + job.params["prompt"] = motion_prompt +``` + +Add `enhance_handler` as an optional `__init__` parameter. + +**Step 4: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_queue_worker.py -v --tb=short` +Expected: All PASS + +**Step 5: Commit** + +```bash +git add backend/handlers/queue_worker.py backend/tests/test_queue_worker.py +git commit -m "feat: wire i2v auto-prompt generation into QueueWorker" +``` + +--- + +## Task 12: Add batchSoundEnabled Setting + +**Files:** +- Modify: `backend/state/app_settings.py:62-91` +- Test: `backend/tests/test_settings.py` + +**Step 1: Write the failing test** + +```python +def test_batch_sound_enabled_default(client: TestClient) -> None: + resp = client.get("/api/settings") + assert resp.status_code == 200 + data = resp.json() + assert data["batchSoundEnabled"] is True # Default on +``` + +**Step 2: Run test to verify it fails** + +Run: `cd backend && uv run pytest tests/test_settings.py::test_batch_sound_enabled_default -v --tb=short` +Expected: FAIL — field doesn't exist + +**Step 3: Add field to AppSettings** + +In `backend/state/app_settings.py`, add to `AppSettings` class: + +```python +batch_sound_enabled: bool = True +``` + +Add corresponding field to `SettingsResponse`: + +```python +batch_sound_enabled: bool +``` + +Update `to_settings_response()` to include it. + +**Step 4: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_settings.py -v --tb=short` +Expected: All PASS + +**Step 5: Commit** + +```bash +git add backend/state/app_settings.py backend/tests/test_settings.py +git commit -m "feat: add batchSoundEnabled setting (default true)" +``` + +--- + +## Task 13: Extend QueueJobResponse with Batch Fields + +**Files:** +- Modify: `backend/api_types.py:240-251` +- Modify: `backend/_routes/queue.py:29-41` (status mapping) +- Test: `backend/tests/test_batch.py` + +**Step 1: Write the failing test** + +```python +def test_queue_status_includes_batch_fields(client: TestClient) -> None: + # Submit a batch + resp = client.post("/api/queue/submit-batch", json={ + "mode": "list", + "target": "local", + "jobs": [{"type": "image", "model": "zit", "params": {"prompt": "cat"}}], + }) + batch_id = resp.json()["batch_id"] + + # Check regular queue status + resp = client.get("/api/queue/status") + data = resp.json() + job = data["jobs"][0] + assert job["batch_id"] == batch_id + assert job["batch_index"] == 0 + assert "batch:" in job["tags"][0] +``` + +**Step 2: Run test to verify it fails** + +Run: `cd backend && uv run pytest tests/test_batch.py::test_queue_status_includes_batch_fields -v --tb=short` +Expected: FAIL — fields not in response + +**Step 3: Add batch fields to QueueJobResponse** + +In `backend/api_types.py`, extend `QueueJobResponse`: + +```python +class QueueJobResponse(BaseModel): + id: str + type: str + model: str + params: dict[str, object] = {} + status: str + slot: str + progress: int + phase: str + result_paths: list[str] = [] + error: str | None = None + created_at: str = "" + # Batch fields + batch_id: str | None = None + batch_index: int = 0 + tags: list[str] = [] +``` + +Update the queue status route mapping in `backend/_routes/queue.py` to include the new fields when converting `QueueJob` to `QueueJobResponse`. + +**Step 4: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_batch.py -v --tb=short` +Expected: All PASS + +**Step 5: Commit** + +```bash +git add backend/api_types.py backend/_routes/queue.py backend/tests/test_batch.py +git commit -m "feat: include batch fields in QueueJobResponse" +``` + +--- + +## Task 14: Frontend — Batch Types and API Client + +**Files:** +- Modify: `frontend/hooks/use-generation.ts:5-17` (QueueJob interface) +- Create: `frontend/lib/batch-api.ts` +- Create: `frontend/types/batch.ts` + +**Step 1: Add batch types** + +Create `frontend/types/batch.ts`: + +```typescript +export interface BatchJobItem { + type: 'video' | 'image' + model: string + params: Record +} + +export interface SweepAxis { + param: string + values: unknown[] + mode: 'replace' | 'search_replace' + search?: string +} + +export interface SweepDefinition { + base_type: 'video' | 'image' + base_model: string + base_params: Record + axes: SweepAxis[] +} + +export interface PipelineStep { + type: 'video' | 'image' + model: string + params: Record + auto_prompt: boolean +} + +export interface PipelineDefinition { + steps: PipelineStep[] +} + +export interface BatchSubmitRequest { + mode: 'list' | 'sweep' | 'pipeline' + target: 'local' | 'cloud' + jobs?: BatchJobItem[] + sweep?: SweepDefinition + pipeline?: PipelineDefinition +} + +export interface BatchSubmitResponse { + batch_id: string + job_ids: string[] + total_jobs: number +} + +export interface BatchReport { + batch_id: string + total: number + succeeded: number + failed: number + cancelled: number + duration_seconds: number + avg_job_seconds: number + result_paths: string[] + failed_indices: number[] + sweep_axes: string[] | null +} + +export interface BatchStatusResponse { + batch_id: string + total: number + completed: number + failed: number + running: number + queued: number + jobs: QueueJob[] + report: BatchReport | null +} +``` + +**Step 2: Create batch API client** + +Create `frontend/lib/batch-api.ts`: + +```typescript +import type { BatchSubmitRequest, BatchSubmitResponse, BatchStatusResponse } from '@/types/batch' + +const getBaseUrl = async (): Promise => { + if (window.electronAPI) { + return await window.electronAPI.getBackendUrl() + } + return 'http://localhost:8000' +} + +export async function submitBatch(request: BatchSubmitRequest): Promise { + const base = await getBaseUrl() + const resp = await fetch(`${base}/api/queue/submit-batch`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(request), + }) + if (!resp.ok) throw new Error(`Batch submit failed: ${resp.status}`) + return resp.json() +} + +export async function getBatchStatus(batchId: string): Promise { + const base = await getBaseUrl() + const resp = await fetch(`${base}/api/queue/batch/${batchId}/status`) + if (!resp.ok) throw new Error(`Batch status failed: ${resp.status}`) + return resp.json() +} + +export async function cancelBatch(batchId: string): Promise { + const base = await getBaseUrl() + await fetch(`${base}/api/queue/batch/${batchId}/cancel`, { method: 'POST' }) +} + +export async function retryFailedBatch(batchId: string): Promise { + const base = await getBaseUrl() + const resp = await fetch(`${base}/api/queue/batch/${batchId}/retry-failed`, { method: 'POST' }) + if (!resp.ok) throw new Error(`Batch retry failed: ${resp.status}`) + return resp.json() +} +``` + +**Step 3: Update QueueJob interface in use-generation.ts** + +Add batch fields to the existing `QueueJob` interface: + +```typescript +export interface QueueJob { + // ... existing fields + batch_id: string | null + batch_index: number + tags: string[] +} +``` + +**Step 4: Commit** + +```bash +git add frontend/types/batch.ts frontend/lib/batch-api.ts frontend/hooks/use-generation.ts +git commit -m "feat: add frontend batch types and API client" +``` + +--- + +## Task 15: Frontend — useBatch Hook + +**Files:** +- Create: `frontend/hooks/use-batch.ts` + +**Step 1: Create the hook** + +```typescript +import { useState, useRef, useCallback, useEffect } from 'react' +import type { BatchSubmitRequest, BatchStatusResponse, BatchReport } from '@/types/batch' +import { submitBatch, getBatchStatus, cancelBatch, retryFailedBatch } from '@/lib/batch-api' + +export interface UseBatchReturn { + activeBatchId: string | null + batchStatus: BatchStatusResponse | null + batchReport: BatchReport | null + isRunning: boolean + submit: (request: BatchSubmitRequest) => Promise + cancel: () => Promise + retryFailed: () => Promise + reset: () => void +} + +export function useBatch(): UseBatchReturn { + const [activeBatchId, setActiveBatchId] = useState(null) + const [batchStatus, setBatchStatus] = useState(null) + const [batchReport, setBatchReport] = useState(null) + const pollRef = useRef | null>(null) + + const stopPolling = useCallback(() => { + if (pollRef.current) { + clearInterval(pollRef.current) + pollRef.current = null + } + }, []) + + const startPolling = useCallback((batchId: string) => { + stopPolling() + pollRef.current = setInterval(async () => { + try { + const status = await getBatchStatus(batchId) + setBatchStatus(status) + if (status.report) { + setBatchReport(status.report) + stopPolling() + // Play completion sound + playCompletionSound() + } + } catch { + // Ignore polling errors + } + }, 1000) // Poll every 1s for batches (less aggressive than single job) + }, [stopPolling]) + + const submit = useCallback(async (request: BatchSubmitRequest) => { + const response = await submitBatch(request) + setActiveBatchId(response.batch_id) + setBatchReport(null) + startPolling(response.batch_id) + }, [startPolling]) + + const cancel = useCallback(async () => { + if (activeBatchId) { + await cancelBatch(activeBatchId) + } + }, [activeBatchId]) + + const retryFailed = useCallback(async () => { + if (activeBatchId) { + await retryFailedBatch(activeBatchId) + startPolling(activeBatchId) + } + }, [activeBatchId, startPolling]) + + const reset = useCallback(() => { + stopPolling() + setActiveBatchId(null) + setBatchStatus(null) + setBatchReport(null) + }, [stopPolling]) + + useEffect(() => stopPolling, [stopPolling]) + + const isRunning = batchStatus !== null && batchStatus.report === null + + return { activeBatchId, batchStatus, batchReport, isRunning, submit, cancel, retryFailed, reset } +} + +function playCompletionSound(): void { + try { + const audio = new Audio('/sounds/batch-complete.mp3') + audio.volume = 0.5 + audio.play().catch(() => {}) // Ignore autoplay restrictions + } catch { + // Sound not critical + } +} +``` + +**Step 2: Commit** + +```bash +git add frontend/hooks/use-batch.ts +git commit -m "feat: add useBatch hook with polling and completion sound" +``` + +--- + +## Task 16: Frontend — Batch Builder Modal (List Tab) + +**Files:** +- Create: `frontend/components/BatchBuilderModal.tsx` +- Modify: `frontend/views/GenSpace.tsx` (add batch builder button) + +**Step 1: Create BatchBuilderModal with List tab** + +Create `frontend/components/BatchBuilderModal.tsx` with: +- Modal overlay with three tabs: List, Import, Grid +- List tab: table with prompt, type, model, LoRA fields +- Add Row / Delete Row / Duplicate Row buttons +- Per-batch target selector (Local / Cloud) +- Pipeline toggle: "Also generate video from each image" +- Submit button that calls `useBatch().submit()` + +Keep the component focused — the List tab only. Import and Grid tabs render "Coming soon" placeholders for now (implemented in Tasks 17-18). + +Follow the existing modal patterns in the codebase (check `frontend/components/SettingsModal.tsx` for the pattern). + +Use Tailwind classes matching the app's existing dark theme. Refer to the global design standards in the CLAUDE.md for OKLCH colors and component patterns. + +**Step 2: Add "Batch" button to GenSpace** + +In `frontend/views/GenSpace.tsx`, add a button near the generate button that opens the BatchBuilderModal. Icon: grid/layers icon from Lucide. + +**Step 3: Commit** + +```bash +git add frontend/components/BatchBuilderModal.tsx frontend/views/GenSpace.tsx +git commit -m "feat: add batch builder modal with list tab" +``` + +--- + +## Task 17: Frontend — Batch Builder Import Tab + +**Files:** +- Modify: `frontend/components/BatchBuilderModal.tsx` +- Create: `frontend/lib/batch-import.ts` + +**Step 1: Create CSV/JSON parser** + +Create `frontend/lib/batch-import.ts`: + +```typescript +import type { BatchJobItem } from '@/types/batch' + +export function parseCSV(text: string): BatchJobItem[] { + const lines = text.trim().split('\n') + if (lines.length < 2) return [] + const headers = lines[0].split(',').map(h => h.trim().toLowerCase()) + const promptIdx = headers.indexOf('prompt') + if (promptIdx === -1) throw new Error('CSV must have a "prompt" column') + + return lines.slice(1).map(line => { + const cols = parseCSVLine(line) + const params: Record = {} + headers.forEach((h, i) => { + if (h !== 'type' && h !== 'model' && cols[i]?.trim()) { + params[h] = inferType(cols[i].trim()) + } + }) + return { + type: (cols[headers.indexOf('type')]?.trim() as 'video' | 'image') || 'image', + model: cols[headers.indexOf('model')]?.trim() || 'zit', + params, + } + }) +} + +export function parseJSON(text: string): BatchJobItem[] { + const data = JSON.parse(text) + const defaults = data.defaults || {} + return (data.jobs || []).map((job: Record) => ({ + type: job.type || defaults.type || 'image', + model: job.model || defaults.model || 'zit', + params: { ...defaults, ...job.params, prompt: job.prompt || '' }, + })) +} +``` + +Include a `parseCSVLine` helper that handles quoted fields with commas. Include an `inferType` helper that converts numeric strings to numbers. + +**Step 2: Wire into Import tab** + +Add textarea for pasting, file upload button, preview table, validation error display. + +**Step 3: Commit** + +```bash +git add frontend/lib/batch-import.ts frontend/components/BatchBuilderModal.tsx +git commit -m "feat: add CSV/JSON import tab to batch builder" +``` + +--- + +## Task 18: Frontend — Batch Builder Grid Tab (Sweeps) + +**Files:** +- Modify: `frontend/components/BatchBuilderModal.tsx` + +**Step 1: Build Grid tab UI** + +Add to the Grid tab: +- Base prompt + settings section (inherits from current GenSpace settings) +- Up to 3 axis rows, each with: + - Param selector dropdown (loraWeight, loraPath, prompt, numSteps, seed, cameraMotion, model) + - Values input field (comma-separated, or range syntax `start-end:count`) + - Remove axis button +- "Add Axis" button +- Live preview: "{X} x {Y} x {Z} = {total} jobs" with estimated time +- Range parser: `0.3-1.0:8` → `[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]` + +**Step 2: Wire to sweep mode submission** + +When user clicks Generate, build a `BatchSubmitRequest` with `mode: "sweep"` and the configured axes. Call `useBatch().submit()`. + +**Step 3: Commit** + +```bash +git add frontend/components/BatchBuilderModal.tsx +git commit -m "feat: add grid sweep builder tab to batch builder" +``` + +--- + +## Task 19: Frontend — Batch Queue Panel + +**Files:** +- Modify: Wherever the queue/progress panel is rendered (likely in `GenSpace.tsx` or a queue component) + +**Step 1: Add batch grouping to queue display** + +When rendering the jobs list from `useGeneration().jobs`, group jobs by `batch_id`: +- Non-batch jobs render as today +- Batch jobs render under a collapsible header: "Batch {batch_id}: {completed}/{total} complete" +- Header shows aggregate progress bar +- Expand to see individual job rows +- Per-batch actions: Cancel Batch, Retry Failed (visible when batch has errors) + +**Step 2: Add completion toast** + +When `useBatch().batchReport` becomes non-null, show a toast notification: +- "Batch complete — {succeeded}/{total} succeeded ({failed} failed) in {duration}" +- "View Results" button that navigates to Gallery with `?batch={batch_id}` filter + +**Step 3: Commit** + +```bash +git add frontend/components/ frontend/views/GenSpace.tsx +git commit -m "feat: add batch grouping and completion toast to queue panel" +``` + +--- + +## Task 20: Frontend — Gallery Batch Filtering + +**Files:** +- Modify: `frontend/views/Gallery.tsx` + +**Step 1: Add batch tag filter** + +Read `tags` from gallery items (already available via queue job results). Add a filter dropdown or pill selector that shows available batch tags. When selected, filter gallery items to those matching the tag. + +Support URL param `?batch={batch_id}` for direct linking from completion toast. + +**Step 2: Commit** + +```bash +git add frontend/views/Gallery.tsx +git commit -m "feat: add batch tag filtering to gallery" +``` + +--- + +## Task 21: Add Completion Sound Asset + +**Files:** +- Create: `public/sounds/batch-complete.mp3` + +**Step 1: Source or generate a short completion chime** + +Use a royalty-free chime sound (~1 second, pleasant, not jarring). Place at `public/sounds/batch-complete.mp3`. + +If generating: use a simple ascending two-note chime (C5 → E5, ~0.8s, sine wave with decay). + +**Step 2: Commit** + +```bash +git add public/sounds/batch-complete.mp3 +git commit -m "feat: add batch completion sound asset" +``` + +--- + +## Task 22: Typecheck and Full Test Pass + +**Files:** All modified files + +**Step 1: Run Python typecheck** + +Run: `cd backend && uv run pyright` +Expected: No new errors. Fix any type issues introduced. + +**Step 2: Run all backend tests** + +Run: `cd backend && uv sync --frozen --extra test --extra dev && uv run pytest -v --tb=short` +Expected: All PASS + +**Step 3: Run TypeScript typecheck** + +Run: `pnpm typecheck:ts` +Expected: No errors. Fix any TS issues. + +**Step 4: Run frontend build** + +Run: `pnpm build:frontend` +Expected: Build succeeds. + +**Step 5: Fix any issues found, commit** + +```bash +git add -A +git commit -m "chore: fix typecheck and test issues from bulk generation" +``` + +--- + +## Summary + +| Task | Component | Commits | +|------|-----------|---------| +| 1-2 | QueueJob batch fields + helpers | 2 | +| 3 | Batch API types | 1 | +| 4-6 | BatchHandler (list, sweep, pipeline) | 3 | +| 7-8 | QueueWorker dependency + completion | 2 | +| 9 | Batch routes + wiring | 1 | +| 10-11 | I2V auto-prompt generation | 2 | +| 12 | batchSoundEnabled setting | 1 | +| 13 | QueueJobResponse batch fields | 1 | +| 14-15 | Frontend types, API client, hook | 2 | +| 16-18 | Batch builder modal (3 tabs) | 3 | +| 19 | Queue panel batch grouping | 1 | +| 20 | Gallery batch filtering | 1 | +| 21 | Sound asset | 1 | +| 22 | Typecheck + test pass | 1 | +| **Total** | | **22 commits** | diff --git a/docs/plans/2026-03-09-gpu-optimizations-r2-storage.md b/docs/plans/2026-03-09-gpu-optimizations-r2-storage.md new file mode 100644 index 00000000..95b024cb --- /dev/null +++ b/docs/plans/2026-03-09-gpu-optimizations-r2-storage.md @@ -0,0 +1,957 @@ +# GPU Optimizations + R2 Storage Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Dramatically improve GPU generation speed and VRAM efficiency, add R2 cloud upload, commit+push pending security work. + +**Architecture:** Add a `gpu_optimizations` service module that monkey-patches the loaded LTX transformer at pipeline creation time. FFN chunking reduces peak VRAM by splitting feedforward along the sequence dimension. TeaCache wraps the denoising function to skip redundant transformer passes. R2 upload uses boto3 S3-compatible client post-generation. All optimizations are toggleable via AppSettings. + +**Tech Stack:** PyTorch, ltx_core (monkey-patching), boto3 (R2/S3), numpy (TeaCache polynomial) + +--- + +### Task 1: Commit + push security fixes and README + +These changes already exist in the working tree from a previous session. + +**Step 1: Review staged changes** + +Run: `cd D:/git/directors-desktop && git status` + +**Step 2: Commit security fixes** + +Run: +```bash +git add backend/services/palette_sync_client/palette_sync_client_impl.py electron/main.ts electron/ipc/file-handlers.ts README.md +git commit -m "fix: externalize Supabase credentials, redact auth tokens in logs, update README with new features" +``` + +**Step 3: Push** + +Run: `git push origin main` + +--- + +### Task 2: Add FFN Chunked Feedforward optimization + +Reduces peak VRAM by 8x in feedforward layers. Mathematically identical output. + +**Files:** +- Create: `backend/services/gpu_optimizations/__init__.py` +- Create: `backend/services/gpu_optimizations/ffn_chunking.py` +- Test: `backend/tests/test_ffn_chunking.py` + +**Step 1: Create the module** + +Create `backend/services/gpu_optimizations/__init__.py` (empty file). + +Create `backend/services/gpu_optimizations/ffn_chunking.py`: + +```python +"""Chunked feedforward optimization for LTX transformer. + +Splits FeedForward.forward along the sequence dimension (dim=1) to reduce +peak VRAM. Output is mathematically identical to unchunked forward — +FeedForward is pointwise along the sequence dimension so chunking is lossless. + +Reference: RandomInternetPreson/ComfyUI_LTX-2_VRAM_Memory_Management (V3.1) +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import Any + +import torch + +logger = logging.getLogger(__name__) + +_MIN_SEQ_PER_CHUNK = 100 # skip chunking for short sequences + + +def _make_chunked_forward( + original_forward: Callable[[torch.Tensor], torch.Tensor], + num_chunks: int, +) -> Callable[[torch.Tensor], torch.Tensor]: + """Return a drop-in replacement for FeedForward.forward that chunks along dim=1.""" + + def chunked_forward(x: torch.Tensor) -> torch.Tensor: + if x.dim() != 3: + return original_forward(x) + + seq_len = x.shape[1] + if seq_len < num_chunks * _MIN_SEQ_PER_CHUNK: + return original_forward(x) + + chunk_size = (seq_len + num_chunks - 1) // num_chunks + outputs: list[torch.Tensor] = [] + for start in range(0, seq_len, chunk_size): + end = min(start + chunk_size, seq_len) + outputs.append(original_forward(x[:, start:end, :])) + return torch.cat(outputs, dim=1) + + return chunked_forward + + +def patch_ffn_chunking(model: torch.nn.Module, num_chunks: int = 8) -> int: + """Monkey-patch all FeedForward modules in *model* to use chunked forward. + + Returns the number of modules patched. + """ + patched = 0 + for name, module in model.named_modules(): + if not hasattr(module, "net"): + continue + if not isinstance(module.net, torch.nn.Sequential): + continue + # Match FeedForward modules (they live at .ff and .audio_ff in each block) + if not (name.endswith(".ff") or name.endswith(".audio_ff")): + continue + + original = module.forward + module.forward = _make_chunked_forward(original, num_chunks) # type: ignore[assignment] + patched += 1 + + if patched: + logger.info("FFN chunking: patched %d feedforward modules (chunks=%d)", patched, num_chunks) + return patched +``` + +**Step 2: Write test** + +Create `backend/tests/test_ffn_chunking.py`: + +```python +"""Tests for FFN chunked feedforward optimization.""" + +from __future__ import annotations + +import torch + +from services.gpu_optimizations.ffn_chunking import _make_chunked_forward, patch_ffn_chunking + + +class _FakeFeedForward(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.net = torch.nn.Sequential( + torch.nn.Linear(dim, dim * 4), + torch.nn.GELU(), + torch.nn.Linear(dim * 4, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class _FakeTransformerBlock(torch.nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.ff = _FakeFeedForward(dim) + self.audio_ff = _FakeFeedForward(dim) + + +class _FakeTransformer(torch.nn.Module): + def __init__(self, dim: int, num_blocks: int) -> None: + super().__init__() + self.blocks = torch.nn.ModuleList([_FakeTransformerBlock(dim) for _ in range(num_blocks)]) + + +def test_chunked_forward_matches_original() -> None: + dim = 32 + ff = _FakeFeedForward(dim) + x = torch.randn(1, 2000, dim) + + original_out = ff(x) + chunked_fn = _make_chunked_forward(ff.forward, num_chunks=4) + chunked_out = chunked_fn(x) + + assert torch.allclose(original_out, chunked_out, atol=1e-5) + + +def test_chunked_forward_skips_short_sequences() -> None: + dim = 16 + ff = _FakeFeedForward(dim) + x = torch.randn(1, 50, dim) # too short to chunk + + original_out = ff(x) + chunked_fn = _make_chunked_forward(ff.forward, num_chunks=4) + chunked_out = chunked_fn(x) + + assert torch.allclose(original_out, chunked_out, atol=1e-5) + + +def test_patch_ffn_chunking_patches_correct_modules() -> None: + model = _FakeTransformer(dim=16, num_blocks=3) + count = patch_ffn_chunking(model, num_chunks=4) + assert count == 6 # 3 blocks x 2 (ff + audio_ff) + + +def test_patch_ffn_chunking_zero_when_no_match() -> None: + model = torch.nn.Linear(16, 16) + count = patch_ffn_chunking(model, num_chunks=4) + assert count == 0 +``` + +**Step 3: Run tests** + +Run: `cd backend && uv run pytest tests/test_ffn_chunking.py -v --tb=short` +Expected: 4 tests PASS + +**Step 4: Commit** + +```bash +git add backend/services/gpu_optimizations/ backend/tests/test_ffn_chunking.py +git commit -m "feat: add FFN chunked feedforward to reduce peak VRAM by up to 8x" +``` + +--- + +### Task 3: Add TeaCache optimization + +Caches transformer residuals between denoising steps to skip redundant computation. 1.6-2.1x speedup. + +**Files:** +- Create: `backend/services/gpu_optimizations/tea_cache.py` +- Test: `backend/tests/test_tea_cache.py` + +**Step 1: Create TeaCache module** + +Create `backend/services/gpu_optimizations/tea_cache.py`: + +```python +"""TeaCache: Timestep-Aware Caching for diffusion denoising loops. + +Wraps a denoising function to skip transformer forward passes when the +timestep embedding hasn't changed significantly from the previous step. +First and last steps are always computed. + +Reference: ali-vilab/TeaCache (TeaCache4LTX-Video) +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, replace +from typing import Any + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + +# Polynomial fitted to LTX-Video noise schedule for rescaling relative L1 distance +_RESCALE_COEFFICIENTS = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03] +_rescale_poly = np.poly1d(_RESCALE_COEFFICIENTS) + + +@dataclass +class TeaCacheState: + """Mutable state held across denoising steps.""" + accumulated_distance: float = 0.0 + previous_residual: torch.Tensor | None = None + previous_modulated_input: torch.Tensor | None = None + step_count: int = 0 + skipped: int = 0 + computed: int = 0 + + +def wrap_denoise_fn_with_tea_cache( + denoise_fn: Any, + num_steps: int, + threshold: float, +) -> Any: + """Wrap a denoising function with TeaCache. + + The wrapped function has the same signature as the original: + denoise_fn(video_state, audio_state, sigmas, step_index) + -> (denoised_video, denoised_audio) + + When the relative L1 distance of the timestep-modulated input is below + *threshold*, the previous residual is reused instead of calling the + transformer. + + Args: + denoise_fn: Original denoising function. + num_steps: Total number of denoising steps (len(sigmas) - 1). + threshold: Caching threshold. 0 disables. 0.03 = balanced. 0.05 = aggressive. + """ + if threshold <= 0: + return denoise_fn + + state = TeaCacheState() + + def cached_denoise( + video_state: Any, + audio_state: Any, + sigmas: torch.Tensor, + step_index: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Always compute first and last steps + if step_index == 0 or step_index == num_steps - 1: + should_compute = True + state.accumulated_distance = 0.0 + elif state.previous_modulated_input is not None: + # Estimate change using video_state latent as proxy for modulated input + current = video_state.latent + prev = state.previous_modulated_input + rel_l1 = ((current - prev).abs().mean() / prev.abs().mean().clamp(min=1e-8)).item() + rescaled = float(_rescale_poly(rel_l1)) + state.accumulated_distance += rescaled + + if state.accumulated_distance < threshold: + should_compute = False + else: + should_compute = True + state.accumulated_distance = 0.0 + else: + should_compute = True + state.accumulated_distance = 0.0 + + state.previous_modulated_input = video_state.latent.clone() + state.step_count += 1 + + if not should_compute and state.previous_residual is not None: + # Reuse cached residual + cached_video = video_state.latent + state.previous_residual + state.skipped += 1 + return cached_video, audio_state.latent + else: + # Full computation + original_latent = video_state.latent.clone() + denoised_video, denoised_audio = denoise_fn(video_state, audio_state, sigmas, step_index) + state.previous_residual = denoised_video - original_latent + state.computed += 1 + return denoised_video, denoised_audio + + cached_denoise._tea_cache_state = state # type: ignore[attr-defined] + return cached_denoise +``` + +**Step 2: Write test** + +Create `backend/tests/test_tea_cache.py`: + +```python +"""Tests for TeaCache denoising loop caching.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from services.gpu_optimizations.tea_cache import ( + TeaCacheState, + wrap_denoise_fn_with_tea_cache, +) + + +@dataclass +class FakeLatentState: + latent: torch.Tensor + denoise_mask: torch.Tensor + clean_latent: torch.Tensor + + +def _make_fake_denoise(): + call_count = [0] + + def denoise_fn(video_state, audio_state, sigmas, step_index): + call_count[0] += 1 + return video_state.latent * 0.9, audio_state.latent * 0.9 + + return denoise_fn, call_count + + +def test_tea_cache_disabled_when_threshold_zero() -> None: + original, call_count = _make_fake_denoise() + wrapped = wrap_denoise_fn_with_tea_cache(original, num_steps=10, threshold=0.0) + assert wrapped is original # no wrapping + + +def test_tea_cache_always_computes_first_step() -> None: + original, call_count = _make_fake_denoise() + wrapped = wrap_denoise_fn_with_tea_cache(original, num_steps=10, threshold=0.05) + latent = torch.randn(1, 100, 64) + vs = FakeLatentState(latent=latent, denoise_mask=torch.ones(1), clean_latent=latent) + _v, _a = wrapped(vs, vs, torch.linspace(1, 0, 11), 0) + assert call_count[0] == 1 + + +def test_tea_cache_always_computes_last_step() -> None: + original, call_count = _make_fake_denoise() + wrapped = wrap_denoise_fn_with_tea_cache(original, num_steps=10, threshold=0.05) + latent = torch.randn(1, 100, 64) + vs = FakeLatentState(latent=latent, denoise_mask=torch.ones(1), clean_latent=latent) + _v, _a = wrapped(vs, vs, torch.linspace(1, 0, 11), 0) + _v, _a = wrapped(vs, vs, torch.linspace(1, 0, 11), 9) + assert call_count[0] == 2 # both first and last always computed + + +def test_tea_cache_skips_similar_steps() -> None: + original, call_count = _make_fake_denoise() + wrapped = wrap_denoise_fn_with_tea_cache(original, num_steps=10, threshold=100.0) + latent = torch.ones(1, 100, 64) + vs = FakeLatentState(latent=latent, denoise_mask=torch.ones(1), clean_latent=latent) + # Step 0: always computed + wrapped(vs, vs, torch.linspace(1, 0, 11), 0) + # Step 1: very high threshold means this should be skipped + wrapped(vs, vs, torch.linspace(1, 0, 11), 1) + tea_state: TeaCacheState = wrapped._tea_cache_state + assert tea_state.skipped >= 1 +``` + +**Step 3: Run tests** + +Run: `cd backend && uv run pytest tests/test_tea_cache.py -v --tb=short` +Expected: 4 tests PASS + +**Step 4: Commit** + +```bash +git add backend/services/gpu_optimizations/tea_cache.py backend/tests/test_tea_cache.py +git commit -m "feat: add TeaCache timestep-aware caching for 1.6-2.1x speedup" +``` + +--- + +### Task 4: Integrate optimizations into pipeline lifecycle + +Wire FFN chunking and TeaCache into the existing pipeline creation and generation flow. + +**Files:** +- Modify: `backend/state/app_settings.py` (add settings) +- Modify: `backend/handlers/pipelines_handler.py` (apply FFN chunking at load time) +- Modify: `backend/services/ltx_pipeline_common.py` (wrap denoising loop with TeaCache) +- Modify: `backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py` (accept tea_cache_threshold) + +**Step 1: Add settings fields** + +In `backend/state/app_settings.py`, add to `AppSettings` class after `batch_sound_enabled`: + +```python + ffn_chunk_count: int = 8 + tea_cache_threshold: float = 0.0 +``` + +Add validator: +```python + @field_validator("ffn_chunk_count", mode="before") + @classmethod + def _clamp_ffn_chunk_count(cls, value: Any) -> int: + return _clamp_int(value, minimum=0, maximum=32, default=8) +``` + +Also add to `SettingsResponse`: +```python + ffn_chunk_count: int = 8 + tea_cache_threshold: float = 0.0 +``` + +**Step 2: Apply FFN chunking in PipelinesHandler** + +In `backend/handlers/pipelines_handler.py`, add import at top: + +```python +from services.gpu_optimizations.ffn_chunking import patch_ffn_chunking +``` + +In `_create_video_pipeline` method, after `self._compile_if_enabled(state)` on line 139, add FFN chunking: + +```python + def _create_video_pipeline(self, model_type: VideoPipelineModelType) -> VideoPipelineState: + # ... existing code creating pipeline and state ... + state = self._compile_if_enabled(state) + + # Apply FFN chunking if enabled and torch.compile is not active + chunk_count = self.state.app_settings.ffn_chunk_count + if chunk_count > 0 and not state.is_compiled: + transformer = state.pipeline.pipeline.model_ledger.transformer() + patch_ffn_chunking(transformer, num_chunks=chunk_count) + + return state +``` + +**Step 3: Wire TeaCache into DistilledNativePipeline** + +In `backend/services/ltx_pipeline_common.py`, add import: + +```python +from services.gpu_optimizations.tea_cache import wrap_denoise_fn_with_tea_cache +``` + +In `DistilledNativePipeline.__init__`, add a `tea_cache_threshold` parameter: + +```python + def __init__( + self, + checkpoint_path: str, + gemma_root: str | None, + device: torch.device | None = None, + fp8transformer: bool = False, + tea_cache_threshold: float = 0.0, + ) -> None: + # ... existing init code ... + self._tea_cache_threshold = tea_cache_threshold +``` + +In `DistilledNativePipeline.__call__`, wrap the denoising_loop function before it's called: + +After the `denoising_loop` closure is defined (around line 147), add: + +```python + denoising_loop_fn = denoising_loop + if self._tea_cache_threshold > 0: + num_steps = len(sigmas) - 1 + cached_denoise = wrap_denoise_fn_with_tea_cache( + simple_denoising_func(video_context=video_context, audio_context=audio_context, transformer=transformer), + num_steps=num_steps, + threshold=self._tea_cache_threshold, + ) + + def tea_cache_loop( + sigmas: torch.Tensor, + video_state: LatentState, + audio_state: LatentState, + stepper: EulerDiffusionStep, + ) -> tuple[LatentState, LatentState]: + return euler_denoising_loop( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=cached_denoise, + ) + + denoising_loop_fn = tea_cache_loop +``` + +Then use `denoising_loop_fn` instead of `denoising_loop` in the `denoise_audio_video` call. + +**Step 4: Run existing tests** + +Run: `cd backend && uv run pytest -v --tb=short` +Expected: All existing tests still pass (optimizations only activate with real GPU models) + +**Step 5: Commit** + +```bash +git add backend/state/app_settings.py backend/handlers/pipelines_handler.py backend/services/ltx_pipeline_common.py backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py +git commit -m "feat: wire FFN chunking and TeaCache into pipeline lifecycle" +``` + +--- + +### Task 5: Add VRAM cleanup after every generation + +Prevents post-heavy-load GPU degradation (stalling at 15% after long generations). + +**Files:** +- Modify: `backend/handlers/queue_worker.py` (add cleanup after job completion) +- Modify: `backend/services/gpu_cleaner/torch_cleaner.py` (add aggressive cleanup) + +**Step 1: Add aggressive cleanup method** + +In `backend/services/gpu_cleaner/torch_cleaner.py`: + +```python +class TorchCleaner: + def __init__(self, device: str | torch.device = "cpu") -> None: + self._device = device + + def cleanup(self) -> None: + empty_device_cache(self._device) + gc.collect() + + def deep_cleanup(self) -> None: + """Aggressive cleanup for after heavy GPU workloads.""" + gc.collect() + empty_device_cache(self._device) + gc.collect() + empty_device_cache(self._device) + if str(self._device) != "cpu" and torch.cuda.is_available(): + torch.cuda.synchronize() +``` + +**Step 2: Add cleanup in QueueWorker after job completion** + +In `backend/handlers/queue_worker.py`, the `_run_job` method calls `executor.execute(job)` in a try/finally. Add a `gpu_cleaner` parameter to QueueWorker and call cleanup after GPU jobs: + +Add to `__init__`: +```python + def __init__( + self, + *, + queue: JobQueue, + gpu_executor: JobExecutor, + api_executor: JobExecutor, + gpu_cleaner: GpuCleaner | None = None, + on_batch_complete: Callable[[str, list[QueueJob]], None] | None = None, + enhance_handler: EnhancePromptProvider | None = None, + ) -> None: + # ... existing ... + self._gpu_cleaner = gpu_cleaner +``` + +In `_run_job`, add cleanup after GPU jobs complete: +```python + def _run_job(self, job: QueueJob, executor: JobExecutor, slot: str) -> None: + try: + result_paths = executor.execute(job) + self._queue.update_job(job.id, status="complete", progress=100, phase="complete", result_paths=result_paths) + except Exception as exc: + logger.error("Job %s failed: %s", job.id, exc) + self._queue.update_job(job.id, status="error", error=str(exc)) + finally: + if slot == "gpu" and self._gpu_cleaner is not None: + try: + self._gpu_cleaner.deep_cleanup() + except Exception: + pass + with self._lock: + if slot == "gpu": + self._gpu_busy = False + else: + self._api_busy = False +``` + +**Step 3: Wire gpu_cleaner in AppHandler** (where QueueWorker is constructed) + +Find where `QueueWorker` is instantiated in `backend/app_handler.py` and pass the existing `gpu_cleaner` service. + +**Step 4: Run tests** + +Run: `cd backend && uv run pytest tests/test_queue_worker.py -v --tb=short` +Expected: PASS (existing tests use fakes, new param is optional) + +**Step 5: Commit** + +```bash +git add backend/services/gpu_cleaner/torch_cleaner.py backend/handlers/queue_worker.py backend/app_handler.py +git commit -m "feat: aggressive VRAM cleanup after every GPU generation" +``` + +--- + +### Task 6: R2 Storage upload + +Upload generated videos/images to Cloudflare R2 after generation. + +**Files:** +- Create: `backend/services/r2_client/__init__.py` +- Create: `backend/services/r2_client/r2_client.py` (Protocol) +- Create: `backend/services/r2_client/r2_client_impl.py` (boto3 implementation) +- Modify: `backend/handlers/job_executors.py` (upload after completion) +- Modify: `backend/state/app_settings.py` (R2 settings) +- Test: `backend/tests/test_r2_upload.py` + +**Step 1: Add boto3 dependency** + +In `backend/pyproject.toml`, add to dependencies: +```toml + "boto3>=1.34.0", +``` + +Run: `cd backend && uv sync --frozen --extra test --extra dev` (or `uv lock` if lock needs updating) + +**Step 2: Add R2 settings** + +In `backend/state/app_settings.py`, add to `AppSettings`: +```python + r2_access_key_id: str = "" + r2_secret_access_key: str = "" + r2_endpoint: str = "" + r2_bucket: str = "" + r2_public_url: str = "" + auto_upload_to_r2: bool = False +``` + +Add to `SettingsResponse`: +```python + has_r2_credentials: bool = False + auto_upload_to_r2: bool = False +``` + +Update `to_settings_response` to handle R2 fields: +```python + r2_key = data.pop("r2_access_key_id", "") + data.pop("r2_secret_access_key", "") + data.pop("r2_endpoint", "") + data.pop("r2_bucket", "") + data.pop("r2_public_url", "") + data["has_r2_credentials"] = bool(r2_key) +``` + +**Step 3: Create R2 client protocol** + +Create `backend/services/r2_client/__init__.py` (empty). + +Create `backend/services/r2_client/r2_client.py`: + +```python +"""Protocol for R2/S3 compatible object storage.""" + +from __future__ import annotations + +from typing import Protocol + + +class R2Client(Protocol): + def upload_file( + self, *, local_path: str, remote_key: str, content_type: str, + ) -> str: + """Upload a local file. Returns the public URL.""" + ... + + def is_configured(self) -> bool: + """Return True if R2 credentials are set.""" + ... +``` + +**Step 4: Create R2 client implementation** + +Create `backend/services/r2_client/r2_client_impl.py`: + +```python +"""Cloudflare R2 client using boto3 S3-compatible API.""" + +from __future__ import annotations + +import logging +import mimetypes +from pathlib import Path + +logger = logging.getLogger(__name__) + + +class R2ClientImpl: + def __init__( + self, + access_key_id: str, + secret_access_key: str, + endpoint: str, + bucket: str, + public_url: str, + ) -> None: + self._access_key_id = access_key_id + self._secret_access_key = secret_access_key + self._endpoint = endpoint + self._bucket = bucket + self._public_url = public_url.rstrip("/") + + def is_configured(self) -> bool: + return bool(self._access_key_id and self._secret_access_key and self._endpoint and self._bucket) + + def upload_file(self, *, local_path: str, remote_key: str, content_type: str) -> str: + if not self.is_configured(): + raise RuntimeError("R2 credentials not configured") + + import boto3 + + s3 = boto3.client( + "s3", + endpoint_url=self._endpoint, + aws_access_key_id=self._access_key_id, + aws_secret_access_key=self._secret_access_key, + ) + + content_type = content_type or mimetypes.guess_type(local_path)[0] or "application/octet-stream" + s3.upload_file( + local_path, + self._bucket, + remote_key, + ExtraArgs={"ContentType": content_type}, + ) + + public_url = f"{self._public_url}/{remote_key}" + logger.info("Uploaded %s -> %s", Path(local_path).name, public_url) + return public_url +``` + +**Step 5: Hook upload into job executors** + +In `backend/handlers/job_executors.py`, add upload after GPU job completion. + +Add to `GpuJobExecutor.__init__`: +```python + def __init__(self, handler: AppHandler) -> None: + self._handler = handler + + def _try_upload_to_r2(self, job: QueueJob, result_paths: list[str]) -> None: + """Upload results to R2 if configured.""" + settings = self._handler.state.app_settings + if not settings.auto_upload_to_r2: + return + if not (settings.r2_access_key_id and settings.r2_endpoint): + return + + from services.r2_client.r2_client_impl import R2ClientImpl + + client = R2ClientImpl( + access_key_id=settings.r2_access_key_id, + secret_access_key=settings.r2_secret_access_key, + endpoint=settings.r2_endpoint, + bucket=settings.r2_bucket, + public_url=settings.r2_public_url, + ) + + for path in result_paths: + try: + ext = Path(path).suffix + content_type = "video/mp4" if ext == ".mp4" else "image/png" + remote_key = f"videos/{job.id}{ext}" + client.upload_file(local_path=path, remote_key=remote_key, content_type=content_type) + except Exception as exc: + logger.warning("R2 upload failed for %s: %s", path, exc) +``` + +Add `from pathlib import Path` import and call `self._try_upload_to_r2(job, result)` after each execute in `GpuJobExecutor.execute()`: + +```python + def execute(self, job: QueueJob) -> list[str]: + syncer = _ProgressSyncer(self._handler, job.id) + syncer.start() + try: + if job.type == "image": + result = self._execute_image(job) + elif job.type == "video": + result = self._execute_video(job) + else: + raise ValueError(f"Unknown job type: {job.type}") + self._try_upload_to_r2(job, result) + return result + finally: + syncer.stop() +``` + +**Step 6: Write test** + +Create `backend/tests/test_r2_upload.py`: + +```python +"""Tests for R2 upload integration.""" + +from __future__ import annotations + + +def test_r2_client_is_configured_when_credentials_present() -> None: + from services.r2_client.r2_client_impl import R2ClientImpl + + client = R2ClientImpl( + access_key_id="test", + secret_access_key="test", + endpoint="https://example.com", + bucket="test-bucket", + public_url="https://pub.example.com", + ) + assert client.is_configured() is True + + +def test_r2_client_not_configured_when_empty() -> None: + from services.r2_client.r2_client_impl import R2ClientImpl + + client = R2ClientImpl( + access_key_id="", + secret_access_key="", + endpoint="", + bucket="", + public_url="", + ) + assert client.is_configured() is False +``` + +**Step 7: Run tests** + +Run: `cd backend && uv run pytest tests/test_r2_upload.py -v --tb=short` +Expected: PASS + +**Step 8: Commit** + +```bash +git add backend/services/r2_client/ backend/handlers/job_executors.py backend/state/app_settings.py backend/pyproject.toml backend/tests/test_r2_upload.py +git commit -m "feat: add R2 cloud storage upload for generated videos/images" +``` + +--- + +### Task 7: Run typecheck and full test suite + +**Step 1: TypeScript typecheck** + +Run: `cd D:/git/directors-desktop && pnpm typecheck:ts` +Expected: PASS + +**Step 2: Python typecheck** + +Run: `cd D:/git/directors-desktop && pnpm typecheck:py` +Expected: PASS (may need type: ignore for monkey-patches) + +**Step 3: Full test suite** + +Run: `cd backend && uv run pytest -v --tb=short` +Expected: All tests PASS + +**Step 4: Fix any failures, commit fixes** + +--- + +### Task 8: Benchmark with optimizations enabled + +**Step 1: Start backend with optimizations** + +Set `ffn_chunk_count: 8` and `tea_cache_threshold: 0.03` in settings. + +**Step 2: Run 512p benchmark suite** + +Test 512p at 2s, 5s, 8s, 10s. Compare against baseline: + +| Test | Baseline | Target | +|------|----------|--------| +| 512p 2s | 37s | ~20-25s | +| 512p 5s | 84s | ~50-60s | +| 512p 8s | 100s | ~60-70s | +| 512p 10s | 651s | <200s (if FFN chunking fixes the cliff) | + +**Step 3: Test 1080p 2s** (previously crashed with OOM) + +If FFN chunking works, this should no longer OOM. + +**Step 4: Document results in `docs/performance-report.md`** + +**Step 5: Commit benchmark results** + +--- + +### Task 9: SageAttention version bump (optional, if time permits) + +**Files:** +- Modify: `backend/pyproject.toml` +- Modify: `backend/install_sageattention.bat` + +**Step 1: Bump version** + +In `pyproject.toml`, change: +``` +"sageattention>=1.0.0; sys_platform != 'darwin'", +``` +to: +``` +"sageattention>=2.0.0; sys_platform != 'darwin'", +``` + +**Step 2: Update install script** + +Update `install_sageattention.bat` to clone the latest v2 branch. + +**Step 3: Test that sageattention imports correctly** + +Run: `cd backend && uv run python -c "import sageattention; print(sageattention.__version__)"` + +**Step 4: Commit** + +```bash +git add backend/pyproject.toml backend/install_sageattention.bat +git commit -m "chore: bump sageattention to v2 for improved attention performance" +``` diff --git a/docs/plans/feat-image-editing-img2img.md b/docs/plans/feat-image-editing-img2img.md new file mode 100644 index 00000000..384bbf81 --- /dev/null +++ b/docs/plans/feat-image-editing-img2img.md @@ -0,0 +1,449 @@ +# feat: Local Image Editing (img2img) via Z-Image-Turbo + +**Date:** 2026-03-08 +**Type:** Enhancement +**Complexity:** Medium — reuses existing model weights, GPU lifecycle, and queue system + +--- + +## Overview + +Add local image editing to Directors Desktop using `ZImageImg2ImgPipeline` from diffusers. Users select an image, write an editing prompt, adjust a strength slider, and get back an edited image. This reuses the **same Z-Image-Turbo model weights** already downloaded — no new model, no API costs. + +Primary use case: edit frames before animating them into video. + +``` +User has an image (generated or imported) + → Clicks "Edit" or drops an image into the editor + → Writes a prompt: "make the sky a dramatic sunset" + → Adjusts strength: 0.65 (moderate change) + → Gets back the edited image + → Animates it into video +``` + +## Technical Foundation + +`ZImageImg2ImgPipeline` shares the same model architecture and weights as `ZImagePipeline` (text-to-image). The difference is purely in the pipeline class — img2img encodes the source image into latents, adds noise at a level determined by `strength`, then denoises from there. + +**Key parameters:** +- `image`: PIL Image input +- `strength`: 0.0 (no change) to 1.0 (ignore input entirely). Default: **0.65** +- `guidance_scale`: Must be **0.0** for Turbo models +- `num_inference_steps`: **9** (yields 8 DiT forward passes) +- `torch_dtype`: `torch.bfloat16` + +**Requires:** `diffusers >= 0.37.0` + +**VRAM:** Same as text-to-image (~16GB). No additional memory for the source image encoding. + +**LoRA:** Works identically — same `ZImageLoraLoaderMixin` base class. + +--- + +## Phase 1: Backend Pipeline + API + +### 1a. Extend the protocol + +**`backend/services/image_generation_pipeline/image_generation_pipeline.py`** + +Add `img2img` method to `ImageGenerationPipeline` protocol: + +```python +def img2img( + self, + prompt: str, + image: PILImageType, + strength: float, + height: int, + width: int, + guidance_scale: float, + num_inference_steps: int, + seed: int, +) -> ImagePipelineOutputLike: ... +``` + +### 1b. Implement img2img in ZIT pipeline + +**`backend/services/image_generation_pipeline/zit_image_generation_pipeline.py`** + +Create a second internal pipeline instance using `ZImageImg2ImgPipeline`. Both share the same model components (transformer, VAE, text encoder), so we don't double VRAM usage. The img2img pipeline wraps the same underlying model. + +```python +from diffusers import ZImagePipeline, ZImageImg2ImgPipeline + +class ZitImageGenerationPipeline: + def __init__(self, model_path: str, device: str | None = None) -> None: + self.pipeline = ZImagePipeline.from_pretrained( + model_path, torch_dtype=torch.bfloat16, + ) + # Create img2img pipeline sharing the same model components + self._img2img_pipeline = ZImageImg2ImgPipeline( + scheduler=self.pipeline.scheduler, + vae=self.pipeline.vae, + text_encoder=self.pipeline.text_encoder, + tokenizer=self.pipeline.tokenizer, + transformer=self.pipeline.transformer, + ) + if device is not None: + self.to(device) + + @torch.inference_mode() + def img2img( + self, + prompt: str, + image: PILImageType, + strength: float, + height: int, + width: int, + guidance_scale: float, + num_inference_steps: int, + seed: int, + ) -> ImagePipelineOutputLike: + generator = torch.Generator( + device=self._resolve_generator_device() + ).manual_seed(seed) + output = self._img2img_pipeline( + prompt=prompt, + image=image, + strength=strength, + height=height, + width=width, + guidance_scale=0.0, # Always 0.0 for Turbo + num_inference_steps=num_inference_steps, + generator=generator, + output_type="pil", + return_dict=True, + ) + return self._normalize_output(output) +``` + +**Important:** The `to()` method must also move the img2img pipeline. Since they share components, calling `self.pipeline.enable_model_cpu_offload()` should cover both, but verify this. If not, call it on `self._img2img_pipeline` too. + +### 1c. Extend API types + +**`backend/api_types.py`** + +Add fields to `GenerateImageRequest`: + +```python +class GenerateImageRequest(BaseModel): + prompt: NonEmptyPrompt + width: int = 1024 + height: int = 1024 + numSteps: int = 4 + numImages: int = 1 + loraPath: str | None = None + loraWeight: float = 1.0 + # New for img2img: + sourceImagePath: str | None = None + strength: float = 0.65 +``` + +### 1d. Update image generation handler + +**`backend/handlers/image_generation_handler.py`** + +In `generate_image()`, branch on `sourceImagePath`: + +```python +from PIL import Image + +def generate_image(self, ..., source_image_path: str | None = None, strength: float = 0.65) -> list[str]: + # ... existing setup (load pipeline, LoRA, etc.) ... + + source_image: PILImageType | None = None + if source_image_path: + source_image = Image.open(source_image_path).convert("RGB") + # Snap source dimensions to 16-multiples if using source dimensions + width = (source_image.width // 16) * 16 + height = (source_image.height // 16) * 16 + source_image = source_image.resize((width, height), Image.LANCZOS) + + for i in range(num_images): + if source_image is not None: + result = zit.img2img( + prompt=prompt, + image=source_image, + strength=strength, + height=height, + width=width, + guidance_scale=0.0, + num_inference_steps=num_inference_steps, + seed=seed + i, + ) + else: + result = zit.generate( + prompt=prompt, + height=height, + width=width, + guidance_scale=0.0, + num_inference_steps=num_inference_steps, + seed=seed + i, + ) + # ... save output (use prefix "zit_edit_" for edits) ... +``` + +Update the `generate()` entry point to pass `source_image_path` and `strength` from `req`. + +### 1e. Update job executor + +**`backend/handlers/job_executors.py`** + +In `_execute_image()`, pass `sourceImagePath` and `strength` from `job.params` into `GenerateImageRequest`. + +### 1f. Update gallery handler + +**`backend/handlers/gallery_handler.py`** + +Add `"zit_edit_"` prefix mapping: + +```python +_MODEL_PREFIXES: list[tuple[str, str]] = [ + ("zit_edit_", "zit-edit"), # New + ("zit_image_", "zit"), + # ... existing ... +] +``` + +--- + +## Phase 2: Frontend — Edit Flow + +### 2a. Add `editImage` to the generation hook + +**`frontend/hooks/use-generation.ts`** + +Add new method: + +```typescript +const editImage = useCallback(async ( + prompt: string, + sourceImagePath: string, + settings: GenerationSettings, + strength: number = 0.65, +) => { + setState(prev => ({ + ...prev, + isGenerating: true, + progress: 0, + statusMessage: 'Editing image...', + videoUrl: null, videoPath: null, imageUrl: null, imageUrls: [], + error: null, + })) + + try { + const backendUrl = await window.electronAPI.getBackendUrl() + const response = await fetch(`${backendUrl}/api/queue/submit`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + type: 'image', + model: appSettings.imageModel || 'z-image-turbo', + params: { + prompt, + sourceImagePath, + strength, + width: 0, // 0 = use source dimensions + height: 0, + numSteps: settings.imageSteps || 8, + numImages: 1, + ...(settings.loraPath ? { loraPath: settings.loraPath, loraWeight: settings.loraWeight ?? 1.0 } : {}), + }, + }), + }) + // ... same submit + poll pattern as generateImage ... + } catch (error) { /* ... */ } +}, [appSettings.imageModel, startPolling]) +``` + +Add `editImage` to the returned interface. + +### 2b. Integrate into GenSpace + +**`frontend/views/GenSpace.tsx`** + +Add state for the edit source image: + +```typescript +const [editSourceImage, setEditSourceImage] = useState<{ url: string; path: string } | null>(null) +``` + +When `editSourceImage` is set and mode is `image`: +- Show the source image preview above the prompt +- Show a strength slider (0.0–1.0, default 0.65) +- Change prompt placeholder to "Describe your edit..." +- Change the generate button label to "Edit" +- Hide aspect ratio / resolution controls (use source dimensions) +- On submit, call `editImage()` instead of `generateImage()` + +Add a clear button (X) to remove the source image and return to text-to-image. + +### 2c. Wire the Edit button + +**`frontend/components/ImageResult.tsx`** + +The Edit button already exists (line ~130) but has no handler. Add an `onEdit` prop: + +```typescript +interface ImageResultProps { + // ... existing ... + onEdit?: (imagePath: string) => void +} +``` + +Wire `onClick` to call `onEdit` with the image file path. In GenSpace, pass a callback that sets `editSourceImage`. + +### 2d. Add strength slider to SettingsPanel + +**`frontend/components/SettingsPanel.tsx`** + +When in image mode and a source image is present, show: + +```typescript +{/* Strength slider — only visible during image editing */} +
+ + setStrength(Number(e.target.value) / 100)} + /> +
+ Subtle + {Math.round(strength * 100)}% + Heavy +
+
+``` + +Add `strength` to `GenerationSettings` interface. + +### 2e. Add phase message for editing + +**`frontend/hooks/use-generation.ts`** + +```typescript +case 'encoding_image': + return 'Encoding source image...' +``` + +--- + +## Phase 3: Polish + +### 3a. Before/after comparison + +**`frontend/components/ImageResult.tsx`** + +When the result was generated from a source image, show a "Compare" toggle: +- On click/hold, swap the displayed image to the source +- Release returns to the edited result +- Simple state toggle, no complex slider needed for MVP + +### 3b. Edit from video frames + +**`frontend/contexts/ProjectContext.tsx`** + +The `genSpaceEditImageUrl` mechanism already exists for sending frames from VideoEditor to GenSpace. Update the receiver in GenSpace to support routing to image-edit mode: + +```typescript +// When receiving an edit request from VideoEditor +if (genSpaceEditImageUrl) { + setEditSourceImage({ url: genSpaceEditImageUrl, path: fileUrlToPath(genSpaceEditImageUrl) }) + setMode('image') +} +``` + +### 3c. Edit from gallery + +Add an "Edit" action to gallery items (both inline gallery in GenSpace and the full Gallery view). Clicking it navigates to GenSpace with the image pre-loaded as edit source. + +### 3d. Sequential edits + +After an edit completes, the result becomes available as a new source. The "Edit" button on the result should re-populate the source image with the current output, enabling iterative refinement. + +--- + +## What NOT to build (yet) + +- **Inpainting / masking** — requires canvas drawing tools, mask pipeline, significant UI work. Future phase. +- **ControlNet** — `ZImageControlNetPipeline` exists in diffusers but adds model complexity. Future phase. +- **API-backed img2img** — Replicate supports img2img but adds upload/download overhead. Local-only for now. +- **Batch editing** — editing multiple images with the same prompt. Low demand, adds complexity. + +--- + +## Acceptance Criteria + +- [x] User can click "Edit" on a generated image and it becomes the source for img2img +- [ ] User can drop/browse an external image file as edit source +- [x] Strength slider appears when source image is set (default 0.65) +- [x] Prompt describes the desired edit; generation uses img2img pipeline +- [x] Output image matches source dimensions (snapped to 16-multiples) +- [x] LoRA works with img2img (same load/unload behavior) +- [x] Edited images appear in gallery with "zit-edit" model tag +- [x] Progress phases report correctly (Preparing GPU → Loading model → Encoding image → Generating → Complete) +- [x] No new model download required — reuses existing ZIT weights +- [x] Backend tests pass with fake img2img method +- [x] TypeScript and Pyright typechecks pass + +## Quality Gate + +- [x] `pnpm typecheck` passes +- [x] `pnpm backend:test` passes (343+ tests) +- [ ] Manual test: generate image → edit it → verify output is visually modified +- [ ] Manual test: import external image → edit → verify dimensions preserved +- [ ] Manual test: edit with LoRA active → verify style applied + +--- + +## Files to Modify + +### Backend (7 files) +| File | Change | +|------|--------| +| `services/image_generation_pipeline/image_generation_pipeline.py` | Add `img2img()` to protocol | +| `services/image_generation_pipeline/zit_image_generation_pipeline.py` | Implement `img2img()` with `ZImageImg2ImgPipeline` sharing model components | +| `api_types.py` | Add `sourceImagePath`, `strength` to `GenerateImageRequest` | +| `handlers/image_generation_handler.py` | Branch on `sourceImagePath`, load/resize source image, call `img2img` | +| `handlers/job_executors.py` | Pass `sourceImagePath`, `strength` from job params | +| `handlers/gallery_handler.py` | Add `"zit_edit_"` prefix mapping | +| `tests/fakes/services.py` | Add `img2img()` to `FakeImageGenerationPipeline` | + +### Frontend (5 files) +| File | Change | +|------|--------| +| `hooks/use-generation.ts` | Add `editImage()` method, `encoding_image` phase message | +| `views/GenSpace.tsx` | Add `editSourceImage` state, edit mode UI, strength slider, wire Edit button | +| `components/ImageResult.tsx` | Wire `onEdit` prop to Edit button | +| `components/SettingsPanel.tsx` | Add `strength` to `GenerationSettings`, show slider in edit mode | +| `contexts/ProjectContext.tsx` | Support image-edit routing from VideoEditor frames | + +### No changes needed +- Queue system (already supports arbitrary params) +- Queue worker (delegates to executor) +- Pipeline lifecycle / GPU management (img2img reuses same model) +- Device management / CPU offload (shared components) +- LoRA loading (same mixin base class) + +--- + +## Dependency Check + +Verify diffusers version in `backend/pyproject.toml`: + +```bash +grep diffusers backend/pyproject.toml +``` + +If < 0.37.0, update to `diffusers >= 0.37.0` and re-lock with `uv lock`. + +--- + +## References + +- HuggingFace Z-Image docs: https://huggingface.co/docs/diffusers/en/api/pipelines/z_image +- Model card: https://huggingface.co/Tongyi-MAI/Z-Image-Turbo +- Existing ZIT pipeline: `backend/services/image_generation_pipeline/zit_image_generation_pipeline.py` +- ImageResult Edit button (placeholder): `frontend/components/ImageResult.tsx:130-134` +- GenSpace edit routing: `frontend/views/GenSpace.tsx:986-993` diff --git a/docs/superpowers/plans/2026-03-23-quantized-models.md b/docs/superpowers/plans/2026-03-23-quantized-models.md new file mode 100644 index 00000000..12ca62ec --- /dev/null +++ b/docs/superpowers/plans/2026-03-23-quantized-models.md @@ -0,0 +1,2025 @@ +# Quantized Video Model Support Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Support GGUF, NF4, and FP8 checkpoint model formats so Directors Desktop runs on 24GB GPUs. + +**Architecture:** New `ModelScanner` service (Protocol + Fake pattern) scans a user-specified folder for model files and detects their format from file metadata. New `GGUFFastVideoPipeline` and `NF4FastVideoPipeline` classes implement the existing `FastVideoPipeline` protocol. `PipelinesHandler._create_video_pipeline()` reads the selected model from settings and picks the right pipeline class. Frontend gets a new "Models" tab in SettingsModal and a `ModelGuideDialog` popup with GPU-based recommendations and HuggingFace download links. + +**Tech Stack:** Python (FastAPI, pydantic, diffusers GGUF quantizer, bitsandbytes NF4), TypeScript (React, Tailwind CSS), gguf Python package (new dependency) + +**Spec:** `docs/superpowers/specs/2026-03-22-quantized-models-design.md` + +--- + +### Task 1: Add `gguf` dependency and new settings fields + +**Context:** Before anything else, we need the `gguf` package for reading GGUF file metadata, and two new settings fields (`custom_video_model_path`, `selected_video_model`) that the rest of the system depends on. + +**Files:** +- Modify: `backend/pyproject.toml` — add `gguf>=0.10.0` to dependencies +- Modify: `backend/state/app_settings.py` — add 2 new fields to `AppSettings` and `SettingsResponse`, update `to_settings_response()` +- Modify: `backend/api_types.py` — add `DetectedModel`, `ModelFormatInfo`, `SelectModelRequest`, response types + +- [ ] **Step 1: Add gguf dependency** + +In `backend/pyproject.toml`, add `"gguf>=0.10.0"` to the `dependencies` list. + +- [ ] **Step 2: Install the dependency** + +Run: `cd backend && uv sync` +Expected: Resolves and installs gguf package + +- [ ] **Step 3: Add settings fields to AppSettings** + +In `backend/state/app_settings.py`, add these two fields to `AppSettings` class (after `civitai_api_key`): + +```python +custom_video_model_path: str = "" +selected_video_model: str = "" +``` + +- [ ] **Step 4: Add settings fields to SettingsResponse** + +In `backend/state/app_settings.py`, add these two fields to `SettingsResponse` class (after `has_civitai_api_key`): + +```python +custom_video_model_path: str = "" +selected_video_model: str = "" +``` + +- [ ] **Step 5: Update to_settings_response()** + +In `backend/state/app_settings.py`, the `to_settings_response()` function currently pops API keys and replaces them with `has_*` booleans. The new fields are NOT sensitive — they should pass through as-is. Since `SettingsResponse` uses `extra="ignore"` via `SettingsBaseModel`, and both fields are simple strings with the same names, they will pass through `model_validate(data)` automatically. **No code changes needed to `to_settings_response()` itself** — the fields are already in `data` from `model_dump()` and `SettingsResponse` has matching fields. + +Verify this works: `cd backend && uv run python -c "from state.app_settings import AppSettings, to_settings_response; s = AppSettings(custom_video_model_path='/test', selected_video_model='model.gguf'); r = to_settings_response(s); print(r.custom_video_model_path, r.selected_video_model)"` + +Expected: `/test model.gguf` + +- [ ] **Step 6: Add new types to api_types.py** + +At the end of `backend/api_types.py` (before the final empty line), add: + +```python +# ============================================================ +# Video Model Scanner Types +# ============================================================ + + +class DetectedModel(BaseModel): + filename: str + path: str + format: str # "bf16" | "fp8" | "gguf" | "nf4" + quant_type: str | None = None + size_bytes: int + size_gb: float + is_distilled: bool + display_name: str + + +class ModelFormatInfo(BaseModel): + id: str + name: str + size_gb: float + min_vram_gb: int + quality_tier: str + needs_distilled_lora: bool + download_url: str + description: str + + +class DistilledLoraInfo(BaseModel): + name: str + size_gb: float + download_url: str + description: str + + +class VideoModelScanResponse(BaseModel): + models: list[DetectedModel] + distilled_lora_found: bool + + +class VideoModelGuideResponse(BaseModel): + gpu_name: str | None + vram_gb: int | None + recommended_format: str + formats: list[ModelFormatInfo] + distilled_lora: DistilledLoraInfo + + +class SelectModelRequest(BaseModel): + model: str +``` + +- [ ] **Step 7: Run typecheck to verify** + +Run: `cd backend && uv run pyright` +Expected: `0 errors, 0 warnings, 0 informations` + +- [ ] **Step 8: Commit** + +```bash +cd D:/git/directors-desktop +git add backend/pyproject.toml backend/uv.lock backend/state/app_settings.py backend/api_types.py +git commit -m "feat(quantized-models): add gguf dependency, settings fields, and API types" +``` + +--- + +### Task 2: Create ModelScanner service (Protocol + Impl + Fake) + +**Context:** The `ModelScanner` service scans a folder for video model files and returns structured `DetectedModel` results. It follows the codebase's Protocol + Impl + Fake pattern (see `services/palette_sync_client/` for reference). Detection uses file metadata, NOT file size heuristics. + +**Files:** +- Create: `backend/services/model_scanner/__init__.py` +- Create: `backend/services/model_scanner/model_scanner.py` — Protocol +- Create: `backend/services/model_scanner/model_scanner_impl.py` — Real implementation +- Create: `backend/services/model_scanner/model_guide_data.py` — Static format metadata +- Modify: `backend/services/interfaces.py` — re-export ModelScanner +- Modify: `backend/tests/fakes/services.py` — add FakeModelScanner +- Create: `backend/tests/test_model_scanner.py` — tests + +- [ ] **Step 1: Write the failing test** + +Create `backend/tests/test_model_scanner.py`: + +```python +"""Tests for the ModelScanner service.""" + +from __future__ import annotations + +import struct +from pathlib import Path + +import pytest + +from services.model_scanner.model_scanner_impl import ModelScannerImpl + + +class TestScanVideoModels: + def test_empty_folder_returns_empty_list(self, tmp_path: Path) -> None: + scanner = ModelScannerImpl() + result = scanner.scan_video_models(tmp_path) + assert result == [] + + def test_nonexistent_folder_returns_empty_list(self, tmp_path: Path) -> None: + scanner = ModelScannerImpl() + result = scanner.scan_video_models(tmp_path / "nonexistent") + assert result == [] + + def test_detects_gguf_file(self, tmp_path: Path) -> None: + # Create a minimal GGUF file with magic bytes and version + gguf_path = tmp_path / "model-Q4_K.gguf" + _write_minimal_gguf(gguf_path) + + scanner = ModelScannerImpl() + result = scanner.scan_video_models(tmp_path) + + assert len(result) == 1 + assert result[0].filename == "model-Q4_K.gguf" + assert result[0].format == "gguf" + assert result[0].is_distilled is False + + def test_detects_safetensors_file(self, tmp_path: Path) -> None: + st_path = tmp_path / "model.safetensors" + st_path.write_bytes(b"\x00" * 4096) # minimal file + + scanner = ModelScannerImpl() + result = scanner.scan_video_models(tmp_path) + + assert len(result) == 1 + assert result[0].filename == "model.safetensors" + assert result[0].format in ("bf16", "fp8") + + def test_detects_nf4_folder(self, tmp_path: Path) -> None: + nf4_dir = tmp_path / "my-nf4-model" + nf4_dir.mkdir() + (nf4_dir / "model.safetensors").write_bytes(b"\x00" * 1024) + (nf4_dir / "quantize_config.json").write_text('{"quant_type": "nf4"}') + + scanner = ModelScannerImpl() + result = scanner.scan_video_models(tmp_path) + + assert len(result) == 1 + assert result[0].format == "nf4" + assert result[0].quant_type == "nf4" + + def test_skips_corrupt_gguf_file(self, tmp_path: Path) -> None: + bad_path = tmp_path / "corrupt.gguf" + bad_path.write_bytes(b"not a gguf file") + + scanner = ModelScannerImpl() + result = scanner.scan_video_models(tmp_path) + + assert len(result) == 0 + + def test_skips_non_model_files(self, tmp_path: Path) -> None: + (tmp_path / "readme.txt").write_text("hello") + (tmp_path / "config.json").write_text("{}") + + scanner = ModelScannerImpl() + result = scanner.scan_video_models(tmp_path) + + assert len(result) == 0 + + def test_multiple_models_in_folder(self, tmp_path: Path) -> None: + # GGUF file + _write_minimal_gguf(tmp_path / "model-q4.gguf") + # Safetensors file + (tmp_path / "model-bf16.safetensors").write_bytes(b"\x00" * 4096) + + scanner = ModelScannerImpl() + result = scanner.scan_video_models(tmp_path) + + assert len(result) == 2 + formats = {m.format for m in result} + assert "gguf" in formats + + +def _write_minimal_gguf(path: Path) -> None: + """Write a minimal valid GGUF file header (magic + version + tensor/kv counts).""" + with open(path, "wb") as f: + f.write(b"GGUF") # magic + f.write(struct.pack(" list[DetectedModel]: ... +``` + +- [ ] **Step 4: Create model_guide_data.py** + +Create `backend/services/model_scanner/model_guide_data.py`: +```python +"""Static metadata about available video model formats and download URLs.""" + +from __future__ import annotations + +from api_types import DistilledLoraInfo, ModelFormatInfo + +MODEL_FORMATS: list[ModelFormatInfo] = [ + ModelFormatInfo( + id="bf16", + name="BF16 (Full Precision)", + size_gb=43, + min_vram_gb=32, + quality_tier="Best", + needs_distilled_lora=False, + download_url="https://huggingface.co/Lightricks/LTX-Video-2.3-22b-distilled", + description="Best quality. Requires 32GB+ VRAM. Auto-downloaded by default.", + ), + ModelFormatInfo( + id="fp8", + name="FP8 Distilled Checkpoint", + size_gb=22, + min_vram_gb=20, + quality_tier="Excellent", + needs_distilled_lora=False, + download_url="https://huggingface.co/Lightricks/LTX-Video-2.3-22b-distilled", + description="Excellent quality, smaller file. Good for 20-31GB VRAM GPUs.", + ), + ModelFormatInfo( + id="gguf_q8", + name="GGUF Q8", + size_gb=22, + min_vram_gb=18, + quality_tier="Excellent", + needs_distilled_lora=True, + download_url="https://huggingface.co/city96/LTX-Video-2.3-22b-0.9.7-dev-gguf", + description="Excellent quality quantized model. Needs distilled LoRA.", + ), + ModelFormatInfo( + id="gguf_q5k", + name="GGUF Q5_K", + size_gb=15, + min_vram_gb=13, + quality_tier="Very Good", + needs_distilled_lora=True, + download_url="https://huggingface.co/city96/LTX-Video-2.3-22b-0.9.7-dev-gguf", + description="Very good quality, balanced size. Good for 16-19GB VRAM GPUs.", + ), + ModelFormatInfo( + id="gguf_q4k", + name="GGUF Q4_K", + size_gb=12, + min_vram_gb=10, + quality_tier="Good", + needs_distilled_lora=True, + download_url="https://huggingface.co/city96/LTX-Video-2.3-22b-0.9.7-dev-gguf", + description="Good quality, smallest file. Good for 10-15GB VRAM GPUs.", + ), + ModelFormatInfo( + id="nf4", + name="NF4 (4-bit BitsAndBytes)", + size_gb=12, + min_vram_gb=10, + quality_tier="Good", + needs_distilled_lora=True, + download_url="https://huggingface.co/Lightricks/LTX-Video-2.3-22b-distilled", + description="4-bit quantization via BitsAndBytes. Good for 10-15GB VRAM GPUs.", + ), +] + +DISTILLED_LORA_INFO = DistilledLoraInfo( + name="LTX 2.3 Distilled LoRA", + size_gb=0.5, + download_url="https://huggingface.co/Lightricks/LTX-Video-2.3-22b-distilled", + description="Required for GGUF and NF4 models to enable fast distilled generation.", +) + + +def recommend_format(vram_gb: int | None) -> str: + """Return the recommended format ID based on available VRAM.""" + if vram_gb is None: + return "bf16" + if vram_gb >= 32: + return "bf16" + if vram_gb >= 20: + return "fp8" + if vram_gb >= 16: + return "gguf_q5k" + if vram_gb >= 10: + return "gguf_q4k" + return "api_only" +``` + +- [ ] **Step 5: Create ModelScannerImpl** + +Create `backend/services/model_scanner/model_scanner_impl.py`: +```python +"""Real implementation of ModelScanner that reads file metadata.""" + +from __future__ import annotations + +import json +import logging +import struct +from pathlib import Path + +from api_types import DetectedModel + +logger = logging.getLogger(__name__) + +# Known distilled checkpoint filenames (these don't need distilled LoRA) +_KNOWN_DISTILLED = { + "ltx-video-2.3-22b-distilled.safetensors", + "ltx-video-2.3-22b-distilled-fp8.safetensors", +} + +_GGUF_MAGIC = b"GGUF" + + +class ModelScannerImpl: + def scan_video_models(self, folder: Path) -> list[DetectedModel]: + """Scan folder for supported video model files.""" + if not folder.exists(): + return [] + + models: list[DetectedModel] = [] + + for entry in sorted(folder.iterdir()): + try: + if entry.is_file() and entry.suffix == ".gguf": + model = self._scan_gguf(entry) + if model is not None: + models.append(model) + elif entry.is_file() and entry.suffix == ".safetensors": + models.append(self._scan_safetensors(entry)) + elif entry.is_dir(): + model = self._scan_nf4_folder(entry) + if model is not None: + models.append(model) + except Exception: + logger.warning("Failed to scan model file %s", entry, exc_info=True) + + return models + + def _scan_gguf(self, path: Path) -> DetectedModel | None: + """Read GGUF file header to verify validity and extract quant type.""" + with open(path, "rb") as f: + magic = f.read(4) + if magic != _GGUF_MAGIC: + logger.warning("Skipping %s: invalid GGUF magic bytes", path.name) + return None + + version_bytes = f.read(4) + if len(version_bytes) < 4: + return None + struct.unpack(" DetectedModel: + """Scan a safetensors file. Detect FP8 from companion config or header.""" + size_bytes = path.stat().st_size + size_gb = round(size_bytes / (1024**3), 1) + is_distilled = path.name.lower() in _KNOWN_DISTILLED + fmt = self._detect_safetensors_format(path) + + return DetectedModel( + filename=path.name, + path=str(path), + format=fmt, + quant_type="fp8" if fmt == "fp8" else None, + size_bytes=size_bytes, + size_gb=size_gb, + is_distilled=is_distilled, + display_name=f"LTX 2.3 — {fmt.upper()} ({size_gb} GB)", + ) + + def _scan_nf4_folder(self, folder: Path) -> DetectedModel | None: + """Check if a subfolder contains an NF4 quantized model.""" + config_path = folder / "quantize_config.json" + if not config_path.exists(): + return None + + try: + config = json.loads(config_path.read_text()) + except (json.JSONDecodeError, OSError): + return None + + quant_type = config.get("quant_type") + if quant_type != "nf4": + return None + + size_bytes = sum(f.stat().st_size for f in folder.rglob("*") if f.is_file()) + size_gb = round(size_bytes / (1024**3), 1) + + return DetectedModel( + filename=folder.name, + path=str(folder), + format="nf4", + quant_type="nf4", + size_bytes=size_bytes, + size_gb=size_gb, + is_distilled=False, + display_name=f"LTX 2.3 — NF4 ({size_gb} GB)", + ) + + def _detect_safetensors_format(self, path: Path) -> str: + """Detect whether a safetensors file is BF16 or FP8. + + Checks companion config.json first, then falls back to reading + the safetensors header for dtype metadata. + """ + # Check companion config.json + config_path = path.parent / "config.json" + if config_path.exists(): + try: + config = json.loads(config_path.read_text()) + dtype = config.get("torch_dtype", "") + if "float8" in dtype or "fp8" in dtype: + return "fp8" + except (json.JSONDecodeError, OSError): + pass + + # Check safetensors header for dtype info + try: + with open(path, "rb") as f: + header_size_bytes = f.read(8) + if len(header_size_bytes) == 8: + header_size = struct.unpack(" str | None: + """Extract quantization type from common GGUF naming conventions.""" + name_upper = filename.upper() + for qt in ("Q8_0", "Q6_K", "Q5_K", "Q5_1", "Q5_0", "Q4_K", "Q4_1", "Q4_0", "Q3_K", "Q2_K"): + if qt in name_upper: + return qt + return None + + @staticmethod + def _gguf_display_name(filename: str, quant_type: str | None, size_gb: float) -> str: + qt = quant_type or "unknown quant" + return f"LTX 2.3 — GGUF {qt} ({size_gb} GB)" +``` + +- [ ] **Step 6: Add re-export in interfaces.py** + +In `backend/services/interfaces.py`, add this import at the top (after the PaletteSyncClient import): + +```python +from services.model_scanner.model_scanner import ModelScanner +``` + +And add `"ModelScanner"` to the `__all__` list. + +- [ ] **Step 7: Add FakeModelScanner to test fakes** + +In `backend/tests/fakes/services.py`, add `DetectedModel` to the imports from `api_types`, then add this class (before the `FakeServices` dataclass at the bottom): + +```python +class FakeModelScanner: + def __init__(self) -> None: + self._models: list[DetectedModel] = [] + + def set_models(self, models: list[DetectedModel]) -> None: + self._models = models + + def scan_video_models(self, folder: Path) -> list[DetectedModel]: + return self._models +``` + +Also add `model_scanner: FakeModelScanner = field(default_factory=FakeModelScanner)` to the `FakeServices` dataclass. + +- [ ] **Step 8: Run tests to verify they pass** + +Run: `cd backend && uv run pytest tests/test_model_scanner.py -v --tb=short` +Expected: All 7 tests PASS + +- [ ] **Step 9: Run pyright** + +Run: `cd backend && uv run pyright` +Expected: `0 errors, 0 warnings, 0 informations` + +- [ ] **Step 10: Commit** + +```bash +cd D:/git/directors-desktop +git add backend/services/model_scanner/ backend/services/interfaces.py backend/tests/fakes/services.py backend/tests/test_model_scanner.py +git commit -m "feat(quantized-models): add ModelScanner service with Protocol, impl, fake, and tests" +``` + +--- + +### Task 3: Add model guide recommendation logic and tests + +**Context:** The `recommend_format()` function maps GPU VRAM to a recommended format. This is simple logic but critical for the user experience — it's what tells users with a 3090 to download GGUF Q5_K. + +**Files:** +- Already created: `backend/services/model_scanner/model_guide_data.py` (in Task 2) +- Create: `backend/tests/test_model_guide.py` + +- [ ] **Step 1: Write the tests** + +Create `backend/tests/test_model_guide.py`: + +```python +"""Tests for model guide recommendation logic.""" + +from __future__ import annotations + +from services.model_scanner.model_guide_data import MODEL_FORMATS, DISTILLED_LORA_INFO, recommend_format + + +class TestRecommendFormat: + def test_48gb_recommends_bf16(self) -> None: + assert recommend_format(48) == "bf16" + + def test_32gb_recommends_bf16(self) -> None: + assert recommend_format(32) == "bf16" + + def test_24gb_recommends_fp8(self) -> None: + assert recommend_format(24) == "fp8" + + def test_20gb_recommends_fp8(self) -> None: + assert recommend_format(20) == "fp8" + + def test_16gb_recommends_gguf_q5k(self) -> None: + assert recommend_format(16) == "gguf_q5k" + + def test_12gb_recommends_gguf_q4k(self) -> None: + assert recommend_format(12) == "gguf_q4k" + + def test_10gb_recommends_gguf_q4k(self) -> None: + assert recommend_format(10) == "gguf_q4k" + + def test_8gb_recommends_api_only(self) -> None: + assert recommend_format(8) == "api_only" + + def test_none_vram_defaults_to_bf16(self) -> None: + assert recommend_format(None) == "bf16" + + +class TestModelFormatsData: + def test_all_formats_have_required_fields(self) -> None: + for fmt in MODEL_FORMATS: + assert fmt.id + assert fmt.name + assert fmt.min_vram_gb > 0 + assert fmt.download_url.startswith("https://") + + def test_distilled_lora_info_has_url(self) -> None: + assert DISTILLED_LORA_INFO.download_url.startswith("https://") + assert DISTILLED_LORA_INFO.size_gb > 0 +``` + +- [ ] **Step 2: Run tests** + +Run: `cd backend && uv run pytest tests/test_model_guide.py -v --tb=short` +Expected: All tests PASS + +- [ ] **Step 3: Commit** + +```bash +cd D:/git/directors-desktop +git add backend/tests/test_model_guide.py +git commit -m "test(quantized-models): add model guide recommendation tests" +``` + +--- + +### Task 4: Wire ModelScanner into handler + routes + integration tests + +**Context:** Now connect the scanner to the API layer. Add `scan_video_models()`, `select_video_model()`, and `video_model_guide()` to `ModelsHandler`, add thin routes, and wire `ModelScannerImpl` into `ServiceBundle` and `AppHandler`. The test infrastructure (`conftest.py`) needs to wire `FakeModelScanner`. + +**Files:** +- Modify: `backend/handlers/models_handler.py` — add 3 new methods +- Modify: `backend/_routes/models.py` — add 3 new routes +- Modify: `backend/app_handler.py` — add `ModelScanner` to `AppHandler.__init__()` and `ServiceBundle`, wire `ModelScannerImpl` in `build_default_service_bundle()` +- Modify: `backend/tests/conftest.py` — wire `FakeModelScanner` in test `ServiceBundle` +- Create: `backend/tests/test_model_selection.py` — integration tests + +- [ ] **Step 1: Write the failing integration test** + +Create `backend/tests/test_model_selection.py`: + +```python +"""Integration tests for video model scan, select, and guide endpoints.""" + +from __future__ import annotations + +import struct +from pathlib import Path + +import pytest + +from app_handler import AppHandler + + +def _write_minimal_gguf(path: Path) -> None: + with open(path, "wb") as f: + f.write(b"GGUF") + f.write(struct.pack(" None: + resp = client.get("/api/models/video/scan") + assert resp.status_code == 200 + data = resp.json() + assert data["models"] == [] + assert data["distilled_lora_found"] is False + + +class TestVideoModelSelect: + def test_select_nonexistent_model_returns_400(self, client) -> None: + resp = client.post("/api/models/video/select", json={"model": "nonexistent.gguf"}) + assert resp.status_code == 400 + + def test_select_valid_model_updates_settings(self, client, test_state: AppHandler) -> None: + # Create a model file in the models dir + models_dir = test_state.config.models_dir + gguf_path = models_dir / "test-model.gguf" + _write_minimal_gguf(gguf_path) + + # First set the custom path to models_dir + client.post("/api/settings", json={"customVideoModelPath": str(models_dir)}) + + resp = client.post("/api/models/video/select", json={"model": "test-model.gguf"}) + assert resp.status_code == 200 + + # Verify settings updated + assert test_state.state.app_settings.selected_video_model == "test-model.gguf" + + +class TestVideoModelGuide: + def test_guide_returns_formats_and_recommendation(self, client) -> None: + resp = client.get("/api/models/video/guide") + assert resp.status_code == 200 + data = resp.json() + assert "formats" in data + assert "recommended_format" in data + assert "distilled_lora" in data + assert len(data["formats"]) > 0 +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cd backend && uv run pytest tests/test_model_selection.py -v --tb=short` +Expected: FAIL — routes don't exist yet + +- [ ] **Step 3: Update ModelsHandler** + +In `backend/handlers/models_handler.py`, update the imports: + +```python +from __future__ import annotations + +import logging +from pathlib import Path +from threading import RLock +from typing import TYPE_CHECKING + +from api_types import ( + DetectedModel, + ModelFileStatus, + ModelInfo, + ModelsStatusResponse, + TextEncoderStatus, + VideoModelGuideResponse, + VideoModelScanResponse, +) +from handlers.base import StateHandlerBase, with_state_lock +from runtime_config.model_download_specs import MODEL_FILE_ORDER, resolve_required_model_types +from services.model_scanner.model_guide_data import DISTILLED_LORA_INFO, MODEL_FORMATS, recommend_format +from state.app_state_types import AppState, AvailableFiles + +if TYPE_CHECKING: + from runtime_config.runtime_config import RuntimeConfig + from services.gpu_info.gpu_info import GpuInfo + from services.model_scanner.model_scanner import ModelScanner + +logger = logging.getLogger(__name__) +``` + +Update `ModelsHandler.__init__()` to accept `model_scanner`: + +```python +class ModelsHandler(StateHandlerBase): + def __init__( + self, + state: AppState, + lock: RLock, + config: RuntimeConfig, + model_scanner: ModelScanner, + gpu_info_service: GpuInfo, + ) -> None: + super().__init__(state, lock) + self._config = config + self._model_scanner = model_scanner + self._gpu_info = gpu_info_service +``` + +Add these three methods to `ModelsHandler` (after `get_models_status()`): + +```python + def _get_video_models_dir(self) -> Path: + """Get the folder to scan for video models.""" + custom = self.state.app_settings.custom_video_model_path + if custom: + return Path(custom) + return self._config.models_dir + + def scan_video_models(self) -> VideoModelScanResponse: + """Scan the configured folder for video model files.""" + folder = self._get_video_models_dir() + models = self._model_scanner.scan_video_models(folder) + + # Check for distilled LoRA in the same folder or loras/ subfolder + lora_found = self._check_distilled_lora(folder) + + return VideoModelScanResponse( + models=models, + distilled_lora_found=lora_found, + ) + + def select_video_model(self, filename: str) -> dict[str, str]: + """Select a video model by filename. Validates it exists and no generation running.""" + from _routes._errors import HTTPError + from state.app_state_types import GenerationRunning, GpuSlot + + # Guard: don't swap model while generating (matches _ensure_no_running_generation pattern) + match self.state.gpu_slot: + case GpuSlot(generation=GenerationRunning()): + raise HTTPError(409, "Cannot change model while generation is running") + + folder = self._get_video_models_dir() + target = folder / filename + + # Check file or folder exists + if not target.exists(): + raise HTTPError(400, f"Model file not found: {filename}") + + self.state.app_settings.selected_video_model = filename + return {"status": "ok", "selected": filename} + + def video_model_guide(self) -> VideoModelGuideResponse: + """Return GPU info and format recommendations for the model guide UI.""" + gpu_name: str | None = None + vram_gb: int | None = None + try: + gpu_name = self._gpu_info.get_device_name() + vram_gb = self._gpu_info.get_vram_total_gb() + except Exception: + logger.warning("Could not get GPU info for model guide", exc_info=True) + + return VideoModelGuideResponse( + gpu_name=gpu_name, + vram_gb=vram_gb, + recommended_format=recommend_format(vram_gb), + formats=MODEL_FORMATS, + distilled_lora=DISTILLED_LORA_INFO, + ) + + @staticmethod + def _check_distilled_lora(folder: Path) -> bool: + """Check if a distilled LoRA file exists in the folder or loras/ subfolder.""" + if not folder.exists(): + return False + + search_dirs = [folder] + loras_sub = folder / "loras" + if loras_sub.exists(): + search_dirs.append(loras_sub) + + try: + for d in search_dirs: + for f in d.iterdir(): + if f.is_file() and "distill" in f.name.lower() and f.suffix in (".safetensors", ".bin"): + return True + except OSError: + pass + return False +``` + +- [ ] **Step 4: Add routes** + +In `backend/_routes/models.py`, add these imports at the top: + +```python +from api_types import ( + DownloadProgressResponse, + ModelDownloadRequest, + ModelDownloadStartResponse, + ModelInfo, + ModelsStatusResponse, + SelectModelRequest, + TextEncoderDownloadResponse, + VideoModelGuideResponse, + VideoModelScanResponse, +) +``` + +Add these routes at the end of the file: + +```python +@router.get("/models/video/scan", response_model=VideoModelScanResponse) +def route_scan_video_models(handler: AppHandler = Depends(get_state_service)) -> VideoModelScanResponse: + return handler.models.scan_video_models() + + +@router.post("/models/video/select") +def route_select_video_model( + req: SelectModelRequest, + handler: AppHandler = Depends(get_state_service), +) -> dict[str, str]: + result = handler.models.select_video_model(req.model) + handler.settings.save_settings() + return result + + +@router.get("/models/video/guide", response_model=VideoModelGuideResponse) +def route_video_model_guide(handler: AppHandler = Depends(get_state_service)) -> VideoModelGuideResponse: + return handler.models.video_model_guide() +``` + +- [ ] **Step 5: Update AppHandler and ServiceBundle** + +In `backend/app_handler.py`: + +Add to imports: +```python +from services.model_scanner.model_scanner import ModelScanner +``` + +Add `model_scanner: ModelScanner` parameter to `AppHandler.__init__()` (after `ic_lora_model_downloader`): +```python + model_scanner: ModelScanner, +``` + +Store it: +```python + self.model_scanner = model_scanner +``` + +Update the `self.models = ModelsHandler(...)` call to pass scanner and gpu_info: +```python + self.models = ModelsHandler( + state=self.state, + lock=self._lock, + config=config, + model_scanner=model_scanner, + gpu_info_service=gpu_info, + ) +``` + +Add `model_scanner: ModelScanner` to the `ServiceBundle` dataclass (after `ic_lora_model_downloader`): +```python + model_scanner: ModelScanner +``` + +Update `build_default_service_bundle()` — add import: +```python + from services.model_scanner.model_scanner_impl import ModelScannerImpl +``` + +Add to the returned `ServiceBundle(...)`: +```python + model_scanner=ModelScannerImpl(), +``` + +Update `build_initial_state()` — pass `model_scanner=bundle.model_scanner` to `AppHandler(...)`. + +- [ ] **Step 6: Update conftest.py** + +In `backend/tests/conftest.py`, add `model_scanner=fake_services.model_scanner` to the `ServiceBundle(...)` constructor call. + +- [ ] **Step 7: Run integration tests** + +Run: `cd backend && uv run pytest tests/test_model_selection.py -v --tb=short` +Expected: All 4 tests PASS + +- [ ] **Step 8: Run full test suite + pyright** + +Run: `cd backend && uv run pyright && uv run pytest -v --tb=short` +Expected: pyright clean, all tests pass + +- [ ] **Step 9: Commit** + +```bash +cd D:/git/directors-desktop +git add backend/handlers/models_handler.py backend/_routes/models.py backend/app_handler.py backend/tests/conftest.py backend/tests/test_model_selection.py +git commit -m "feat(quantized-models): wire ModelScanner to handler, routes, and integration tests" +``` + +--- + +### Task 5: Add Models tab to SettingsModal + +**Context:** Add a new "Models" tab to the existing SettingsModal. This tab shows: (1) a dropdown of detected model files, (2) the model folder path with Change/Open/Scan buttons, (3) GPU info summary, (4) a "Model Guide" button, and (5) distilled LoRA status. The tab calls the 3 new API endpoints from Task 4. + +**Files:** +- Modify: `frontend/components/SettingsModal.tsx` — add `models` tab +- Modify: `frontend/contexts/AppSettingsContext.tsx` — if needed, add new settings fields to TypeScript types + +- [ ] **Step 1: Add TypeScript types for new settings** + +In `frontend/contexts/AppSettingsContext.tsx`: + +Add these fields to the `AppSettings` interface (after `hasCivitaiApiKey`): + +```typescript +customVideoModelPath: string +selectedVideoModel: string +``` + +Add these default values to the `DEFAULT_APP_SETTINGS` constant (after `hasCivitaiApiKey: false`): + +```typescript +customVideoModelPath: '', +selectedVideoModel: '', +``` + +- [ ] **Step 2: Add Models tab to SettingsModal** + +In `frontend/components/SettingsModal.tsx`: + +Update the `TabId` type: +```typescript +type TabId = 'general' | 'apiKeys' | 'inference' | 'models' | 'promptEnhancer' | 'about' +``` + +Add `Cpu` icon to the lucide imports (for GPU info display). + +Add state variables for the Models tab (inside the `SettingsModal` component, near the other state): + +```typescript +const [videoModels, setVideoModels] = useState([]) +const [distilledLoraFound, setDistilledLoraFound] = useState(false) +const [modelScanning, setModelScanning] = useState(false) +const [gpuInfo, setGpuInfo] = useState<{ name: string | null; vram: number | null } | null>(null) +const [showModelGuide, setShowModelGuide] = useState(false) +``` + +Add a `useEffect` to load model data when the Models tab is active: + +```typescript +useEffect(() => { + if (!isOpen || activeTab !== 'models') return + let cancelled = false + const load = async () => { + try { + const backendUrl = await window.electronAPI.getBackendUrl() + const [scanRes, guideRes] = await Promise.all([ + fetch(`${backendUrl}/api/models/video/scan`), + fetch(`${backendUrl}/api/models/video/guide`), + ]) + if (cancelled) return + if (scanRes.ok) { + const data = await scanRes.json() + setVideoModels(data.models) + setDistilledLoraFound(data.distilled_lora_found) + } + if (guideRes.ok) { + const guide = await guideRes.json() + setGpuInfo({ name: guide.gpu_name, vram: guide.vram_gb }) + } + } catch (err) { + logger.error('Failed to load model data', err) + } + } + load() + return () => { cancelled = true } +}, [isOpen, activeTab]) +``` + +Add the Models tab button to the tab bar (between Inference and About): +```tsx + +``` + +Add the Models tab content (inside the tab content switch): + +```tsx +{activeTab === 'models' && ( +
+ {/* GPU Info */} + {gpuInfo && ( +
+
+ + + {gpuInfo.name || 'Unknown GPU'} + {gpuInfo.vram && ` — ${gpuInfo.vram} GB VRAM`} + +
+
+ )} + + {/* Video Model Selection */} +
+ + +
+ + {/* Model Folder */} +
+ +
+ + + + +
+
+ + {/* Distilled LoRA Status */} + {settings.selectedVideoModel && !distilledLoraFound && ( +
+

+ + This model may need a distilled LoRA for fast generation. + Check the Model Guide for download links. +

+
+ )} + + {/* Model Guide Button */} + +
+)} +``` + +- [ ] **Step 3: Run TypeScript typecheck** + +Run: `pnpm typecheck:ts` +Expected: No errors + +- [ ] **Step 4: Commit** + +```bash +cd D:/git/directors-desktop +git add frontend/components/SettingsModal.tsx frontend/contexts/AppSettingsContext.tsx +git commit -m "feat(quantized-models): add Models tab to SettingsModal" +``` + +--- + +### Task 6: Create ModelGuideDialog component + +**Context:** A standalone modal dialog that shows GPU-based recommendations and download links for each model format. Fetches data from `GET /api/models/video/guide`. Design follows the CLAUDE.md design standards (OKLCH colors, rounded corners, hover states, etc). + +**Files:** +- Create: `frontend/components/ModelGuideDialog.tsx` +- Modify: `frontend/components/SettingsModal.tsx` — import and render ModelGuideDialog + +- [ ] **Step 1: Create ModelGuideDialog.tsx** + +Create `frontend/components/ModelGuideDialog.tsx` with the full dialog implementation. It should: + +- Fetch from `/api/models/video/guide` on mount +- Show GPU name + VRAM at the top +- Grid of format cards (responsive: 1 col on small, 2 on medium+) +- Each card: format name, size, quality tier badge, "Recommended" pill if it matches recommended_format, "Download" button that opens URL in browser via `window.electronAPI.openExternal` or `window.open` +- Distilled LoRA callout at the bottom (amber background) with its own download button +- Model folder path display +- Close button + +Style with Tailwind using the project's dark theme (`bg-zinc-900`, `border-zinc-700`, `text-zinc-300`, purple accents for recommended items). + +```tsx +import { AlertCircle, Download, ExternalLink, Monitor, X } from 'lucide-react' +import React, { useEffect, useState } from 'react' +import { Button } from './ui/button' +import { logger } from '../lib/logger' + +interface ModelFormat { + id: string + name: string + size_gb: number + min_vram_gb: number + quality_tier: string + needs_distilled_lora: boolean + download_url: string + description: string +} + +interface DistilledLora { + name: string + size_gb: number + download_url: string + description: string +} + +interface GuideData { + gpu_name: string | null + vram_gb: number | null + recommended_format: string + formats: ModelFormat[] + distilled_lora: DistilledLora +} + +interface ModelGuideDialogProps { + isOpen: boolean + onClose: () => void +} + +const QUALITY_COLORS: Record = { + 'Best': 'bg-green-500/20 text-green-400 border-green-500/30', + 'Excellent': 'bg-blue-500/20 text-blue-400 border-blue-500/30', + 'Very Good': 'bg-purple-500/20 text-purple-400 border-purple-500/30', + 'Good': 'bg-amber-500/20 text-amber-400 border-amber-500/30', +} + +export function ModelGuideDialog({ isOpen, onClose }: ModelGuideDialogProps) { + const [data, setData] = useState(null) + const [loading, setLoading] = useState(true) + + useEffect(() => { + if (!isOpen) return + let cancelled = false + setLoading(true) + const load = async () => { + try { + const backendUrl = await window.electronAPI.getBackendUrl() + const res = await fetch(`${backendUrl}/api/models/video/guide`) + if (!cancelled && res.ok) { + setData(await res.json()) + } + } catch (err) { + logger.error('Failed to load model guide', err) + } finally { + if (!cancelled) setLoading(false) + } + } + load() + return () => { cancelled = true } + }, [isOpen]) + + if (!isOpen) return null + + const openUrl = (url: string) => { + window.open(url, '_blank') + } + + return ( +
+
+ {/* Header */} +
+

Video Model Guide

+ +
+ +
+ {loading ? ( +
Loading...
+ ) : data ? ( + <> + {/* GPU Info Banner */} +
+ +
+

+ {data.gpu_name || 'GPU not detected'} +

+

+ {data.vram_gb + ? `${data.vram_gb} GB VRAM available` + : 'VRAM could not be determined'} +

+
+
+ + {/* Format Cards */} +
+ {data.formats.map((fmt) => { + const isRecommended = fmt.id === data.recommended_format + return ( +
+ {isRecommended && ( + + Recommended + + )} +

{fmt.name}

+
+ {fmt.size_gb} GB + · + + {fmt.quality_tier} + + · + ≥{fmt.min_vram_gb} GB VRAM +
+

{fmt.description}

+ +
+ ) + })} +
+ + {/* Distilled LoRA Notice */} +
+
+ +
+

{data.distilled_lora.name}

+

{data.distilled_lora.description}

+ +
+
+
+ + {/* Instructions */} +
+

Setup Instructions

+
    +
  1. Download the model file for your GPU from the links above
  2. +
  3. Go to Settings → Models and set your model folder
  4. +
  5. If using GGUF or NF4, also download the distilled LoRA
  6. +
  7. Select your model from the dropdown
  8. +
  9. Generate!
  10. +
+
+ + ) : ( +
Failed to load model guide data.
+ )} +
+
+
+ ) +} +``` + +- [ ] **Step 2: Import and render in SettingsModal** + +In `frontend/components/SettingsModal.tsx`, add import: +```typescript +import { ModelGuideDialog } from './ModelGuideDialog' +``` + +Render the dialog at the end of the SettingsModal component's return, just before the closing ``: +```tsx + setShowModelGuide(false)} /> +``` + +- [ ] **Step 3: Run TypeScript typecheck** + +Run: `pnpm typecheck:ts` +Expected: No errors + +- [ ] **Step 4: Commit** + +```bash +cd D:/git/directors-desktop +git add frontend/components/ModelGuideDialog.tsx frontend/components/SettingsModal.tsx +git commit -m "feat(quantized-models): add ModelGuideDialog popup with GPU recommendations" +``` + +--- + +### Task 7: Update PipelinesHandler to support multiple pipeline classes + +**Context:** Currently `PipelinesHandler._create_video_pipeline()` always uses `self._fast_video_pipeline_class`. We need it to check `selected_video_model` from settings, determine the file format, and use the appropriate pipeline class. For now, we wire the logic but use the existing `LTXFastVideoPipeline` for all safetensors files. The actual GGUF and NF4 pipeline classes will be added in Tasks 8 and 9. + +**Files:** +- Modify: `backend/handlers/pipelines_handler.py` — update `_create_video_pipeline()` to route by format +- Modify: `backend/app_handler.py` — pass additional pipeline classes to PipelinesHandler + +- [ ] **Step 1: Update PipelinesHandler constructor** + +In `backend/handlers/pipelines_handler.py`, add `gguf_video_pipeline_class` and `nf4_video_pipeline_class` parameters to `__init__()`: + +```python + def __init__( + self, + state: AppState, + lock: RLock, + text_handler: TextHandler, + gpu_cleaner: GpuCleaner, + fast_video_pipeline_class: type[FastVideoPipeline], + gguf_video_pipeline_class: type[FastVideoPipeline] | None, + nf4_video_pipeline_class: type[FastVideoPipeline] | None, + image_generation_pipeline_class: type[ImageGenerationPipeline], + # ... rest unchanged + ) -> None: + # ... existing code ... + self._gguf_video_pipeline_class = gguf_video_pipeline_class + self._nf4_video_pipeline_class = nf4_video_pipeline_class +``` + +- [ ] **Step 2: Update _create_video_pipeline() to route by format** + +In `_create_video_pipeline()`, replace the line that calls `self._fast_video_pipeline_class.create(...)` with format-aware routing: + +```python + def _create_video_pipeline( + self, + model_type: VideoPipelineModelType, + lora_path: str | None = None, + lora_weight: float = 1.0, + ) -> VideoPipelineState: + gemma_root = self._text_handler.resolve_gemma_root() + + # Determine checkpoint path and pipeline class based on selected model + selected = self.state.app_settings.selected_video_model + custom_dir = self.state.app_settings.custom_video_model_path + pipeline_class = self._fast_video_pipeline_class + + if selected: + base_dir = Path(custom_dir) if custom_dir else self._config.models_dir + model_path = base_dir / selected + checkpoint_path = str(model_path) + + # Validate model file/folder still exists on disk + if not model_path.exists(): + raise FileNotFoundError( + f"Selected model not found: {checkpoint_path}. " + "Go to Settings → Models to select a different model." + ) + + if selected.endswith(".gguf") and self._gguf_video_pipeline_class is not None: + pipeline_class = self._gguf_video_pipeline_class + elif model_path.is_dir() and self._nf4_video_pipeline_class is not None: + # NF4 models are folders + pipeline_class = self._nf4_video_pipeline_class + # else: safetensors files use default pipeline + else: + checkpoint_path = str(self._config.model_path("checkpoint")) + + upsampler_path = str(self._config.model_path("upsampler")) + + pipeline = pipeline_class.create( + checkpoint_path, + gemma_root, + upsampler_path, + self._device, + lora_path=lora_path, + lora_weight=lora_weight, + ) + + state = VideoPipelineState( + pipeline=pipeline, + warmth=VideoPipelineWarmth.COLD, + is_compiled=False, + lora_path=lora_path, + ) + state = self._compile_if_enabled(state) + + # Apply FFN chunking if enabled and torch.compile is not active + chunk_count = self.state.app_settings.ffn_chunk_count + if chunk_count > 0 and not state.is_compiled: + try: + transformer: torch.nn.Module = state.pipeline.pipeline.model_ledger.transformer() # type: ignore[union-attr] + patch_ffn_chunking(transformer, num_chunks=chunk_count) # pyright: ignore[reportUnknownArgumentType] + except AttributeError: + logger.debug("FFN chunking skipped — pipeline has no model_ledger") + + # Install TeaCache denoising loop patch + tea_threshold = self.state.app_settings.tea_cache_threshold + try: + install_tea_cache_patch(tea_threshold) + except (ImportError, AttributeError): + logger.debug("TeaCache skipped — ltx_pipelines not available") + + return state +``` + +Add `from pathlib import Path` to the imports. + +- [ ] **Step 3: Update AppHandler to pass new pipeline class params** + +In `backend/app_handler.py`, update the `self.pipelines = PipelinesHandler(...)` call to include: + +```python + gguf_video_pipeline_class=None, # Will be set in Task 8 + nf4_video_pipeline_class=None, # Will be set in Task 9 +``` + +- [ ] **Step 4: Run full test suite + pyright** + +Run: `cd backend && uv run pyright && uv run pytest -v --tb=short` +Expected: pyright clean, all tests pass + +- [ ] **Step 5: Commit** + +```bash +cd D:/git/directors-desktop +git add backend/handlers/pipelines_handler.py backend/app_handler.py +git commit -m "feat(quantized-models): update PipelinesHandler to route by model format" +``` + +--- + +### Task 8: Create GGUF Video Pipeline + +**Context:** This is the pipeline class that loads GGUF quantized LTX models using the `diffusers` GGUF quantizer infrastructure. It implements the `FastVideoPipeline` protocol so it can be swapped in by PipelinesHandler. This is the most complex task — it requires integrating the diffusers GGUF loading with the ltx_pipelines inference code. + +**Note:** This pipeline cannot be fully tested without real GGUF model files and a GPU. The implementation is based on the diffusers GGUF quantizer API. Manual testing required. + +**Files:** +- Create: `backend/services/fast_video_pipeline/gguf_fast_video_pipeline.py` +- Modify: `backend/app_handler.py` — wire the GGUF pipeline class + +- [ ] **Step 1: Create the GGUF pipeline class** + +Create `backend/services/fast_video_pipeline/gguf_fast_video_pipeline.py`: + +```python +"""GGUF quantized LTX video pipeline. + +Loads LTX-Video transformer weights from a GGUF file using the diffusers +GGUF quantizer, while loading VAE, text encoder, and upsampler normally. +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Iterator +from typing import Any, Final, cast + +import torch + +from api_types import ImageConditioningInput +from services.ltx_pipeline_common import default_tiling_config, encode_video_output, video_chunks_number +from services.services_utils import AudioOrNone, TilingConfigType + +logger = logging.getLogger(__name__) + + +class GGUFFastVideoPipeline: + """FastVideoPipeline implementation for GGUF quantized models. + + Uses diffusers GGUF quantizer to load quantized transformer weights. + Falls back to error with install instructions if gguf package missing. + """ + + pipeline_kind: Final = "fast" + + @staticmethod + def create( + checkpoint_path: str, + gemma_root: str | None, + upsampler_path: str, + device: torch.device, + lora_path: str | None = None, + lora_weight: float = 1.0, + ) -> "GGUFFastVideoPipeline": + return GGUFFastVideoPipeline( + checkpoint_path=checkpoint_path, + gemma_root=gemma_root, + upsampler_path=upsampler_path, + device=device, + lora_path=lora_path, + lora_weight=lora_weight, + ) + + def __init__( + self, + checkpoint_path: str, + gemma_root: str | None, + upsampler_path: str, + device: torch.device, + lora_path: str | None = None, + lora_weight: float = 1.0, + ) -> None: + try: + import gguf # noqa: F401 + except ImportError: + raise RuntimeError( + "GGUF model support requires the 'gguf' package. " + "Install it with: pip install gguf>=0.10.0" + ) from None + + # TODO: Implement GGUF model loading using diffusers GGUFQuantizer. + # This requires understanding the exact diffusers API for loading + # LTX-Video transformer from GGUF format. The implementation will: + # + # 1. Load the GGUF file using diffusers.quantizers.gguf + # 2. Build the LTX transformer model architecture + # 3. Load quantized weights into the transformer + # 4. Load VAE, text encoder, and upsampler normally via ltx_pipelines + # 5. Assemble into a pipeline that matches DistilledPipeline's interface + # + # For now, raise NotImplementedError until we can test with real model files. + raise NotImplementedError( + "GGUF pipeline loading is not yet fully implemented. " + "This requires testing with real GGUF model files to validate the " + "diffusers GGUF quantizer integration with ltx_pipelines." + ) + + def _run_inference( + self, + prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + images: list[ImageConditioningInput], + tiling_config: TilingConfigType, + ) -> tuple[torch.Tensor | Iterator[torch.Tensor], AudioOrNone]: + raise NotImplementedError + + @torch.inference_mode() + def generate( + self, + prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + images: list[ImageConditioningInput], + output_path: str, + ) -> None: + tiling_config = default_tiling_config() + video, audio = self._run_inference( + prompt=prompt, seed=seed, height=height, width=width, + num_frames=num_frames, frame_rate=frame_rate, images=images, + tiling_config=tiling_config, + ) + chunks = video_chunks_number(num_frames, tiling_config) + encode_video_output(video=video, audio=audio, fps=int(frame_rate), + output_path=output_path, video_chunks_number_value=chunks) + + @torch.inference_mode() + def warmup(self, output_path: str) -> None: + try: + self.generate( + prompt="test warmup", seed=42, height=256, width=384, + num_frames=9, frame_rate=8, images=[], output_path=output_path, + ) + finally: + if os.path.exists(output_path): + os.unlink(output_path) + + def compile_transformer(self) -> None: + logger.info("Skipping torch.compile for GGUF pipeline — not supported with quantized weights") +``` + +- [ ] **Step 2: Wire in AppHandler** + +In `backend/app_handler.py`, in `build_default_service_bundle()`, add import: +```python + from services.fast_video_pipeline.gguf_fast_video_pipeline import GGUFFastVideoPipeline +``` + +In the `self.pipelines = PipelinesHandler(...)` call in `AppHandler.__init__`, update: +```python + gguf_video_pipeline_class=gguf_video_pipeline_class, +``` + +Add `gguf_video_pipeline_class: type[FastVideoPipeline] | None` to `AppHandler.__init__()` params and `ServiceBundle`. + +In `build_default_service_bundle()`, add to the returned bundle: +```python + gguf_video_pipeline_class=GGUFFastVideoPipeline, +``` + +- [ ] **Step 3: Run pyright** + +Run: `cd backend && uv run pyright` +Expected: `0 errors` + +- [ ] **Step 4: Commit** + +```bash +cd D:/git/directors-desktop +git add backend/services/fast_video_pipeline/gguf_fast_video_pipeline.py backend/app_handler.py +git commit -m "feat(quantized-models): add GGUF video pipeline scaffold (NotImplementedError until tested with real models)" +``` + +--- + +### Task 9: Create NF4 Video Pipeline + +**Context:** NF4 pipeline using BitsAndBytes 4-bit quantization, following the same pattern as the existing FLUX Klein pipeline. Same situation as GGUF — scaffold with NotImplementedError until tested with real models. + +**Files:** +- Create: `backend/services/fast_video_pipeline/nf4_fast_video_pipeline.py` +- Modify: `backend/app_handler.py` — wire the NF4 pipeline class + +- [ ] **Step 1: Create the NF4 pipeline class** + +Create `backend/services/fast_video_pipeline/nf4_fast_video_pipeline.py`: + +```python +"""NF4 (4-bit BitsAndBytes) quantized LTX video pipeline. + +Uses BitsAndBytes NF4 quantization to load the LTX transformer at 4-bit +precision, following the same pattern as FluxKleinImagePipeline. +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Iterator +from typing import Any, Final, cast + +import torch + +from api_types import ImageConditioningInput +from services.ltx_pipeline_common import default_tiling_config, encode_video_output, video_chunks_number +from services.services_utils import AudioOrNone, TilingConfigType + +logger = logging.getLogger(__name__) + + +class NF4FastVideoPipeline: + """FastVideoPipeline implementation for NF4 quantized models. + + Uses BitsAndBytes 4-bit quantization (same approach as FLUX Klein). + """ + + pipeline_kind: Final = "fast" + + @staticmethod + def create( + checkpoint_path: str, + gemma_root: str | None, + upsampler_path: str, + device: torch.device, + lora_path: str | None = None, + lora_weight: float = 1.0, + ) -> "NF4FastVideoPipeline": + return NF4FastVideoPipeline( + checkpoint_path=checkpoint_path, + gemma_root=gemma_root, + upsampler_path=upsampler_path, + device=device, + lora_path=lora_path, + lora_weight=lora_weight, + ) + + def __init__( + self, + checkpoint_path: str, + gemma_root: str | None, + upsampler_path: str, + device: torch.device, + lora_path: str | None = None, + lora_weight: float = 1.0, + ) -> None: + # TODO: Implement NF4 model loading using BitsAndBytes. + # Pattern from FLUX Klein pipeline (flux_klein_pipeline.py): + # + # 1. Create BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4") + # 2. Load LTX transformer with quantization config + # 3. Load VAE and text encoder normally + # 4. Enable model CPU offload for text encoder + # 5. Assemble pipeline matching DistilledPipeline interface + # + # Requires testing with real NF4 quantized LTX model files. + raise NotImplementedError( + "NF4 pipeline loading is not yet fully implemented. " + "This requires testing with real NF4 quantized model files to validate " + "the BitsAndBytes integration with ltx_pipelines." + ) + + def _run_inference( + self, + prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + images: list[ImageConditioningInput], + tiling_config: TilingConfigType, + ) -> tuple[torch.Tensor | Iterator[torch.Tensor], AudioOrNone]: + raise NotImplementedError + + @torch.inference_mode() + def generate( + self, + prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + images: list[ImageConditioningInput], + output_path: str, + ) -> None: + tiling_config = default_tiling_config() + video, audio = self._run_inference( + prompt=prompt, seed=seed, height=height, width=width, + num_frames=num_frames, frame_rate=frame_rate, images=images, + tiling_config=tiling_config, + ) + chunks = video_chunks_number(num_frames, tiling_config) + encode_video_output(video=video, audio=audio, fps=int(frame_rate), + output_path=output_path, video_chunks_number_value=chunks) + + @torch.inference_mode() + def warmup(self, output_path: str) -> None: + try: + self.generate( + prompt="test warmup", seed=42, height=256, width=384, + num_frames=9, frame_rate=8, images=[], output_path=output_path, + ) + finally: + if os.path.exists(output_path): + os.unlink(output_path) + + def compile_transformer(self) -> None: + logger.info("Skipping torch.compile for NF4 pipeline — not supported with quantized weights") +``` + +- [ ] **Step 2: Wire in AppHandler** + +In `backend/app_handler.py`, in `build_default_service_bundle()`, add import: +```python + from services.fast_video_pipeline.nf4_fast_video_pipeline import NF4FastVideoPipeline +``` + +Add `nf4_video_pipeline_class` to `AppHandler.__init__()` params and `ServiceBundle`. + +Update `self.pipelines = PipelinesHandler(...)` call: +```python + nf4_video_pipeline_class=nf4_video_pipeline_class, +``` + +In `build_default_service_bundle()`, add to returned bundle: +```python + nf4_video_pipeline_class=NF4FastVideoPipeline, +``` + +- [ ] **Step 3: Run pyright + full tests** + +Run: `cd backend && uv run pyright && uv run pytest -v --tb=short` +Expected: pyright clean, all tests pass + +- [ ] **Step 4: Commit** + +```bash +cd D:/git/directors-desktop +git add backend/services/fast_video_pipeline/nf4_fast_video_pipeline.py backend/app_handler.py +git commit -m "feat(quantized-models): add NF4 video pipeline scaffold" +``` + +--- + +### Task 10: Add README section for custom video models + +**Context:** Add a "Custom Video Models" section to the README explaining what formats are supported, how to download, and how to set up. Keep it brief and user-friendly. + +**Files:** +- Modify: `README.md` + +- [ ] **Step 1: Read the current README** + +Read `README.md` to find the right place to add the section. + +- [ ] **Step 2: Add Custom Video Models section** + +Add this section to the README (after the main features/usage section, before any contribution or license section): + +```markdown +## Custom Video Models + +Directors Desktop supports multiple LTX 2.3 model formats, so you can run on GPUs with less VRAM. + +| Your GPU VRAM | Recommended Format | File Size | +|---------------|-------------------|-----------| +| 32 GB+ | BF16 (auto-downloaded) | ~43 GB | +| 20-31 GB | [FP8 Checkpoint](https://huggingface.co/Lightricks/LTX-Video-2.3-22b-distilled) | ~22 GB | +| 16-19 GB | [GGUF Q5_K](https://huggingface.co/city96/LTX-Video-2.3-22b-0.9.7-dev-gguf) | ~15 GB | +| 10-15 GB | [GGUF Q4_K](https://huggingface.co/city96/LTX-Video-2.3-22b-0.9.7-dev-gguf) | ~12 GB | + +### Setup + +1. Download the model file for your GPU from the links above +2. Open **Settings → Models** and set your model folder +3. If using GGUF or NF4, also download the [distilled LoRA](https://huggingface.co/Lightricks/LTX-Video-2.3-22b-distilled) +4. Select your model from the dropdown +5. Generate! + +The app also has a built-in **Model Guide** (Settings → Models → Open Model Guide) that detects your GPU and recommends the best format automatically. +``` + +- [ ] **Step 3: Commit** + +```bash +cd D:/git/directors-desktop +git add README.md +git commit -m "docs(quantized-models): add Custom Video Models section to README" +``` + +--- + +## Summary of Tasks + +| Task | Description | Dependencies | +|------|-------------|-------------| +| 1 | Add gguf dependency, settings fields, API types | None | +| 2 | Create ModelScanner service (Protocol + Impl + Fake + Tests) | Task 1 | +| 3 | Add model guide recommendation logic and tests | Task 2 | +| 4 | Wire scanner to handler + routes + integration tests | Tasks 2, 3 | +| 5 | Add Models tab to SettingsModal | Task 4 | +| 6 | Create ModelGuideDialog component | Task 5 | +| 7 | Update PipelinesHandler for format routing | Task 4 | +| 8 | Create GGUF Video Pipeline scaffold | Task 7 | +| 9 | Create NF4 Video Pipeline scaffold | Task 7 | +| 10 | Add README section | None (can run in parallel) | + +**Note on Tasks 8 & 9:** The GGUF and NF4 pipeline classes are scaffolded with `NotImplementedError`. Completing the actual model loading requires downloading real quantized model files and testing on GPU hardware. The scaffold ensures the routing, settings, and UI all work end-to-end — when the pipeline `__init__` is implemented, everything else is already wired up. diff --git a/docs/superpowers/specs/2026-03-22-quantized-models-design.md b/docs/superpowers/specs/2026-03-22-quantized-models-design.md new file mode 100644 index 00000000..9def0051 --- /dev/null +++ b/docs/superpowers/specs/2026-03-22-quantized-models-design.md @@ -0,0 +1,362 @@ +# Quantized Video Model Support — Design Spec + +## Goal + +Let Directors Desktop run on 24GB GPUs (RTX 3090, 4070 Ti Super, etc.) by supporting quantized LTX 2.3 model formats (GGUF, NF4, FP8 checkpoints) alongside the existing BF16. Make it dead simple for users to understand what they need, where to get it, and how to set it up. + +## Architecture + +Three new components: + +1. **Model Scanner Service** (backend) — New service with Protocol + Fake that scans a folder for model files, detects format/quant type from file metadata (not file size heuristics), and returns structured results +2. **Extended Pipeline System** (backend) — New `FastVideoPipeline` implementations for GGUF and NF4 formats. The existing `PipelinesHandler` already accepts `type[FastVideoPipeline]` — we extend `_create_video_pipeline()` to pick the right class based on the selected model's format. The `FastVideoPipeline.create()` signature stays unchanged; format-specific config (quant type, BnB config) is passed via the checkpoint path pointing to the right file. +3. **Model Setup UI** (frontend) — New "Models" tab in SettingsModal + a "Model Guide" popup that recommends what to download based on GPU, with instructions in-app, in a popup, and in the GitHub README + +## Supported Formats + +| Format | Extension | Typical Size | Min VRAM | Quality | Speed | Needs Distilled LoRA? | +|--------|-----------|-------------|----------|---------|-------|-----------------------| +| BF16 (current default) | `.safetensors` | ~43 GB | 32 GB | Best | Baseline | No (already distilled) | +| FP8 Checkpoint | `.safetensors` | ~22 GB | ~20 GB | Excellent | ~Same | No (distilled version available) | +| GGUF Q8 | `.gguf` | ~22 GB | ~18 GB | Excellent | Slightly slower | Yes | +| GGUF Q5_K | `.gguf` | ~15 GB | ~13 GB | Very Good | Slightly slower | Yes | +| GGUF Q4_K | `.gguf` | ~12 GB | ~10 GB | Good | Slightly slower | Yes | +| NF4 (4-bit) | `.safetensors` + bnb config | ~12 GB | ~10 GB | Good | Slightly slower | Yes | + +## User Experience + +### Model Guide Popup + +Triggered on: +- First app launch (after initial model download completes) +- User clicks "Model Guide" button in Models tab +- No usable video model detected in configured path + +The popup is a friendly, visual guide — NOT a wall of text. It: + +1. **Detects GPU** — Shows "You have: NVIDIA RTX 3090 (24 GB VRAM)" at the top +2. **Recommends a format** — Based on VRAM: + - 32+ GB → "You can run the full BF16 model. You're all set!" + - 20-31 GB → "We recommend the FP8 checkpoint for your GPU" + - 16-19 GB → "We recommend GGUF Q8 or Q5_K for your GPU" + - 10-15 GB → "We recommend GGUF Q4_K or NF4 for your GPU" + - <10 GB → "Your GPU doesn't have enough VRAM for local generation. Use API mode instead." +3. **Shows download cards** — Each format as a card with: + - Format name + file size + - Quality rating (stars or bar) + - "Download from HuggingFace" button (opens browser to exact file URL) + - Recommended badge on the best fit for their GPU +4. **Shows where to put it** — Current model path displayed, with "Change Folder" button +5. **Distilled LoRA notice** — If they pick GGUF/NF4: "This model also needs the distilled LoRA file. Download it here: [link]" with its own download button + +### Models Tab in Settings + +New tab in SettingsModal (between "Inference" and "About"): + +``` +[General] [API Keys] [Inference] [Models] [About] +``` + +Contents: +- **Video Model** section: + - Dropdown: lists all detected model files in the configured path with format/size info + - Current selection highlighted, e.g. "LTX-2.3-Q4_K.gguf (12 GB) — GGUF Q4" + - Model folder path with "Change" and "Open Folder" buttons + - "Scan for Models" button (re-scans the folder) + - "Model Guide" button (opens the popup) +- **Distilled LoRA** section (only shown when active model needs it): + - Status: "Found" (green) or "Not found — required for this model" (red with download link) + - Path display +- **GPU Info** section: + - GPU name, VRAM total, VRAM currently used + - Estimated VRAM usage for selected model + +### README Section + +New section in GitHub README: + +```markdown +## Custom Video Models + +Directors Desktop supports multiple LTX 2.3 model formats. Pick the one +that fits your GPU: + +| Your GPU VRAM | Recommended Format | Download | +|---------------|-------------------|----------| +| 32 GB+ | BF16 (default, auto-downloaded) | Included | +| 20-31 GB | FP8 Checkpoint | [HuggingFace link] | +| 16-19 GB | GGUF Q8 or Q5_K | [HuggingFace link] | +| 10-15 GB | GGUF Q4_K | [HuggingFace link] | + +### Setup +1. Download the model file for your GPU +2. Open Settings → Models → Change model folder (or drop it in the default folder) +3. If using GGUF or NF4, also download the distilled LoRA: [link] +4. Select your model from the dropdown +5. Generate! +``` + +## Backend Design + +### New Types (`backend/api_types.py`) + +```python +class DetectedModel(TypedDict): + filename: str # e.g. "LTX-2.3-Q4_K.gguf" + path: str # Absolute path + format: str # "bf16" | "fp8" | "gguf" | "nf4" + quant_type: str | None # e.g. "Q4_K", "Q8_0", None for bf16/fp8 + size_bytes: int + size_gb: float + is_distilled: bool + display_name: str # Human-friendly, e.g. "LTX 2.3 — GGUF Q4_K (12 GB)" + +class ModelFormatInfo(TypedDict): + id: str + name: str + size_gb: float + min_vram_gb: int + quality_tier: str # "Best" | "Excellent" | "Very Good" | "Good" + needs_distilled_lora: bool + download_url: str + description: str +``` + +### Model Scanner Service (`backend/services/model_scanner/`) + +New service with Protocol + real + fake implementations (follows codebase pattern): + +**Protocol** (`model_scanner.py`): +```python +class ModelScanner(Protocol): + def scan_video_models(self, folder: Path) -> list[DetectedModel]: ... +``` + +**Implementation** (`model_scanner_impl.py`): +```python +class ModelScannerImpl: + def scan_video_models(self, folder: Path) -> list[DetectedModel]: + """Scan folder for supported video model files.""" +``` + +**Fake** (`tests/fakes/services.py`): +```python +class FakeModelScanner: + def __init__(self) -> None: + self._models: list[DetectedModel] = [] + def set_models(self, models: list[DetectedModel]) -> None: + self._models = models + def scan_video_models(self, folder: Path) -> list[DetectedModel]: + return self._models +``` + +**Detection logic** (NO file size heuristics — uses file metadata): +- `.gguf` extension → read GGUF file header for quant type metadata (using `gguf` Python reader or raw struct parsing) +- `.safetensors` with companion `config.json` or `model_index.json` containing dtype info → BF16 or FP8 +- `.safetensors` without companion config → read safetensors header metadata for dtype field +- Folder with `quantize_config.json` containing `"quant_type": "nf4"` → NF4 +- **Corrupt/unreadable files**: skip with warning log, don't crash the scan + +**Distilled LoRA detection**: +- Scans the same folder (and a `loras/` subfolder) for files matching known distilled LoRA filenames +- Returns `is_distilled: True` on the model if it's a known distilled checkpoint (the default BF16 and official FP8 distilled are pre-distilled) +- For GGUF/NF4 models: `is_distilled` is always `False` (they need the LoRA) + +### Pipeline Changes + +**No changes to `FastVideoPipeline` Protocol.** The `create()` signature stays the same: +```python +create(checkpoint_path, gemma_root, upsampler_path, device, lora_path, lora_weight) +``` + +Each new pipeline class handles its format-specific loading internally. The `checkpoint_path` points to the actual file (`.gguf` or `.safetensors`), and the pipeline detects the format from the file extension. + +**GGUF Pipeline** (`backend/services/fast_video_pipeline/gguf_fast_video_pipeline.py`): +- Implements `FastVideoPipeline` protocol +- Uses `diffusers` GGUF quantizer infrastructure to load transformer weights +- Loads VAE and text encoder normally (small enough for any GPU) +- Handles distilled LoRA injection when `lora_path` is provided +- Implements `generate()`, `warmup()`, `compile_transformer()` matching existing interface + +**NF4 Pipeline** (`backend/services/fast_video_pipeline/nf4_fast_video_pipeline.py`): +- Implements `FastVideoPipeline` protocol +- Uses BitsAndBytes `BnbQuantizationConfig` with `load_in_4bit=True, bnb_4bit_quant_type="nf4"` +- Same pattern as existing FLUX Klein pipeline (`flux_klein_pipeline.py`) +- Handles distilled LoRA injection + +### PipelinesHandler Changes + +`_create_video_pipeline()` currently hardcodes `self._fast_video_pipeline_class`. Updated to: + +1. Read `selected_video_model` from `AppSettings` +2. If set, determine format from the selected model's file extension +3. Pick the right pipeline class: `.gguf` → `GGUFFastVideoPipeline`, NF4 folder → `NF4FastVideoPipeline`, else → existing `LTXFastVideoPipeline` +4. Call `.create()` with the selected model's path as `checkpoint_path` + +The pipeline classes are injected via `ServiceBundle` (like the existing pattern), so tests can provide fakes. + +### Settings Changes (`backend/state/app_settings.py`) + +New fields: +```python +custom_video_model_path: str # User-chosen folder for custom models, empty = use default models dir +selected_video_model: str # Filename of selected model, empty = use default BF16 checkpoint +``` + +Also update `SettingsResponse` and `to_settings_response()` to include these new fields. + +### Interaction with existing `video_model` setting + +`video_model: str = "ltx-fast"` remains as-is — it selects the model *type* (fast vs pro mode). +`selected_video_model` selects which *checkpoint file* to use for that model type. +They are independent: `video_model` picks the mode, `selected_video_model` picks the weights. + +### New API Endpoints + +``` +GET /api/models/video/scan → { models: DetectedModel[], distilled_lora_found: bool } +POST /api/models/video/select → { model: str } (filename — validated against scan results) +GET /api/models/video/guide → { gpu_name, vram_gb, recommended_format, formats: ModelFormatInfo[], distilled_lora: DistilledLoraInfo } +``` + +**Validation**: `POST /select` checks that the filename exists in the configured folder. Returns 400 if not found. + +**Generation guard**: If a generation is running, `POST /select` returns 409 Conflict (uses existing `_ensure_no_running_generation()` pattern). + +### Model Guide Data (`backend/services/model_scanner/model_guide_data.py`) + +Static config file (not hardcoded in handler logic) containing format metadata: +```python +MODEL_FORMATS: list[ModelFormatInfo] = [ + { + "id": "bf16", + "name": "BF16 (Full Precision)", + "size_gb": 43, + "min_vram_gb": 32, + "quality_tier": "Best", + "needs_distilled_lora": False, + "download_url": "https://huggingface.co/Lightricks/LTX-Video-2.3-22b-distilled/...", + "description": "Best quality. Requires 32GB+ VRAM." + }, + ... +] + +DISTILLED_LORA_INFO = { + "name": "LTX 2.3 Distilled LoRA", + "size_gb": 0.5, + "download_url": "https://huggingface.co/...", + "description": "Required for GGUF and NF4 models to generate quickly." +} +``` + +This is a data file, easily updated when new models release — not buried in handler code. + +### Routes (`backend/_routes/models.py`) + +Thin routes delegating to handler (follows existing pattern): +```python +@router.get("/video/scan") +def scan_video_models(handler = Depends(get_state_service)): + return handler.models.scan_video_models() + +@router.post("/video/select") +def select_video_model(body: SelectModelRequest, handler = Depends(get_state_service)): + return handler.models.select_video_model(body.model) + +@router.get("/video/guide") +def video_model_guide(handler = Depends(get_state_service)): + return handler.models.video_model_guide() +``` +``` + +## Frontend Design + +### Model Guide Popup Component + +`frontend/components/ModelGuideDialog.tsx` + +Visual design: +- Modal dialog (same style as SettingsModal) +- Top banner: GPU detection result with icon +- Grid of format cards (2 columns), each card shows: + - Format name (bold) + - File size badge + - Quality bar (5 segments, filled proportionally) + - Speed bar (5 segments) + - "Recommended" pill badge (green, on best match) + - "Download" button (opens HuggingFace URL in browser) +- Bottom section: model folder path + change button +- Distilled LoRA callout box (yellow/amber) when applicable + +### Models Tab Component + +Added to `SettingsModal.tsx` as new tab: + +- Model dropdown (select from scanned models) +- Folder path display + Change/Open/Scan buttons +- GPU info summary +- "Open Model Guide" button +- Distilled LoRA status indicator + +## Error Handling + +- **Corrupt model files**: `ModelScannerImpl` catches read errors, logs warning, skips the file +- **Missing distilled LoRA**: UI shows amber warning with download link; generation still attempted (may be slow but won't crash) +- **Model selected but deleted from disk**: `PipelinesHandler` catches FileNotFoundError at load time, returns clear error to frontend; frontend shows "Model file not found — please re-scan" message +- **Selecting model during generation**: Returns 409 Conflict, frontend disables the dropdown while generating +- **GGUF package not installed**: `GGUFFastVideoPipeline.create()` catches ImportError, raises RuntimeError with "Install gguf package: pip install gguf>=0.10.0" + +## Testing Strategy + +### Backend Tests + +- `test_model_scanner.py` — scan folder with mixed model files (real temp files with correct headers), verify format detection; test corrupt file handling; test empty folder; test distilled LoRA detection +- `test_model_guide.py` — verify GPU VRAM → recommended format mapping logic +- `test_model_selection.py` — integration test via TestClient: scan → select → verify settings updated; select nonexistent model → 400; select during generation → 409 + +Pipeline integration tests for GGUF/NF4 are deferred — they require actual model files and GPU. Manual testing against real quantized models will validate these. + +### Frontend + +No frontend tests currently exist in the project. The Model Guide and Models tab should be manually tested. + +## File Map + +### New Files +- `backend/services/model_scanner/model_scanner.py` — ModelScanner Protocol +- `backend/services/model_scanner/model_scanner_impl.py` — Real implementation +- `backend/services/model_scanner/__init__.py` +- `backend/services/model_scanner/model_guide_data.py` — Static format metadata config +- `backend/services/fast_video_pipeline/gguf_fast_video_pipeline.py` — GGUF pipeline +- `backend/services/fast_video_pipeline/nf4_fast_video_pipeline.py` — NF4 pipeline +- `backend/tests/test_model_scanner.py` — Model scanning tests +- `backend/tests/test_model_guide.py` — Guide recommendation tests +- `backend/tests/test_model_selection.py` — Selection integration tests +- `frontend/components/ModelGuideDialog.tsx` — Model Guide popup + +### Modified Files +- `backend/api_types.py` — Add `DetectedModel`, `ModelFormatInfo`, `SelectModelRequest` types +- `backend/state/app_settings.py` — Add `custom_video_model_path`, `selected_video_model`; update `SettingsResponse` and `to_settings_response()` +- `backend/handlers/models_handler.py` — Add `scan_video_models()`, `select_video_model()`, `video_model_guide()` methods +- `backend/_routes/models.py` — Add 3 new routes under `/video/` +- `backend/app_handler.py` — Wire `ModelScannerImpl` into `ServiceBundle`, pass to `ModelsHandler` +- `backend/handlers/pipelines_handler.py` — Update `_create_video_pipeline()` to check `selected_video_model` setting and pick pipeline class by format +- `backend/services/interfaces.py` — Re-export `ModelScanner` +- `backend/tests/fakes/services.py` — Add `FakeModelScanner` +- `backend/tests/conftest.py` — Wire `FakeModelScanner` into test `ServiceBundle` +- `frontend/components/SettingsModal.tsx` — Add Models tab, Model Guide button +- `README.md` — Add Custom Video Models section + +## Dependencies + +- `gguf>=0.10.0` — needed to read GGUF file metadata and load quantized weights. Must be added to `backend/pyproject.toml`. +- `diffusers` GGUF quantizer — already installed, provides `GGUFQuantizer`, `GGUFLinear`, dequantization functions +- `bitsandbytes` — already installed, provides NF4 quantization (used by FLUX Klein) + +## Out of Scope + +- Auto-downloading quantized models (user downloads manually) +- Converting between formats +- Quantizing models locally +- Supporting non-LTX video models +- Image model quantization changes (FLUX Klein NF4 already works) diff --git a/electron-builder.yml b/electron-builder.yml index 94753cc6..3d22141a 100644 --- a/electron-builder.yml +++ b/electron-builder.yml @@ -1,5 +1,5 @@ appId: com.lightricks.ltx-desktop -productName: LTX Desktop +productName: Director's Desktop copyright: Copyright © 2026 Lightricks directories: @@ -55,7 +55,7 @@ nsis: installerHeaderIcon: resources/icon.ico createDesktopShortcut: true createStartMenuShortcut: true - shortcutName: LTX Desktop + shortcutName: Director's Desktop mac: hardenedRuntime: true diff --git a/electron/csp.ts b/electron/csp.ts index 9d78dd9f..95701186 100644 --- a/electron/csp.ts +++ b/electron/csp.ts @@ -11,8 +11,8 @@ export function setupCSP(): void { "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com", "font-src 'self' https://fonts.gstatic.com", "connect-src 'self' http://localhost:* http://127.0.0.1:* ws://localhost:* ws://127.0.0.1:*", - "img-src 'self' data: blob: file:", - "media-src 'self' blob: file:", + "img-src 'self' data: blob: file: http://localhost:* http://127.0.0.1:*", + "media-src 'self' blob: file: http://localhost:* http://127.0.0.1:*", "object-src 'none'", "base-uri 'self'", "form-action 'self'", @@ -24,8 +24,8 @@ export function setupCSP(): void { "style-src 'self' https://fonts.googleapis.com", "font-src 'self' https://fonts.gstatic.com", "connect-src 'self' http://localhost:* http://127.0.0.1:* ws://localhost:* ws://127.0.0.1:*", - "img-src 'self' data: blob: file:", - "media-src 'self' blob: file:", + "img-src 'self' data: blob: file: http://localhost:* http://127.0.0.1:*", + "media-src 'self' blob: file: http://localhost:* http://127.0.0.1:*", "object-src 'none'", "base-uri 'self'", "form-action 'self'", diff --git a/electron/ipc/file-handlers.ts b/electron/ipc/file-handlers.ts index db099739..90cb035b 100644 --- a/electron/ipc/file-handlers.ts +++ b/electron/ipc/file-handlers.ts @@ -1,4 +1,4 @@ -import { ipcMain, dialog } from 'electron' +import { ipcMain, dialog, BrowserWindow, session } from 'electron' import path from 'path' import fs from 'fs' import { getAllowedRoots } from '../config' @@ -70,9 +70,190 @@ export function registerFileHandlers(): void { return true }) - ipcMain.handle('open-fal-api-key-page', async () => { + ipcMain.handle('open-replicate-api-key-page', async () => { const { shell } = await import('electron') - await shell.openExternal('https://fal.ai/dashboard/keys') + await shell.openExternal('https://replicate.com/account/api-tokens') + return true + }) + + ipcMain.handle('open-palette-login-page', async () => { + const PALETTE_URL = 'https://directorspal.com' + const mainWindow = getMainWindow() + + // Create a dedicated login session so we don't pollute the main session + const loginSession = session.fromPartition('palette-login') + + const loginWindow = new BrowserWindow({ + width: 500, + height: 700, + parent: mainWindow ?? undefined, + modal: true, + show: false, + webPreferences: { + session: loginSession, + nodeIntegration: false, + contextIsolation: true, + }, + backgroundColor: '#1a1a1a', + title: "Sign In to Director's Palette", + }) + + loginWindow.setMenuBarVisibility(false) + loginWindow.once('ready-to-show', () => loginWindow.show()) + + // Poll for the Supabase session from cookies or localStorage. + // @supabase/ssr stores auth tokens as cookies on the Palette domain, + // either as a single cookie or chunked (sb--auth-token.0, .1, etc.) + const checkForToken = async (): Promise => { + try { + const allCookies = await loginSession.cookies.get({}) + + logger.info(`[palette-login] Found ${allCookies.length} cookies`) + + // Look for Supabase SSR auth cookies (sb--auth-token or chunked .0, .1, ...) + const baseCookies = allCookies.filter(c => + c.name.startsWith('sb-') && c.name.includes('-auth-token') + ) + + if (baseCookies.length > 0) { + // Check for a single (non-chunked) cookie first + const single = baseCookies.find(c => + c.name.match(/^sb-[^.]+$/) && c.name.endsWith('-auth-token') + ) + let cookieValue: string | null = null + + if (single?.value) { + cookieValue = single.value + } else { + // Reassemble chunked cookies: sb--auth-token.0, .1, .2, ... + const chunks = baseCookies + .filter(c => /\.\d+$/.test(c.name)) + .sort((a, b) => { + const aIdx = parseInt(a.name.split('.').pop() || '0', 10) + const bIdx = parseInt(b.name.split('.').pop() || '0', 10) + return aIdx - bIdx + }) + + if (chunks.length > 0) { + cookieValue = chunks.map(c => c.value).join('') + } + } + + if (cookieValue) { + try { + const decoded = decodeURIComponent(cookieValue) + const parsed = JSON.parse(decoded) + if (parsed.access_token) { + logger.info('[palette-login] Token found in Supabase SSR cookie') + return parsed.access_token as string + } + } catch { + // Try base64 decode + try { + const decoded = Buffer.from(cookieValue, 'base64').toString('utf-8') + const parsed = JSON.parse(decoded) + if (parsed.access_token) { + logger.info('[palette-login] Token found in base64 cookie') + return parsed.access_token as string + } + } catch { /* not parseable */ } + } + } + } + + // Fallback: check localStorage (older Supabase clients store tokens there) + const token = await loginWindow.webContents.executeJavaScript(` + (function() { + try { + for (let i = 0; i < localStorage.length; i++) { + const key = localStorage.key(i); + if (key && key.startsWith('sb-') && key.endsWith('-auth-token')) { + const raw = localStorage.getItem(key); + if (raw) { + try { + const parsed = JSON.parse(raw); + if (parsed.access_token) return parsed.access_token; + } catch(e) {} + } + } + } + } catch(e) {} + return null; + })() + `).catch(() => null) + + return token as string | null + } catch { + return null + } + } + + let pollTimer: ReturnType | null = null + let resolved = false + + const onTokenFound = (token: string) => { + if (resolved) return + resolved = true + if (pollTimer) clearInterval(pollTimer) + logger.info('[palette-login] Token captured successfully') + if (mainWindow) { + mainWindow.webContents.send('palette-auth-callback', { token }) + } + loginWindow.close() + } + + // Start polling after each navigation completes + loginWindow.webContents.on('did-finish-load', () => { + const url = loginWindow.webContents.getURL() + logger.info(`[palette-login] Navigated to: ${url}`) + + // Start polling aggressively on any Palette domain page + if (url.startsWith(PALETTE_URL)) { + if (pollTimer) clearInterval(pollTimer) + pollTimer = setInterval(async () => { + const token = await checkForToken() + if (token) onTokenFound(token) + }, 500) + // Also check immediately + void checkForToken().then(t => { if (t) onTokenFound(t) }) + } + }) + + // Also check after any redirect (with delay for cookie to be set) + loginWindow.webContents.on('did-navigate', async (_event, url) => { + logger.info(`[palette-login] did-navigate: ${url}`) + // Check immediately + let token = await checkForToken() + if (token) { onTokenFound(token); return } + // Cookies may not be set yet — retry after a short delay + await new Promise(r => setTimeout(r, 1000)) + token = await checkForToken() + if (token) onTokenFound(token) + }) + + // Check on in-page navigation too (SPA redirects) + loginWindow.webContents.on('did-navigate-in-page', async (_event, url) => { + logger.info(`[palette-login] did-navigate-in-page: ${url}`) + let token = await checkForToken() + if (token) { onTokenFound(token); return } + await new Promise(r => setTimeout(r, 1000)) + token = await checkForToken() + if (token) onTokenFound(token) + }) + + loginWindow.on('closed', () => { + if (pollTimer) clearInterval(pollTimer) + // Clear the login session cookies + loginSession.cookies.flushStore().catch(() => {}) + }) + + await loginWindow.loadURL(`${PALETTE_URL}/auth/signin`) + return true + }) + + ipcMain.handle('open-palette-api-key-page', async () => { + const { shell } = await import('electron') + await shell.openExternal('https://directorspal.com/settings/api-keys') return true }) diff --git a/electron/main.ts b/electron/main.ts index c871f7fa..fcf77516 100644 --- a/electron/main.ts +++ b/electron/main.ts @@ -1,5 +1,5 @@ import './app-paths' -import { app } from 'electron' +import { app, protocol, net } from 'electron' import { setupCSP } from './csp' import { registerExportHandlers } from './export/export-handler' import { stopExportProcess } from './export/ffmpeg-utils' @@ -12,6 +12,44 @@ import { stopPythonBackend } from './python-backend' import { initAutoUpdater } from './updater' import { createWindow, getMainWindow } from './window' import { sendAnalyticsEvent } from './analytics' +import { logger } from './logger' + +// Register directorsdesktop:// protocol for auth callbacks +if (process.defaultApp) { + // Dev mode: pass the app path so Electron can find us + if (process.argv.length >= 2) { + app.setAsDefaultProtocolClient('directorsdesktop', process.execPath, [ + path.resolve(process.argv[1]), + ]) + } +} else { + app.setAsDefaultProtocolClient('directorsdesktop') +} + +import path from 'path' + +/** Extract auth token from a directorsdesktop:// deep link URL. */ +function handleDeepLink(url: string): void { + // Redact query params (may contain auth tokens) + const safeUrl = url.split('?')[0] + (url.includes('?') ? '?' : '') + logger.info(`[deep-link] Received: ${safeUrl}`) + try { + const parsed = new URL(url) + // Expected: directorsdesktop://auth/callback?token=XXX + if (parsed.hostname === 'auth' && parsed.pathname === '/callback') { + const token = parsed.searchParams.get('token') + if (token) { + const mainWindow = getMainWindow() + if (mainWindow) { + mainWindow.webContents.send('palette-auth-callback', { token }) + logger.info('[deep-link] Auth token forwarded to renderer') + } + } + } + } catch (err) { + logger.error(`[deep-link] Failed to parse URL: ${err}`) + } +} const gotLock = app.requestSingleInstanceLock() @@ -26,7 +64,7 @@ if (!gotLock) { registerExportHandlers() registerVideoProcessingHandlers() - app.on('second-instance', () => { + app.on('second-instance', (_event, commandLine) => { const mainWindow = getMainWindow() if (mainWindow) { if (mainWindow.isMinimized()) { @@ -36,15 +74,29 @@ if (!gotLock) { mainWindow.show() } mainWindow.focus() - return - } - if (app.isReady()) { + } else if (app.isReady()) { createWindow() } + + // On Windows, the deep link URL is in the command line args + const deepLinkUrl = commandLine.find((arg) => arg.startsWith('directorsdesktop://')) + if (deepLinkUrl) { + handleDeepLink(deepLinkUrl) + } + }) + + // macOS: handle protocol URL via open-url event + app.on('open-url', (_event, url) => { + handleDeepLink(url) }) app.whenReady().then(async () => { setupCSP() + + // Allow file:// URLs to load when the page is served from http://localhost (dev mode). + // Without this, Chromium blocks file:// resources on http:// origins. + protocol.handle('file', (request) => net.fetch(request)) + createWindow() initAutoUpdater() // Python setup + backend start are now driven by the renderer via IPC diff --git a/electron/preload.ts b/electron/preload.ts index dcb3d5f5..3bdbb1e9 100644 --- a/electron/preload.ts +++ b/electron/preload.ts @@ -30,7 +30,9 @@ contextBridge.exposeInMainWorld('electronAPI', { // Open specific app pages / folders openLtxApiKeyPage: (): Promise => ipcRenderer.invoke('open-ltx-api-key-page'), - openFalApiKeyPage: (): Promise => ipcRenderer.invoke('open-fal-api-key-page'), + openReplicateApiKeyPage: (): Promise => ipcRenderer.invoke('open-replicate-api-key-page'), + openPaletteLoginPage: (): Promise => ipcRenderer.invoke('open-palette-login-page'), + openPaletteApiKeyPage: (): Promise => ipcRenderer.invoke('open-palette-api-key-page'), openParentFolderOfFile: (filePath: string): Promise => ipcRenderer.invoke('open-parent-folder-of-file', filePath), // Reveal a specific file in the OS file manager (Explorer/Finder) @@ -117,6 +119,15 @@ contextBridge.exposeInMainWorld('electronAPI', { sendAnalyticsEvent: (eventName: string, extraDetails?: Record | null): Promise => ipcRenderer.invoke('send-analytics-event', eventName, extraDetails), + // Deep link auth callback (from directorsdesktop:// protocol) + onPaletteAuthCallback: (cb: (data: { token: string }) => void) => { + const listener = (_: unknown, data: { token: string }) => cb(data) + ipcRenderer.on('palette-auth-callback', listener) + return () => { + ipcRenderer.removeListener('palette-auth-callback', listener) + } + }, + // Platform info platform: process.platform, }) @@ -182,6 +193,7 @@ declare global { getAnalyticsState: () => Promise<{ analyticsEnabled: boolean; installationId: string }> setAnalyticsEnabled: (enabled: boolean) => Promise sendAnalyticsEvent: (eventName: string, extraDetails?: Record | null) => Promise + onPaletteAuthCallback: (cb: (data: { token: string }) => void) => (() => void) platform: string } } diff --git a/electron/python-backend.ts b/electron/python-backend.ts index e5409f9c..c98015e5 100644 --- a/electron/python-backend.ts +++ b/electron/python-backend.ts @@ -222,6 +222,7 @@ export async function startPythonBackend(): Promise { env: { ...process.env, PYTHONUNBUFFERED: '1', + PYTHONNOUSERSITE: '1', LTX_PORT: String(PYTHON_PORT), LTX_LOG_FILE: getCurrentLogFilename(), LTX_APP_DATA_DIR: getAppDataDir(), diff --git a/electron/updater.ts b/electron/updater.ts index b3f33aaf..510b1087 100644 --- a/electron/updater.ts +++ b/electron/updater.ts @@ -13,40 +13,26 @@ export function initAutoUpdater( autoUpdater.allowPrerelease = true } - // On Windows, don't auto-install — we need to pre-download python-embed first. - // On macOS, python is bundled in the DMG so auto-install is fine. - if (process.platform === 'win32') { - autoUpdater.autoInstallOnAppQuit = false - } - + // Windows/Linux: best-effort pre-download of python-embed so the new version + // doesn't have to download it on first launch. This doesn't block the update — + // autoInstallOnAppQuit stays true (the default) so the update installs whenever + // the user naturally quits, whether or not the pre-download has finished. autoUpdater.on('update-downloaded', async (info: UpdateDownloadedEvent) => { - if (process.platform !== 'win32') { - // macOS: python is bundled, just install normally - autoUpdater.quitAndInstall(false, true) - return - } + if (process.platform === 'darwin') return - // Windows: pre-download python-embed if deps changed before restarting const newVersion = info.version - logger.info( `[updater] Update downloaded: v${newVersion}, checking python deps...`) + logger.info( `[updater] Update downloaded: v${newVersion}, pre-downloading python deps...`) try { const didDownload = await preDownloadPythonForUpdate(newVersion, (progress) => { - // Forward progress to renderer so it can show a "Preparing update..." UI getMainWindow()?.webContents.send('python-update-progress', progress) }) - - if (didDownload) { - logger.info( '[updater] Python pre-download complete, installing update...') - } else { - logger.info( '[updater] No python changes needed, installing update...') - } + logger.info( didDownload + ? '[updater] Python pre-download complete' + : '[updater] No python changes needed') } catch (err) { - // Pre-download failed — install anyway; the app will download at next launch - logger.error( `[updater] Python pre-download failed, proceeding with update: ${err}`) + logger.error( `[updater] Python pre-download failed: ${err}`) } - - autoUpdater.quitAndInstall(false, true) }) const update = () => { diff --git a/frontend/App.tsx b/frontend/App.tsx index 425217c5..489a8ea9 100644 --- a/frontend/App.tsx +++ b/frontend/App.tsx @@ -9,6 +9,12 @@ import { logger } from './lib/logger' import { Home } from './views/Home' import { Project } from './views/Project' import { Playground } from './views/Playground' +import { Gallery } from './views/Gallery' +import { Characters } from './views/Characters' +import { Styles } from './views/Styles' +import { References } from './views/References' +import { Wildcards } from './views/Wildcards' +import { PromptLibrary } from './views/PromptLibrary' import { LaunchGate } from './components/FirstRunSetup' import { PythonSetup } from './components/PythonSetup' import { SettingsModal, type SettingsTabId } from './components/SettingsModal' @@ -22,7 +28,7 @@ type RequiredModelsGateState = 'checking' | 'missing' | 'ready' function AppContent() { const { currentView } = useProjects() const { status, processStatus, isLoading: backendLoading, error: backendError } = useBackend() - const { settings, saveLtxApiKey, saveFalApiKey, forceApiGenerations, isLoaded, runtimePolicyLoaded } = useAppSettings() + const { settings, saveLtxApiKey, saveReplicateApiKey, forceApiGenerations, isLoaded, runtimePolicyLoaded } = useAppSettings() const [pythonReady, setPythonReady] = useState(null) const [backendStarted, setBackendStarted] = useState(false) @@ -36,7 +42,7 @@ function AppContent() { const setupCompletionInFlightRef = useRef | null>(null) type ApiGatewayRequest = { - requiredKeys: Array<'ltx' | 'fal'> + requiredKeys: Array<'ltx' | 'replicate'> title: string description: string blocking?: boolean @@ -299,16 +305,16 @@ function AppContent() { getKeyLabel: 'Get LTX API key', }, { - keyType: 'fal', - title: 'FAL AI', - description: 'Required to generate images with Z Image Turbo.', - required: apiGatewayRequest.requiredKeys.includes('fal'), - isConfigured: settings.hasFalApiKey, - inputLabel: 'FAL AI API key', - placeholder: 'Enter your FAL AI API key...', - onSave: saveFalApiKey, - onGetKey: () => window.electronAPI.openFalApiKeyPage(), - getKeyLabel: 'Get FAL API key', + keyType: 'replicate', + title: 'Replicate', + description: 'Required for cloud image generation.', + required: apiGatewayRequest.requiredKeys.includes('replicate'), + isConfigured: settings.hasReplicateApiKey, + inputLabel: 'Replicate API key', + placeholder: 'Enter your Replicate API key...', + onSave: saveReplicateApiKey, + onGetKey: () => window.electronAPI.openReplicateApiKeyPage(), + getKeyLabel: 'Get Replicate API key', }, ] @@ -321,9 +327,9 @@ function AppContent() { apiGatewayRequest, isForcedFirstRun, saveApiKeyForFirstRun, - saveFalApiKey, + saveReplicateApiKey, saveLtxApiKey, - settings.hasFalApiKey, + settings.hasReplicateApiKey, settings.hasLtxApiKey, ]) @@ -372,7 +378,7 @@ function AppContent() {
-

Starting LTX Desktop...

+

Starting Director's Desktop...

Initializing the inference engine

@@ -431,6 +437,18 @@ function AppContent() { return case 'playground': return + case 'gallery': + return + case 'characters': + return + case 'styles': + return + case 'references': + return + case 'wildcards': + return + case 'prompt-library': + return default: return } diff --git a/frontend/components/ApiGatewayModal.tsx b/frontend/components/ApiGatewayModal.tsx index 4713c917..fb2d74a3 100644 --- a/frontend/components/ApiGatewayModal.tsx +++ b/frontend/components/ApiGatewayModal.tsx @@ -2,7 +2,7 @@ import { useEffect, useMemo, useState } from 'react' import { KeyRound, X, Zap } from 'lucide-react' import { ApiKeyHelperRow, LtxApiKeyInput } from './LtxApiKeyInput' -export type ApiKeyType = 'ltx' | 'fal' +export type ApiKeyType = 'ltx' | 'replicate' export interface ApiGatewaySection { keyType: ApiKeyType @@ -32,7 +32,7 @@ const KEY_TYPE_META: Record>({ ltx: '', fal: '' }) - const [isSaving, setIsSaving] = useState>({ ltx: false, fal: false }) - const [errors, setErrors] = useState>({ ltx: null, fal: null }) + const [values, setValues] = useState>({ ltx: '', replicate: '' }) + const [isSaving, setIsSaving] = useState>({ ltx: false, replicate: false }) + const [errors, setErrors] = useState>({ ltx: null, replicate: null }) useEffect(() => { if (!isOpen) return - setValues({ ltx: '', fal: '' }) - setIsSaving({ ltx: false, fal: false }) - setErrors({ ltx: null, fal: null }) + setValues({ ltx: '', replicate: '' }) + setIsSaving({ ltx: false, replicate: false }) + setErrors({ ltx: null, replicate: null }) }, [isOpen]) const allRequiredConfigured = useMemo(() => { diff --git a/frontend/components/BatchBuilderModal.tsx b/frontend/components/BatchBuilderModal.tsx new file mode 100644 index 00000000..2471f336 --- /dev/null +++ b/frontend/components/BatchBuilderModal.tsx @@ -0,0 +1,476 @@ +import { useState, useRef } from 'react' +import { X, Plus, Trash2, Copy, Upload, Grid3X3, List, FileText, Play, AlertCircle } from 'lucide-react' +import type { BatchSubmitRequest, BatchJobItem, SweepAxis } from '@/types/batch' +import { parseCSV, parseJSON, parseRange } from '@/lib/batch-import' +import { useBatch } from '@/hooks/use-batch' + +interface BatchBuilderModalProps { + isOpen: boolean + onClose: () => void +} + +type TabId = 'list' | 'import' | 'grid' + +interface ListRow { + id: string + type: 'image' | 'video' + model: string + prompt: string + loraPath: string + loraWeight: number +} + +interface GridAxis { + id: string + param: string + valuesInput: string +} + +const PARAM_OPTIONS = [ + { value: 'loraWeight', label: 'LoRA Weight' }, + { value: 'loraPath', label: 'LoRA Path' }, + { value: 'prompt', label: 'Prompt' }, + { value: 'numSteps', label: 'Steps' }, + { value: 'seed', label: 'Seed' }, + { value: 'model', label: 'Model' }, +] + +let rowIdCounter = 0 +function nextRowId(): string { + return `row_${++rowIdCounter}` +} + +export function BatchBuilderModal({ isOpen, onClose }: BatchBuilderModalProps) { + const [activeTab, setActiveTab] = useState('list') + const [target, setTarget] = useState<'local' | 'cloud'>('local') + const [pipelineEnabled, setPipelineEnabled] = useState(false) + const batch = useBatch() + + // List tab state + const [rows, setRows] = useState([ + { id: nextRowId(), type: 'image', model: 'flux-klein-9b', prompt: '', loraPath: '', loraWeight: 1.0 }, + ]) + + // Import tab state + const [importText, setImportText] = useState('') + const [importError, setImportError] = useState(null) + const [importedItems, setImportedItems] = useState([]) + const fileInputRef = useRef(null) + + // Grid tab state + const [gridBasePrompt, setGridBasePrompt] = useState('') + const [gridBaseModel, setGridBaseModel] = useState('z-image-turbo') + const [gridAxes, setGridAxes] = useState([ + { id: nextRowId(), param: 'loraWeight', valuesInput: '0.3-1.0:8' }, + ]) + + if (!isOpen) return null + + const addRow = () => { + setRows(prev => [...prev, { + id: nextRowId(), type: 'image', model: 'flux-klein-9b', prompt: '', loraPath: '', loraWeight: 1.0, + }]) + } + + const removeRow = (id: string) => { + setRows(prev => prev.filter(r => r.id !== id)) + } + + const duplicateRow = (id: string) => { + setRows(prev => { + const idx = prev.findIndex(r => r.id === id) + if (idx < 0) return prev + const copy = { ...prev[idx], id: nextRowId() } + const next = [...prev] + next.splice(idx + 1, 0, copy) + return next + }) + } + + const updateRow = (id: string, field: keyof ListRow, value: string | number) => { + setRows(prev => prev.map(r => r.id === id ? { ...r, [field]: value } : r)) + } + + const handleImportParse = (text: string) => { + setImportText(text) + setImportError(null) + setImportedItems([]) + if (!text.trim()) return + try { + const items = text.trim().startsWith('{') || text.trim().startsWith('[') + ? parseJSON(text) + : parseCSV(text) + setImportedItems(items) + } catch (err) { + setImportError(err instanceof Error ? err.message : 'Parse error') + } + } + + const handleFileUpload = (e: React.ChangeEvent) => { + const file = e.target.files?.[0] + if (!file) return + const reader = new FileReader() + reader.onload = () => handleImportParse(reader.result as string) + reader.readAsText(file) + } + + const addAxis = () => { + if (gridAxes.length >= 3) return + setGridAxes(prev => [...prev, { id: nextRowId(), param: 'numSteps', valuesInput: '4, 8, 12' }]) + } + + const removeAxis = (id: string) => { + setGridAxes(prev => prev.filter(a => a.id !== id)) + } + + const updateAxis = (id: string, field: keyof GridAxis, value: string) => { + setGridAxes(prev => prev.map(a => a.id === id ? { ...a, [field]: value } : a)) + } + + const getGridTotalJobs = (): number => { + return gridAxes.reduce((total, axis) => { + const values = parseRange(axis.valuesInput) + return total * Math.max(values.length, 1) + }, 1) + } + + const handleSubmit = async () => { + let request: BatchSubmitRequest + + if (activeTab === 'list') { + const jobs: BatchJobItem[] = rows.filter(r => r.prompt.trim()).map(r => ({ + type: r.type, + model: r.model, + params: { + prompt: r.prompt, + ...(r.loraPath ? { loraPath: r.loraPath, loraWeight: r.loraWeight } : {}), + }, + })) + if (pipelineEnabled) { + request = { + mode: 'pipeline', + target, + pipeline: { + steps: jobs.flatMap(j => [ + { type: 'image' as const, model: j.model, params: j.params, auto_prompt: false }, + { type: 'video' as const, model: 'ltx-fast', params: {}, auto_prompt: true }, + ]), + }, + } + } else { + request = { mode: 'list', target, jobs } + } + } else if (activeTab === 'import') { + request = { mode: 'list', target, jobs: importedItems } + } else { + const axes: SweepAxis[] = gridAxes.map(a => ({ + param: a.param, + values: a.param === 'prompt' + ? a.valuesInput.split(',').map(v => v.trim()) + : parseRange(a.valuesInput), + mode: a.param === 'prompt' ? 'search_replace' as const : 'replace' as const, + ...(a.param === 'prompt' ? { search: gridBasePrompt.split(' ')[0] } : {}), + })) + request = { + mode: 'sweep', + target, + sweep: { + base_type: 'image', + base_model: gridBaseModel, + base_params: { prompt: gridBasePrompt }, + axes, + }, + } + } + + await batch.submit(request) + onClose() + } + + const tabs: { id: TabId; label: string; icon: React.ReactNode }[] = [ + { id: 'list', label: 'List', icon: }, + { id: 'import', label: 'Import', icon: }, + { id: 'grid', label: 'Grid Sweep', icon: }, + ] + + return ( +
+
+ {/* Header */} +
+

Batch Generation

+ +
+ + {/* Tabs + Target */} +
+
+ {tabs.map(tab => ( + + ))} +
+
+ Target: + +
+
+ + {/* Tab Content */} +
+ {activeTab === 'list' && ( +
+ {/* Table header */} +
+ Type + Model + Prompt + LoRA Path + LoRA Wt + +
+ {rows.map(row => ( +
+ + updateRow(row.id, 'model', e.target.value)} + className="text-sm rounded-lg px-2 py-1.5 border" + style={{ background: 'oklch(0.22 0.025 290)', borderColor: 'oklch(0.32 0.03 290)', color: 'oklch(0.92 0.02 290)' }} + /> + updateRow(row.id, 'prompt', e.target.value)} + placeholder="Enter prompt..." + className="text-sm rounded-lg px-2 py-1.5 border" + style={{ background: 'oklch(0.22 0.025 290)', borderColor: 'oklch(0.32 0.03 290)', color: 'oklch(0.92 0.02 290)' }} + /> + updateRow(row.id, 'loraPath', e.target.value)} + placeholder="Optional" + className="text-sm rounded-lg px-2 py-1.5 border" + style={{ background: 'oklch(0.22 0.025 290)', borderColor: 'oklch(0.32 0.03 290)', color: 'oklch(0.92 0.02 290)' }} + /> + updateRow(row.id, 'loraWeight', Number(e.target.value))} + step={0.1} + min={0} + max={2} + className="text-sm rounded-lg px-2 py-1.5 border" + style={{ background: 'oklch(0.22 0.025 290)', borderColor: 'oklch(0.32 0.03 290)', color: 'oklch(0.92 0.02 290)' }} + /> +
+ + +
+
+ ))} + + + {/* Pipeline toggle */} + +
+ )} + + {activeTab === 'import' && ( +
+
+ + + + Or paste below + +
+