diff --git a/.env.example b/.env.example
index 4dcaae1..e2ff555 100644
--- a/.env.example
+++ b/.env.example
@@ -6,7 +6,7 @@
# --- Application ---
APP_ENV=development
# Set to true for development only. Production MUST use false.
-APP_DEBUG=false
+APP_DEBUG=true
APP_HOST=0.0.0.0
APP_PORT=8000
# SECURITY: Change this to a random secret key in production!
@@ -59,31 +59,66 @@ OLLAMA_MODEL=llama3
# --- Embedding ---
# Provider: local (HuggingFace) | api (OpenAI) | mock
EMBEDDING_PROVIDER=local
-EMBEDDING_MODEL=BAAI/bge-m3
+EMBEDDING_MODEL=Qwen/Qwen3-Embedding-8B
EMBEDDING_API_KEY=
-RERANKER_MODEL=BAAI/bge-reranker-v2-m3
+RERANKER_MODEL=tomaarsen/Qwen3-Reranker-8B-seq-cls
# --- OCR ---
# PaddleOCR language: ch (Chinese+English) | en (English only)
OCR_LANG=ch
-# --- PDF Parsing ---
+# --- PDF Parsing / MinerU ---
# Parser selection: auto (pdfplumber first, fallback to MinerU) | mineru | pdfplumber
-PDF_PARSER=auto
+PDF_PARSER=mineru
# MinerU independent API service URL
MINERU_API_URL=http://localhost:8010
# MinerU backend: pipeline | hybrid-auto-engine | vlm-auto-engine
MINERU_BACKEND=pipeline
# Timeout per PDF in seconds
-MINERU_TIMEOUT=300
+MINERU_TIMEOUT=8000
+# Auto start/stop MinerU subprocess (true = Omelette manages MinerU lifecycle)
+MINERU_AUTO_MANAGE=true
+# Conda environment name for MinerU (used with conda run)
+MINERU_CONDA_ENV=mineru
+# Stop MinerU after N seconds idle (0 = never auto-stop)
+MINERU_TTL_SECONDS=600
+# MinerU startup timeout in seconds
+MINERU_STARTUP_TIMEOUT=120
+# GPU IDs for MinerU (empty = inherit CUDA_VISIBLE_DEVICES)
+MINERU_GPU_IDS=
# --- GPU ---
# Comma-separated GPU IDs for OCR/embedding tasks
-CUDA_VISIBLE_DEVICES=0,3
+CUDA_VISIBLE_DEVICES=
+
+# Auto-unload GPU models after N seconds idle (0 = never auto-unload)
+MODEL_TTL_SECONDS=300
+# TTL check interval in seconds
+MODEL_TTL_CHECK_INTERVAL=30
+
+# GPU preset mode: conservative | balanced | aggressive
+# conservative: batch=1, parallel=1, safe for small VRAM / debugging
+# balanced: batch=8/16, auto parallel, good default
+# aggressive: batch=32/50, parallel=GPU*2, max throughput (32G+ VRAM)
+GPU_MODE=balanced
+
+# Per-service overrides (0 = follow GPU_MODE preset)
+# EMBED_BATCH_SIZE=0
+# RERANK_BATCH_SIZE=0
+
+# Pin models to specific GPU index (-1 = auto-select by free memory)
+# EMBED_GPU_ID=-1
+# RERANK_GPU_ID=-1
+
+# Comma-separated GPU IDs for OCR workers (empty = use all visible GPUs)
+# OCR_GPU_IDS=
+
+# Max parallel OCR tasks. 0=auto (GPU count, or GPU*2 in aggressive mode)
+# OCR_PARALLEL_LIMIT=0
# --- Network Proxy ---
-HTTP_PROXY=http://127.0.0.1:20171/
-HTTPS_PROXY=http://127.0.0.1:20171/
+# HTTP_PROXY=http://your-proxy:port
+# HTTPS_PROXY=http://your-proxy:port
# --- HuggingFace Mirror ---
# For users in China, set to https://hf-mirror.com to speed up model downloads
diff --git a/README.md b/README.md
index fa2a83e..7bcd6a2 100644
--- a/README.md
+++ b/README.md
@@ -60,7 +60,7 @@ Omelette automates the full research literature pipeline — from keyword manage
Multi-channel download via Unpaywall, arXiv, and direct URL fallback strategies.
**📝 OCR Processing**
- Native text extraction with PaddleOCR GPU fallback for scanned documents.
+ Native text extraction via MinerU (auto-managed subprocess) or PaddleOCR GPU fallback.
**🧠 RAG Knowledge Base**
LlamaIndex engine with ChromaDB, GPU-aware embeddings, hybrid retrieval, and cited answers.
@@ -69,7 +69,10 @@ Omelette automates the full research literature pipeline — from keyword manage
Summarization, citation generation (GB/T 7714, APA, MLA), review outlines, and gap analysis.
**🔄 LangGraph Pipeline**
- Pipeline orchestration with human-in-the-loop interrupt and resume.
+ Pipeline orchestration with HITL interrupt/resume and persistent checkpointing.
+
+ **⚡ GPU Resource Management**
+ TTL-based auto-unload for GPU models, MinerU subprocess auto-management, monitoring API, and exit cleanup watchdog.
**🔗 MCP Integration**
Model Context Protocol server for AI IDE clients (Cursor, Claude Code, etc.).
@@ -103,7 +106,7 @@ Keywords ─→ Search ─→ Dedup ─→ Crawler ─→ OCR ─→ RAG ─→
| **RAG** | LlamaIndex with GPU-aware embeddings |
| **LLM** | LangChain (OpenAI, Anthropic, Aliyun, Volcengine, Ollama) |
| **Orchestration** | LangGraph with HITL interrupt/resume |
-| **OCR** | pdfplumber (native) + PaddleOCR (scanned, optional) |
+| **OCR** | MinerU (auto-managed) + pdfplumber (native) + PaddleOCR (scanned) |
| **MCP** | Model Context Protocol server |
| **Docs** | VitePress (bilingual EN/ZH) |
@@ -147,6 +150,10 @@ cp .env.example .env
| `ALIYUN_API_KEY` | Aliyun Bailian API key |
| `VOLCENGINE_API_KEY` | Volcengine Doubao API key |
| `SEMANTIC_SCHOLAR_API_KEY` | Optional; increases Semantic Scholar rate limit |
+| `GPU_MODE` | GPU preset: `conservative`, `balanced` (default), `aggressive` |
+| `MODEL_TTL_SECONDS` | Auto-unload GPU models after N seconds idle (default: 300) |
+| `MINERU_AUTO_MANAGE` | Auto start/stop MinerU subprocess (default: true) |
+| `PDF_PARSER` | `auto`, `mineru`, or `pdfplumber` |
See [`.env.example`](.env.example) for the full list.
@@ -156,10 +163,31 @@ See [`.env.example`](.env.example) for the full list.
```bash
cd backend
+
+# Run database migrations
+alembic upgrade head
+
+# Start server
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
```
-### 4. Start frontend
+On startup, the backend automatically:
+- Writes a PID file to `DATA_DIR/omelette.pid`
+- Starts a GPU model TTL monitor (auto-unloads idle models)
+- If `MINERU_AUTO_MANAGE=true`, manages MinerU subprocess lifecycle
+- Registers cleanup handlers (`atexit` + `SIGHUP`) so GPU resources are released even if the process exits unexpectedly
+
+### 4. (Optional) GPU watchdog
+
+For extra safety against `kill -9` or crashes, run the external watchdog:
+
+```bash
+python backend/scripts/gpu_watchdog.py --daemon
+```
+
+The watchdog monitors the Omelette process and cleans up GPU resources if it terminates abnormally.
+
+### 5. Start frontend
```bash
cd frontend
@@ -169,13 +197,19 @@ npm run dev
Open [http://localhost:3000](http://localhost:3000) in your browser.
-### 5. (Optional) OCR & Embeddings
+### 6. (Optional) MinerU setup
+
+If using MinerU for PDF parsing (`PDF_PARSER=mineru`):
```bash
-cd backend
-pip install -e ".[ocr,ml]"
+# Create a separate conda env for MinerU
+conda create -n mineru python=3.10
+conda activate mineru
+pip install magic-pdf[full]
```
+Set `MINERU_CONDA_ENV=mineru` in `.env`. Omelette will auto-start MinerU when needed.
+
> **Troubleshooting:** If you get `ModuleNotFoundError: No module named 'fastapi'`, ensure the conda environment is activated: `conda activate omelette`.
## 📂 Project Layout
@@ -194,7 +228,8 @@ omelette/
│ │ └── main.py # App entry, lifespan, CORS
│ ├── mcp_server.py # MCP (Model Context Protocol) server
│ ├── alembic/ # Database migrations
-│ ├── tests/ # pytest-asyncio tests (178 tests)
+│ ├── scripts/ # Utilities (gpu_watchdog.py)
+│ ├── tests/ # pytest-asyncio tests (526 tests)
│ └── pyproject.toml # Python dependencies
├── frontend/ # React SPA
│ └── src/
@@ -230,7 +265,7 @@ make dev # Start both backend and frontend
### Running Tests
```bash
-# Backend (178 tests)
+# Backend (526 tests)
cd backend && pytest tests/ -v
# Frontend unit tests (28 tests — Vitest + Testing Library + MSW)
@@ -269,6 +304,8 @@ REST APIs under `/api/v1/`:
| `GET/POST /subscriptions` | Subscription management |
| `GET/POST /settings` | Settings and health |
| `GET /settings/health` | Health check |
+| `GET /gpu/status` | GPU model and memory status |
+| `POST /gpu/unload` | Manually unload GPU models |
MCP server: `/mcp` (WebSocket/SSE for AI IDE clients)
diff --git a/README_zh.md b/README_zh.md
index e9b9b11..2cfacd9 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -60,7 +60,7 @@ Omelette 覆盖科研文献全流程自动化 — 从关键词管理、多源检
Unpaywall、arXiv、直链多通道下载,智能回退策略。
**📝 OCR 解析**
- pdfplumber 原生文本提取,PaddleOCR GPU 加速处理扫描件。
+ MinerU(自动管理子进程)或 pdfplumber 原生提取,PaddleOCR GPU 加速处理扫描件。
**🧠 RAG 知识库**
LlamaIndex 引擎,ChromaDB 向量存储,GPU 感知嵌入,混合检索,带引用回答。
@@ -69,7 +69,10 @@ Omelette 覆盖科研文献全流程自动化 — 从关键词管理、多源检
论文摘要、引用生成(GB/T 7714、APA、MLA)、综述提纲、缺口分析。
**🔄 LangGraph 流水线**
- 流水线编排,支持人机协同中断与恢复。
+ 流水线编排,支持人机协同中断/恢复与持久化检查点。
+
+ **⚡ GPU 资源管理**
+ TTL 自动卸载 GPU 模型、MinerU 子进程自动管理、监控 API、退出清理看门狗。
**🔗 MCP 集成**
Model Context Protocol 服务端,面向 AI IDE 客户端(Cursor、Claude Code 等)。
@@ -103,7 +106,7 @@ Keywords ─→ Search ─→ Dedup ─→ Crawler ─→ OCR ─→ RAG ─→
| **RAG** | LlamaIndex,GPU 感知嵌入 |
| **LLM** | LangChain(OpenAI、Anthropic、阿里云、火山引擎、Ollama) |
| **编排** | LangGraph,支持人机协同中断与恢复 |
-| **OCR** | pdfplumber(原生)+ PaddleOCR(扫描件,可选) |
+| **OCR** | MinerU(自动管理)+ pdfplumber(原生)+ PaddleOCR(扫描件) |
| **MCP** | Model Context Protocol 服务端 |
| **文档** | VitePress(中英双语) |
@@ -156,10 +159,31 @@ cp .env.example .env
```bash
cd backend
+
+# 执行数据库迁移
+alembic upgrade head
+
+# 启动服务
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
```
-### 4. 启动前端
+启动时后端自动完成以下操作:
+- 写入 PID 文件到 `DATA_DIR/omelette.pid`
+- 启动 GPU 模型 TTL 监控(自动卸载空闲模型)
+- 若 `MINERU_AUTO_MANAGE=true`,自动管理 MinerU 子进程生命周期
+- 注册清理钩子(`atexit` + `SIGHUP`),即使进程意外退出也会释放 GPU 资源
+
+### 4.(可选)GPU 看门狗
+
+为防止 `kill -9` 或崩溃导致资源泄漏,可运行外部看门狗:
+
+```bash
+python backend/scripts/gpu_watchdog.py --daemon
+```
+
+看门狗会监控 Omelette 进程,在其异常终止后自动清理 GPU 资源。
+
+### 5. 启动前端
```bash
cd frontend
@@ -169,13 +193,19 @@ npm run dev
在浏览器中打开 [http://localhost:3000](http://localhost:3000)。
-### 5.(可选)OCR 与嵌入
+### 6.(可选)MinerU 配置
+
+若使用 MinerU 解析 PDF(`PDF_PARSER=mineru`):
```bash
-cd backend
-pip install -e ".[ocr,ml]"
+# 为 MinerU 创建独立 conda 环境
+conda create -n mineru python=3.10
+conda activate mineru
+pip install magic-pdf[full]
```
+在 `.env` 中设置 `MINERU_CONDA_ENV=mineru`,Omelette 将在需要时自动启动 MinerU。
+
> **常见问题:** 若出现 `ModuleNotFoundError: No module named 'fastapi'`,请确认已激活 conda 环境:`conda activate omelette`。
## 📂 项目结构
@@ -194,7 +224,8 @@ omelette/
│ │ └── main.py # App entry, lifespan, CORS
│ ├── mcp_server.py # MCP (Model Context Protocol) server
│ ├── alembic/ # Database migrations
-│ ├── tests/ # pytest-asyncio 测试(178 个)
+│ ├── scripts/ # 工具脚本(gpu_watchdog.py)
+│ ├── tests/ # pytest-asyncio 测试(526 个)
│ └── pyproject.toml # Python dependencies
├── frontend/ # React SPA
│ └── src/
@@ -230,7 +261,7 @@ make dev # Start both backend and frontend
### 运行测试
```bash
-# 后端(178 个测试)
+# 后端(526 个测试)
cd backend && pytest tests/ -v
# 前端单元测试(28 个测试 — Vitest + Testing Library + MSW)
@@ -266,6 +297,8 @@ REST API 位于 `/api/v1/` 下:
| `GET/POST /subscriptions` | 订阅管理 |
| `GET/POST /settings` | 设置与健康状态 |
| `GET /settings/health` | 健康检查 |
+| `GET /gpu/status` | GPU 模型与显存状态 |
+| `POST /gpu/unload` | 手动卸载 GPU 模型 |
MCP 服务端:`/mcp`(WebSocket/SSE,面向 AI IDE 客户端)
diff --git a/backend/alembic/versions/a1b2c3d4e5f6_add_composite_indexes.py b/backend/alembic/versions/a1b2c3d4e5f6_add_composite_indexes.py
new file mode 100644
index 0000000..8f68f51
--- /dev/null
+++ b/backend/alembic/versions/a1b2c3d4e5f6_add_composite_indexes.py
@@ -0,0 +1,26 @@
+"""add composite indexes for paper and task tables
+
+Revision ID: a1b2c3d4e5f6
+Revises: f2bee250c39f
+Create Date: 2026-03-18 10:00:00.000000
+
+"""
+
+from collections.abc import Sequence
+
+from alembic import op
+
+revision: str = "a1b2c3d4e5f6"
+down_revision: str | None = "f2bee250c39f"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+ op.create_index("ix_paper_project_status", "papers", ["project_id", "status"])
+ op.create_index("ix_task_project_status", "tasks", ["project_id", "status"])
+
+
+def downgrade() -> None:
+ op.drop_index("ix_task_project_status", table_name="tasks")
+ op.drop_index("ix_paper_project_status", table_name="papers")
diff --git a/backend/alembic/versions/cb8130e58f92_add_paper_project_doi_unique_constraint.py b/backend/alembic/versions/cb8130e58f92_add_paper_project_doi_unique_constraint.py
new file mode 100644
index 0000000..543e824
--- /dev/null
+++ b/backend/alembic/versions/cb8130e58f92_add_paper_project_doi_unique_constraint.py
@@ -0,0 +1,29 @@
+"""add paper project_doi unique constraint
+
+Revision ID: cb8130e58f92
+Revises: a1b2c3d4e5f6
+Create Date: 2026-03-18 22:54:13.519198
+
+"""
+
+from collections.abc import Sequence
+
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision: str = "cb8130e58f92"
+down_revision: str | Sequence[str] | None = "a1b2c3d4e5f6"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+ """Upgrade schema."""
+ with op.batch_alter_table("papers", schema=None) as batch_op:
+ batch_op.create_unique_constraint("uq_paper_project_doi", ["project_id", "doi"])
+
+
+def downgrade() -> None:
+ """Downgrade schema."""
+ with op.batch_alter_table("papers", schema=None) as batch_op:
+ batch_op.drop_constraint("uq_paper_project_doi", type_="unique")
diff --git a/backend/alembic/versions/e7a9b1c3d5f7_add_keyword_parent_id_index.py b/backend/alembic/versions/e7a9b1c3d5f7_add_keyword_parent_id_index.py
new file mode 100644
index 0000000..2e5a0fe
--- /dev/null
+++ b/backend/alembic/versions/e7a9b1c3d5f7_add_keyword_parent_id_index.py
@@ -0,0 +1,24 @@
+"""add keyword parent_id index
+
+Revision ID: e7a9b1c3d5f7
+Revises: cb8130e58f92
+Create Date: 2026-03-18 12:00:00.000000
+
+"""
+
+from collections.abc import Sequence
+
+from alembic import op
+
+revision: str = "e7a9b1c3d5f7"
+down_revision: str | Sequence[str] | None = "cb8130e58f92"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+ op.create_index(op.f("ix_keywords_parent_id"), "keywords", ["parent_id"], unique=False)
+
+
+def downgrade() -> None:
+ op.drop_index(op.f("ix_keywords_parent_id"), table_name="keywords")
diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py
index 0672d7a..4abd3c8 100644
--- a/backend/app/api/deps.py
+++ b/backend/app/api/deps.py
@@ -1,13 +1,15 @@
"""Shared FastAPI dependencies for dependency injection."""
+from __future__ import annotations
+
from collections.abc import AsyncGenerator
from fastapi import Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
-from app.database import get_session
+from app.database import Base, get_session
from app.models import Project
-from app.services.llm_client import LLMClient, get_llm_client
+from app.services.llm.client import LLMClient, get_llm_client
async def get_db() -> AsyncGenerator[AsyncSession, None]:
@@ -15,12 +17,27 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
yield session
+async def get_or_404[T: Base](
+ db: AsyncSession,
+ model: type[T],
+ resource_id: int,
+ *,
+ project_id: int | None = None,
+ detail: str = "Resource not found",
+) -> T:
+ """Fetch a model instance by primary key, raising 404 if missing or project mismatch."""
+ obj = await db.get(model, resource_id)
+ if not obj:
+ raise HTTPException(status_code=404, detail=detail)
+ obj_project_id = getattr(obj, "project_id", None)
+ if project_id is not None and obj_project_id is not None and obj_project_id != project_id:
+ raise HTTPException(status_code=404, detail=detail)
+ return obj
+
+
async def get_project_or_404(project_id: int, db: AsyncSession) -> Project:
- """Fetch project by ID. Raises HTTPException 404 if not found. Use when project_id comes from body/query."""
- project = await db.get(Project, project_id)
- if not project:
- raise HTTPException(status_code=404, detail="Project not found")
- return project
+ """Fetch project by ID. Raises HTTPException 404 if not found."""
+ return await get_or_404(db, Project, project_id, detail="Project not found")
async def get_project(
diff --git a/backend/app/api/v1/__init__.py b/backend/app/api/v1/__init__.py
index a3b3ea4..439e01c 100644
--- a/backend/app/api/v1/__init__.py
+++ b/backend/app/api/v1/__init__.py
@@ -7,6 +7,7 @@
conversations,
crawler,
dedup,
+ gpu,
keywords,
ocr,
papers,
@@ -41,3 +42,4 @@
api_router.include_router(chat.router)
api_router.include_router(rewrite.router)
api_router.include_router(pipelines.router)
+api_router.include_router(gpu.router)
diff --git a/backend/app/api/v1/chat.py b/backend/app/api/v1/chat.py
index 70edf2c..e662035 100644
--- a/backend/app/api/v1/chat.py
+++ b/backend/app/api/v1/chat.py
@@ -7,21 +7,18 @@
import uuid
from collections.abc import Callable
-from fastapi import APIRouter, Depends
+from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db
+from app.middleware.rate_limit import limiter
from app.pipelines.chat.graph import create_chat_pipeline
-from app.pipelines.chat.stream_writer import (
- format_done,
- format_error,
- format_finish,
- format_start,
-)
+from app.pipelines.chat.stream_writer import format_done, format_finish, format_start
from app.schemas.common import ApiResponse
from app.schemas.conversation import ChatStreamRequest
+from app.utils.sse import format_sse_error
logger = logging.getLogger(__name__)
@@ -83,6 +80,8 @@ async def _stream_chat(
"tool_mode": request.tool_mode,
"conversation_id": request.conversation_id,
"model": request.model or "",
+ "rag_top_k": request.rag_top_k,
+ "use_reranker": request.use_reranker,
}
async for event in pipeline.astream(
@@ -95,19 +94,21 @@ async def _stream_chat(
yield format_finish()
except Exception as e:
logger.exception("Chat stream error")
- yield format_error(str(e))
+ yield format_sse_error(str(e), code=500)
finally:
yield format_done()
-@router.post("/stream")
+@router.post("/stream", summary="Stream chat completion")
+@limiter.limit("30/minute")
async def chat_stream(
- request: ChatStreamRequest,
+ request: Request,
+ body: ChatStreamRequest,
db: AsyncSession = Depends(get_db),
):
"""Data Stream Protocol (Vercel AI SDK 5.0) chat endpoint."""
return StreamingResponse(
- _stream_chat(request, db),
+ _stream_chat(body, db),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
@@ -118,7 +119,7 @@ async def chat_stream(
)
-@router.post("/complete", response_model=ApiResponse[CompletionResponse])
+@router.post("/complete", response_model=ApiResponse[CompletionResponse], summary="Autocomplete suggestion")
async def complete(request: CompletionRequest):
"""Return a short text completion suggestion for autocomplete."""
from app.services.completion_service import CompletionService
diff --git a/backend/app/api/v1/conversations.py b/backend/app/api/v1/conversations.py
index ffbb48b..3c0a929 100644
--- a/backend/app/api/v1/conversations.py
+++ b/backend/app/api/v1/conversations.py
@@ -1,7 +1,7 @@
"""Conversation CRUD API endpoints."""
from fastapi import APIRouter, Depends, HTTPException
-from sqlalchemy import func, select
+from sqlalchemy import func, select, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
@@ -19,7 +19,7 @@
router = APIRouter(prefix="/conversations", tags=["conversations"])
-@router.get("", response_model=ApiResponse[PaginatedData[ConversationListSchema]])
+@router.get("", response_model=ApiResponse[PaginatedData[ConversationListSchema]], summary="List conversations")
async def list_conversations(
page: int = 1,
page_size: int = 20,
@@ -27,15 +27,16 @@ async def list_conversations(
db: AsyncSession = Depends(get_db),
):
"""List conversations, newest first."""
- stmt = select(Conversation).order_by(Conversation.updated_at.desc())
-
+ kb_filter = None
if knowledge_base_id is not None:
- stmt = stmt.where(
- Conversation.knowledge_base_ids.isnot(None),
- )
+ kb_filter = text(
+ "EXISTS (SELECT 1 FROM json_each(conversations.knowledge_base_ids) WHERE value = :kb_id)"
+ ).bindparams(kb_id=knowledge_base_id)
- count_stmt = select(func.count()).select_from(stmt.subquery())
- total = (await db.execute(count_stmt)).scalar_one()
+ count_base = select(func.count(Conversation.id))
+ if kb_filter is not None:
+ count_base = count_base.where(kb_filter)
+ total = (await db.execute(count_base)).scalar_one()
msg_count_sq = (
select(func.count(Message.id))
@@ -60,31 +61,25 @@ async def list_conversations(
.offset((page - 1) * page_size)
.limit(page_size)
)
- if knowledge_base_id is not None:
- detail_stmt = detail_stmt.where(Conversation.knowledge_base_ids.isnot(None))
+ if kb_filter is not None:
+ detail_stmt = detail_stmt.where(kb_filter)
detail_result = await db.execute(detail_stmt)
- items = []
- for conv, msg_count, last_msg_content in detail_result.all():
- if knowledge_base_id is not None:
- kb_ids = conv.knowledge_base_ids or []
- if knowledge_base_id not in kb_ids:
- continue
-
- items.append(
- ConversationListSchema(
- id=conv.id,
- title=conv.title,
- knowledge_base_ids=conv.knowledge_base_ids,
- model=conv.model,
- tool_mode=conv.tool_mode,
- created_at=conv.created_at,
- updated_at=conv.updated_at,
- message_count=msg_count or 0,
- last_message_preview=(last_msg_content[:100] if last_msg_content else ""),
- )
+ items = [
+ ConversationListSchema(
+ id=conv.id,
+ title=conv.title,
+ knowledge_base_ids=conv.knowledge_base_ids,
+ model=conv.model,
+ tool_mode=conv.tool_mode,
+ created_at=conv.created_at,
+ updated_at=conv.updated_at,
+ message_count=msg_count or 0,
+ last_message_preview=(last_msg_content[:100] if last_msg_content else ""),
)
+ for conv, msg_count, last_msg_content in detail_result.all()
+ ]
total_pages = (total + page_size - 1) // page_size if total > 0 else 1
@@ -99,7 +94,7 @@ async def list_conversations(
)
-@router.post("", response_model=ApiResponse[ConversationSchema])
+@router.post("", response_model=ApiResponse[ConversationSchema], summary="Create conversation")
async def create_conversation(
body: ConversationCreateSchema,
db: AsyncSession = Depends(get_db),
@@ -112,7 +107,7 @@ async def create_conversation(
tool_mode=body.tool_mode,
)
db.add(conv)
- await db.commit()
+ await db.flush()
result = await db.execute(
select(Conversation).where(Conversation.id == conv.id).options(selectinload(Conversation.messages))
@@ -121,7 +116,7 @@ async def create_conversation(
return ApiResponse(data=ConversationSchema.model_validate(conv))
-@router.get("/{conversation_id}", response_model=ApiResponse[ConversationSchema])
+@router.get("/{conversation_id}", response_model=ApiResponse[ConversationSchema], summary="Get conversation")
async def get_conversation(
conversation_id: int,
db: AsyncSession = Depends(get_db),
@@ -136,7 +131,7 @@ async def get_conversation(
return ApiResponse(data=ConversationSchema.model_validate(conv))
-@router.put("/{conversation_id}", response_model=ApiResponse[ConversationSchema])
+@router.put("/{conversation_id}", response_model=ApiResponse[ConversationSchema], summary="Update conversation")
async def update_conversation(
conversation_id: int,
body: ConversationUpdateSchema,
@@ -151,8 +146,6 @@ async def update_conversation(
for field, value in body.model_dump(exclude_none=True).items():
setattr(conv, field, value)
- await db.commit()
-
result2 = await db.execute(
select(Conversation).where(Conversation.id == conversation_id).options(selectinload(Conversation.messages))
)
@@ -160,7 +153,7 @@ async def update_conversation(
return ApiResponse(data=ConversationSchema.model_validate(conv))
-@router.delete("/{conversation_id}", response_model=ApiResponse[dict])
+@router.delete("/{conversation_id}", response_model=ApiResponse[dict], summary="Delete conversation")
async def delete_conversation(
conversation_id: int,
db: AsyncSession = Depends(get_db),
@@ -172,5 +165,4 @@ async def delete_conversation(
raise HTTPException(status_code=404, detail="Conversation not found")
await db.delete(conv)
- await db.commit()
return ApiResponse(data={"deleted": True, "id": conversation_id})
diff --git a/backend/app/api/v1/crawler.py b/backend/app/api/v1/crawler.py
index e251ad8..1667c3c 100644
--- a/backend/app/api/v1/crawler.py
+++ b/backend/app/api/v1/crawler.py
@@ -1,5 +1,7 @@
"""PDF crawler API endpoints."""
+from typing import Literal
+
from fastapi import APIRouter, Depends
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -12,10 +14,10 @@
router = APIRouter(prefix="/projects/{project_id}/crawl", tags=["crawler"])
-@router.post("/start", response_model=ApiResponse[dict])
+@router.post("/start", response_model=ApiResponse[dict], summary="Start PDF download crawl")
async def start_crawl(
project_id: int,
- priority: str = "high",
+ priority: Literal["high", "low"] = "low",
max_papers: int = 50,
db: AsyncSession = Depends(get_db),
project: Project = Depends(get_project),
@@ -52,7 +54,7 @@ async def start_crawl(
return ApiResponse(data=download_results)
-@router.get("/stats", response_model=ApiResponse[dict])
+@router.get("/stats", response_model=ApiResponse[dict], summary="Get crawl statistics")
async def crawl_stats(
project_id: int,
db: AsyncSession = Depends(get_db),
diff --git a/backend/app/api/v1/dedup.py b/backend/app/api/v1/dedup.py
index c32aa24..90dd204 100644
--- a/backend/app/api/v1/dedup.py
+++ b/backend/app/api/v1/dedup.py
@@ -2,17 +2,19 @@
import logging
from pathlib import Path
+from typing import Literal
-from fastapi import APIRouter, Depends, HTTPException, Query
+from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db, get_llm, get_project
from app.config import settings
+from app.middleware.rate_limit import limiter
from app.models import Paper, PaperStatus, Project
-from app.schemas.common import ApiResponse
+from app.schemas.common import ApiResponse, PaginatedData, PaginationParams
from app.schemas.knowledge_base import AutoResolveRequest, ResolveConflictRequest
from app.services.dedup_service import DedupService
-from app.services.llm_client import LLMClient
+from app.services.llm.client import LLMClient
from app.services.pdf_metadata import extract_metadata
logger = logging.getLogger(__name__)
@@ -20,10 +22,12 @@
router = APIRouter(prefix="/projects/{project_id}/dedup", tags=["dedup"])
-@router.post("/run", response_model=ApiResponse[dict])
+@router.post("/run", response_model=ApiResponse[dict], summary="Run deduplication pipeline")
+@limiter.limit("5/minute")
async def run_dedup(
+ request: Request,
project_id: int,
- strategy: str = "full",
+ strategy: Literal["full", "doi_only", "title_only"] = "full",
db: AsyncSession = Depends(get_db),
llm: LLMClient = Depends(get_llm),
project: Project = Depends(get_project),
@@ -41,19 +45,33 @@ async def run_dedup(
return ApiResponse(data=result)
-@router.get("/candidates", response_model=ApiResponse[list[dict]])
+@router.get("/candidates", response_model=ApiResponse[PaginatedData[dict]], summary="List dedup candidates")
async def list_dedup_candidates(
project_id: int,
+ pagination: PaginationParams = Depends(),
db: AsyncSession = Depends(get_db),
project: Project = Depends(get_project),
):
- """List potential duplicate pairs for manual review."""
+ """List potential duplicate pairs for manual review with pagination."""
+ page, page_size = pagination.page, pagination.page_size
service = DedupService(db)
candidates = await service.find_llm_dedup_candidates(project_id)
- return ApiResponse(data=candidates)
+ total = len(candidates)
+ start = (page - 1) * page_size
+ end = start + page_size
+ items = candidates[start:end]
+ return ApiResponse(
+ data=PaginatedData(
+ items=items,
+ total=total,
+ page=page,
+ page_size=page_size,
+ total_pages=(total + page_size - 1) // page_size if total else 1,
+ )
+ )
-@router.post("/verify", response_model=ApiResponse[dict])
+@router.post("/verify", response_model=ApiResponse[dict], summary="Verify duplicate with LLM")
async def verify_duplicate(
project_id: int,
paper_a_id: int = Query(..., description="First paper ID"),
@@ -68,7 +86,7 @@ async def verify_duplicate(
return ApiResponse(data=result)
-@router.post("/resolve", response_model=ApiResponse[dict])
+@router.post("/resolve", response_model=ApiResponse[dict], summary="Resolve upload conflict")
async def resolve_conflict(
project_id: int,
body: ResolveConflictRequest,
@@ -141,7 +159,7 @@ async def resolve_conflict(
raise HTTPException(status_code=400, detail=f"Invalid action: {body.action}")
-@router.post("/auto-resolve", response_model=ApiResponse[list[dict]])
+@router.post("/auto-resolve", response_model=ApiResponse[list[dict]], summary="Auto-resolve conflicts with LLM")
async def auto_resolve_conflicts(
project_id: int,
body: AutoResolveRequest,
@@ -185,56 +203,14 @@ async def auto_resolve_conflicts(
new_metadata = await extract_metadata(pdf_path, fallback_title="Untitled")
- if not llm:
- resolutions.append(
- {
- "conflict_id": conflict_id,
- "action": "keep_new",
- "reason": "LLM not available, defaulting to keep_new",
- }
- )
- continue
-
- prompt = f"""Two papers may be duplicates. Decide the best resolution:
-
-Existing paper (in DB):
-- ID: {old_paper.id}
-- Title: {old_paper.title}
-- DOI: {old_paper.doi or "N/A"}
-- Year: {old_paper.year}
-- Journal: {old_paper.journal}
-
-New upload:
-- Title: {new_metadata.title}
-- DOI: {new_metadata.doi or "N/A"}
-- Year: {new_metadata.year}
-- Journal: {new_metadata.journal}
-
-Return JSON: {{"action": "keep_old"|"keep_new"|"merge", "reason": "..."}}
-- keep_old: existing is better, discard new
-- keep_new: new is better or different work, add new
-- merge: combine metadata, add as new paper"""
-
- try:
- result = await llm.chat_json(
- messages=[
- {"role": "system", "content": "You are a deduplication expert. Return valid JSON only."},
- {"role": "user", "content": prompt},
- ],
- task_type="dedup_resolve",
- )
- action = result.get("action", "keep_new")
- if action not in ("keep_old", "keep_new", "merge"):
- action = "keep_new"
- resolutions.append(
- {
- "conflict_id": conflict_id,
- "action": action,
- "reason": result.get("reason", ""),
- }
- )
- except Exception as e:
- logger.warning("LLM auto-resolve failed for %s: %s", conflict_id, e)
- resolutions.append({"conflict_id": conflict_id, "action": "keep_new", "reason": f"Error: {e}"})
+ dedup_svc = DedupService(db, llm)
+ resolution = await dedup_svc.resolve_conflict(
+ old_paper=old_paper,
+ new_title=new_metadata.title,
+ new_doi=new_metadata.doi,
+ new_year=new_metadata.year,
+ new_journal=new_metadata.journal,
+ )
+ resolutions.append({"conflict_id": conflict_id, **resolution})
return ApiResponse(data=resolutions)
diff --git a/backend/app/api/v1/gpu.py b/backend/app/api/v1/gpu.py
new file mode 100644
index 0000000..d2fff21
--- /dev/null
+++ b/backend/app/api/v1/gpu.py
@@ -0,0 +1,70 @@
+"""GPU resource monitoring and management API."""
+
+from __future__ import annotations
+
+import logging
+
+from fastapi import APIRouter
+
+from app.schemas.common import ApiResponse
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(prefix="/gpu", tags=["gpu"])
+
+
+def _get_gpu_memory() -> list[dict]:
+ """Query GPU memory info via torch.cuda (returns empty list if unavailable)."""
+ try:
+ import torch
+
+ if not torch.cuda.is_available():
+ return []
+
+ import os
+
+ cuda_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "")
+ physical_ids = [int(x.strip()) for x in cuda_ids.split(",") if x.strip()] if cuda_ids else []
+
+ result = []
+ for idx in range(torch.cuda.device_count()):
+ free, total = torch.cuda.mem_get_info(idx)
+ used = total - free
+ gpu_id = physical_ids[idx] if idx < len(physical_ids) else idx
+ result.append(
+ {
+ "gpu_id": gpu_id,
+ "total_mb": round(total / (1024 * 1024)),
+ "used_mb": round(used / (1024 * 1024)),
+ "free_mb": round(free / (1024 * 1024)),
+ }
+ )
+ return result
+ except (ImportError, RuntimeError):
+ return []
+
+
+@router.get("/status", summary="Get GPU status")
+async def gpu_status():
+ """Return loaded GPU models, MinerU status, and GPU memory usage."""
+ from app.services.gpu_model_manager import gpu_model_manager
+ from app.services.mineru_process_manager import mineru_process_manager
+
+ return ApiResponse(
+ data={
+ "models": gpu_model_manager.get_status(),
+ "mineru": mineru_process_manager.get_status(),
+ "gpu_memory": _get_gpu_memory(),
+ }
+ )
+
+
+@router.post("/unload", summary="Unload GPU models")
+async def gpu_unload():
+ """Immediately unload all GPU models and release VRAM."""
+ from app.services.gpu_model_manager import gpu_model_manager
+
+ names = list(gpu_model_manager.loaded_model_names)
+ gpu_model_manager.unload_all()
+ logger.info("Manual unload: released models %s", names)
+ return ApiResponse(data={"unloaded": names})
diff --git a/backend/app/api/v1/keywords.py b/backend/app/api/v1/keywords.py
index 45710e9..ae7ffa3 100644
--- a/backend/app/api/v1/keywords.py
+++ b/backend/app/api/v1/keywords.py
@@ -1,36 +1,51 @@
"""Keyword management API endpoints."""
-from fastapi import APIRouter, Depends, HTTPException
-from sqlalchemy import select
+from fastapi import APIRouter, Depends
+from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
-from app.api.deps import get_db, get_llm, get_project
+from app.api.deps import get_db, get_llm, get_or_404, get_project
from app.models import Keyword, Project
-from app.schemas.common import ApiResponse
+from app.schemas.common import ApiResponse, KeywordPaginationParams, PaginatedData
from app.schemas.keyword import KeywordCreate, KeywordExpandRequest, KeywordExpandResponse, KeywordRead, KeywordUpdate
from app.services.keyword_service import KeywordService
-from app.services.llm_client import LLMClient
+from app.services.llm.client import LLMClient
router = APIRouter(prefix="/projects/{project_id}/keywords", tags=["keywords"])
-@router.get("", response_model=ApiResponse[list[KeywordRead]])
+@router.get("", response_model=ApiResponse[PaginatedData[KeywordRead]], summary="List keywords")
async def list_keywords(
project_id: int,
+ pagination: KeywordPaginationParams = Depends(),
level: int | None = None,
db: AsyncSession = Depends(get_db),
project: Project = Depends(get_project),
):
- stmt = select(Keyword).where(Keyword.project_id == project_id)
+ page, page_size = pagination.page, pagination.page_size
+ base = select(Keyword).where(Keyword.project_id == project_id)
if level is not None:
- stmt = stmt.where(Keyword.level == level)
- stmt = stmt.order_by(Keyword.level, Keyword.id)
- result = await db.execute(stmt)
- keywords = result.scalars().all()
- return ApiResponse(data=[KeywordRead.model_validate(k) for k in keywords])
+ base = base.where(Keyword.level == level)
+ count = (await db.execute(select(func.count()).select_from(base.subquery()))).scalar_one()
+ items = (
+ (await db.execute(base.order_by(Keyword.level, Keyword.id).offset((page - 1) * page_size).limit(page_size)))
+ .scalars()
+ .all()
+ )
+
+ return ApiResponse(
+ data=PaginatedData(
+ items=[KeywordRead.model_validate(k) for k in items],
+ total=count,
+ page=page,
+ page_size=page_size,
+ total_pages=(count + page_size - 1) // page_size or 1,
+ )
+ )
-@router.post("", response_model=ApiResponse[KeywordRead], status_code=201)
+
+@router.post("", response_model=ApiResponse[KeywordRead], status_code=201, summary="Create keyword")
async def create_keyword(
project_id: int,
body: KeywordCreate,
@@ -44,7 +59,7 @@ async def create_keyword(
return ApiResponse(code=201, message="Keyword created", data=KeywordRead.model_validate(keyword))
-@router.post("/bulk", response_model=ApiResponse[dict])
+@router.post("/bulk", response_model=ApiResponse[dict], summary="Bulk create keywords")
async def bulk_create_keywords(
project_id: int,
keywords: list[KeywordCreate],
@@ -60,7 +75,7 @@ async def bulk_create_keywords(
return ApiResponse(data={"created": created})
-@router.get("/search-formula", response_model=ApiResponse[dict])
+@router.get("/search-formula", response_model=ApiResponse[dict], summary="Generate boolean search formula")
async def generate_search_formula(
project_id: int,
database: str = "wos",
@@ -74,7 +89,7 @@ async def generate_search_formula(
return ApiResponse(data=result)
-@router.put("/{keyword_id}", response_model=ApiResponse[KeywordRead])
+@router.put("/{keyword_id}", response_model=ApiResponse[KeywordRead], summary="Update keyword")
async def update_keyword(
project_id: int,
keyword_id: int,
@@ -82,9 +97,7 @@ async def update_keyword(
db: AsyncSession = Depends(get_db),
project: Project = Depends(get_project),
):
- keyword = await db.get(Keyword, keyword_id)
- if not keyword or keyword.project_id != project_id:
- raise HTTPException(status_code=404, detail="Keyword not found")
+ keyword = await get_or_404(db, Keyword, keyword_id, project_id=project_id, detail="Keyword not found")
for key, value in body.model_dump(exclude_unset=True).items():
setattr(keyword, key, value)
await db.flush()
@@ -92,21 +105,19 @@ async def update_keyword(
return ApiResponse(data=KeywordRead.model_validate(keyword))
-@router.delete("/{keyword_id}", response_model=ApiResponse)
+@router.delete("/{keyword_id}", response_model=ApiResponse, summary="Delete keyword")
async def delete_keyword(
project_id: int,
keyword_id: int,
db: AsyncSession = Depends(get_db),
project: Project = Depends(get_project),
):
- keyword = await db.get(Keyword, keyword_id)
- if not keyword or keyword.project_id != project_id:
- raise HTTPException(status_code=404, detail="Keyword not found")
+ keyword = await get_or_404(db, Keyword, keyword_id, project_id=project_id, detail="Keyword not found")
await db.delete(keyword)
return ApiResponse(message="Keyword deleted")
-@router.post("/expand", response_model=ApiResponse[KeywordExpandResponse])
+@router.post("/expand", response_model=ApiResponse[KeywordExpandResponse], summary="Expand keywords with LLM")
async def expand_keywords(
project_id: int,
body: KeywordExpandRequest,
@@ -115,26 +126,17 @@ async def expand_keywords(
project: Project = Depends(get_project),
):
"""Use LLM to expand seed keywords with synonyms and related terms."""
-
- prompt = (
- f"Given these seed keywords in the field of scientific research: {body.seed_terms}\n"
- f"Language: {body.language}\n"
- f"Generate up to {body.max_results} related terms including synonyms, abbreviations, "
- "alternate names, and cross-disciplinary terms.\n"
- 'Return JSON: {"expanded_terms": [{"term": "...", "term_zh": "...", "relation": "synonym|abbreviation|related"}]}'
- )
-
- result = await llm.chat_json(
- messages=[
- {"role": "system", "content": "You are a scientific terminology expert. Return valid JSON only."},
- {"role": "user", "content": prompt},
- ],
- task_type="keyword_expand",
+ svc = KeywordService(db, llm)
+ expanded = await svc.expand_keywords_with_llm(
+ project_id=project_id,
+ seed_terms=body.seed_terms,
+ language=body.language,
+ max_results=body.max_results,
)
return ApiResponse(
data=KeywordExpandResponse(
- expanded_terms=result.get("expanded_terms", []),
- source=f"llm:{llm.provider}",
+ expanded_terms=expanded,
+ source=f"llm:{llm.provider}" if llm else "none",
)
)
diff --git a/backend/app/api/v1/ocr.py b/backend/app/api/v1/ocr.py
index b1722b2..209a6a3 100644
--- a/backend/app/api/v1/ocr.py
+++ b/backend/app/api/v1/ocr.py
@@ -1,12 +1,14 @@
"""OCR processing API endpoints."""
+import asyncio
import logging
-from fastapi import APIRouter, Depends
+from fastapi import APIRouter, Depends, Request
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db, get_project
+from app.middleware.rate_limit import limiter
from app.models import Paper, PaperChunk, PaperStatus, Project
from app.schemas.common import ApiResponse
from app.services.ocr_service import OCRService
@@ -16,8 +18,10 @@
router = APIRouter(prefix="/projects/{project_id}/ocr", tags=["ocr"])
-@router.post("/process", response_model=ApiResponse[dict])
+@router.post("/process", response_model=ApiResponse[dict], summary="Run OCR on PDFs")
+@limiter.limit("5/minute")
async def process_ocr(
+ request: Request,
project_id: int,
paper_ids: list[int] | None = None,
force_ocr: bool = False,
@@ -41,49 +45,47 @@ async def process_ocr(
if not papers:
return ApiResponse(data={"processed": 0, "failed": 0, "total": 0, "message": "No papers to process"})
- service = OCRService(use_gpu=use_gpu)
processed = 0
failed = 0
- for paper in papers:
- if not paper.pdf_path:
- failed += 1
- continue
-
- try:
- ocr_result = service.process_pdf(paper.pdf_path, force_ocr=force_ocr)
-
- if ocr_result.get("error"):
+ with OCRService(use_gpu=use_gpu) as service:
+ for paper in papers:
+ if not paper.pdf_path:
failed += 1
continue
- # Save OCR result
- service.save_result(paper.id, ocr_result)
-
- # Create chunks and store in DB
- chunks = service.chunk_text(ocr_result["pages"])
- for chunk_data in chunks:
- chunk = PaperChunk(
- paper_id=paper.id,
- chunk_type=chunk_data["chunk_type"],
- content=chunk_data["content"],
- page_number=chunk_data.get("page_number"),
- chunk_index=chunk_data["chunk_index"],
- token_count=chunk_data.get("token_count", 0),
- )
- db.add(chunk)
-
- paper.status = PaperStatus.OCR_COMPLETE
- processed += 1
- except Exception as e:
- logger.error("OCR failed for paper %s: %s", paper.id, e)
- failed += 1
+ try:
+ ocr_result = await asyncio.to_thread(service.process_pdf, paper.pdf_path, force_ocr=force_ocr)
+
+ if ocr_result.get("error"):
+ failed += 1
+ continue
+
+ service.save_result(paper.id, ocr_result)
+
+ chunks = service.chunk_text(ocr_result["pages"])
+ for chunk_data in chunks:
+ chunk = PaperChunk(
+ paper_id=paper.id,
+ chunk_type=chunk_data["chunk_type"],
+ content=chunk_data["content"],
+ page_number=chunk_data.get("page_number"),
+ chunk_index=chunk_data["chunk_index"],
+ token_count=chunk_data.get("token_count", 0),
+ )
+ db.add(chunk)
+
+ paper.status = PaperStatus.OCR_COMPLETE
+ processed += 1
+ except Exception as e:
+ logger.error("OCR failed for paper %s: %s", paper.id, e)
+ failed += 1
await db.flush()
return ApiResponse(data={"processed": processed, "failed": failed, "total": len(papers)})
-@router.get("/stats", response_model=ApiResponse[dict])
+@router.get("/stats", response_model=ApiResponse[dict], summary="Get OCR statistics")
async def ocr_stats(
project_id: int,
db: AsyncSession = Depends(get_db),
diff --git a/backend/app/api/v1/papers.py b/backend/app/api/v1/papers.py
index 777fbcb..a632b9d 100644
--- a/backend/app/api/v1/papers.py
+++ b/backend/app/api/v1/papers.py
@@ -7,20 +7,21 @@
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
-from app.api.deps import get_db, get_project
+from app.api.deps import get_db, get_or_404, get_project
from app.config import settings
from app.models import Paper, Project
-from app.schemas.common import ApiResponse, PaginatedData
-from app.schemas.paper import PaperBulkImport, PaperCreate, PaperRead, PaperUpdate
+from app.models.chunk import PaperChunk
+from app.schemas.chunk import ChunkRead
+from app.schemas.common import ApiResponse, PaginatedData, PaginationParams
+from app.schemas.paper import PaperBatchDeleteRequest, PaperBulkImport, PaperCreate, PaperRead, PaperUpdate
router = APIRouter(tags=["papers"])
-@router.get("", response_model=ApiResponse[PaginatedData[PaperRead]])
+@router.get("", response_model=ApiResponse[PaginatedData[PaperRead]], summary="List papers with filters")
async def list_papers(
project_id: int,
- page: int = 1,
- page_size: int = 20,
+ pagination: PaginationParams = Depends(),
status: str | None = None,
year: int | None = None,
q: str | None = Query(default=None, description="Search title/abstract"),
@@ -29,6 +30,7 @@ async def list_papers(
db: AsyncSession = Depends(get_db),
project: Project = Depends(get_project),
):
+ page, page_size = pagination.page, pagination.page_size
base = select(Paper).where(Paper.project_id == project_id)
count_base = select(func.count(Paper.id)).where(Paper.project_id == project_id)
@@ -64,7 +66,7 @@ async def list_papers(
)
-@router.post("", response_model=ApiResponse[PaperRead], status_code=201)
+@router.post("", response_model=ApiResponse[PaperRead], status_code=201, summary="Create paper")
async def create_paper(
project_id: int,
body: PaperCreate,
@@ -78,7 +80,7 @@ async def create_paper(
return ApiResponse(code=201, message="Paper created", data=PaperRead.model_validate(paper))
-@router.post("/bulk", response_model=ApiResponse[dict])
+@router.post("/bulk", response_model=ApiResponse[dict], summary="Bulk import papers")
async def bulk_import_papers(
project_id: int,
body: PaperBulkImport,
@@ -102,20 +104,38 @@ async def bulk_import_papers(
return ApiResponse(data={"created": created, "skipped": skipped, "total": len(body.papers)})
-@router.get("/{paper_id}", response_model=ApiResponse[PaperRead])
+@router.post("/batch-delete", response_model=ApiResponse[dict], summary="Batch delete papers")
+async def batch_delete_papers(
+ project_id: int,
+ body: PaperBatchDeleteRequest,
+ db: AsyncSession = Depends(get_db),
+ project: Project = Depends(get_project),
+):
+ """Delete multiple papers at once."""
+ stmt = select(Paper).where(
+ Paper.project_id == project_id,
+ Paper.id.in_(body.paper_ids),
+ )
+ result = await db.execute(stmt)
+ papers = list(result.scalars().all())
+ for paper in papers:
+ await db.delete(paper)
+ await db.flush()
+ return ApiResponse(data={"deleted": len(papers), "requested": len(body.paper_ids)})
+
+
+@router.get("/{paper_id}", response_model=ApiResponse[PaperRead], summary="Get paper by ID")
async def get_paper(
project_id: int,
paper_id: int,
db: AsyncSession = Depends(get_db),
project: Project = Depends(get_project),
):
- paper = await db.get(Paper, paper_id)
- if not paper or paper.project_id != project_id:
- raise HTTPException(status_code=404, detail="Paper not found")
+ paper = await get_or_404(db, Paper, paper_id, project_id=project_id, detail="Paper not found")
return ApiResponse(data=PaperRead.model_validate(paper))
-@router.put("/{paper_id}", response_model=ApiResponse[PaperRead])
+@router.put("/{paper_id}", response_model=ApiResponse[PaperRead], summary="Update paper")
async def update_paper(
project_id: int,
paper_id: int,
@@ -123,9 +143,7 @@ async def update_paper(
db: AsyncSession = Depends(get_db),
project: Project = Depends(get_project),
):
- paper = await db.get(Paper, paper_id)
- if not paper or paper.project_id != project_id:
- raise HTTPException(status_code=404, detail="Paper not found")
+ paper = await get_or_404(db, Paper, paper_id, project_id=project_id, detail="Paper not found")
for key, value in body.model_dump(exclude_unset=True).items():
setattr(paper, key, value)
await db.flush()
@@ -133,21 +151,19 @@ async def update_paper(
return ApiResponse(data=PaperRead.model_validate(paper))
-@router.delete("/{paper_id}", response_model=ApiResponse)
+@router.delete("/{paper_id}", response_model=ApiResponse, summary="Delete paper")
async def delete_paper(
project_id: int,
paper_id: int,
db: AsyncSession = Depends(get_db),
project: Project = Depends(get_project),
):
- paper = await db.get(Paper, paper_id)
- if not paper or paper.project_id != project_id:
- raise HTTPException(status_code=404, detail="Paper not found")
+ paper = await get_or_404(db, Paper, paper_id, project_id=project_id, detail="Paper not found")
await db.delete(paper)
return ApiResponse(message="Paper deleted")
-@router.get("/{paper_id}/pdf")
+@router.get("/{paper_id}/pdf", summary="Serve PDF file")
async def serve_pdf(
project_id: int,
paper_id: int,
@@ -155,9 +171,7 @@ async def serve_pdf(
project: Project = Depends(get_project),
):
"""Serve the PDF file for a paper."""
- paper = await db.get(Paper, paper_id)
- if not paper or paper.project_id != project_id:
- raise HTTPException(status_code=404, detail="Paper not found")
+ paper = await get_or_404(db, Paper, paper_id, project_id=project_id, detail="Paper not found")
if not paper.pdf_path or not Path(paper.pdf_path).exists():
raise HTTPException(status_code=404, detail="PDF file not available")
@@ -174,7 +188,42 @@ async def serve_pdf(
return FileResponse(str(pdf_path), media_type="application/pdf", filename=f"{paper.title[:80]}.pdf")
-@router.get("/{paper_id}/citation-graph", response_model=ApiResponse)
+@router.get("/{paper_id}/chunks", response_model=ApiResponse[PaginatedData[ChunkRead]], summary="List paper chunks")
+async def list_paper_chunks(
+ project_id: int,
+ paper_id: int,
+ page: int = 1,
+ page_size: int = Query(default=50, ge=1, le=200),
+ chunk_type: str | None = Query(default=None, description="Filter by chunk type"),
+ db: AsyncSession = Depends(get_db),
+ project: Project = Depends(get_project),
+):
+ """List chunks for a specific paper."""
+ await get_or_404(db, Paper, paper_id, project_id=project_id, detail="Paper not found")
+
+ base = select(PaperChunk).where(PaperChunk.paper_id == paper_id)
+ count_base = select(func.count(PaperChunk.id)).where(PaperChunk.paper_id == paper_id)
+
+ if chunk_type:
+ base = base.where(PaperChunk.chunk_type == chunk_type)
+ count_base = count_base.where(PaperChunk.chunk_type == chunk_type)
+
+ total = (await db.execute(count_base)).scalar() or 0
+ base = base.order_by(PaperChunk.chunk_index).offset((page - 1) * page_size).limit(page_size)
+ chunks = (await db.execute(base)).scalars().all()
+
+ return ApiResponse(
+ data=PaginatedData(
+ items=[ChunkRead.model_validate(c) for c in chunks],
+ total=total,
+ page=page,
+ page_size=page_size,
+ total_pages=(total + page_size - 1) // page_size if total else 1,
+ )
+ )
+
+
+@router.get("/{paper_id}/citation-graph", response_model=ApiResponse, summary="Get citation graph")
async def get_citation_graph(
project_id: int,
paper_id: int,
diff --git a/backend/app/api/v1/pipelines.py b/backend/app/api/v1/pipelines.py
index 4a2301f..ad628bc 100644
--- a/backend/app/api/v1/pipelines.py
+++ b/backend/app/api/v1/pipelines.py
@@ -3,13 +3,19 @@
import asyncio
import logging
import uuid
+from typing import Literal
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, Depends, HTTPException, Query, Request, WebSocket, WebSocketDisconnect
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db, get_project_or_404
+from app.config import settings
+from app.middleware.rate_limit import limiter
+from app.models.task import Task, TaskStatus, TaskType
+from app.pipelines.cancellation import clear_cancelled, mark_cancelled
from app.schemas.common import ApiResponse
+from app.websocket.manager import pipeline_manager
logger = logging.getLogger(__name__)
@@ -30,12 +36,43 @@ class UploadPipelineRequest(BaseModel):
pdf_paths: list[str]
+class ResolvedConflict(BaseModel):
+ conflict_id: str
+ action: Literal["keep_old", "keep_new", "merge", "skip"]
+ merged_paper: dict | None = None
+ new_paper: dict | None = None
+
+
class ResumeRequest(BaseModel):
- resolved_conflicts: list[dict] = []
+ resolved_conflicts: list[ResolvedConflict] = []
+
+
+@router.get("", response_model=ApiResponse[list[dict]], summary="List pipelines")
+async def list_pipelines(
+ status: str | None = None,
+):
+ """List all pipelines (running, interrupted, completed, failed, cancelled)."""
+ data = []
+ for thread_id, task in _running_tasks.items():
+ if status and task["status"] != status:
+ continue
+ data.append(
+ {
+ "thread_id": thread_id,
+ "status": task["status"],
+ "task_id": task.get("task_id"),
+ }
+ )
+ return ApiResponse(data=data)
-@router.post("/search", response_model=ApiResponse[dict])
-async def start_search_pipeline(body: SearchPipelineRequest, db: AsyncSession = Depends(get_db)):
+@router.post("/search", response_model=ApiResponse[dict], summary="Start search pipeline")
+@limiter.limit("10/minute")
+async def start_search_pipeline(
+ request: Request,
+ body: SearchPipelineRequest,
+ db: AsyncSession = Depends(get_db),
+):
"""Start a keyword-search pipeline: search → dedup → crawl → OCR → index."""
await get_project_or_404(body.project_id, db)
@@ -66,7 +103,24 @@ async def start_search_pipeline(body: SearchPipelineRequest, db: AsyncSession =
config = {"configurable": {"thread_id": thread_id}}
- _running_tasks[thread_id] = {"status": "running", "pipeline": pipeline, "config": config}
+ task_record = Task(
+ project_id=body.project_id,
+ task_type=TaskType.SEARCH,
+ status=TaskStatus.RUNNING,
+ progress=0,
+ total=100,
+ result={"thread_id": thread_id, "pipeline_type": "search"},
+ )
+ db.add(task_record)
+ await db.flush()
+
+ _running_tasks[thread_id] = {
+ "status": "running",
+ "pipeline": pipeline,
+ "config": config,
+ "task_id": task_record.id,
+ "project_id": body.project_id,
+ }
async def _run():
try:
@@ -78,24 +132,49 @@ async def _run():
else:
_running_tasks[thread_id]["status"] = "completed"
_running_tasks[thread_id]["result"] = result
+ await pipeline_manager.broadcast_to_room(
+ thread_id,
+ {
+ "type": "status",
+ "status": _running_tasks[thread_id]["status"],
+ "stage": result.get("stage", ""),
+ "progress": result.get("progress", 0),
+ },
+ )
+ except asyncio.CancelledError:
+ _running_tasks[thread_id]["status"] = "cancelled"
+ await pipeline_manager.broadcast_to_room(thread_id, {"type": "status", "status": "cancelled"})
except Exception as e:
logger.error("Pipeline %s failed: %s", thread_id, e)
_running_tasks[thread_id]["status"] = "failed"
_running_tasks[thread_id]["error"] = str(e)
+ await pipeline_manager.broadcast_to_room(thread_id, {"type": "error", "message": str(e)})
+ finally:
+ s = _running_tasks.get(thread_id, {}).get("status")
+ if s in ("completed", "failed", "cancelled"):
+ clear_cancelled(thread_id)
+ _running_tasks.pop(thread_id, None)
- asyncio.create_task(_run())
+ task_ref = asyncio.create_task(_run())
+ _running_tasks[thread_id]["asyncio_task"] = task_ref
return ApiResponse(
data={
"thread_id": thread_id,
"status": "running",
"project_id": body.project_id,
+ "task_id": task_record.id,
}
)
-@router.post("/upload", response_model=ApiResponse[dict])
-async def start_upload_pipeline(body: UploadPipelineRequest, db: AsyncSession = Depends(get_db)):
+@router.post("/upload", response_model=ApiResponse[dict], summary="Start upload pipeline")
+@limiter.limit("10/minute")
+async def start_upload_pipeline(
+ request: Request,
+ body: UploadPipelineRequest,
+ db: AsyncSession = Depends(get_db),
+):
"""Start a PDF-upload pipeline: extract → dedup → OCR → index."""
from pathlib import Path as _Path
@@ -107,7 +186,7 @@ async def start_upload_pipeline(body: UploadPipelineRequest, db: AsyncSession =
safe_paths: list[str] = []
for p in body.pdf_paths:
resolved = _Path(p).resolve()
- if not str(resolved).startswith(str(allowed_root)):
+ if not resolved.is_relative_to(allowed_root):
raise HTTPException(status_code=400, detail=f"Path not within allowed directory: {p}")
safe_paths.append(str(resolved))
@@ -133,7 +212,25 @@ async def start_upload_pipeline(body: UploadPipelineRequest, db: AsyncSession =
}
config = {"configurable": {"thread_id": thread_id}}
- _running_tasks[thread_id] = {"status": "running", "pipeline": pipeline, "config": config}
+
+ task_record = Task(
+ project_id=body.project_id,
+ task_type=TaskType.OCR,
+ status=TaskStatus.RUNNING,
+ progress=0,
+ total=100,
+ result={"thread_id": thread_id, "pipeline_type": "upload"},
+ )
+ db.add(task_record)
+ await db.flush()
+
+ _running_tasks[thread_id] = {
+ "status": "running",
+ "pipeline": pipeline,
+ "config": config,
+ "task_id": task_record.id,
+ "project_id": body.project_id,
+ }
async def _run():
try:
@@ -145,23 +242,38 @@ async def _run():
else:
_running_tasks[thread_id]["status"] = "completed"
_running_tasks[thread_id]["result"] = result
+ await pipeline_manager.broadcast_to_room(
+ thread_id,
+ {
+ "type": "status",
+ "status": _running_tasks[thread_id]["status"],
+ "stage": result.get("stage", ""),
+ "progress": result.get("progress", 0),
+ },
+ )
+ except asyncio.CancelledError:
+ _running_tasks[thread_id]["status"] = "cancelled"
+ await pipeline_manager.broadcast_to_room(thread_id, {"type": "status", "status": "cancelled"})
except Exception as e:
logger.error("Pipeline %s failed: %s", thread_id, e)
_running_tasks[thread_id]["status"] = "failed"
_running_tasks[thread_id]["error"] = str(e)
+ await pipeline_manager.broadcast_to_room(thread_id, {"type": "error", "message": str(e)})
- asyncio.create_task(_run())
+ task_ref = asyncio.create_task(_run())
+ _running_tasks[thread_id]["asyncio_task"] = task_ref
return ApiResponse(
data={
"thread_id": thread_id,
"status": "running",
"project_id": body.project_id,
+ "task_id": task_record.id,
}
)
-@router.get("/{thread_id}/status", response_model=ApiResponse[dict])
+@router.get("/{thread_id}/status", response_model=ApiResponse[dict], summary="Get pipeline status")
async def get_pipeline_status(thread_id: str):
"""Get pipeline execution status."""
task = _running_tasks.get(thread_id)
@@ -200,7 +312,7 @@ async def get_pipeline_status(thread_id: str):
return ApiResponse(data=data)
-@router.post("/{thread_id}/resume", response_model=ApiResponse[dict])
+@router.post("/{thread_id}/resume", response_model=ApiResponse[dict], summary="Resume pipeline")
async def resume_pipeline(thread_id: str, body: ResumeRequest):
"""Resume an interrupted pipeline with resolved conflicts."""
from langgraph.types import Command
@@ -208,6 +320,8 @@ async def resume_pipeline(thread_id: str, body: ResumeRequest):
task = _running_tasks.get(thread_id)
if not task:
raise HTTPException(status_code=404, detail="Pipeline not found")
+ if task["status"] == "cancelled":
+ raise HTTPException(status_code=400, detail="Pipeline was cancelled, cannot resume")
if task["status"] != "interrupted":
raise HTTPException(status_code=400, detail=f"Pipeline is {task['status']}, not interrupted")
@@ -215,10 +329,12 @@ async def resume_pipeline(thread_id: str, body: ResumeRequest):
config = task["config"]
task["status"] = "running"
+ raw_conflicts = [rc.model_dump() for rc in body.resolved_conflicts]
+
async def _resume():
try:
result = await pipeline.ainvoke(
- Command(resume=body.resolved_conflicts),
+ Command(resume=raw_conflicts),
config=config,
)
snapshot = pipeline.get_state(config)
@@ -227,22 +343,71 @@ async def _resume():
else:
task["status"] = "completed"
task["result"] = result
+ except asyncio.CancelledError:
+ task["status"] = "cancelled"
except Exception as e:
logger.error("Pipeline resume %s failed: %s", thread_id, e)
task["status"] = "failed"
task["error"] = str(e)
+ finally:
+ s = task.get("status")
+ if s in ("completed", "failed", "cancelled"):
+ clear_cancelled(thread_id)
+ _running_tasks.pop(thread_id, None)
- asyncio.create_task(_resume())
+ task_ref = asyncio.create_task(_resume())
+ task["asyncio_task"] = task_ref
return ApiResponse(data={"thread_id": thread_id, "status": "running"})
-@router.post("/{thread_id}/cancel", response_model=ApiResponse[dict])
+@router.post("/{thread_id}/cancel", response_model=ApiResponse[dict], summary="Cancel pipeline")
async def cancel_pipeline(thread_id: str):
- """Cancel a running pipeline."""
+ """Cancel a running or interrupted pipeline."""
task = _running_tasks.get(thread_id)
if not task:
raise HTTPException(status_code=404, detail="Pipeline not found")
+ if task["status"] in ("completed", "failed"):
+ raise HTTPException(status_code=400, detail=f"Pipeline already {task['status']}")
+ mark_cancelled(thread_id)
task["status"] = "cancelled"
+
+ asyncio_task = task.get("asyncio_task")
+ if asyncio_task and not asyncio_task.done():
+ asyncio_task.cancel()
+ else:
+ # Interrupted pipeline: no running task, cleanup here to avoid leak
+ clear_cancelled(thread_id)
+ _running_tasks.pop(thread_id, None)
+
+ await pipeline_manager.broadcast_to_room(thread_id, {"type": "status", "status": "cancelled"})
return ApiResponse(data={"thread_id": thread_id, "status": "cancelled"})
+
+
+@router.websocket("/{thread_id}/ws")
+async def pipeline_status_websocket(
+ websocket: WebSocket,
+ thread_id: str,
+ api_key: str | None = Query(default=None),
+):
+ """WebSocket endpoint for real-time pipeline status updates."""
+ if settings.api_secret_key and api_key != settings.api_secret_key:
+ await websocket.close(code=4008)
+ return
+
+ await pipeline_manager.connect(websocket, thread_id)
+ try:
+ task = _running_tasks.get(thread_id)
+ if task:
+ await websocket.send_json(
+ {
+ "type": "status",
+ "status": task["status"],
+ "thread_id": thread_id,
+ }
+ )
+ while True:
+ await websocket.receive_text()
+ except WebSocketDisconnect:
+ pipeline_manager.disconnect(websocket, thread_id)
diff --git a/backend/app/api/v1/projects.py b/backend/app/api/v1/projects.py
index 4fa0446..262f48b 100644
--- a/backend/app/api/v1/projects.py
+++ b/backend/app/api/v1/projects.py
@@ -1,24 +1,43 @@
"""Project CRUD API endpoints."""
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, Depends
+from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
-from app.api.deps import get_db
-from app.models import Keyword, Paper, Project
-from app.schemas.common import ApiResponse, PaginatedData
-from app.schemas.project import ProjectCreate, ProjectRead, ProjectUpdate
+from app.api.deps import get_db, get_or_404
+from app.models import Keyword, Paper, Project, Subscription
+from app.schemas.common import ApiResponse, PaginatedData, PaginationParams
+from app.schemas.project import (
+ KeywordImportItem,
+ PaperImportItem,
+ ProjectCreate,
+ ProjectRead,
+ ProjectUpdate,
+ SubscriptionImportItem,
+)
from app.services.pipeline_service import PipelineService
router = APIRouter(tags=["projects"])
-@router.get("", response_model=ApiResponse[PaginatedData[ProjectRead]])
+class ProjectImportRequest(BaseModel):
+ """Request body for project import."""
+
+ name: str
+ description: str = ""
+ domain: str = ""
+ papers: list[PaperImportItem] = []
+ keywords: list[KeywordImportItem] = []
+ subscriptions: list[SubscriptionImportItem] = []
+
+
+@router.get("", response_model=ApiResponse[PaginatedData[ProjectRead]], summary="List all projects")
async def list_projects(
- page: int = 1,
- page_size: int = 20,
+ pagination: PaginationParams = Depends(),
db: AsyncSession = Depends(get_db),
):
+ page, page_size = pagination.page, pagination.page_size
total_stmt = select(func.count(Project.id))
total = (await db.execute(total_stmt)).scalar() or 0
@@ -71,7 +90,7 @@ async def list_projects(
)
-@router.post("", response_model=ApiResponse[ProjectRead], status_code=201)
+@router.post("", response_model=ApiResponse[ProjectRead], status_code=201, summary="Create a project")
async def create_project(body: ProjectCreate, db: AsyncSession = Depends(get_db)):
project = Project(**body.model_dump())
db.add(project)
@@ -92,11 +111,9 @@ async def create_project(body: ProjectCreate, db: AsyncSession = Depends(get_db)
)
-@router.get("/{project_id}", response_model=ApiResponse[ProjectRead])
+@router.get("/{project_id}", response_model=ApiResponse[ProjectRead], summary="Get project by ID")
async def get_project(project_id: int, db: AsyncSession = Depends(get_db)):
- project = await db.get(Project, project_id)
- if not project:
- raise HTTPException(status_code=404, detail="Project not found")
+ project = await get_or_404(db, Project, project_id, detail="Project not found")
paper_count = (await db.execute(select(func.count(Paper.id)).where(Paper.project_id == project_id))).scalar() or 0
kw_count = (await db.execute(select(func.count(Keyword.id)).where(Keyword.project_id == project_id))).scalar() or 0
return ApiResponse(
@@ -114,11 +131,9 @@ async def get_project(project_id: int, db: AsyncSession = Depends(get_db)):
)
-@router.put("/{project_id}", response_model=ApiResponse[ProjectRead])
+@router.put("/{project_id}", response_model=ApiResponse[ProjectRead], summary="Update project")
async def update_project(project_id: int, body: ProjectUpdate, db: AsyncSession = Depends(get_db)):
- project = await db.get(Project, project_id)
- if not project:
- raise HTTPException(status_code=404, detail="Project not found")
+ project = await get_or_404(db, Project, project_id, detail="Project not found")
for key, value in body.model_dump(exclude_unset=True).items():
setattr(project, key, value)
await db.flush()
@@ -140,35 +155,119 @@ async def update_project(project_id: int, body: ProjectUpdate, db: AsyncSession
)
-@router.delete("/{project_id}", response_model=ApiResponse)
+@router.delete("/{project_id}", response_model=ApiResponse, summary="Delete project")
async def delete_project(project_id: int, db: AsyncSession = Depends(get_db)):
- project = await db.get(Project, project_id)
- if not project:
- raise HTTPException(status_code=404, detail="Project not found")
+ project = await get_or_404(db, Project, project_id, detail="Project not found")
await db.delete(project)
return ApiResponse(message="Project deleted")
-@router.post("/{project_id}/pipeline/run", response_model=ApiResponse[dict])
+@router.get("/{project_id}/export", response_model=ApiResponse[dict], summary="Export project as JSON")
+async def export_project(project_id: int, db: AsyncSession = Depends(get_db)):
+ """Export project data as JSON (papers, keywords, subscriptions)."""
+ project = await get_or_404(db, Project, project_id, detail="Project not found")
+
+ papers = (await db.execute(select(Paper).where(Paper.project_id == project_id))).scalars().all()
+ keywords = (await db.execute(select(Keyword).where(Keyword.project_id == project_id))).scalars().all()
+ subs = (await db.execute(select(Subscription).where(Subscription.project_id == project_id))).scalars().all()
+
+ return ApiResponse(
+ data={
+ "name": project.name,
+ "description": project.description,
+ "domain": project.domain,
+ "papers": [
+ {
+ "title": p.title,
+ "abstract": p.abstract,
+ "doi": p.doi,
+ "authors": p.authors,
+ "year": p.year,
+ "journal": p.journal,
+ "source": p.source,
+ "pdf_url": p.pdf_url,
+ "status": p.status,
+ "citation_count": p.citation_count,
+ }
+ for p in papers
+ ],
+ "keywords": [
+ {"term": k.term, "term_en": k.term_en, "level": k.level, "category": k.category, "synonyms": k.synonyms}
+ for k in keywords
+ ],
+ "subscriptions": [
+ {
+ "name": s.name,
+ "query": s.query,
+ "sources": s.sources,
+ "frequency": s.frequency,
+ "max_results": s.max_results,
+ }
+ for s in subs
+ ],
+ }
+ )
+
+
+@router.post("/import", response_model=ApiResponse[ProjectRead], status_code=201, summary="Import project from JSON")
+async def import_project(body: ProjectImportRequest, db: AsyncSession = Depends(get_db)):
+ """Import a previously exported project."""
+ project = Project(name=body.name, description=body.description, domain=body.domain)
+ db.add(project)
+ await db.flush()
+
+ paper_cols = {c.name for c in Paper.__table__.columns} - {"id", "project_id", "created_at", "updated_at"}
+ kw_cols = {c.name for c in Keyword.__table__.columns} - {"id", "project_id", "created_at"}
+ sub_cols = {c.name for c in Subscription.__table__.columns} - {"id", "project_id", "created_at", "updated_at"}
+
+ for pd in body.papers:
+ db.add(Paper(project_id=project.id, **{k: v for k, v in pd.model_dump().items() if k in paper_cols}))
+
+ for kd in body.keywords:
+ db.add(Keyword(project_id=project.id, **{k: v for k, v in kd.model_dump().items() if k in kw_cols}))
+
+ for sd in body.subscriptions:
+ db.add(Subscription(project_id=project.id, **{k: v for k, v in sd.model_dump().items() if k in sub_cols}))
+
+ await db.flush()
+ await db.refresh(project)
+
+ paper_count = (await db.execute(select(func.count(Paper.id)).where(Paper.project_id == project.id))).scalar() or 0
+ kw_count = (await db.execute(select(func.count(Keyword.id)).where(Keyword.project_id == project.id))).scalar() or 0
+
+ return ApiResponse(
+ code=201,
+ message="Project imported",
+ data=ProjectRead(
+ id=project.id,
+ name=project.name,
+ description=project.description,
+ domain=project.domain,
+ settings=project.settings,
+ created_at=project.created_at,
+ updated_at=project.updated_at,
+ paper_count=paper_count,
+ keyword_count=kw_count,
+ ),
+ )
+
+
+@router.post("/{project_id}/pipeline/run", response_model=ApiResponse[dict], summary="Run crawl-OCR-index pipeline")
async def run_pipeline(project_id: int, db: AsyncSession = Depends(get_db)):
"""Trigger the crawl → OCR → index pipeline for all pending papers."""
- project = await db.get(Project, project_id)
- if not project:
- raise HTTPException(status_code=404, detail="Project not found")
+ await get_or_404(db, Project, project_id, detail="Project not found")
svc = PipelineService(db)
result = await svc.process_project_pending(project_id)
return ApiResponse(data=result)
-@router.post("/{project_id}/pipeline/paper/{paper_id}", response_model=ApiResponse[dict])
+@router.post(
+ "/{project_id}/pipeline/paper/{paper_id}", response_model=ApiResponse[dict], summary="Run pipeline for single paper"
+)
async def run_paper_pipeline(project_id: int, paper_id: int, db: AsyncSession = Depends(get_db)):
"""Trigger the pipeline for a single paper."""
- project = await db.get(Project, project_id)
- if not project:
- raise HTTPException(status_code=404, detail="Project not found")
- paper = await db.get(Paper, paper_id)
- if not paper or paper.project_id != project_id:
- raise HTTPException(status_code=404, detail="Paper not found in this project")
+ await get_or_404(db, Project, project_id, detail="Project not found")
+ await get_or_404(db, Paper, paper_id, project_id=project_id, detail="Paper not found in this project")
svc = PipelineService(db)
result = await svc.process_paper(paper_id)
return ApiResponse(data=result)
diff --git a/backend/app/api/v1/rag.py b/backend/app/api/v1/rag.py
index 607d914..d3fe099 100644
--- a/backend/app/api/v1/rag.py
+++ b/backend/app/api/v1/rag.py
@@ -4,27 +4,30 @@
import json
import logging
-from fastapi import APIRouter, Depends
+from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
-from app.api.deps import get_db, get_llm
-from app.models import Paper, PaperStatus
+from app.api.deps import get_db, get_llm, get_project
+from app.middleware.rate_limit import limiter
+from app.models import Paper, PaperStatus, Project
from app.schemas.common import ApiResponse
-from app.services.llm_client import LLMClient
+from app.services.llm.client import LLMClient
from app.services.rag_service import RAGService
+from app.utils.sse import format_sse_error
logger = logging.getLogger(__name__)
+
router = APIRouter(prefix="/projects/{project_id}/rag", tags=["rag"])
class RAGQueryRequest(BaseModel):
question: str
- top_k: int = 10
+ top_k: int = Field(default=10, ge=1, le=50)
use_reranker: bool = True
include_sources: bool = True
@@ -39,11 +42,12 @@ def get_rag_service(llm: LLMClient = Depends(get_llm)) -> RAGService:
return RAGService(llm=llm)
-@router.post("/query", response_model=ApiResponse[RAGQueryResponse])
+@router.post("/query", response_model=ApiResponse[RAGQueryResponse], summary="RAG query over literature")
async def rag_query(
project_id: int,
body: RAGQueryRequest,
rag: RAGService = Depends(get_rag_service),
+ project: Project = Depends(get_project),
):
"""Answer a question using RAG over the project's indexed literature."""
result = await rag.query(
@@ -56,8 +60,10 @@ async def rag_query(
return ApiResponse(data=RAGQueryResponse(**result))
-@router.post("/index", response_model=ApiResponse[dict])
+@router.post("/index", response_model=ApiResponse[dict], summary="Build vector index")
+@limiter.limit("5/minute")
async def build_index(
+ request: Request,
project_id: int,
db: AsyncSession = Depends(get_db),
rag: RAGService = Depends(get_rag_service),
@@ -95,7 +101,14 @@ async def build_index(
}
)
- index_result = await rag.index_chunks(project_id=project_id, chunks=chunks_to_index)
+ try:
+ index_result = await rag.index_chunks(project_id=project_id, chunks=chunks_to_index)
+ except RuntimeError as exc:
+ if "CUDA out of memory" not in str(exc):
+ raise
+ logger.warning("CUDA OOM during indexing, reloading model on best GPU and retrying")
+ rag._reload_embed_model()
+ index_result = await rag.index_chunks(project_id=project_id, chunks=chunks_to_index)
# Update paper status to INDEXED
for paper in papers:
@@ -110,11 +123,12 @@ async def build_index(
)
-@router.post("/index/stream")
+@router.post("/index/stream", summary="Build index with SSE progress")
async def build_index_stream(
project_id: int,
db: AsyncSession = Depends(get_db),
rag: RAGService = Depends(get_rag_service),
+ project: Project = Depends(get_project),
):
"""SSE streaming rebuild — sends progress events so the UI stays responsive."""
@@ -191,7 +205,7 @@ def on_progress(stage: str, percent: int) -> None:
)
except Exception as exc:
logger.exception("SSE index build failed")
- yield _sse("error", {"message": str(exc)})
+ yield format_sse_error(str(exc), code=500)
return StreamingResponse(
_generate(),
@@ -204,15 +218,23 @@ def on_progress(stage: str, percent: int) -> None:
)
-@router.get("/stats", response_model=ApiResponse[dict])
-async def index_stats(project_id: int, rag: RAGService = Depends(get_rag_service)):
+@router.get("/stats", response_model=ApiResponse[dict], summary="Get index statistics")
+async def index_stats(
+ project_id: int,
+ rag: RAGService = Depends(get_rag_service),
+ project: Project = Depends(get_project),
+):
"""Return indexing statistics."""
stats = await rag.get_stats(project_id=project_id)
return ApiResponse(data=stats)
-@router.delete("/index", response_model=ApiResponse[dict])
-async def delete_index(project_id: int, rag: RAGService = Depends(get_rag_service)):
+@router.delete("/index", response_model=ApiResponse[dict], summary="Delete vector index")
+async def delete_index(
+ project_id: int,
+ rag: RAGService = Depends(get_rag_service),
+ project: Project = Depends(get_project),
+):
"""Delete the vector index for the project."""
result = await rag.delete_index(project_id=project_id)
return ApiResponse(data=result)
diff --git a/backend/app/api/v1/rewrite.py b/backend/app/api/v1/rewrite.py
index 5b40fd2..a22bfe0 100644
--- a/backend/app/api/v1/rewrite.py
+++ b/backend/app/api/v1/rewrite.py
@@ -7,41 +7,24 @@
import logging
from typing import Literal
+import httpx
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, field_validator, model_validator
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db
+from app.config import settings
+from app.prompts.rewrite import REWRITE_PROMPTS
from app.services.llm.client import get_llm_client
from app.services.user_settings_service import UserSettingsService
+from app.utils.sse import format_sse_error
logger = logging.getLogger(__name__)
-router = APIRouter(prefix="/chat", tags=["rewrite"])
+router = APIRouter(prefix="/chat", tags=["chat"])
-_rewrite_semaphore = asyncio.Semaphore(3)
-
-REWRITE_PROMPTS: dict[str, str] = {
- "simplify": (
- "Rewrite the following academic text in plain, accessible language. "
- "Keep the core meaning and key concepts intact, but make it understandable "
- "to a general audience. Output only the rewritten text, no explanations."
- ),
- "academic": (
- "Rewrite the following text in formal academic style. "
- "Use precise terminology, passive voice where appropriate, and proper "
- "academic conventions. Maintain the original meaning. Output only the rewritten text."
- ),
- "translate_en": (
- "Translate the following text into English. "
- "Preserve academic terminology and the original meaning. "
- "Output only the translation, no explanations."
- ),
- "translate_zh": ("将以下文本翻译为中文。保留学术术语和原意。仅输出翻译结果,不要添加解释。"),
-}
-
-REWRITE_TIMEOUT = 30.0
+_rewrite_semaphore = asyncio.Semaphore(settings.rewrite_semaphore_limit)
class RewriteRequest(BaseModel):
@@ -85,12 +68,15 @@ async def _stream_rewrite(request: RewriteRequest, db: AsyncSession):
full_text = ""
try:
- async with asyncio.timeout(REWRITE_TIMEOUT):
+ async with asyncio.timeout(settings.rewrite_timeout):
async for token in llm.chat_stream(messages, temperature=0.3, task_type="rewrite"):
full_text += token
yield _sse("rewrite_delta", {"delta": token})
except TimeoutError:
- yield _sse("error", {"code": "timeout", "message": "Rewrite timed out after 30s"})
+ yield format_sse_error(
+ f"Rewrite timed out after {settings.rewrite_timeout}s",
+ code=408,
+ )
return
yield _sse("rewrite_end", {"full_text": full_text})
@@ -98,12 +84,12 @@ async def _stream_rewrite(request: RewriteRequest, db: AsyncSession):
except asyncio.CancelledError:
logger.info("Rewrite stream cancelled by client")
return
- except Exception as e:
+ except (httpx.HTTPError, ValueError, RuntimeError) as e:
logger.exception("Rewrite stream error")
- yield _sse("error", {"code": "rewrite_error", "message": str(e)})
+ yield format_sse_error(str(e), code=500)
-@router.post("/rewrite")
+@router.post("/rewrite", summary="Stream excerpt rewrite")
async def rewrite_stream(
request: RewriteRequest,
db: AsyncSession = Depends(get_db),
diff --git a/backend/app/api/v1/search.py b/backend/app/api/v1/search.py
index e2c8f66..569a36d 100644
--- a/backend/app/api/v1/search.py
+++ b/backend/app/api/v1/search.py
@@ -1,10 +1,12 @@
"""Literature search API endpoints — multi-source federated search."""
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, Depends, HTTPException, Request
+from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db, get_project
+from app.middleware.rate_limit import limiter
from app.models import Keyword, Paper, Project
from app.schemas.common import ApiResponse
from app.services.search_service import SearchService
@@ -12,17 +14,29 @@
router = APIRouter(prefix="/projects/{project_id}/search", tags=["search"])
-@router.post("/execute", response_model=ApiResponse[dict])
+class SearchExecuteRequest(BaseModel):
+ """Request body for federated search execution."""
+
+ query: str = Field(default="", description="Search query; if empty, built from project keywords")
+ sources: list[str] | None = Field(default=None, description="Search sources to use")
+ max_results: int = Field(default=100, ge=1, le=500, description="Maximum results per source")
+ auto_import: bool = Field(default=False, description="Import results into project")
+
+
+@router.post("/execute", response_model=ApiResponse[dict], summary="Execute federated search")
+@limiter.limit("10/minute")
async def execute_search(
+ request: Request,
project_id: int,
- query: str = "",
- sources: list[str] | None = None,
- max_results: int = 100,
- auto_import: bool = False,
+ body: SearchExecuteRequest,
db: AsyncSession = Depends(get_db),
project: Project = Depends(get_project),
):
"""Execute federated search. If auto_import=True, import results to project."""
+ query = body.query
+ sources = body.sources
+ max_results = body.max_results
+ auto_import = body.auto_import
# If no query, build from project keywords
if not query:
@@ -38,7 +52,7 @@ async def execute_search(
)
service = SearchService()
- results = await service.search(query, sources, max_results)
+ results = await service.search(query, sources=sources, max_results=max_results)
# Optionally auto-import results
if auto_import and results["papers"]:
@@ -79,8 +93,8 @@ async def execute_search(
return ApiResponse(data=results)
-@router.get("/sources", response_model=ApiResponse[list[dict]])
-async def list_search_sources():
+@router.get("/sources", response_model=ApiResponse[list[dict]], summary="List search sources")
+async def list_search_sources(project: Project = Depends(get_project)):
"""Return available search sources and their status."""
return ApiResponse(
data=[
diff --git a/backend/app/api/v1/settings_api.py b/backend/app/api/v1/settings_api.py
index 73cc9c1..8d4abce 100644
--- a/backend/app/api/v1/settings_api.py
+++ b/backend/app/api/v1/settings_api.py
@@ -1,6 +1,6 @@
"""Application settings API — CRUD, model listing, and connection testing."""
-from fastapi import APIRouter, Depends
+from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db
@@ -11,7 +11,7 @@
router = APIRouter(prefix="/settings", tags=["settings"])
-@router.get("", response_model=ApiResponse[SettingsSchema])
+@router.get("", response_model=ApiResponse[SettingsSchema], summary="Get settings")
async def get_settings(db: AsyncSession = Depends(get_db)):
"""Return merged settings (DB overrides .env); API keys are masked."""
svc = UserSettingsService(db)
@@ -19,7 +19,7 @@ async def get_settings(db: AsyncSession = Depends(get_db)):
return ApiResponse(data=merged)
-@router.put("", response_model=ApiResponse[SettingsSchema])
+@router.put("", response_model=ApiResponse[SettingsSchema], summary="Update settings")
async def put_settings(
payload: SettingsUpdateSchema,
db: AsyncSession = Depends(get_db),
@@ -31,13 +31,13 @@ async def put_settings(
return ApiResponse(data=merged)
-@router.get("/models", response_model=ApiResponse[list[ProviderModelInfo]])
+@router.get("/models", response_model=ApiResponse[list[ProviderModelInfo]], summary="List available models")
async def list_models():
"""Return available LLM providers and their model lists."""
return ApiResponse(data=get_available_models())
-@router.post("/test-connection", response_model=ApiResponse[dict])
+@router.post("/test-connection", response_model=ApiResponse[dict], summary="Test LLM connection")
async def test_connection(db: AsyncSession = Depends(get_db)):
"""Test the current LLM configuration by sending a simple prompt."""
svc = UserSettingsService(db)
@@ -53,14 +53,10 @@ async def test_connection(db: AsyncSession = Depends(get_db)):
)
return ApiResponse(data={"success": True, "response": response[:200]})
except Exception as e:
- return ApiResponse(
- code=500,
- message="Connection test failed",
- data={"success": False, "error": str(e)},
- )
+ raise HTTPException(status_code=502, detail=f"Connection test failed: {e}") from e
-@router.get("/health", response_model=ApiResponse[dict])
+@router.get("/health", response_model=ApiResponse[dict], summary="Health check")
async def health_check():
"""Simple health check endpoint."""
return ApiResponse(data={"status": "healthy", "version": "0.1.0"})
diff --git a/backend/app/api/v1/subscription.py b/backend/app/api/v1/subscription.py
index 3c8ee73..f161c13 100644
--- a/backend/app/api/v1/subscription.py
+++ b/backend/app/api/v1/subscription.py
@@ -3,12 +3,12 @@
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query
-from sqlalchemy import select
+from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db, get_project
from app.models import Project, Subscription
-from app.schemas.common import ApiResponse
+from app.schemas.common import ApiResponse, PaginatedData, PaginationParams
from app.schemas.subscription import (
SubscriptionCreate,
SubscriptionRead,
@@ -20,26 +20,30 @@
router = APIRouter(prefix="/projects/{project_id}/subscriptions", tags=["subscriptions"])
-@router.get("/feeds", response_model=ApiResponse[list[dict]])
+@router.get("/feeds", response_model=ApiResponse[list[dict]], summary="List common RSS feeds")
async def list_common_feeds():
"""Return common academic RSS feed templates."""
return ApiResponse(data=SubscriptionService.get_common_feeds())
-@router.post("/check-rss", response_model=ApiResponse[dict])
+@router.post("/check-rss", response_model=ApiResponse[dict], summary="Check RSS feed for entries")
async def check_rss(
project_id: int,
feed_url: str = Query(..., description="RSS/Atom feed URL"),
since_days: int = Query(7, ge=1, le=365),
+ project: Project = Depends(get_project),
):
"""Check an RSS feed for new entries since the given number of days."""
service = SubscriptionService()
since = datetime.now() - timedelta(days=since_days)
- entries = await service.check_rss_feed(feed_url, since)
+ try:
+ entries = await service.check_rss_feed(feed_url, since)
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e)) from e
return ApiResponse(data={"entries": entries, "count": len(entries)})
-@router.post("/check-updates", response_model=ApiResponse[dict])
+@router.post("/check-updates", response_model=ApiResponse[dict], summary="Check API for new papers")
async def check_updates(
project_id: int,
query: str = Query(""),
@@ -47,6 +51,7 @@ async def check_updates(
since_days: int = Query(7, ge=1, le=365),
max_results: int = Query(50, ge=1, le=200),
db: AsyncSession = Depends(get_db),
+ project: Project = Depends(get_project),
):
"""Check for new papers via API search (incremental update)."""
service = SubscriptionService()
@@ -54,19 +59,38 @@ async def check_updates(
return ApiResponse(data=result)
-@router.get("", response_model=ApiResponse[list[SubscriptionRead]])
+@router.get("", response_model=ApiResponse[PaginatedData[SubscriptionRead]], summary="List subscriptions")
async def list_subscriptions(
project_id: int,
+ pagination: PaginationParams = Depends(),
db: AsyncSession = Depends(get_db),
project: Project = Depends(get_project),
):
- """List all subscriptions for a project."""
- result = await db.execute(select(Subscription).where(Subscription.project_id == project_id))
+ """List subscriptions for a project with pagination."""
+ page, page_size = pagination.page, pagination.page_size
+ count_stmt = select(func.count(Subscription.id)).where(Subscription.project_id == project_id)
+ total = (await db.execute(count_stmt)).scalar() or 0
+ stmt = (
+ select(Subscription)
+ .where(Subscription.project_id == project_id)
+ .order_by(Subscription.created_at.desc())
+ .offset((page - 1) * page_size)
+ .limit(page_size)
+ )
+ result = await db.execute(stmt)
subs = result.scalars().all()
- return ApiResponse(data=[SubscriptionRead.model_validate(s) for s in subs])
+ return ApiResponse(
+ data=PaginatedData(
+ items=[SubscriptionRead.model_validate(s) for s in subs],
+ total=total,
+ page=page,
+ page_size=page_size,
+ total_pages=(total + page_size - 1) // page_size if total else 1,
+ )
+ )
-@router.post("", response_model=ApiResponse[SubscriptionRead], status_code=201)
+@router.post("", response_model=ApiResponse[SubscriptionRead], status_code=201, summary="Create subscription")
async def create_subscription(
project_id: int,
body: SubscriptionCreate,
@@ -81,7 +105,7 @@ async def create_subscription(
return ApiResponse(code=201, message="Subscription created", data=SubscriptionRead.model_validate(sub))
-@router.get("/{sub_id}", response_model=ApiResponse[SubscriptionRead])
+@router.get("/{sub_id}", response_model=ApiResponse[SubscriptionRead], summary="Get subscription by ID")
async def get_subscription(
project_id: int,
sub_id: int,
@@ -97,7 +121,7 @@ async def get_subscription(
return ApiResponse(data=SubscriptionRead.model_validate(sub))
-@router.put("/{sub_id}", response_model=ApiResponse[SubscriptionRead])
+@router.put("/{sub_id}", response_model=ApiResponse[SubscriptionRead], summary="Update subscription")
async def update_subscription(
project_id: int,
sub_id: int,
@@ -119,7 +143,7 @@ async def update_subscription(
return ApiResponse(data=SubscriptionRead.model_validate(sub))
-@router.delete("/{sub_id}", response_model=ApiResponse[None])
+@router.delete("/{sub_id}", response_model=ApiResponse[None], summary="Delete subscription")
async def delete_subscription(
project_id: int,
sub_id: int,
@@ -136,15 +160,20 @@ async def delete_subscription(
return ApiResponse(message="Subscription deleted", data=None)
-@router.post("/{sub_id}/trigger", response_model=ApiResponse[SubscriptionRunResult])
+@router.post(
+ "/{sub_id}/trigger", response_model=ApiResponse[SubscriptionRunResult], summary="Trigger subscription update"
+)
async def trigger_subscription(
project_id: int,
sub_id: int,
since_days: int = Query(7, ge=1, le=365),
+ auto_import: bool = Query(False, description="Auto-import new papers into project"),
db: AsyncSession = Depends(get_db),
project: Project = Depends(get_project),
):
"""Manually trigger a subscription update (check API for new papers)."""
+ from app.models import Paper
+
sub = (
await db.execute(select(Subscription).where(Subscription.id == sub_id, Subscription.project_id == project_id))
).scalar_one_or_none()
@@ -160,6 +189,23 @@ async def trigger_subscription(
new_papers = result.get("new_papers", [])
total_found = result.get("total_found", 0)
sources_checked = result.get("sources_checked", {})
+
+ imported_count = 0
+ if auto_import and new_papers:
+ for paper_data in new_papers:
+ paper = Paper(
+ project_id=project_id,
+ title=paper_data.get("title", "Untitled"),
+ abstract=paper_data.get("abstract", ""),
+ doi=paper_data.get("doi"),
+ authors=paper_data.get("authors"),
+ year=paper_data.get("year"),
+ source=paper_data.get("source", "subscription"),
+ pdf_url=paper_data.get("pdf_url", ""),
+ )
+ db.add(paper)
+ imported_count += 1
+
sub.last_run_at = datetime.now()
sub.total_found = total_found
await db.flush()
@@ -169,5 +215,6 @@ async def trigger_subscription(
new_papers=len(new_papers),
total_checked=total_found,
sources_searched=list(sources_checked.keys()) if sources_checked else [],
+ imported=imported_count,
)
)
diff --git a/backend/app/api/v1/tasks.py b/backend/app/api/v1/tasks.py
index 5e76fa9..9bded5a 100644
--- a/backend/app/api/v1/tasks.py
+++ b/backend/app/api/v1/tasks.py
@@ -1,21 +1,19 @@
"""Task status and management API endpoints."""
from fastapi import APIRouter, Depends, HTTPException
-from sqlalchemy import select
+from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
-from app.api.deps import get_db
+from app.api.deps import get_db, get_or_404
from app.models import Task
-from app.schemas.common import ApiResponse
+from app.schemas.common import ApiResponse, PaginatedData
router = APIRouter(prefix="/tasks", tags=["tasks"])
-@router.get("/{task_id}", response_model=ApiResponse[dict])
+@router.get("/{task_id}", response_model=ApiResponse[dict], summary="Get task by ID")
async def get_task(task_id: int, db: AsyncSession = Depends(get_db)):
- task = await db.get(Task, task_id)
- if not task:
- raise HTTPException(status_code=404, detail="Task not found")
+ task = await get_or_404(db, Task, task_id, detail="Task not found")
return ApiResponse(
data={
"id": task.id,
@@ -34,41 +32,49 @@ async def get_task(task_id: int, db: AsyncSession = Depends(get_db)):
)
-@router.get("", response_model=ApiResponse[list[dict]])
+@router.get("", response_model=ApiResponse[PaginatedData[dict]], summary="List tasks")
async def list_tasks(
project_id: int | None = None,
status: str | None = None,
- limit: int = 50,
+ page: int = 1,
+ page_size: int = 50,
db: AsyncSession = Depends(get_db),
):
- stmt = select(Task).order_by(Task.created_at.desc()).limit(limit)
+ base = select(Task)
if project_id:
- stmt = stmt.where(Task.project_id == project_id)
+ base = base.where(Task.project_id == project_id)
if status:
- stmt = stmt.where(Task.status == status)
- result = await db.execute(stmt)
+ base = base.where(Task.status == status)
+
+ total = (await db.execute(select(func.count()).select_from(base.subquery()))).scalar_one()
+ result = await db.execute(base.order_by(Task.created_at.desc()).offset((page - 1) * page_size).limit(page_size))
tasks = result.scalars().all()
+
return ApiResponse(
- data=[
- {
- "id": t.id,
- "project_id": t.project_id,
- "task_type": t.task_type,
- "status": t.status,
- "progress": t.progress,
- "total": t.total,
- "created_at": t.created_at.isoformat() if t.created_at else None,
- }
- for t in tasks
- ]
+ data=PaginatedData(
+ items=[
+ {
+ "id": t.id,
+ "project_id": t.project_id,
+ "task_type": t.task_type,
+ "status": t.status,
+ "progress": t.progress,
+ "total": t.total,
+ "created_at": t.created_at.isoformat() if t.created_at else None,
+ }
+ for t in tasks
+ ],
+ total=total,
+ page=page,
+ page_size=page_size,
+ total_pages=(total + page_size - 1) // page_size or 1,
+ )
)
-@router.post("/{task_id}/cancel", response_model=ApiResponse)
+@router.post("/{task_id}/cancel", response_model=ApiResponse, summary="Cancel task")
async def cancel_task(task_id: int, db: AsyncSession = Depends(get_db)):
- task = await db.get(Task, task_id)
- if not task:
- raise HTTPException(status_code=404, detail="Task not found")
+ task = await get_or_404(db, Task, task_id, detail="Task not found")
if task.status in ("completed", "failed", "cancelled"):
raise HTTPException(status_code=400, detail=f"Cannot cancel task in {task.status} state")
task.status = "cancelled"
diff --git a/backend/app/api/v1/upload.py b/backend/app/api/v1/upload.py
index 2ca0487..bdd902e 100644
--- a/backend/app/api/v1/upload.py
+++ b/backend/app/api/v1/upload.py
@@ -6,12 +6,13 @@
from difflib import SequenceMatcher
from pathlib import Path
-from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
+from fastapi import APIRouter, Depends, File, HTTPException, Query, Request, UploadFile
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db, get_project
from app.config import settings
+from app.middleware.rate_limit import limiter
from app.models import Paper, PaperStatus, Project
from app.schemas.common import ApiResponse
from app.schemas.knowledge_base import DedupConflictPair, NewPaperData, UploadResult
@@ -24,12 +25,11 @@
router = APIRouter(tags=["papers"])
-MAX_FILE_SIZE_MB = 50
-TITLE_SIMILARITY_THRESHOLD = 0.85
-
-@router.post("/upload", response_model=ApiResponse[UploadResult])
+@router.post("/upload", response_model=ApiResponse[UploadResult], summary="Upload PDF files")
+@limiter.limit("5/minute")
async def upload_pdfs(
+ request: Request,
project_id: int,
files: list[UploadFile] = File(...),
db: AsyncSession = Depends(get_db),
@@ -41,7 +41,7 @@ async def upload_pdfs(
project_pdf_dir = pdf_dir / str(project_id)
project_pdf_dir.mkdir(parents=True, exist_ok=True)
- max_bytes = MAX_FILE_SIZE_MB * 1024 * 1024
+ max_bytes = settings.max_upload_size_mb * 1024 * 1024
papers: list[NewPaperData] = []
conflicts: list[DedupConflictPair] = []
new_paper_objects: list[Paper] = []
@@ -66,7 +66,7 @@ async def upload_pdfs(
if len(content) > max_bytes:
raise HTTPException(
status_code=413,
- detail=f"File {upload_file.filename} exceeds {MAX_FILE_SIZE_MB}MB limit",
+ detail=f"File {upload_file.filename} exceeds {settings.max_upload_size_mb}MB limit",
)
safe_filename = Path(upload_file.filename or "upload.pdf").name.replace("..", "")
@@ -98,7 +98,7 @@ async def upload_pdfs(
norm_new = DedupService.normalize_title(metadata.title)
if norm_existing and norm_new:
sim = SequenceMatcher(None, norm_existing, norm_new).ratio()
- if sim >= TITLE_SIMILARITY_THRESHOLD:
+ if sim >= settings.title_similarity_threshold:
conflict_id = f"{existing.id}:{saved_name}"
conflicts.append(
DedupConflictPair(
@@ -156,7 +156,7 @@ async def upload_pdfs(
)
-@router.post("/process", response_model=ApiResponse[dict])
+@router.post("/process", response_model=ApiResponse[dict], summary="Trigger OCR and RAG indexing")
async def process_papers(
project_id: int,
paper_ids: list[int] | None = Query(default=None),
diff --git a/backend/app/api/v1/writing.py b/backend/app/api/v1/writing.py
index 92488c5..8097379 100644
--- a/backend/app/api/v1/writing.py
+++ b/backend/app/api/v1/writing.py
@@ -1,14 +1,15 @@
"""Writing assistance API endpoints."""
-from fastapi import APIRouter, Depends
+from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db, get_llm, get_project
+from app.middleware.rate_limit import limiter
from app.models import Project
from app.schemas.common import ApiResponse
-from app.services.llm_client import LLMClient
+from app.services.llm.client import LLMClient
from app.services.rag_service import RAGService
from app.services.writing_service import WritingService
@@ -57,8 +58,10 @@ def get_writing_service(
return WritingService(db=db, llm=llm, rag=rag)
-@router.post("/assist", response_model=ApiResponse[WritingAssistResponse])
+@router.post("/assist", response_model=ApiResponse[WritingAssistResponse], summary="AI writing assistance")
+@limiter.limit("10/minute")
async def writing_assist(
+ request: Request,
project_id: int,
body: WritingAssistRequest,
db: AsyncSession = Depends(get_db),
@@ -86,10 +89,9 @@ async def writing_assist(
result = await svc.analyze_gaps(project_id=project_id, research_topic=topic)
content = result["analysis"]
else:
- return ApiResponse(
- code=400,
- message=f"Unknown task: {body.task}. Use summarize, cite, review_outline, or gap_analysis.",
- data=WritingAssistResponse(content="", citations=[], suggestions=[]),
+ raise HTTPException(
+ status_code=400,
+ detail=f"Unknown task: {body.task}. Use summarize, cite, review_outline, or gap_analysis.",
)
return ApiResponse(
@@ -101,7 +103,7 @@ async def writing_assist(
)
-@router.post("/summarize", response_model=ApiResponse[dict])
+@router.post("/summarize", response_model=ApiResponse[dict], summary="Summarize papers")
async def summarize_papers(
project_id: int,
body: SummarizeRequest,
@@ -117,7 +119,7 @@ async def summarize_papers(
return ApiResponse(data={"summaries": summaries})
-@router.post("/citations", response_model=ApiResponse[dict])
+@router.post("/citations", response_model=ApiResponse[dict], summary="Generate citations")
async def generate_citations(
project_id: int,
body: CitationsRequest,
@@ -133,7 +135,7 @@ async def generate_citations(
return ApiResponse(data={"citations": citations, "style": body.style})
-@router.post("/review-outline", response_model=ApiResponse[dict])
+@router.post("/review-outline", response_model=ApiResponse[dict], summary="Generate review outline")
async def generate_review_outline(
project_id: int,
body: ReviewOutlineRequest,
@@ -150,7 +152,7 @@ async def generate_review_outline(
return ApiResponse(data=result)
-@router.post("/gap-analysis", response_model=ApiResponse[dict])
+@router.post("/gap-analysis", response_model=ApiResponse[dict], summary="Analyze research gaps")
async def analyze_gaps(
project_id: int,
body: GapAnalysisRequest,
@@ -173,8 +175,10 @@ class ReviewDraftRequest(BaseModel):
language: str = Field(default="zh", pattern=r"^(zh|en)$")
-@router.post("/review-draft/stream")
+@router.post("/review-draft/stream", summary="Stream literature review draft")
+@limiter.limit("10/minute")
async def stream_review_draft(
+ request: Request,
project_id: int,
body: ReviewDraftRequest,
svc: WritingService = Depends(get_writing_service),
diff --git a/backend/app/config.py b/backend/app/config.py
index a44b69c..aa8ca3d 100644
--- a/backend/app/config.py
+++ b/backend/app/config.py
@@ -1,15 +1,44 @@
"""Application configuration using Pydantic Settings."""
import os
+from enum import StrEnum
from pathlib import Path
from typing import Literal
-from pydantic import Field
+from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
+class GpuMode(StrEnum):
+ CONSERVATIVE = "conservative"
+ BALANCED = "balanced"
+ AGGRESSIVE = "aggressive"
+
+
+GPU_MODE_PRESETS: dict[GpuMode, dict[str, int]] = {
+ GpuMode.CONSERVATIVE: {
+ "ocr_parallel_limit": 1,
+ "embed_batch_size": 1,
+ "rerank_batch_size": 1,
+ "reranker_concurrency_limit": 1,
+ },
+ GpuMode.BALANCED: {
+ "ocr_parallel_limit": 0,
+ "embed_batch_size": 8,
+ "rerank_batch_size": 16,
+ "reranker_concurrency_limit": 1,
+ },
+ GpuMode.AGGRESSIVE: {
+ "ocr_parallel_limit": 0,
+ "embed_batch_size": 32,
+ "rerank_batch_size": 50,
+ "reranker_concurrency_limit": 2,
+ },
+}
+
+
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=str(PROJECT_ROOT / ".env"),
@@ -65,28 +94,80 @@ class Settings(BaseSettings):
# Embedding
embedding_provider: str = "local" # local | api | mock
- embedding_model: str = "BAAI/bge-m3"
+ embedding_model: str = "Qwen/Qwen3-Embedding-0.6B"
embedding_api_key: str = ""
- reranker_model: str = "BAAI/bge-reranker-v2-m3"
+ reranker_model: str = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
# OCR
ocr_lang: str = "ch" # PaddleOCR language: ch (Chinese+English) | en (English only)
# PDF Parsing / MinerU
- pdf_parser: str = "auto" # auto | mineru | pdfplumber
+ pdf_parser: str = "mineru" # auto | mineru | pdfplumber
mineru_api_url: str = "http://localhost:8010"
mineru_backend: str = "pipeline" # pipeline | hybrid-auto-engine | vlm-auto-engine
- mineru_timeout: int = 300
+ mineru_timeout: int = 8000
+ mineru_auto_manage: bool = Field(default=True, description="Auto start/stop MinerU subprocess")
+ mineru_conda_env: str = Field(default="mineru", description="Conda env name for MinerU")
+ mineru_ttl_seconds: int = Field(default=600, ge=0, description="Stop MinerU after N seconds idle. 0=disable")
+ mineru_startup_timeout: int = Field(default=120, ge=10, le=600, description="MinerU startup timeout")
+ mineru_gpu_ids: str = Field(default="", description="GPU IDs for MinerU. Empty=inherit cuda_visible_devices")
+
+ # Semantic Scholar API
+ s2_api_base: str = "https://api.semanticscholar.org/graph/v1"
+ s2_timeout: int = Field(default=15, ge=1, le=60)
+ s2_max_per_request: int = Field(default=50, ge=1, le=100)
+
+ # Upload
+ title_similarity_threshold: float = Field(default=0.85, ge=0.0, le=1.0)
+
+ # Rewrite
+ rewrite_timeout: float = Field(default=30.0, ge=5.0, le=120.0)
# Dedup thresholds
dedup_title_hard_threshold: float = 0.90
dedup_title_llm_threshold: float = 0.80
+ # App
+ app_version: str = "0.1.0"
+
+ # Concurrency limits
+ max_upload_size_mb: int = Field(default=50, ge=1, le=500)
+ rate_limit: str = Field(default="120/minute", description="API rate limit")
+ clean_semaphore_limit: int = Field(default=3, ge=1)
+ rewrite_semaphore_limit: int = Field(default=3, ge=1)
+ llm_parallel_limit: int = Field(default=5, ge=1, description="Max parallel LLM calls for batch operations")
+ ocr_parallel_limit: int = Field(
+ default=0,
+ ge=0,
+ le=16,
+ description="Max parallel OCR tasks. 0=auto (GPU count or 1 for CPU)",
+ )
+
+ # RAG retrieval
+ rag_default_top_k: int = Field(default=10, ge=1, le=100, description="Default retrieval top-k")
+ rag_oversample_factor: int = Field(default=3, ge=1, le=10, description="Multiplier for oversampling before rerank")
+ rag_mmr_threshold: float = Field(
+ default=0.5, ge=0.0, le=1.0, description="MMR diversity threshold (0=max diversity, 1=max relevance)"
+ )
+ reranker_concurrency_limit: int = Field(default=1, ge=1, le=4, description="Max concurrent reranker calls")
+
# LangGraph
langgraph_checkpoint_dir: str = ""
+ pipeline_checkpoint_db: str = "" # SQLite checkpoint DB path; defaults to {data_dir}/pipeline_checkpoints.db
+ pid_file: str = "" # PID file path; defaults to {data_dir}/omelette.pid
# GPU
- cuda_visible_devices: str = "0,3"
+ cuda_visible_devices: str = "" # Empty = use all available GPUs
+ model_ttl_seconds: int = Field(
+ default=300, ge=0, description="Auto-unload GPU models after N seconds idle. 0=disable"
+ )
+ model_ttl_check_interval: int = Field(default=30, ge=5, le=300, description="TTL check interval in seconds")
+ gpu_mode: GpuMode = Field(default=GpuMode.BALANCED, description="GPU preset: conservative/balanced/aggressive")
+ embed_batch_size: int = Field(default=0, ge=0, le=128, description="Embedding batch size. 0=follow GPU_MODE")
+ rerank_batch_size: int = Field(default=0, ge=0, le=128, description="Reranker internal top_n. 0=follow GPU_MODE")
+ embed_gpu_id: int = Field(default=-1, ge=-1, le=15, description="Pin embedding to GPU N. -1=auto select")
+ rerank_gpu_id: int = Field(default=-1, ge=-1, le=15, description="Pin reranker to GPU N. -1=auto select")
+ ocr_gpu_ids: str = Field(default="", description="Comma-separated GPU IDs for OCR. Empty=all")
# Network Proxy
http_proxy: str = ""
@@ -103,6 +184,20 @@ class Settings(BaseSettings):
frontend_url: str = "http://localhost:3000"
cors_origins: str = "http://localhost:3000,http://0.0.0.0:3000"
+ @model_validator(mode="after")
+ def _apply_gpu_mode_defaults(self) -> "Settings":
+ """Fill zero-valued GPU params from the active GPU_MODE preset."""
+ preset = GPU_MODE_PRESETS.get(self.gpu_mode, GPU_MODE_PRESETS[GpuMode.BALANCED])
+ if self.embed_batch_size == 0:
+ self.embed_batch_size = preset["embed_batch_size"]
+ if self.rerank_batch_size == 0:
+ self.rerank_batch_size = preset["rerank_batch_size"]
+ if self.ocr_parallel_limit == 0:
+ self.ocr_parallel_limit = preset["ocr_parallel_limit"]
+ if self.reranker_concurrency_limit == 1 and preset["reranker_concurrency_limit"] != 1:
+ self.reranker_concurrency_limit = preset["reranker_concurrency_limit"]
+ return self
+
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not self.pdf_dir:
@@ -113,6 +208,10 @@ def __init__(self, **kwargs):
self.chroma_db_dir = f"{self.data_dir}/chroma_db"
if not self.langgraph_checkpoint_dir:
self.langgraph_checkpoint_dir = f"{self.data_dir}/langgraph_checkpoints"
+ if not self.pipeline_checkpoint_db:
+ self.pipeline_checkpoint_db = f"{self.data_dir}/pipeline_checkpoints.db"
+ if not self.pid_file:
+ self.pid_file = f"{self.data_dir}/omelette.pid"
@property
def cors_origin_list(self) -> list[str]:
diff --git a/backend/app/main.py b/backend/app/main.py
index 682deea..56011ab 100644
--- a/backend/app/main.py
+++ b/backend/app/main.py
@@ -1,15 +1,22 @@
"""Omelette — Scientific Literature Lifecycle Management System."""
+import atexit
+import contextlib
import logging
+import os
+import signal
+import sys
from contextlib import asynccontextmanager
+from pathlib import Path
-from fastapi import FastAPI, Request
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.api.v1 import api_router
from app.config import settings
-from app.database import init_db
+from app.database import engine, init_db
from app.middleware.auth import ApiKeyMiddleware
from app.middleware.rate_limit import setup_rate_limiting
from app.schemas.common import ApiResponse
@@ -20,22 +27,109 @@
)
logger = logging.getLogger("omelette")
+_cleanup_done = False
+
+
+def _sync_cleanup() -> None:
+ """Synchronous cleanup: release GPU models and kill MinerU processes.
+
+ Safe to call multiple times (idempotent via _cleanup_done flag).
+ Registered with atexit and called from signal handlers.
+ """
+ global _cleanup_done # noqa: PLW0603
+ if _cleanup_done:
+ return
+ _cleanup_done = True
+
+ logger.info("Running sync cleanup (atexit / signal)")
+ try:
+ from app.services.gpu_model_manager import gpu_model_manager
+
+ gpu_model_manager.unload_all()
+ except Exception:
+ logger.warning("GPU model cleanup failed", exc_info=True)
+
+ try:
+ from app.services.mineru_process_manager import mineru_process_manager
+
+ mineru_process_manager.stop_sync()
+ except Exception:
+ logger.warning("MinerU cleanup failed", exc_info=True)
+
+ with contextlib.suppress(OSError):
+ Path(settings.pid_file).unlink(missing_ok=True)
+
+
+def _handle_sighup(signum: int, frame: object) -> None:
+ """Handle terminal close (SIGHUP): cleanup and exit."""
+ logger.info("Received SIGHUP — cleaning up and exiting")
+ _sync_cleanup()
+ sys.exit(0)
+
@asynccontextmanager
async def lifespan(app: FastAPI):
- logger.info("Starting Omelette v0.1.0 ...")
+ from app.pipelines.graphs import set_checkpointer
+ from app.services.gpu_model_manager import gpu_model_manager
+ from app.services.mineru_process_manager import mineru_process_manager
+
+ logger.info("Starting Omelette v%s ...", settings.app_version)
if settings.app_env == "production" and settings.app_secret_key == "change-me-to-a-random-secret-key":
logger.warning("SECURITY: Using default secret key in production! Set APP_SECRET_KEY in .env")
await init_db()
logger.info("Database initialized")
+
+ # Write PID file
+ pid_path = Path(settings.pid_file)
+ pid_path.parent.mkdir(parents=True, exist_ok=True)
+ pid_path.write_text(str(os.getpid()))
+ logger.info("PID file: %s (pid=%d)", pid_path, os.getpid())
+
+ # Register safety nets
+ atexit.register(_sync_cleanup)
+ with contextlib.suppress(OSError, ValueError):
+ signal.signal(signal.SIGHUP, _handle_sighup)
+
+ # Pipeline checkpoint persistence (AsyncSqliteSaver)
+ checkpoint_cm = None
+ try:
+ from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
+
+ db_path = settings.pipeline_checkpoint_db
+ Path(db_path).parent.mkdir(parents=True, exist_ok=True)
+ cm = AsyncSqliteSaver.from_conn_string(db_path)
+ saver = await cm.__aenter__()
+ checkpoint_cm = cm
+ set_checkpointer(saver)
+ logger.info("Pipeline checkpoint DB: %s", db_path)
+ except Exception as e:
+ logger.warning("AsyncSqliteSaver unavailable, using MemorySaver: %s", e)
+ set_checkpointer(None)
+
+ await gpu_model_manager.start()
+ await mineru_process_manager.start()
yield
logger.info("Shutting down Omelette")
+ await mineru_process_manager.stop()
+ await gpu_model_manager.stop()
+ if checkpoint_cm is not None:
+ try:
+ await checkpoint_cm.__aexit__(None, None, None)
+ except Exception as e:
+ logger.warning("Checkpoint saver teardown: %s", e)
+ set_checkpointer(None)
+ await engine.dispose()
+
+ with contextlib.suppress(OSError):
+ Path(settings.pid_file).unlink(missing_ok=True)
+ global _cleanup_done # noqa: PLW0603
+ _cleanup_done = True
app = FastAPI(
title="Omelette API",
description="Scientific Literature Lifecycle Management System / 科研文献全生命周期管理系统",
- version="0.1.0",
+ version=settings.app_version,
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
@@ -48,12 +142,38 @@ async def lifespan(app: FastAPI):
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
+ expose_headers=["X-Request-ID", "X-Process-Time"],
+ max_age=600,
)
setup_rate_limiting(app)
app.include_router(api_router)
+@app.exception_handler(HTTPException)
+async def http_exception_handler(request: Request, exc: HTTPException):
+ """Wrap HTTPException in ApiResponse format for consistent frontend handling."""
+ return JSONResponse(
+ status_code=exc.status_code,
+ content={"code": exc.status_code, "message": exc.detail, "data": None},
+ )
+
+
+@app.exception_handler(RequestValidationError)
+async def validation_exception_handler(request: Request, exc: RequestValidationError):
+ """Wrap Pydantic validation errors in ApiResponse format."""
+ errors = []
+ for err in exc.errors():
+ clean = {k: v for k, v in err.items() if k != "ctx"}
+ if "ctx" in err:
+ clean["ctx"] = {k: str(v) for k, v in err["ctx"].items()}
+ errors.append(clean)
+ return JSONResponse(
+ status_code=422,
+ content={"code": 422, "message": "Validation error", "data": errors},
+ )
+
+
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""Return sanitised error in production, full detail in debug mode."""
@@ -72,8 +192,14 @@ async def global_exception_handler(request: Request, exc: Exception):
mcp_app = mcp_server.streamable_http_app()
app.mount("/mcp", mcp_app)
logger.info("MCP server mounted at /mcp")
-except Exception as e:
- logger.warning("MCP server mount failed: %s", e)
+except Exception:
+ logger.error("MCP server mount failed", exc_info=True)
+
+
+@app.get("/health")
+async def health():
+ """Health check endpoint — exempt from API key authentication."""
+ return ApiResponse(data={"status": "ok"})
@app.get("/")
@@ -81,7 +207,7 @@ async def root():
return ApiResponse(
data={
"name": "Omelette",
- "version": "0.1.0",
+ "version": settings.app_version,
"description": "Scientific Literature Lifecycle Management System",
"docs": "/docs",
}
diff --git a/backend/app/mcp_server.py b/backend/app/mcp_server.py
index 792b1ec..8e233b7 100644
--- a/backend/app/mcp_server.py
+++ b/backend/app/mcp_server.py
@@ -72,8 +72,11 @@ async def search_knowledge_base(query: str, kb_id: int, top_k: int = 5) -> str:
Args:
query: The search question or keywords
kb_id: Knowledge base ID (use list_knowledge_bases to find IDs)
- top_k: Number of result chunks to return (default 5)
+ top_k: Number of result chunks to return (default 5, max 50)
"""
+ if top_k < 1 or top_k > 50:
+ return "Error: top_k must be between 1 and 50."
+
from app.services.rag_service import RAGService
rag = RAGService()
@@ -195,6 +198,13 @@ async def add_paper_by_doi(doi: str, kb_id: int) -> str:
doi: The paper's DOI
kb_id: Target knowledge base ID
"""
+ from app.services.url_validator import validate_doi
+
+ try:
+ validate_doi(doi)
+ except ValueError as e:
+ return f"Error: {e}"
+
from sqlalchemy import select
async with get_session() as db:
@@ -289,6 +299,9 @@ async def get_paper_summary(paper_id: int, summary_type: str = "abstract") -> st
if not paper:
return f"Error: Paper {paper_id} not found."
+ if summary_type not in ("abstract", "llm"):
+ return f"Error: Unknown summary type '{summary_type}'. Use 'abstract' or 'llm'."
+
if summary_type == "abstract":
return f"""## Paper Summary
@@ -315,8 +328,11 @@ async def search_papers_by_keyword(query: str, sources: str = "", max_results: i
Args:
query: Search keywords
sources: Comma-separated data sources (semantic_scholar,openalex,arxiv,crossref). Empty = all.
- max_results: Maximum number of results (default 20)
+ max_results: Maximum number of results (default 20, max 100)
"""
+ if max_results < 1 or max_results > 100:
+ return "Error: max_results must be between 1 and 100."
+
from app.services.search_service import SearchService
source_list = [s.strip() for s in sources.split(",") if s.strip()] if sources else None
@@ -350,6 +366,118 @@ async def search_papers_by_keyword(query: str, sources: str = "", max_results: i
return "\n".join(lines)
+@mcp.tool()
+async def summarize_papers(kb_id: int, paper_ids: list[int] | None = None, language: str = "en") -> str:
+ """Summarize papers in a knowledge base.
+
+ Args:
+ kb_id: Knowledge base ID
+ paper_ids: Optional list of specific paper IDs to summarize. If empty, summarizes all.
+ language: Output language (en/zh)
+ """
+ from app.services.writing_service import WritingService
+
+ svc = WritingService()
+ result = await svc.summarize(project_id=kb_id, paper_ids=paper_ids, language=language)
+ return f"## Summary\n\n{result.get('content', 'No summary generated.')}"
+
+
+@mcp.tool()
+async def generate_review_outline(kb_id: int, topic: str, language: str = "en") -> str:
+ """Generate a literature review outline based on papers in a knowledge base.
+
+ Args:
+ kb_id: Knowledge base ID
+ topic: Research topic for the review
+ language: Output language (en/zh)
+ """
+ from app.services.writing_service import WritingService
+
+ svc = WritingService()
+ result = await svc.generate_review_outline(project_id=kb_id, topic=topic, language=language)
+ return f"## Review Outline\n\n{result.get('outline', 'No outline generated.')}"
+
+
+@mcp.tool()
+async def analyze_gaps(kb_id: int, research_topic: str) -> str:
+ """Analyze research gaps in the literature of a knowledge base.
+
+ Args:
+ kb_id: Knowledge base ID
+ research_topic: The research topic to analyze gaps for
+ """
+ from app.services.writing_service import WritingService
+
+ svc = WritingService()
+ result = await svc.analyze_gaps(project_id=kb_id, research_topic=research_topic)
+ return f"## Gap Analysis\n\n{result.get('analysis', 'No gap analysis generated.')}"
+
+
+@mcp.tool()
+async def manage_keywords(kb_id: int, action: str = "list", term: str = "", language: str = "en") -> str:
+ """Manage keywords for a knowledge base — list, add, expand, or delete.
+
+ Args:
+ kb_id: Knowledge base ID
+ action: One of: list, add, expand, delete
+ term: Keyword term (required for add/expand/delete)
+ language: Language for keyword expansion (en/zh)
+ """
+ if action not in ("list", "add", "expand", "delete"):
+ return "Error: action must be one of: list, add, expand, delete."
+
+ from sqlalchemy import select
+
+ from app.models.keyword import Keyword
+
+ if action == "list":
+ async with get_session() as db:
+ stmt = select(Keyword).where(Keyword.project_id == kb_id).order_by(Keyword.level, Keyword.term)
+ result = await db.execute(stmt)
+ keywords = result.scalars().all()
+ if not keywords:
+ return "No keywords found in this knowledge base."
+ lines = ["## Keywords\n", "| Term | EN | Level | Category |", "|---|---|---|---|"]
+ for kw in keywords:
+ lines.append(f"| {kw.term} | {kw.term_en} | {kw.level} | {kw.category} |")
+ return "\n".join(lines)
+
+ if not term:
+ return f"Error: 'term' is required for action '{action}'."
+
+ if action == "add":
+ async with get_session() as db:
+ kw = Keyword(project_id=kb_id, term=term, level=1)
+ db.add(kw)
+ await db.flush()
+ return f"Added keyword: {term}"
+
+ if action == "expand":
+ from app.services.keyword_service import KeywordService
+
+ svc = KeywordService()
+ result = await svc.expand_keywords([term], language=language)
+ expanded = result.get("expanded_terms", [])
+ if not expanded:
+ return "No expanded terms found."
+ lines = [f"## Expanded from: {term}\n"]
+ for et in expanded:
+ lines.append(f"- {et.get('term', '')} ({et.get('relation', '')})")
+ return "\n".join(lines)
+
+ if action == "delete":
+ async with get_session() as db:
+ stmt = select(Keyword).where(Keyword.project_id == kb_id, Keyword.term == term)
+ result = await db.execute(stmt)
+ kw = result.scalar_one_or_none()
+ if not kw:
+ return f"Keyword '{term}' not found."
+ await db.delete(kw)
+ return f"Deleted keyword: {term}"
+
+ return "Unknown action."
+
+
# ==================== RESOURCES ====================
diff --git a/backend/app/middleware/auth.py b/backend/app/middleware/auth.py
index f98c9c7..9cf1a2a 100644
--- a/backend/app/middleware/auth.py
+++ b/backend/app/middleware/auth.py
@@ -10,7 +10,7 @@
logger = logging.getLogger(__name__)
-EXEMPT_PATHS = frozenset({"/", "/health", "/docs", "/openapi.json", "/redoc"})
+EXEMPT_PATHS = frozenset({"/", "/health", "/api/v1/settings/health", "/docs", "/openapi.json", "/redoc"})
EXEMPT_PREFIXES = ("/mcp",)
@@ -26,7 +26,7 @@ async def dispatch(self, request: Request, call_next) -> Response:
if path in EXEMPT_PATHS or any(path.startswith(p) for p in EXEMPT_PREFIXES):
return await call_next(request)
- api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
+ api_key = request.headers.get("X-API-Key")
if api_key != settings.api_secret_key:
return JSONResponse(
status_code=401,
diff --git a/backend/app/middleware/rate_limit.py b/backend/app/middleware/rate_limit.py
index 220a42a..c5086e4 100644
--- a/backend/app/middleware/rate_limit.py
+++ b/backend/app/middleware/rate_limit.py
@@ -8,12 +8,15 @@
from slowapi.middleware import SlowAPIMiddleware
from slowapi.util import get_remote_address
+from app.config import settings
+
logger = logging.getLogger(__name__)
limiter = Limiter(
key_func=get_remote_address,
- default_limits=["120/minute"],
+ default_limits=[settings.rate_limit],
storage_uri="memory://",
+ enabled=settings.app_env != "testing",
)
@@ -22,4 +25,4 @@ def setup_rate_limiting(app: FastAPI) -> None:
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)
- logger.info("Rate limiting enabled (default: 120/min)")
+ logger.info("Rate limiting enabled (default: %s)", settings.rate_limit)
diff --git a/backend/app/models/keyword.py b/backend/app/models/keyword.py
index c92a437..d21b3fc 100644
--- a/backend/app/models/keyword.py
+++ b/backend/app/models/keyword.py
@@ -17,7 +17,7 @@ class Keyword(Base):
term_en: Mapped[str] = mapped_column(String(500), default="")
level: Mapped[int] = mapped_column(Integer, default=1) # 1=core, 2=sub-domain, 3=expanded
category: Mapped[str] = mapped_column(String(100), default="")
- parent_id: Mapped[int | None] = mapped_column(Integer, ForeignKey("keywords.id"), default=None)
+ parent_id: Mapped[int | None] = mapped_column(Integer, ForeignKey("keywords.id"), default=None, index=True)
synonyms: Mapped[str] = mapped_column(Text, default="")
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
diff --git a/backend/app/models/paper.py b/backend/app/models/paper.py
index 23e5e5d..55aa28b 100644
--- a/backend/app/models/paper.py
+++ b/backend/app/models/paper.py
@@ -3,7 +3,7 @@
from datetime import datetime
from enum import StrEnum
-from sqlalchemy import JSON, DateTime, ForeignKey, Integer, String, Text, func
+from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
@@ -20,6 +20,10 @@ class PaperStatus(StrEnum):
class Paper(Base):
__tablename__ = "papers"
+ __table_args__ = (
+ Index("ix_paper_project_status", "project_id", "status"),
+ UniqueConstraint("project_id", "doi", name="uq_paper_project_doi"),
+ )
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
project_id: Mapped[int] = mapped_column(Integer, ForeignKey("projects.id"), nullable=False, index=True)
diff --git a/backend/app/models/task.py b/backend/app/models/task.py
index 0a42012..dbe1ce5 100644
--- a/backend/app/models/task.py
+++ b/backend/app/models/task.py
@@ -3,7 +3,7 @@
from datetime import datetime
from enum import StrEnum
-from sqlalchemy import JSON, DateTime, ForeignKey, Integer, String, Text, func
+from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
@@ -28,6 +28,7 @@ class TaskType(StrEnum):
class Task(Base):
__tablename__ = "tasks"
+ __table_args__ = (Index("ix_task_project_status", "project_id", "status"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
project_id: Mapped[int] = mapped_column(Integer, ForeignKey("projects.id"), nullable=False, index=True)
diff --git a/backend/app/pipelines/cancellation.py b/backend/app/pipelines/cancellation.py
new file mode 100644
index 0000000..2b362c4
--- /dev/null
+++ b/backend/app/pipelines/cancellation.py
@@ -0,0 +1,18 @@
+"""Shared cancellation state for pipeline execution."""
+
+_cancelled: dict[str, bool] = {}
+
+
+def is_cancelled(thread_id: str) -> bool:
+ """Check if pipeline has been cancelled via the API."""
+ return _cancelled.get(thread_id, False)
+
+
+def mark_cancelled(thread_id: str) -> None:
+ """Mark a pipeline as cancelled."""
+ _cancelled[thread_id] = True
+
+
+def clear_cancelled(thread_id: str) -> None:
+ """Clear cancellation flag for a pipeline (call when pipeline ends)."""
+ _cancelled.pop(thread_id, None)
diff --git a/backend/app/pipelines/chat/nodes.py b/backend/app/pipelines/chat/nodes.py
index d4ad3e7..f53f6fe 100644
--- a/backend/app/pipelines/chat/nodes.py
+++ b/backend/app/pipelines/chat/nodes.py
@@ -15,45 +15,23 @@
from langchain_core.runnables import RunnableConfig
from langgraph.config import get_stream_writer
+from app.config import settings
from app.pipelines.chat.config_helpers import (
get_chat_db,
get_chat_llm,
get_chat_rag,
)
from app.pipelines.chat.state import ChatMessageDict, ChatState, CitationDict
+from app.prompts.chat import (
+ CHAT_FALLBACK_SYSTEM,
+ CHAT_QA_SYSTEM,
+ CHAT_TOOL_MODE_PROMPTS,
+ EXCERPT_CLEAN_SYSTEM,
+)
logger = logging.getLogger(__name__)
-TOOL_MODE_PROMPTS = {
- "qa": (
- "You are a scientific research assistant. Answer the question based on the provided context. "
- "Use inline citations like [1], [2] to reference source papers. "
- "If the context doesn't contain enough information, say so honestly."
- ),
- "citation_lookup": (
- "You are a citation finder. Given the user's text, identify and list the most relevant "
- "references from the provided context. Format as a numbered list with paper titles, authors, "
- "and brief explanations of relevance. Keep your own commentary minimal."
- ),
- "review_outline": (
- "You are a literature review expert. Based on the provided context, generate a structured "
- "review outline with sections, subsections, and key points. Use citations like [1], [2] "
- "to reference sources. Suggest a logical flow and highlight key themes."
- ),
- "gap_analysis": (
- "You are a research gap analyst. Based on the provided literature context, identify "
- "research gaps, unexplored areas, and potential future directions. Cite existing work "
- "using [1], [2] format. Be specific about what has been studied and what remains open."
- ),
-}
-
-EXCERPT_CLEAN_PROMPT = (
- "Clean up the following text extracted from an academic PDF. "
- "Fix OCR errors, add missing spaces between words, restore formatting. "
- "Keep the original meaning intact. Output only the cleaned text, nothing else."
-)
-
-_clean_semaphore = asyncio.Semaphore(3)
+_clean_semaphore = asyncio.Semaphore(settings.clean_semaphore_limit)
def _emit_thinking(
@@ -114,14 +92,7 @@ async def understand_node(state: ChatState, config: RunnableConfig) -> dict[str,
# Build system prompt
kb_ids = state.get("knowledge_base_ids", [])
tool_mode = state.get("tool_mode", "qa")
- if kb_ids:
- system_prompt = TOOL_MODE_PROMPTS.get(tool_mode, TOOL_MODE_PROMPTS["qa"])
- else:
- system_prompt = (
- "You are a helpful scientific research assistant. "
- "Answer questions clearly and accurately. "
- "If you don't know the answer, say so honestly."
- )
+ system_prompt = CHAT_TOOL_MODE_PROMPTS.get(tool_mode, CHAT_QA_SYSTEM) if kb_ids else CHAT_FALLBACK_SYSTEM
_emit_thinking(
writer,
@@ -158,7 +129,11 @@ async def retrieve_node(state: ChatState, config: RunnableConfig) -> dict[str, A
)
top_k = state.get("rag_top_k") or 10
- tasks = [rag.retrieve_only(project_id=kb_id, question=state["message"], top_k=top_k) for kb_id in kb_ids]
+ use_reranker = state.get("use_reranker", False)
+ tasks = [
+ rag.retrieve_only(project_id=kb_id, question=state["message"], top_k=top_k, use_reranker=use_reranker)
+ for kb_id in kb_ids
+ ]
results = await asyncio.gather(*tasks, return_exceptions=True)
all_sources: list[dict[str, Any]] = []
@@ -258,7 +233,7 @@ async def _clean_single_excerpt(llm, excerpt: str) -> str:
return excerpt
async with _clean_semaphore:
messages = [
- {"role": "system", "content": EXCERPT_CLEAN_PROMPT},
+ {"role": "system", "content": EXCERPT_CLEAN_SYSTEM},
{"role": "user", "content": excerpt},
]
result = ""
@@ -427,7 +402,7 @@ async def persist_node(state: ChatState, config: RunnableConfig) -> dict[str, An
)
db.add(user_msg)
db.add(assistant_msg)
- await db.commit()
+ await db.flush()
citation_count = len(state.get("citations") or [])
_emit_thinking(
diff --git a/backend/app/pipelines/graphs.py b/backend/app/pipelines/graphs.py
index 2ef4b11..f3e4497 100644
--- a/backend/app/pipelines/graphs.py
+++ b/backend/app/pipelines/graphs.py
@@ -4,6 +4,7 @@
import logging
+from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
@@ -23,16 +24,22 @@
logger = logging.getLogger(__name__)
_memory_saver = MemorySaver()
+_checkpoint_saver: BaseCheckpointSaver | None = None
+
+
+def set_checkpointer(saver: BaseCheckpointSaver | None) -> None:
+ """Set the persistent checkpointer (called from lifespan). None restores MemorySaver fallback."""
+ global _checkpoint_saver
+ _checkpoint_saver = saver
def _get_checkpointer():
"""Return a checkpointer for pipeline state persistence.
- AsyncSqliteSaver.from_conn_string returns an async context manager,
- not a direct BaseCheckpointSaver instance. Use MemorySaver which is
- sufficient for single-process deployments with in-memory task tracking.
+ Uses AsyncSqliteSaver when available (set via set_checkpointer in lifespan),
+ otherwise falls back to MemorySaver for single-process in-memory persistence.
"""
- return _memory_saver
+ return _checkpoint_saver if _checkpoint_saver is not None else _memory_saver
def _route_after_dedup(state: PipelineState) -> str:
diff --git a/backend/app/pipelines/nodes.py b/backend/app/pipelines/nodes.py
index dea9545..6cffde4 100644
--- a/backend/app/pipelines/nodes.py
+++ b/backend/app/pipelines/nodes.py
@@ -12,8 +12,19 @@
logger = logging.getLogger(__name__)
+def _is_cancelled(state: PipelineState) -> bool:
+ """Check if pipeline has been cancelled via the API."""
+ from app.pipelines.cancellation import is_cancelled
+
+ thread_id = state.get("thread_id", "")
+ return is_cancelled(thread_id) or state.get("cancelled", False)
+
+
async def search_node(state: PipelineState) -> dict[str, Any]:
"""Run multi-source federated search."""
+ if _is_cancelled(state):
+ return {"stage": "cancelled", "cancelled": True}
+
from app.services.search_service import SearchService
params = state.get("params", {})
@@ -34,6 +45,9 @@ async def search_node(state: PipelineState) -> dict[str, Any]:
async def extract_metadata_node(state: PipelineState) -> dict[str, Any]:
"""Extract metadata from uploaded PDF files."""
+ if _is_cancelled(state):
+ return {"stage": "cancelled", "cancelled": True}
+
from app.services.pdf_metadata import extract_metadata
params = state.get("params", {})
@@ -57,6 +71,9 @@ async def extract_metadata_node(state: PipelineState) -> dict[str, Any]:
async def dedup_node(state: PipelineState) -> dict[str, Any]:
"""Check for duplicates against existing papers in the knowledge base."""
+ if _is_cancelled(state):
+ return {"stage": "cancelled", "cancelled": True}
+
from sqlalchemy import select
from app.config import settings
@@ -149,20 +166,29 @@ async def hitl_dedup_node(state: PipelineState) -> dict[str, Any]:
async def apply_resolution_node(state: PipelineState) -> dict[str, Any]:
"""Apply conflict resolutions and merge clean papers for import."""
+ if _is_cancelled(state):
+ return {"stage": "cancelled", "cancelled": True}
+
resolved = state.get("resolved_conflicts", [])
clean_papers = list(state.get("papers", []))
for res in resolved:
action = res.get("action", "skip")
- new_paper = res.get("new_paper", {})
+ new_paper = res.get("new_paper") or {}
+ merged_paper = res.get("merged_paper") or {}
if action == "keep_new" and new_paper:
clean_papers.append(new_paper)
+ elif action == "merge" and merged_paper:
+ clean_papers.append(merged_paper)
return {"papers": clean_papers, "stage": "resolved"}
async def import_node(state: PipelineState) -> dict[str, Any]:
"""Import clean papers into the database."""
+ if _is_cancelled(state):
+ return {"stage": "cancelled", "cancelled": True}
+
from app.database import async_session_factory
from app.models import Paper
@@ -193,6 +219,9 @@ async def import_node(state: PipelineState) -> dict[str, Any]:
async def crawl_node(state: PipelineState) -> dict[str, Any]:
"""Download PDFs for papers that have pdf_url but no pdf_path."""
+ if _is_cancelled(state):
+ return {"stage": "cancelled", "cancelled": True}
+
from sqlalchemy import select
from app.database import async_session_factory
@@ -241,6 +270,9 @@ async def ocr_node(state: PipelineState) -> dict[str, Any]:
Uses MinerU (if available) for deep parsing with formula/table/figure
recognition, falling back to pdfplumber + PaddleOCR.
"""
+ if _is_cancelled(state):
+ return {"stage": "cancelled", "cancelled": True}
+
from sqlalchemy import select
from app.database import async_session_factory
@@ -258,43 +290,43 @@ async def ocr_node(state: PipelineState) -> dict[str, Any]:
Paper.pdf_path != "",
)
papers = (await db.execute(stmt)).scalars().all()
- ocr = OCRService(use_gpu=True)
- for paper in papers:
- if state.get("cancelled"):
- break
- try:
- result = await ocr.process_pdf_async(paper.pdf_path)
- if result.get("error"):
- paper.status = PaperStatus.ERROR
- continue
+ with OCRService(use_gpu=True) as ocr:
+ for paper in papers:
+ if state.get("cancelled"):
+ break
+ try:
+ result = await ocr.process_pdf_async(paper.pdf_path)
+ if result.get("error"):
+ paper.status = PaperStatus.ERROR
+ continue
- if result.get("method") == "mineru":
- chunks = ocr.chunk_mineru_markdown(result["md_content"], chunk_size=1024, overlap=100)
- else:
- pages = result.get("pages", [])
- chunks = ocr.chunk_text(pages, chunk_size=1024, overlap=100)
-
- for chunk_data in chunks:
- db.add(
- PaperChunk(
- paper_id=paper.id,
- content=chunk_data["content"],
- page_number=chunk_data.get("page_number", 0),
- chunk_index=chunk_data["chunk_index"],
- chunk_type=chunk_data.get("chunk_type", "text"),
- section=chunk_data.get("section", ""),
- token_count=chunk_data.get("token_count", 0),
- has_formula=chunk_data.get("has_formula", False),
- figure_path=chunk_data.get("figure_path", ""),
+ if result.get("method") == "mineru":
+ chunks = ocr.chunk_mineru_markdown(result["md_content"], chunk_size=1024, overlap=100)
+ else:
+ pages = result.get("pages", [])
+ chunks = ocr.chunk_text(pages, chunk_size=1024, overlap=100)
+
+ for chunk_data in chunks:
+ db.add(
+ PaperChunk(
+ paper_id=paper.id,
+ content=chunk_data["content"],
+ page_number=chunk_data.get("page_number", 0),
+ chunk_index=chunk_data["chunk_index"],
+ chunk_type=chunk_data.get("chunk_type", "text"),
+ section=chunk_data.get("section", ""),
+ token_count=chunk_data.get("token_count", 0),
+ has_formula=chunk_data.get("has_formula", False),
+ figure_path=chunk_data.get("figure_path", ""),
+ )
)
- )
- paper.status = PaperStatus.OCR_COMPLETE
- processed += 1
- except Exception as e:
- logger.warning("OCR failed for paper %d: %s", paper.id, e)
- paper.status = PaperStatus.ERROR
+ paper.status = PaperStatus.OCR_COMPLETE
+ processed += 1
+ except Exception as e:
+ logger.warning("OCR failed for paper %d: %s", paper.id, e)
+ paper.status = PaperStatus.ERROR
await db.commit()
return {
@@ -306,6 +338,9 @@ async def ocr_node(state: PipelineState) -> dict[str, Any]:
async def index_node(state: PipelineState) -> dict[str, Any]:
"""Index OCR-processed papers into the RAG vector store."""
+ if _is_cancelled(state):
+ return {"stage": "cancelled", "cancelled": True}
+
from sqlalchemy import select
from app.database import async_session_factory
diff --git a/backend/app/prompts/__init__.py b/backend/app/prompts/__init__.py
new file mode 100644
index 0000000..741b798
--- /dev/null
+++ b/backend/app/prompts/__init__.py
@@ -0,0 +1,52 @@
+"""Centralized LLM prompt management for all Omelette backend services."""
+
+from app.prompts.chat import (
+ CHAT_CITATION_SYSTEM,
+ CHAT_FALLBACK_SYSTEM,
+ CHAT_GAP_SYSTEM,
+ CHAT_OUTLINE_SYSTEM,
+ CHAT_QA_SYSTEM,
+ CHAT_TOOL_MODE_PROMPTS,
+ EXCERPT_CLEAN_SYSTEM,
+)
+from app.prompts.completion import COMPLETION_SYSTEM
+from app.prompts.dedup import DEDUP_RESOLVE_SYSTEM, DEDUP_VERIFY_SYSTEM
+from app.prompts.keyword import KEYWORD_EXPAND_SYSTEM
+from app.prompts.rag import RAG_ANSWER_SYSTEM
+from app.prompts.rewrite import (
+ REWRITE_ACADEMIC,
+ REWRITE_PROMPTS,
+ REWRITE_SIMPLIFY,
+ REWRITE_TRANSLATE_EN,
+ REWRITE_TRANSLATE_ZH,
+)
+from app.prompts.writing import (
+ WRITING_GAP_SYSTEM,
+ WRITING_OUTLINE_SYSTEM,
+ WRITING_SECTION_SYSTEM,
+ WRITING_SUMMARIZE_SYSTEM,
+)
+
+__all__ = [
+ "CHAT_CITATION_SYSTEM",
+ "CHAT_FALLBACK_SYSTEM",
+ "CHAT_GAP_SYSTEM",
+ "CHAT_OUTLINE_SYSTEM",
+ "CHAT_QA_SYSTEM",
+ "CHAT_TOOL_MODE_PROMPTS",
+ "COMPLETION_SYSTEM",
+ "DEDUP_RESOLVE_SYSTEM",
+ "DEDUP_VERIFY_SYSTEM",
+ "EXCERPT_CLEAN_SYSTEM",
+ "KEYWORD_EXPAND_SYSTEM",
+ "RAG_ANSWER_SYSTEM",
+ "REWRITE_ACADEMIC",
+ "REWRITE_PROMPTS",
+ "REWRITE_SIMPLIFY",
+ "REWRITE_TRANSLATE_EN",
+ "REWRITE_TRANSLATE_ZH",
+ "WRITING_GAP_SYSTEM",
+ "WRITING_OUTLINE_SYSTEM",
+ "WRITING_SECTION_SYSTEM",
+ "WRITING_SUMMARIZE_SYSTEM",
+]
diff --git a/backend/app/prompts/chat.py b/backend/app/prompts/chat.py
new file mode 100644
index 0000000..ee31681
--- /dev/null
+++ b/backend/app/prompts/chat.py
@@ -0,0 +1,49 @@
+"""Chat pipeline system prompts."""
+
+CHAT_QA_SYSTEM = (
+ "You are a scientific research assistant. Answer the question based on the provided context. "
+ "Use inline citations like [1], [2] to reference source papers. "
+ "If the context doesn't contain enough information, say so honestly. "
+ "Structure your answer with clear paragraphs. "
+ "Respond in the same language as the user's question."
+)
+
+CHAT_CITATION_SYSTEM = (
+ "You are a citation finder. Given the user's text, identify and list the most relevant "
+ "references from the provided context. Format as a numbered list with paper titles, authors, "
+ "and brief explanations of relevance. Include DOI when available. "
+ "Keep your own commentary minimal."
+)
+
+CHAT_OUTLINE_SYSTEM = (
+ "You are a literature review expert. Based on the provided context, generate a structured "
+ "review outline with sections, subsections, and key points. Use markdown headers for sections. "
+ "Use citations like [1], [2] to reference sources. Suggest a logical flow and highlight key themes."
+)
+
+CHAT_GAP_SYSTEM = (
+ "You are a research gap analyst. Based on the provided literature context, identify "
+ "research gaps, unexplored areas, and potential future directions. Cite existing work "
+ "using [1], [2] format. Organize by theme, not by individual papers. "
+ "Be specific about what has been studied and what remains open."
+)
+
+CHAT_FALLBACK_SYSTEM = (
+ "You are a scientific research assistant specializing in academic literature analysis. "
+ "Answer questions clearly and accurately based on your knowledge. "
+ "When the user's question is outside your expertise or you are uncertain, say so honestly. "
+ "Respond in the same language as the user's question."
+)
+
+EXCERPT_CLEAN_SYSTEM = (
+ "Clean up the following text extracted from an academic PDF. "
+ "Fix OCR errors, add missing spaces between words, restore formatting. "
+ "Keep the original meaning intact. Output only the cleaned text, nothing else."
+)
+
+CHAT_TOOL_MODE_PROMPTS: dict[str, str] = {
+ "qa": CHAT_QA_SYSTEM,
+ "citation_lookup": CHAT_CITATION_SYSTEM,
+ "review_outline": CHAT_OUTLINE_SYSTEM,
+ "gap_analysis": CHAT_GAP_SYSTEM,
+}
diff --git a/backend/app/prompts/completion.py b/backend/app/prompts/completion.py
new file mode 100644
index 0000000..bef1a5e
--- /dev/null
+++ b/backend/app/prompts/completion.py
@@ -0,0 +1,8 @@
+"""Writing completion system prompts."""
+
+COMPLETION_SYSTEM = (
+ "You are a scientific writing assistant. Predict and complete the user's text. "
+ "Return only the completion (do not repeat the user's input), max 50 characters. "
+ "If you cannot reasonably predict, return an empty string. "
+ "Return plain text only — no quotes, explanations, or formatting."
+)
diff --git a/backend/app/prompts/dedup.py b/backend/app/prompts/dedup.py
new file mode 100644
index 0000000..e23ee3c
--- /dev/null
+++ b/backend/app/prompts/dedup.py
@@ -0,0 +1,13 @@
+"""Deduplication system prompts."""
+
+DEDUP_VERIFY_SYSTEM = (
+ "You are a scientific literature deduplication expert. "
+ "Compare papers carefully based on title, authors, DOI, and journal. "
+ "Return valid JSON only."
+)
+
+DEDUP_RESOLVE_SYSTEM = (
+ "You are a scientific literature deduplication expert. "
+ "Determine the best resolution for duplicate candidates. "
+ "Return valid JSON only."
+)
diff --git a/backend/app/prompts/keyword.py b/backend/app/prompts/keyword.py
new file mode 100644
index 0000000..ff84af1
--- /dev/null
+++ b/backend/app/prompts/keyword.py
@@ -0,0 +1,8 @@
+"""Keyword expansion system prompts."""
+
+KEYWORD_EXPAND_SYSTEM = (
+ "You are a scientific terminology expert. "
+ "Generate related terms including synonyms, abbreviations, technical variants, "
+ "and cross-disciplinary application terms. "
+ "Return valid JSON only."
+)
diff --git a/backend/app/prompts/rag.py b/backend/app/prompts/rag.py
new file mode 100644
index 0000000..006ff07
--- /dev/null
+++ b/backend/app/prompts/rag.py
@@ -0,0 +1,8 @@
+"""RAG knowledge base system prompts."""
+
+RAG_ANSWER_SYSTEM = (
+ "You are a scientific research assistant. "
+ "Answer questions based strictly on the provided context. "
+ "Cite sources accurately using the format provided. "
+ "Respond in the same language as the user's question."
+)
diff --git a/backend/app/prompts/rewrite.py b/backend/app/prompts/rewrite.py
new file mode 100644
index 0000000..42b13a1
--- /dev/null
+++ b/backend/app/prompts/rewrite.py
@@ -0,0 +1,32 @@
+"""Text rewrite and translation system prompts."""
+
+REWRITE_SIMPLIFY = (
+ "Rewrite the following academic text in plain, accessible language. "
+ "Keep the core meaning and key concepts intact, but make it understandable "
+ "to a general audience. Output only the rewritten text, no explanations."
+)
+
+REWRITE_ACADEMIC = (
+ "Rewrite the following text in formal academic style. "
+ "Use precise terminology, passive voice where appropriate, and proper "
+ "academic conventions. Maintain the original meaning. Output only the rewritten text."
+)
+
+REWRITE_TRANSLATE_EN = (
+ "Translate the following text into English. "
+ "Preserve academic terminology and the original meaning. "
+ "Output only the translation, no explanations."
+)
+
+REWRITE_TRANSLATE_ZH = (
+ "Translate the following text into Chinese. "
+ "Preserve academic terminology and the original meaning. "
+ "Output only the translation, no explanations."
+)
+
+REWRITE_PROMPTS: dict[str, str] = {
+ "simplify": REWRITE_SIMPLIFY,
+ "academic": REWRITE_ACADEMIC,
+ "translate_en": REWRITE_TRANSLATE_EN,
+ "translate_zh": REWRITE_TRANSLATE_ZH,
+}
diff --git a/backend/app/prompts/writing.py b/backend/app/prompts/writing.py
new file mode 100644
index 0000000..e8babee
--- /dev/null
+++ b/backend/app/prompts/writing.py
@@ -0,0 +1,26 @@
+"""Writing assistant system prompts."""
+
+WRITING_SECTION_SYSTEM = (
+ "You are an academic review writing expert. Write a review paragraph for the given section. "
+ "Requirements: "
+ "1. Use academic language with clear logic. "
+ "2. Use [1][2] format for citations at appropriate positions. "
+ "3. Every citation must correspond to a provided reference — do not fabricate. "
+ "4. Paragraph length: 200-400 words."
+)
+
+WRITING_SUMMARIZE_SYSTEM = (
+ "You are a scientific paper analyst. Provide structured, accurate summaries. "
+ "Focus on empirical findings and methodology. "
+ "Do not hallucinate information not present in the provided metadata."
+)
+
+WRITING_OUTLINE_SYSTEM = (
+ "You are a scientific writing expert. Generate well-structured review outlines "
+ "organized by research themes with clear section hierarchy."
+)
+
+WRITING_GAP_SYSTEM = (
+ "You are a research gap analyst. Identify unexplored areas and innovation opportunities "
+ "based on the provided literature."
+)
diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py
index 5dc46e7..dd54771 100644
--- a/backend/app/schemas/__init__.py
+++ b/backend/app/schemas/__init__.py
@@ -1,6 +1,7 @@
"""Pydantic schemas for API request/response validation."""
-from app.schemas.common import ApiResponse, PaginatedData, PaginationParams, TaskResponse
+from app.schemas.common import ApiResponse, KeywordPaginationParams, PaginatedData, PaginationParams, TaskResponse
+from app.schemas.conversation import ChatStreamRequest, ConversationCreateSchema, ConversationUpdateSchema
from app.schemas.keyword import (
KeywordCreate,
KeywordExpandRequest,
@@ -8,11 +9,14 @@
KeywordRead,
KeywordUpdate,
)
-from app.schemas.paper import PaperBulkImport, PaperCreate, PaperRead, PaperUpdate
+from app.schemas.llm import LLMConfig, ProviderModelInfo, SettingsSchema, SettingsUpdateSchema
+from app.schemas.paper import PaperBatchDeleteRequest, PaperBulkImport, PaperCreate, PaperRead, PaperUpdate
from app.schemas.project import ProjectCreate, ProjectRead, ProjectUpdate
+from app.schemas.subscription import SubscriptionCreate, SubscriptionRead, SubscriptionUpdate
__all__ = [
"ApiResponse",
+ "KeywordPaginationParams",
"PaginatedData",
"PaginationParams",
"TaskResponse",
@@ -23,9 +27,20 @@
"PaperRead",
"PaperUpdate",
"PaperBulkImport",
+ "PaperBatchDeleteRequest",
"KeywordCreate",
"KeywordRead",
"KeywordUpdate",
"KeywordExpandRequest",
"KeywordExpandResponse",
+ "ConversationCreateSchema",
+ "ConversationUpdateSchema",
+ "ChatStreamRequest",
+ "SubscriptionCreate",
+ "SubscriptionRead",
+ "SubscriptionUpdate",
+ "LLMConfig",
+ "ProviderModelInfo",
+ "SettingsSchema",
+ "SettingsUpdateSchema",
]
diff --git a/backend/app/schemas/chunk.py b/backend/app/schemas/chunk.py
new file mode 100644
index 0000000..a5105ef
--- /dev/null
+++ b/backend/app/schemas/chunk.py
@@ -0,0 +1,21 @@
+"""Pydantic schemas for PaperChunk."""
+
+from datetime import datetime
+
+from pydantic import BaseModel
+
+
+class ChunkRead(BaseModel):
+ id: int
+ paper_id: int
+ chunk_type: str
+ content: str
+ section: str
+ page_number: int | None
+ chunk_index: int
+ token_count: int
+ has_formula: bool
+ figure_path: str
+ created_at: datetime
+
+ model_config = {"from_attributes": True}
diff --git a/backend/app/schemas/common.py b/backend/app/schemas/common.py
index 1fe6986..01c3fa5 100644
--- a/backend/app/schemas/common.py
+++ b/backend/app/schemas/common.py
@@ -3,11 +3,36 @@
from datetime import UTC, datetime
from typing import Generic, TypeVar
+from fastapi import Query
from pydantic import BaseModel, Field
T = TypeVar("T")
+class PaginationParams:
+ """FastAPI dependency for pagination (page, page_size)."""
+
+ def __init__(
+ self,
+ page: int = Query(1, ge=1, description="页码"),
+ page_size: int = Query(20, ge=1, le=100, description="每页数量"),
+ ):
+ self.page = page
+ self.page_size = page_size
+
+
+class KeywordPaginationParams(PaginationParams):
+ """Pagination for keywords (page_size default 50 for backward compatibility)."""
+
+ def __init__(
+ self,
+ page: int = Query(1, ge=1, description="页码"),
+ page_size: int = Query(50, ge=1, le=100, description="每页数量"),
+ ):
+ self.page = page
+ self.page_size = page_size
+
+
class ApiResponse(BaseModel, Generic[T]):
code: int = 200
message: str = "success"
@@ -23,11 +48,6 @@ class PaginatedData(BaseModel, Generic[T]):
total_pages: int = 1
-class PaginationParams(BaseModel):
- page: int = Field(default=1, ge=1)
- page_size: int = Field(default=20, ge=1, le=100)
-
-
class TaskResponse(BaseModel):
task_id: int
status: str
diff --git a/backend/app/schemas/conversation.py b/backend/app/schemas/conversation.py
index 587a016..2c8e429 100644
--- a/backend/app/schemas/conversation.py
+++ b/backend/app/schemas/conversation.py
@@ -1,6 +1,7 @@
"""Schemas for conversations and messages."""
from datetime import datetime
+from typing import Literal
from pydantic import BaseModel, Field
@@ -44,21 +45,23 @@ class ConversationListSchema(BaseModel):
class ConversationCreateSchema(BaseModel):
- title: str = ""
+ title: str = Field(default="", max_length=500)
knowledge_base_ids: list[int] | None = None
model: str = ""
- tool_mode: str = "qa"
+ tool_mode: Literal["qa", "citation_lookup", "review_outline", "gap_analysis"] = "qa"
class ConversationUpdateSchema(BaseModel):
- title: str | None = None
+ title: str | None = Field(default=None, max_length=500)
model: str | None = None
- tool_mode: str | None = None
+ tool_mode: Literal["qa", "citation_lookup", "review_outline", "gap_analysis"] | None = None
class ChatStreamRequest(BaseModel):
conversation_id: int | None = None
- knowledge_base_ids: list[int] = Field(default_factory=list)
+ knowledge_base_ids: list[int] = Field(default_factory=list, max_length=20)
model: str | None = None
- tool_mode: str = "qa"
+ tool_mode: Literal["qa", "citation_lookup", "review_outline", "gap_analysis"] = "qa"
message: str = Field(min_length=1)
+ rag_top_k: int = Field(default=10, ge=1, le=50, description="RAG retrieval top-k")
+ use_reranker: bool = Field(default=False, description="Apply reranker to retrieved nodes")
diff --git a/backend/app/schemas/keyword.py b/backend/app/schemas/keyword.py
index 577e840..3e49dd8 100644
--- a/backend/app/schemas/keyword.py
+++ b/backend/app/schemas/keyword.py
@@ -40,9 +40,9 @@ class KeywordRead(BaseModel):
class KeywordExpandRequest(BaseModel):
- seed_terms: list[str]
- language: str = "en"
- max_results: int = 20
+ seed_terms: list[str] = Field(..., max_length=50)
+ language: str = Field(default="en", max_length=10)
+ max_results: int = Field(default=20, ge=1, le=100)
class KeywordExpandResponse(BaseModel):
diff --git a/backend/app/schemas/knowledge_base.py b/backend/app/schemas/knowledge_base.py
index 75ce2c6..f73c716 100644
--- a/backend/app/schemas/knowledge_base.py
+++ b/backend/app/schemas/knowledge_base.py
@@ -1,13 +1,15 @@
"""Pydantic schemas for knowledge base and PDF upload operations."""
-from pydantic import BaseModel
+from typing import Literal
+
+from pydantic import BaseModel, Field
from app.schemas.paper import PaperRead
class NewPaperData(BaseModel):
- title: str
- abstract: str = ""
+ title: str = Field(..., max_length=2000)
+ abstract: str = Field(default="", max_length=50000)
authors: list[dict[str, str]] | None = None
doi: str | None = None
year: int | None = None
@@ -32,7 +34,7 @@ class UploadResult(BaseModel):
class ResolveConflictRequest(BaseModel):
conflict_id: str
- action: str # "keep_old" | "keep_new" | "merge" | "skip"
+ action: Literal["keep_old", "keep_new", "merge", "skip"]
merged_paper: dict | None = None
diff --git a/backend/app/schemas/llm.py b/backend/app/schemas/llm.py
index 4b21d7b..01dd4d9 100644
--- a/backend/app/schemas/llm.py
+++ b/backend/app/schemas/llm.py
@@ -66,19 +66,19 @@ class SettingsUpdateSchema(BaseModel):
llm_temperature: float | None = Field(default=None, ge=0.0, le=2.0)
llm_max_tokens: int | None = Field(default=None, ge=1, le=128000)
- openai_api_key: str | None = None
- openai_model: str | None = None
+ openai_api_key: str | None = Field(default=None, max_length=500)
+ openai_model: str | None = Field(default=None, max_length=200)
- anthropic_api_key: str | None = None
- anthropic_model: str | None = None
+ anthropic_api_key: str | None = Field(default=None, max_length=500)
+ anthropic_model: str | None = Field(default=None, max_length=200)
- aliyun_api_key: str | None = None
- aliyun_base_url: str | None = None
- aliyun_model: str | None = None
+ aliyun_api_key: str | None = Field(default=None, max_length=500)
+ aliyun_base_url: str | None = Field(default=None, max_length=500)
+ aliyun_model: str | None = Field(default=None, max_length=200)
- volcengine_api_key: str | None = None
- volcengine_base_url: str | None = None
- volcengine_model: str | None = None
+ volcengine_api_key: str | None = Field(default=None, max_length=500)
+ volcengine_base_url: str | None = Field(default=None, max_length=500)
+ volcengine_model: str | None = Field(default=None, max_length=200)
- ollama_base_url: str | None = None
- ollama_model: str | None = None
+ ollama_base_url: str | None = Field(default=None, max_length=500)
+ ollama_model: str | None = Field(default=None, max_length=200)
diff --git a/backend/app/schemas/paper.py b/backend/app/schemas/paper.py
index 1ebe315..97e432a 100644
--- a/backend/app/schemas/paper.py
+++ b/backend/app/schemas/paper.py
@@ -1,33 +1,34 @@
"""Pydantic schemas for Paper operations."""
from datetime import datetime
+from typing import Literal
from pydantic import BaseModel, Field
class PaperCreate(BaseModel):
doi: str | None = None
- title: str = Field(..., min_length=1)
- abstract: str = ""
+ title: str = Field(..., min_length=1, max_length=2000)
+ abstract: str = Field(default="", max_length=50000)
authors: list[dict[str, str]] | None = None
- journal: str = ""
- year: int | None = None
- citation_count: int = 0
- source: str = ""
- source_id: str = ""
- pdf_url: str = ""
+ journal: str = Field(default="", max_length=500)
+ year: int | None = Field(default=None, ge=1800, le=2100)
+ citation_count: int = Field(default=0, ge=0)
+ source: str = Field(default="", max_length=200)
+ source_id: str = Field(default="", max_length=500)
+ pdf_url: str = Field(default="", max_length=5000)
tags: list[str] | None = None
class PaperUpdate(BaseModel):
- title: str | None = None
- abstract: str | None = None
+ title: str | None = Field(default=None, max_length=2000)
+ abstract: str | None = Field(default=None, max_length=50000)
authors: list[dict[str, str]] | None = None
- journal: str | None = None
- year: int | None = None
+ journal: str | None = Field(default=None, max_length=500)
+ year: int | None = Field(default=None, ge=1800, le=2100)
tags: list[str] | None = None
notes: str | None = None
- status: str | None = None
+ status: Literal["pending", "metadata_only", "pdf_downloaded", "ocr_complete", "indexed", "error"] | None = None
class PaperRead(BaseModel):
@@ -54,4 +55,8 @@ class PaperRead(BaseModel):
class PaperBulkImport(BaseModel):
- papers: list[PaperCreate]
+ papers: list[PaperCreate] = Field(..., max_length=500)
+
+
+class PaperBatchDeleteRequest(BaseModel):
+ paper_ids: list[int] = Field(..., min_length=1, max_length=500)
diff --git a/backend/app/schemas/project.py b/backend/app/schemas/project.py
index 6eb1930..1b60292 100644
--- a/backend/app/schemas/project.py
+++ b/backend/app/schemas/project.py
@@ -6,6 +6,41 @@
from pydantic import BaseModel, Field
+class PaperImportItem(BaseModel):
+ """Schema for a single paper in project import."""
+
+ title: str = ""
+ abstract: str = ""
+ doi: str | None = None
+ authors: list | None = None
+ year: int | None = None
+ journal: str = ""
+ source: str = ""
+ pdf_url: str = ""
+ status: str = ""
+ citation_count: int = 0
+
+
+class KeywordImportItem(BaseModel):
+ """Schema for a single keyword in project import."""
+
+ term: str = Field(..., min_length=1)
+ term_en: str = ""
+ level: int = 1
+ category: str = ""
+ synonyms: str = ""
+
+
+class SubscriptionImportItem(BaseModel):
+ """Schema for a single subscription in project import."""
+
+ name: str = Field(..., min_length=1)
+ query: str = ""
+ sources: list[str] = Field(default_factory=list)
+ frequency: str = "weekly"
+ max_results: int = 50
+
+
class ProjectCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
description: str = ""
diff --git a/backend/app/schemas/subscription.py b/backend/app/schemas/subscription.py
index 7982ea2..133ba86 100644
--- a/backend/app/schemas/subscription.py
+++ b/backend/app/schemas/subscription.py
@@ -1,24 +1,25 @@
"""Subscription schemas for request/response."""
from datetime import datetime
+from typing import Literal
from pydantic import BaseModel, Field
class SubscriptionCreate(BaseModel):
- name: str = Field(..., min_length=1)
- query: str = ""
+ name: str = Field(..., min_length=1, max_length=500)
+ query: str = Field(default="", max_length=2000)
sources: list[str] = []
- frequency: str = "weekly"
+ frequency: Literal["daily", "weekly", "monthly"] = "weekly"
max_results: int = Field(50, ge=1, le=200)
class SubscriptionUpdate(BaseModel):
- name: str | None = None
- query: str | None = None
+ name: str | None = Field(default=None, max_length=500)
+ query: str | None = Field(default=None, max_length=2000)
sources: list[str] | None = None
- frequency: str | None = None
- max_results: int | None = None
+ frequency: Literal["daily", "weekly", "monthly"] | None = None
+ max_results: int | None = Field(default=None, ge=1, le=200)
is_active: bool | None = None
@@ -43,3 +44,4 @@ class SubscriptionRunResult(BaseModel):
new_papers: int
total_checked: int
sources_searched: list[str]
+ imported: int = 0
diff --git a/backend/app/services/citation_graph_service.py b/backend/app/services/citation_graph_service.py
index 8c42167..b5af48f 100644
--- a/backend/app/services/citation_graph_service.py
+++ b/backend/app/services/citation_graph_service.py
@@ -6,6 +6,7 @@
from typing import Any
import httpx
+from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -14,10 +15,10 @@
logger = logging.getLogger(__name__)
-S2_API_BASE = "https://api.semanticscholar.org/graph/v1"
S2_FIELDS = "title,year,citationCount,externalIds,authors"
-S2_TIMEOUT = 15
-S2_MAX_PER_REQUEST = 50
+
+# Error messages (extracted for maintainability)
+CITATION_NOT_FOUND = "无法获取引用数据:Semantic Scholar 未收录此论文"
class CitationGraphService:
@@ -37,16 +38,11 @@ async def get_citation_graph(
"""Return {nodes, edges, center_id} for a paper's citation network."""
paper = await self._db.get(Paper, paper_id)
if not paper or paper.project_id != project_id:
- return {"nodes": [], "edges": [], "center_id": None, "error": "Paper not found"}
+ raise HTTPException(status_code=404, detail="Paper not found")
s2_id = await self._resolve_s2_id(paper)
if not s2_id:
- return {
- "nodes": [],
- "edges": [],
- "center_id": None,
- "error": "无法获取引用数据:Semantic Scholar 未收录此论文",
- }
+ raise HTTPException(status_code=502, detail=CITATION_NOT_FOUND)
local_source_ids = await self._get_local_source_ids(project_id)
@@ -64,7 +60,7 @@ async def get_citation_graph(
}
nodes[s2_id] = center_node
- citations = await self._fetch_s2_list(f"{S2_API_BASE}/paper/{s2_id}/citations", max_nodes // 2)
+ citations = await self._fetch_s2_list(f"{settings.s2_api_base}/paper/{s2_id}/citations", max_nodes // 2)
for item in citations:
cited_paper = item.get("citingPaper", {})
cid = cited_paper.get("paperId")
@@ -76,7 +72,9 @@ async def get_citation_graph(
break
if len(nodes) < max_nodes:
- references = await self._fetch_s2_list(f"{S2_API_BASE}/paper/{s2_id}/references", max_nodes - len(nodes))
+ references = await self._fetch_s2_list(
+ f"{settings.s2_api_base}/paper/{s2_id}/references", max_nodes - len(nodes)
+ )
for item in references:
ref_paper = item.get("citedPaper", {})
rid = ref_paper.get("paperId")
@@ -100,7 +98,7 @@ async def _resolve_s2_id(self, paper: Paper) -> str | None:
if paper.doi:
try:
- data = await self._fetch_s2_json(f"{S2_API_BASE}/paper/DOI:{paper.doi}?fields=paperId")
+ data = await self._fetch_s2_json(f"{settings.s2_api_base}/paper/DOI:{paper.doi}?fields=paperId")
if pid := data.get("paperId"):
return pid
except Exception:
@@ -109,7 +107,7 @@ async def _resolve_s2_id(self, paper: Paper) -> str | None:
if paper.title:
try:
data = await self._fetch_s2_json(
- f"{S2_API_BASE}/paper/search",
+ f"{settings.s2_api_base}/paper/search",
params={"query": paper.title[:200], "limit": "1", "fields": "paperId"},
)
papers = data.get("data", [])
@@ -147,7 +145,7 @@ def _make_node(self, s2_paper: dict, local_ids: set[str]) -> dict:
async def _fetch_s2_list(self, url: str, limit: int) -> list[dict]:
"""Fetch paginated list from S2 citations/references endpoint."""
- actual_limit = min(limit, S2_MAX_PER_REQUEST)
+ actual_limit = min(limit, settings.s2_max_per_request)
try:
data = await self._fetch_s2_json(url, params={"fields": S2_FIELDS, "limit": str(actual_limit)})
return data.get("data", [])
@@ -160,7 +158,7 @@ async def _fetch_s2_json(self, url: str, params: dict | None = None) -> dict:
if settings.semantic_scholar_api_key:
headers["x-api-key"] = settings.semantic_scholar_api_key
- async with httpx.AsyncClient(timeout=S2_TIMEOUT) as client:
+ async with httpx.AsyncClient(timeout=settings.s2_timeout) as client:
resp = await client.get(url, headers=headers, params=params)
if resp.status_code == 429:
logger.warning("S2 API rate limited")
diff --git a/backend/app/services/completion_service.py b/backend/app/services/completion_service.py
index 864f2bf..085c942 100644
--- a/backend/app/services/completion_service.py
+++ b/backend/app/services/completion_service.py
@@ -4,17 +4,11 @@
import logging
+from app.prompts.completion import COMPLETION_SYSTEM
from app.services.llm.client import LLMClient, get_llm_client
logger = logging.getLogger(__name__)
-COMPLETION_SYSTEM_PROMPT = (
- "你是一个科研写作助手。根据用户已输入的文本,预测并补全后续内容。\n"
- "只返回补全的部分(不要重复用户已输入的内容),最多50个字符。\n"
- "如果无法合理预测,返回空字符串。\n"
- "不要添加任何解释、引号或格式标记,只返回纯文本补全内容。"
-)
-
class CompletionService:
"""Generates short text completions for chat input autocomplete."""
@@ -38,7 +32,7 @@ async def complete(
return {"completion": "", "confidence": 0.0}
messages: list[dict[str, str]] = [
- {"role": "system", "content": COMPLETION_SYSTEM_PROMPT},
+ {"role": "system", "content": COMPLETION_SYSTEM},
]
if recent_messages:
diff --git a/backend/app/services/crawler_service.py b/backend/app/services/crawler_service.py
index 59b477e..2567d32 100644
--- a/backend/app/services/crawler_service.py
+++ b/backend/app/services/crawler_service.py
@@ -1,5 +1,6 @@
"""PDF crawler service — download papers via Unpaywall, arXiv, and direct URLs."""
+import asyncio
import hashlib
import logging
from pathlib import Path
@@ -80,6 +81,13 @@ def _build_unpaywall_url(self, doi: str) -> str:
async def _download_pdf(self, url: str, paper: Paper) -> dict:
"""Download a PDF from a URL and save to disk."""
+ from app.services.url_validator import validate_url_safe
+
+ try:
+ await asyncio.to_thread(validate_url_safe, url)
+ except ValueError as e:
+ return {"success": False, "error": f"URL blocked: {e}"}
+
proxy = _get_proxy()
timeout = httpx.Timeout(60.0, connect=15.0)
@@ -93,6 +101,10 @@ async def _download_pdf(self, url: str, paper: Paper) -> dict:
pdf_url = best_oa.get("url_for_pdf") or best_oa.get("url") if best_oa else None
if not pdf_url:
return {"success": False, "error": "No open access PDF found"}
+ try:
+ await asyncio.to_thread(validate_url_safe, pdf_url)
+ except ValueError as e:
+ return {"success": False, "error": f"Resolved URL blocked: {e}"}
url = pdf_url
# Download the actual PDF
@@ -132,8 +144,6 @@ def _get_file_path(self, paper: Paper) -> Path:
async def batch_download(self, papers: list[Paper], max_concurrent: int = 5) -> dict:
"""Download PDFs for multiple papers with concurrency control."""
- import asyncio
-
semaphore = asyncio.Semaphore(max_concurrent)
results = {"success": 0, "failed": 0, "skipped": 0, "details": []}
diff --git a/backend/app/services/dedup_service.py b/backend/app/services/dedup_service.py
index 7a25a74..a4694a0 100644
--- a/backend/app/services/dedup_service.py
+++ b/backend/app/services/dedup_service.py
@@ -9,7 +9,8 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Paper, PaperStatus
-from app.services.llm_client import LLMClient
+from app.prompts.dedup import DEDUP_RESOLVE_SYSTEM, DEDUP_VERIFY_SYSTEM
+from app.services.llm.client import LLMClient
logger = logging.getLogger(__name__)
@@ -219,12 +220,57 @@ async def llm_verify_duplicate(self, paper_a_id: int, paper_b_id: int) -> dict:
result = await self.llm.chat_json(
messages=[
- {
- "role": "system",
- "content": "You are a scientific literature deduplication expert. Return valid JSON only.",
- },
+ {"role": "system", "content": DEDUP_VERIFY_SYSTEM},
{"role": "user", "content": prompt},
],
task_type="dedup_check",
)
return result
+
+ async def resolve_conflict(
+ self,
+ old_paper: Paper,
+ new_title: str,
+ new_doi: str | None,
+ new_year: int | None,
+ new_journal: str | None,
+ ) -> dict:
+ """Use LLM to decide how to resolve a duplicate conflict."""
+ if not self.llm:
+ return {"action": "keep_new", "reason": "LLM not available, defaulting to keep_new"}
+
+ prompt = f"""Two papers may be duplicates. Decide the best resolution:
+
+Existing paper (in DB):
+- ID: {old_paper.id}
+- Title: {old_paper.title}
+- DOI: {old_paper.doi or "N/A"}
+- Year: {old_paper.year}
+- Journal: {old_paper.journal}
+
+New upload:
+- Title: {new_title}
+- DOI: {new_doi or "N/A"}
+- Year: {new_year}
+- Journal: {new_journal}
+
+Return JSON: {{"action": "keep_old"|"keep_new"|"merge", "reason": "..."}}
+- keep_old: existing is better, discard new
+- keep_new: new is better or different work, add new
+- merge: combine metadata, add as new paper"""
+
+ try:
+ result = await self.llm.chat_json(
+ messages=[
+ {"role": "system", "content": DEDUP_RESOLVE_SYSTEM},
+ {"role": "user", "content": prompt},
+ ],
+ task_type="dedup_resolve",
+ )
+ action = result.get("action", "keep_new")
+ if action not in ("keep_old", "keep_new", "merge"):
+ action = "keep_new"
+ return {"action": action, "reason": result.get("reason", "")}
+ except Exception as e:
+ logger.warning("LLM auto-resolve failed: %s", e)
+ return {"action": "keep_new", "reason": f"Error: {e}"}
diff --git a/backend/app/services/embedding_service.py b/backend/app/services/embedding_service.py
index f146c7c..fcae7cd 100644
--- a/backend/app/services/embedding_service.py
+++ b/backend/app/services/embedding_service.py
@@ -13,7 +13,6 @@
logger = logging.getLogger(__name__)
-_cached_embed_model: BaseEmbedding | None = None
_env_injected = False
@@ -37,21 +36,34 @@ def _inject_hf_env() -> None:
logger.info("Using HuggingFace mirror: %s", settings.hf_endpoint)
-def detect_gpu() -> tuple[bool, int, str]:
- """Detect GPU availability. Returns (has_gpu, device_count, device_string)."""
+def detect_gpu(*, pinned_gpu_id: int = -1) -> tuple[bool, int, str]:
+ """Detect GPU availability and pick the best device.
+
+ Args:
+ pinned_gpu_id: If >= 0, skip auto-detection and return ``cuda:N``.
+
+ Returns (has_gpu, device_count, device_string) where device_string is
+ ``"cuda:N"`` (best/pinned device) or ``"cpu"``.
+ """
try:
import torch
if torch.cuda.is_available():
count = torch.cuda.device_count()
if count > 0:
+ if 0 <= pinned_gpu_id < count:
+ device = f"cuda:{pinned_gpu_id}"
+ logger.info("GPU pinned: %s (of %d device(s))", device, count)
+ return True, count, device
devices_env = os.environ.get("CUDA_VISIBLE_DEVICES", settings.cuda_visible_devices)
+ best_device = _pick_best_gpu(count)
logger.info(
- "GPU detected: %d device(s), CUDA_VISIBLE_DEVICES=%s",
+ "GPU detected: %d device(s), CUDA_VISIBLE_DEVICES=%s, selected=%s",
count,
devices_env,
+ best_device,
)
- return True, count, "cuda"
+ return True, count, best_device
logger.info("No CUDA GPU available, using CPU")
return False, 0, "cpu"
except ImportError:
@@ -59,6 +71,49 @@ def detect_gpu() -> tuple[bool, int, str]:
return False, 0, "cpu"
+def _make_api_loader(model_name: str):
+ """Return a callable that builds API embedding (avoids lambda for ruff E731)."""
+
+ def _load() -> BaseEmbedding:
+ return _build_api_embedding(model_name)
+
+ return _load
+
+
+def _make_local_loader(model_name: str):
+ """Return a callable that builds local embedding (avoids lambda for ruff E731)."""
+
+ def _load() -> BaseEmbedding:
+ return _build_local_embedding(model_name)
+
+ return _load
+
+
+def _pick_best_gpu(device_count: int) -> str:
+ """Select the CUDA device with the most free memory."""
+ if device_count <= 1:
+ return "cuda:0"
+ try:
+ import torch
+
+ best_idx = 0
+ best_free = 0
+ for idx in range(device_count):
+ free, _total = torch.cuda.mem_get_info(idx)
+ if free > best_free:
+ best_free = free
+ best_idx = idx
+ logger.info(
+ "GPU selection: device cuda:%d has %.1f GiB free (best of %d)",
+ best_idx,
+ best_free / (1024**3),
+ device_count,
+ )
+ return f"cuda:{best_idx}"
+ except Exception:
+ return "cuda:0"
+
+
def get_embedding_model(
*,
provider: str | None = None,
@@ -71,37 +126,55 @@ def get_embedding_model(
- "local": HuggingFaceEmbedding with GPU auto-detection
- "api": OpenAIEmbedding (works with any OpenAI-compatible endpoint)
- "mock": Deterministic mock for tests
+
+ Local models are managed by :class:`GPUModelManager` which provides
+ TTL-based auto-unloading.
"""
- global _cached_embed_model
- if _cached_embed_model is not None and not force_reload:
- return _cached_embed_model
+ from app.services.gpu_model_manager import gpu_model_manager
prov = provider or getattr(settings, "embedding_provider", "local")
name = model_name or settings.embedding_model
if prov == "mock":
- model = _build_mock_embedding()
+ loader = _build_mock_embedding
+ device = "cpu"
elif prov == "api":
- model = _build_api_embedding(name)
+ loader = _make_api_loader(name)
+ device = "cpu"
else:
- model = _build_local_embedding(name)
+ _, _, device = detect_gpu(pinned_gpu_id=settings.embed_gpu_id)
+ loader = _make_local_loader(name)
+
+ return gpu_model_manager.acquire(
+ "embedding",
+ loader,
+ model_name=name,
+ device=device,
+ force_reload=force_reload,
+ )
+
+
+def _cleanup_gpu_memory() -> None:
+ """Force garbage collection and release cached GPU memory."""
+ from app.services.gpu_utils import release_gpu_memory
- _cached_embed_model = model
- return model
+ release_gpu_memory(caller="embedding_service")
def _build_local_embedding(model_name: str) -> BaseEmbedding:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
_inject_hf_env()
+ _cleanup_gpu_memory()
- has_gpu, _count, device = detect_gpu()
- logger.info("Loading local embedding model=%s device=%s", model_name, device)
+ has_gpu, _count, device = detect_gpu(pinned_gpu_id=settings.embed_gpu_id)
+ batch_size = settings.embed_batch_size
+ logger.info("Loading local embedding model=%s device=%s batch_size=%d", model_name, device, batch_size)
return HuggingFaceEmbedding(
model_name=model_name,
device=device,
- embed_batch_size=32 if has_gpu else 8,
+ embed_batch_size=batch_size,
)
diff --git a/backend/app/services/gpu_model_manager.py b/backend/app/services/gpu_model_manager.py
new file mode 100644
index 0000000..416fba2
--- /dev/null
+++ b/backend/app/services/gpu_model_manager.py
@@ -0,0 +1,177 @@
+"""GPU model lifecycle manager with TTL-based auto-unloading.
+
+Uses threading locks so that ``acquire`` / ``release`` work from both sync
+and async code. Only the background TTL sweep needs the event loop.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import contextlib
+import logging
+import threading
+import time
+from collections.abc import Callable
+from dataclasses import dataclass, field
+from typing import Any
+
+from app.config import settings
+from app.services.gpu_utils import release_gpu_memory
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class _ModelEntry:
+ model: Any
+ last_used_at: float = field(default_factory=time.monotonic)
+ model_name: str = ""
+ device: str = ""
+
+
+class GPUModelManager:
+ """Manages GPU model lifecycle with TTL-based auto-unloading.
+
+ Models are loaded on-demand via ``acquire()`` and automatically unloaded
+ after ``model_ttl_seconds`` of inactivity. Set ``model_ttl_seconds=0``
+ to disable auto-unloading (models persist for the process lifetime).
+ """
+
+ def __init__(
+ self,
+ ttl_seconds: int | None = None,
+ check_interval: int | None = None,
+ ):
+ self._ttl = ttl_seconds if ttl_seconds is not None else settings.model_ttl_seconds
+ self._interval = check_interval if check_interval is not None else settings.model_ttl_check_interval
+ self._models: dict[str, _ModelEntry] = {}
+ self._locks: dict[str, threading.Lock] = {}
+ self._global_lock = threading.Lock()
+ self._cleanup_task: asyncio.Task[None] | None = None
+
+ # -- lifecycle --------------------------------------------------------
+
+ async def start(self) -> None:
+ """Start the background TTL cleanup loop (requires a running event loop)."""
+ if self._ttl > 0 and self._cleanup_task is None:
+ self._cleanup_task = asyncio.create_task(self._cleanup_loop())
+ logger.info("GPU model manager started (TTL=%ds, interval=%ds)", self._ttl, self._interval)
+
+ async def stop(self) -> None:
+ """Cancel the cleanup loop and unload all models."""
+ if self._cleanup_task is not None:
+ self._cleanup_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await self._cleanup_task
+ self._cleanup_task = None
+ self.unload_all()
+ logger.info("GPU model manager stopped")
+
+ # -- model access (sync-safe) -----------------------------------------
+
+ def _get_lock(self, name: str) -> threading.Lock:
+ with self._global_lock:
+ if name not in self._locks:
+ self._locks[name] = threading.Lock()
+ return self._locks[name]
+
+ def acquire(
+ self,
+ name: str,
+ loader_fn: Callable[[], Any],
+ *,
+ model_name: str = "",
+ device: str = "",
+ force_reload: bool = False,
+ ) -> Any:
+ """Return a cached model or load it on demand (thread-safe, sync).
+
+ Concurrent callers for the same *name* block on a shared lock so
+ the loader runs at most once.
+ """
+ lock = self._get_lock(name)
+ with lock:
+ entry = self._models.get(name)
+
+ if entry is not None and not force_reload:
+ entry.last_used_at = time.monotonic()
+ return entry.model
+
+ if entry is not None:
+ self._do_unload(name, entry)
+
+ model = loader_fn()
+ self._models[name] = _ModelEntry(
+ model=model,
+ model_name=model_name,
+ device=device,
+ )
+ logger.info("Loaded GPU model %r (model=%s, device=%s)", name, model_name, device)
+ return model
+
+ def touch(self, name: str) -> None:
+ """Update the last-used timestamp for a loaded model."""
+ entry = self._models.get(name)
+ if entry is not None:
+ entry.last_used_at = time.monotonic()
+
+ def unload(self, name: str) -> None:
+ """Unload a single model by name."""
+ lock = self._get_lock(name)
+ with lock:
+ entry = self._models.pop(name, None)
+ if entry is not None:
+ self._do_unload(name, entry)
+
+ def unload_all(self) -> None:
+ """Unload all managed models."""
+ names = list(self._models.keys())
+ for name in names:
+ self.unload(name)
+
+ def is_loaded(self, name: str) -> bool:
+ return name in self._models
+
+ # -- internals --------------------------------------------------------
+
+ def _do_unload(self, name: str, entry: _ModelEntry) -> None:
+ logger.info("Unloading GPU model %r", name)
+ del entry.model
+ release_gpu_memory(caller=f"gpu_model_manager:{name}")
+
+ async def _cleanup_loop(self) -> None:
+ """Periodically check for idle models and unload them."""
+ while True:
+ await asyncio.sleep(self._interval)
+ now = time.monotonic()
+ expired = [name for name, entry in self._models.items() if (now - entry.last_used_at) > self._ttl]
+ for name in expired:
+ logger.info("TTL expired for model %r, unloading", name)
+ self.unload(name)
+
+ # -- status -----------------------------------------------------------
+
+ def get_status(self) -> list[dict[str, Any]]:
+ """Return status information for all managed models."""
+ now = time.monotonic()
+ result = []
+ for name, entry in self._models.items():
+ idle = now - entry.last_used_at
+ result.append(
+ {
+ "name": name,
+ "model_name": entry.model_name,
+ "loaded": True,
+ "device": entry.device,
+ "idle_seconds": round(idle, 1),
+ "ttl_remaining_seconds": max(0, round(self._ttl - idle, 1)) if self._ttl > 0 else None,
+ }
+ )
+ return result
+
+ @property
+ def loaded_model_names(self) -> list[str]:
+ return list(self._models.keys())
+
+
+gpu_model_manager = GPUModelManager()
diff --git a/backend/app/services/gpu_utils.py b/backend/app/services/gpu_utils.py
new file mode 100644
index 0000000..3f286c6
--- /dev/null
+++ b/backend/app/services/gpu_utils.py
@@ -0,0 +1,19 @@
+"""GPU memory utilities — shared cleanup logic for embedding, OCR, and model manager."""
+
+import gc
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def release_gpu_memory(caller: str = "") -> None:
+ """Force garbage collection and release cached GPU memory."""
+ gc.collect()
+ try:
+ import torch
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ logger.info("%s: released GPU memory", caller or "gpu_utils")
+ except ImportError:
+ pass
diff --git a/backend/app/services/keyword_service.py b/backend/app/services/keyword_service.py
index 20e8089..c60469a 100644
--- a/backend/app/services/keyword_service.py
+++ b/backend/app/services/keyword_service.py
@@ -6,7 +6,8 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Keyword
-from app.services.llm_client import LLMClient
+from app.prompts.keyword import KEYWORD_EXPAND_SYSTEM
+from app.services.llm.client import LLMClient
logger = logging.getLogger(__name__)
@@ -62,7 +63,7 @@ async def expand_keywords_with_llm(
try:
result = await self.llm.chat_json(
messages=[
- {"role": "system", "content": "You are a scientific terminology expert. Return valid JSON only."},
+ {"role": "system", "content": KEYWORD_EXPAND_SYSTEM},
{"role": "user", "content": prompt},
],
task_type="keyword_expand",
diff --git a/backend/app/services/llm_client.py b/backend/app/services/llm_client.py
deleted file mode 100644
index 37c9c9c..0000000
--- a/backend/app/services/llm_client.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""Backward-compatibility shim — imports redirect to app.services.llm."""
-
-from app.services.llm.client import LLMClient, get_llm_client
-
-__all__ = ["LLMClient", "get_llm_client"]
diff --git a/backend/app/services/mineru_process_manager.py b/backend/app/services/mineru_process_manager.py
new file mode 100644
index 0000000..979c429
--- /dev/null
+++ b/backend/app/services/mineru_process_manager.py
@@ -0,0 +1,331 @@
+"""MinerU subprocess lifecycle manager with TTL-based auto-stop."""
+
+from __future__ import annotations
+
+import asyncio
+import contextlib
+import logging
+import shutil
+import signal
+import subprocess
+import time
+from typing import Any
+
+import httpx
+
+from app.config import settings
+
+logger = logging.getLogger(__name__)
+
+
+class MinerUProcessManager:
+ """Manages a MinerU FastAPI subprocess, starting it on demand and
+ stopping it after a configurable idle period.
+
+ When ``mineru_auto_manage`` is ``False`` the manager is a no-op —
+ callers should use the existing ``MinerUClient`` / health-check flow.
+ """
+
+ def __init__(self) -> None:
+ self._process: subprocess.Popen[bytes] | None = None
+ self._lock = asyncio.Lock()
+ self._last_used_at: float = 0.0
+ self._cleanup_task: asyncio.Task[None] | None = None
+ self._is_external: bool = False
+
+ # -- lifecycle --------------------------------------------------------
+
+ async def start(self) -> None:
+ """Start the background TTL watcher (does NOT start MinerU yet)."""
+ if not settings.mineru_auto_manage:
+ logger.info("MinerU auto-manage disabled")
+ return
+ ttl = settings.mineru_ttl_seconds
+ if ttl > 0 and self._cleanup_task is None:
+ self._cleanup_task = asyncio.create_task(self._cleanup_loop())
+ logger.info("MinerU process manager started (TTL=%ds)", ttl)
+
+ async def stop(self) -> None:
+ """Cancel the watcher and kill all MinerU processes (owned + external)."""
+ if self._cleanup_task is not None:
+ self._cleanup_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await self._cleanup_task
+ self._cleanup_task = None
+ await self._kill_process()
+ self.kill_external_by_port()
+ logger.info("MinerU process manager stopped")
+
+ def stop_sync(self) -> None:
+ """Synchronous cleanup for atexit — kill subprocess and external MinerU."""
+ if self._process is not None and self._process.poll() is None:
+ pid = self._process.pid
+ try:
+ self._process.send_signal(signal.SIGTERM)
+ self._process.wait(timeout=5)
+ except (subprocess.TimeoutExpired, OSError, ProcessLookupError):
+ try:
+ self._process.kill()
+ self._process.wait(timeout=3)
+ except (OSError, ProcessLookupError):
+ pass
+ logger.info("Sync cleanup: stopped MinerU subprocess pid=%d", pid)
+ self._process = None
+ self.kill_external_by_port()
+
+ # -- public API -------------------------------------------------------
+
+ async def ensure_running(self) -> bool:
+ """Make sure MinerU is reachable. Returns ``True`` on success.
+
+ 1. If an external process already serves the port → use it.
+ 2. Otherwise start a subprocess via ``conda run``.
+ 3. Poll ``/docs`` until healthy or timeout.
+ """
+ if not settings.mineru_auto_manage:
+ return False
+
+ async with self._lock:
+ if await self._health_check():
+ self._touch()
+ if self._process is None:
+ self._is_external = True
+ return True
+
+ self._is_external = False
+ if not self._start_subprocess():
+ return False
+
+ ok = await self._wait_healthy(settings.mineru_startup_timeout)
+ if ok:
+ self._touch()
+ else:
+ logger.warning("MinerU failed to become healthy within %ds", settings.mineru_startup_timeout)
+ await self._kill_process()
+ return ok
+
+ def touch(self) -> None:
+ """Update idle timer (call after every MinerU request)."""
+ self._touch()
+
+ async def shutdown_mineru(self) -> None:
+ """Immediately stop the managed subprocess."""
+ async with self._lock:
+ await self._kill_process()
+
+ def get_status(self) -> dict[str, Any]:
+ now = time.monotonic()
+ if self._process is not None and self._process.poll() is None:
+ idle = now - self._last_used_at if self._last_used_at else 0
+ ttl = settings.mineru_ttl_seconds
+ return {
+ "status": "running",
+ "pid": self._process.pid,
+ "port": self._port,
+ "idle_seconds": round(idle, 1),
+ "ttl_remaining_seconds": max(0, round(ttl - idle, 1)) if ttl > 0 else None,
+ }
+ if self._is_external:
+ return {"status": "external", "pid": None, "port": self._port}
+ return {"status": "stopped", "pid": None, "port": self._port}
+
+ # -- internals --------------------------------------------------------
+
+ @property
+ def _port(self) -> int:
+ url = settings.mineru_api_url.rstrip("/")
+ try:
+ return int(url.rsplit(":", 1)[-1])
+ except (ValueError, IndexError):
+ return 8010
+
+ @property
+ def _host(self) -> str:
+ url = settings.mineru_api_url.rstrip("/")
+ return url.rsplit(":", 1)[0].split("//")[-1] if "//" in url else "0.0.0.0"
+
+ def _touch(self) -> None:
+ self._last_used_at = time.monotonic()
+
+ async def _health_check(self) -> bool:
+ try:
+ async with httpx.AsyncClient(timeout=5) as client:
+ resp = await client.get(f"{settings.mineru_api_url.rstrip('/')}/docs")
+ return resp.status_code == 200
+ except Exception:
+ return False
+
+ def _start_subprocess(self) -> bool:
+ conda_path = shutil.which("conda")
+ if not conda_path:
+ logger.warning("conda not found on PATH, cannot auto-start MinerU")
+ return False
+
+ gpu_ids = settings.mineru_gpu_ids or settings.cuda_visible_devices
+ env_name = settings.mineru_conda_env
+
+ cmd = [
+ conda_path,
+ "run",
+ "-n",
+ env_name,
+ "python",
+ "-m",
+ "mineru.cli.fast_api",
+ "--host",
+ self._host,
+ "--port",
+ str(self._port),
+ ]
+
+ import os
+
+ env = os.environ.copy()
+ if gpu_ids:
+ env["CUDA_VISIBLE_DEVICES"] = gpu_ids
+
+ try:
+ self._process = subprocess.Popen(
+ cmd,
+ env=env,
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.PIPE,
+ )
+ logger.info(
+ "Started MinerU subprocess pid=%d (env=%s, gpu=%s, port=%d)",
+ self._process.pid,
+ env_name,
+ gpu_ids,
+ self._port,
+ )
+ return True
+ except (OSError, FileNotFoundError) as exc:
+ logger.warning("Failed to start MinerU subprocess: %s", exc)
+ return False
+
+ async def _wait_healthy(self, timeout: int) -> bool:
+ deadline = time.monotonic() + timeout
+ interval = 2.0
+ while time.monotonic() < deadline:
+ if self._process is not None and self._process.poll() is not None:
+ stderr_data = await asyncio.to_thread(self._process.stderr.read)
+ stderr = (stderr_data or b"").decode(errors="replace")[:500]
+ logger.warning("MinerU process exited early (code=%s): %s", self._process.returncode, stderr)
+ return False
+ if await self._health_check():
+ return True
+ await asyncio.sleep(interval)
+ interval = min(interval * 1.5, 10.0)
+ return False
+
+ async def _kill_process(self) -> None:
+ if self._process is None:
+ return
+ if self._process.poll() is not None:
+ self._process = None
+ return
+
+ pid = self._process.pid
+ logger.info("Stopping MinerU subprocess pid=%d", pid)
+ try:
+ self._process.send_signal(signal.SIGTERM)
+ try:
+ await asyncio.to_thread(self._process.wait, 10)
+ except subprocess.TimeoutExpired:
+ logger.warning("MinerU pid=%d did not exit after SIGTERM, sending SIGKILL", pid)
+ self._process.kill()
+ await asyncio.to_thread(self._process.wait, 5)
+ except (OSError, ProcessLookupError):
+ pass
+ finally:
+ self._process = None
+
+ def kill_external_by_port(self) -> None:
+ """Find and kill the process listening on the MinerU port (sync)."""
+ import os
+
+ port = self._port
+ my_pid = os.getpid()
+ target_pid = self._find_pid_by_port(port)
+ if target_pid is None or target_pid == my_pid:
+ return
+ if not self._is_mineru_process(target_pid):
+ logger.info("Port %d held by non-MinerU process (pid=%d), skipping", port, target_pid)
+ return
+ try:
+ os.kill(target_pid, signal.SIGTERM)
+ logger.info("Sent SIGTERM to external MinerU pid=%d (port=%d)", target_pid, port)
+ except (OSError, ProcessLookupError) as exc:
+ logger.warning("Failed to kill external MinerU pid=%d: %s", target_pid, exc)
+
+ @staticmethod
+ def _find_pid_by_port(port: int) -> int | None:
+ """Find PID listening on a TCP port using /proc or lsof."""
+ import os
+
+ try:
+ with open("/proc/net/tcp") as f:
+ hex_port = f":{port:04X}"
+ for line in f:
+ parts = line.strip().split()
+ if len(parts) >= 10 and hex_port in parts[1] and parts[3] == "0A":
+ inode = parts[9]
+ for pid_dir in os.listdir("/proc"):
+ if not pid_dir.isdigit():
+ continue
+ try:
+ fd_dir = f"/proc/{pid_dir}/fd"
+ for fd in os.listdir(fd_dir):
+ link = os.readlink(f"{fd_dir}/{fd}")
+ if f"socket:[{inode}]" in link:
+ return int(pid_dir)
+ except (OSError, PermissionError):
+ continue
+ except (OSError, PermissionError):
+ pass
+
+ import shutil
+
+ lsof_path = shutil.which("lsof")
+ if lsof_path:
+ try:
+ result = subprocess.run(
+ [lsof_path, "-ti", f":{port}"],
+ capture_output=True,
+ text=True,
+ timeout=5,
+ )
+ if result.returncode == 0 and result.stdout.strip():
+ return int(result.stdout.strip().split("\n")[0])
+ except (subprocess.TimeoutExpired, ValueError, OSError):
+ pass
+ return None
+
+ @staticmethod
+ def _is_mineru_process(pid: int) -> bool:
+ """Check if a PID is a MinerU process by reading its cmdline."""
+ try:
+ with open(f"/proc/{pid}/cmdline", "rb") as f:
+ cmdline = f.read().decode(errors="replace").lower()
+ return "mineru" in cmdline
+ except (OSError, PermissionError):
+ return False
+
+ async def _cleanup_loop(self) -> None:
+ ttl = settings.mineru_ttl_seconds
+ interval = max(ttl // 4, 30)
+ while True:
+ await asyncio.sleep(interval)
+ if self._process is None or self._is_external:
+ continue
+ if self._process.poll() is not None:
+ logger.info("MinerU subprocess exited unexpectedly")
+ self._process = None
+ continue
+ idle = time.monotonic() - self._last_used_at
+ if self._last_used_at > 0 and idle > ttl:
+ logger.info("MinerU idle for %.0fs (TTL=%ds), stopping", idle, ttl)
+ await self._kill_process()
+
+
+mineru_process_manager = MinerUProcessManager()
diff --git a/backend/app/services/ocr_service.py b/backend/app/services/ocr_service.py
index 5f28ec4..eb6aaf1 100644
--- a/backend/app/services/ocr_service.py
+++ b/backend/app/services/ocr_service.py
@@ -10,11 +10,13 @@
import json
import logging
import re
+import tempfile
from pathlib import Path
import pdfplumber
from app.config import settings
+from app.services.gpu_utils import release_gpu_memory
logger = logging.getLogger(__name__)
@@ -31,6 +33,25 @@ def __init__(self, use_gpu: bool = True, gpu_id: int = 0):
self.output_dir = Path(settings.ocr_output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
+ def close(self) -> None:
+ """Release PaddleOCR model and free GPU memory."""
+ if self._paddle_ocr is not None:
+ del self._paddle_ocr
+ self._paddle_ocr = None
+ if self._marker_converter is not None:
+ del self._marker_converter
+ self._marker_converter = None
+ release_gpu_memory(caller="OCRService")
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exc):
+ self.close()
+
+ def __del__(self):
+ self.close()
+
def extract_text_native(self, pdf_path: str) -> list[dict]:
"""Extract text from native (non-scanned) PDF using pdfplumber."""
pages = []
@@ -145,17 +166,16 @@ def extract_text_ocr(self, pdf_path: str) -> list[dict]:
else:
import fitz
- pdf_doc = fitz.open(pdf_path)
result = []
- for page_num in range(len(pdf_doc)):
- page = pdf_doc[page_num]
- pix = page.get_pixmap(dpi=150)
- img_path = f"/tmp/omelette_ocr_page_{page_num}.png"
- pix.save(img_path)
- page_result = ocr.ocr(img_path, cls=False)
- result.append(page_result[0] if page_result else [])
- Path(img_path).unlink(missing_ok=True)
- pdf_doc.close()
+ with fitz.open(pdf_path) as pdf_doc:
+ for page_num in range(len(pdf_doc)):
+ page = pdf_doc[page_num]
+ pix = page.get_pixmap(dpi=150)
+ img_path = str(Path(tempfile.gettempdir()) / f"omelette_ocr_page_{page_num}.png")
+ pix.save(img_path)
+ page_result = ocr.ocr(img_path, cls=False)
+ result.append(page_result[0] if page_result else [])
+ Path(img_path).unlink(missing_ok=True)
for i, page_result in enumerate(result):
text_lines = []
@@ -188,15 +208,26 @@ async def _extract_with_mineru(self, pdf_path: str) -> dict | None:
return None
from app.services.mineru_client import MinerUClient
+ from app.services.mineru_process_manager import mineru_process_manager
+
+ if settings.mineru_auto_manage:
+ ok = await mineru_process_manager.ensure_running()
+ if not ok:
+ logger.info("MinerU auto-start failed, falling back to pdfplumber")
+ return None
if self._mineru_client is None:
self._mineru_client = MinerUClient()
- if not await self._mineru_client.health_check():
+ if not settings.mineru_auto_manage and not await self._mineru_client.health_check():
logger.info("MinerU service not available, skipping")
return None
result = await self._mineru_client.parse_pdf(pdf_path)
+
+ if settings.mineru_auto_manage:
+ mineru_process_manager.touch()
+
if result.get("error"):
logger.warning("MinerU failed for %s: %s", pdf_path, result["error"])
return None
diff --git a/backend/app/services/paper_processor.py b/backend/app/services/paper_processor.py
index 1f7a4ee..ff7c303 100644
--- a/backend/app/services/paper_processor.py
+++ b/backend/app/services/paper_processor.py
@@ -2,15 +2,24 @@
Designed to run as a fire-and-forget ``asyncio.create_task`` so the upload
API can return immediately while processing continues in the background.
+
+GPU parallelisation:
+ - Multiple PDFs are OCR-ed concurrently via ``asyncio.gather``.
+ - Each worker gets a distinct ``gpu_id`` (round-robin) so that all
+ visible GPUs are utilised.
+ - DB writes and RAG indexing remain serial (ChromaDB limitation).
"""
from __future__ import annotations
+import asyncio
import logging
+import time
from sqlalchemy import select
from sqlalchemy.orm import selectinload
+from app.config import GpuMode, settings
from app.database import async_session_factory
from app.models import Paper, PaperStatus
from app.models.chunk import PaperChunk
@@ -20,6 +29,47 @@
logger = logging.getLogger(__name__)
+def _detect_gpu_count() -> int:
+ """Return the number of CUDA devices visible to this process (0 = CPU-only)."""
+ try:
+ import torch
+
+ if torch.cuda.is_available():
+ return torch.cuda.device_count()
+ except ImportError:
+ pass
+ return 0
+
+
+def _parse_ocr_gpu_ids(gpu_count: int) -> list[int]:
+ """Parse OCR_GPU_IDS into a list of valid indices.
+
+ Empty string → all GPUs ``[0 .. gpu_count-1]``.
+ """
+ raw = settings.ocr_gpu_ids.strip()
+ if not raw or gpu_count == 0:
+ return list(range(max(gpu_count, 1)))
+ ids = []
+ for tok in raw.split(","):
+ tok = tok.strip()
+ if tok.isdigit():
+ idx = int(tok)
+ if idx < gpu_count:
+ ids.append(idx)
+ return ids or list(range(gpu_count))
+
+
+def _resolve_parallel_limit(gpu_count: int) -> int:
+ """Determine how many OCR tasks may run concurrently."""
+ configured = settings.ocr_parallel_limit
+ if configured > 0:
+ return configured
+ base = max(gpu_count, 1)
+ if settings.gpu_mode == GpuMode.AGGRESSIVE:
+ return base * 2
+ return base
+
+
async def process_papers_background(
project_id: int,
paper_ids: list[int],
@@ -32,7 +82,16 @@ async def process_papers_background(
async def _process_papers(project_id: int, paper_ids: list[int]) -> None:
- ocr = OCRService(use_gpu=True)
+ gpu_count = _detect_gpu_count()
+ parallel_limit = _resolve_parallel_limit(gpu_count)
+ use_gpu = gpu_count > 0
+
+ logger.info(
+ "Paper processing: %d papers, %d GPU(s), parallel_limit=%d",
+ len(paper_ids),
+ gpu_count,
+ parallel_limit,
+ )
async with async_session_factory() as db:
stmt = select(Paper).where(
@@ -42,31 +101,62 @@ async def _process_papers(project_id: int, paper_ids: list[int]) -> None:
papers = list((await db.execute(stmt)).scalars().all())
ocr_done_ids: list[int] = []
+ papers_to_ocr: list[Paper] = []
for paper in papers:
if paper.status not in (PaperStatus.PDF_DOWNLOADED, PaperStatus.ERROR):
if paper.status in (PaperStatus.OCR_COMPLETE, PaperStatus.INDEXED):
ocr_done_ids.append(paper.id)
continue
-
if not paper.pdf_path:
paper.status = PaperStatus.ERROR
continue
-
- try:
- result = await ocr.process_pdf_async(paper.pdf_path)
+ papers_to_ocr.append(paper)
+
+ if papers_to_ocr:
+ semaphore = asyncio.Semaphore(parallel_limit)
+ ocr_gpus = _parse_ocr_gpu_ids(gpu_count)
+
+ async def _ocr_one(paper: Paper, worker_id: int) -> tuple[Paper, dict | None]:
+ gpu_id = ocr_gpus[worker_id % len(ocr_gpus)] if use_gpu else 0
+ with OCRService(use_gpu=use_gpu, gpu_id=gpu_id) as ocr:
+ async with semaphore:
+ try:
+ t0 = time.monotonic()
+ result = await ocr.process_pdf_async(paper.pdf_path)
+ elapsed = time.monotonic() - t0
+ logger.info(
+ "OCR worker %d (gpu=%d) finished paper %d in %.1fs",
+ worker_id,
+ gpu_id,
+ paper.id,
+ elapsed,
+ )
+ return paper, result
+ except Exception:
+ logger.exception("OCR failed for paper %d (worker %d)", paper.id, worker_id)
+ return paper, None
+
+ tasks = [_ocr_one(paper, i) for i, paper in enumerate(papers_to_ocr)]
+ results = await asyncio.gather(*tasks)
+
+ for paper, result in results:
+ if result is None:
+ paper.status = PaperStatus.ERROR
+ continue
if result.get("error"):
paper.status = PaperStatus.ERROR
logger.warning("OCR error for paper %d: %s", paper.id, result.get("error"))
continue
- ocr.save_result(paper.id, result)
+ OCRService(use_gpu=False).save_result(paper.id, result)
if result.get("method") == "mineru":
- chunks = ocr.chunk_mineru_markdown(result["md_content"])
+ chunks = OCRService(use_gpu=False).chunk_mineru_markdown(result["md_content"])
else:
- chunks = ocr.chunk_text(result.get("pages", []))
+ chunks = OCRService(use_gpu=False).chunk_text(result.get("pages", []))
+
for chunk_data in chunks:
db.add(
PaperChunk(
@@ -84,9 +174,6 @@ async def _process_papers(project_id: int, paper_ids: list[int]) -> None:
paper.status = PaperStatus.OCR_COMPLETE
ocr_done_ids.append(paper.id)
logger.info("OCR complete for paper %d (%s)", paper.id, paper.title[:40])
- except Exception:
- paper.status = PaperStatus.ERROR
- logger.exception("OCR failed for paper %d", paper.id)
await db.flush()
diff --git a/backend/app/services/pdf_metadata.py b/backend/app/services/pdf_metadata.py
index bbeede0..c6c5675 100644
--- a/backend/app/services/pdf_metadata.py
+++ b/backend/app/services/pdf_metadata.py
@@ -9,6 +9,7 @@
from __future__ import annotations
+import asyncio
import logging
import re
from pathlib import Path
@@ -36,7 +37,7 @@ async def extract_metadata(
fallback_title: str = "Untitled",
) -> NewPaperData:
"""Extract metadata from *pdf_path*, optionally enriching via Crossref."""
- local = _extract_local(pdf_path, fallback_title)
+ local = await asyncio.to_thread(_extract_local, pdf_path, fallback_title)
if local.doi:
enriched = await _crossref_lookup(local.doi)
diff --git a/backend/app/services/pipeline_service.py b/backend/app/services/pipeline_service.py
index 4df3cc5..88412f7 100644
--- a/backend/app/services/pipeline_service.py
+++ b/backend/app/services/pipeline_service.py
@@ -77,25 +77,38 @@ async def _download(self, paper: Paper) -> dict:
async def _ocr(self, paper: Paper) -> dict:
try:
- ocr = OCRService(use_gpu=True)
- result = ocr.process_pdf(paper.pdf_path)
+ with OCRService(use_gpu=True) as ocr:
+ result = await ocr.process_pdf_async(paper.pdf_path)
if result.get("error"):
paper.status = PaperStatus.ERROR
return {"success": False, "reason": result["error"]}
- pages = result.get("pages", [])
chunks = []
- for page in pages:
- if page.get("text", "").strip():
- chunks.append(
- {
- "paper_id": paper.id,
- "content": page["text"],
- "page_number": page.get("page_number", 0),
- "chunk_index": len(chunks),
- }
- )
+ if result.get("method") == "mineru":
+ mineru_chunks = ocr.chunk_mineru_markdown(result["md_content"])
+ for i, c in enumerate(mineru_chunks):
+ if c.get("content", "").strip():
+ chunks.append(
+ {
+ "paper_id": paper.id,
+ "content": c["content"],
+ "page_number": c.get("page_number", 1),
+ "chunk_index": i,
+ }
+ )
+ else:
+ pages = result.get("pages", [])
+ for page in pages:
+ if page.get("text", "").strip():
+ chunks.append(
+ {
+ "paper_id": paper.id,
+ "content": page["text"],
+ "page_number": page.get("page_number", 0),
+ "chunk_index": len(chunks),
+ }
+ )
for chunk_data in chunks:
chunk = PaperChunk(**chunk_data)
diff --git a/backend/app/services/rag_service.py b/backend/app/services/rag_service.py
index a1af81b..be6b43d 100644
--- a/backend/app/services/rag_service.py
+++ b/backend/app/services/rag_service.py
@@ -10,7 +10,9 @@
from __future__ import annotations
+import asyncio
import logging
+import time
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING
@@ -23,7 +25,8 @@
from llama_index.core.schema import Document, NodeRelationship, RelatedNodeInfo, TextNode
from app.config import settings
-from app.services.llm_client import LLMClient
+from app.prompts.rag import RAG_ANSWER_SYSTEM
+from app.services.llm.client import LLMClient
if TYPE_CHECKING:
from llama_index.core.embeddings import BaseEmbedding
@@ -34,6 +37,8 @@
class RAGService:
"""LlamaIndex-powered RAG service with ChromaDB vector store."""
+ _COUNT_CACHE_TTL = 60.0
+
def __init__(
self,
llm: LLMClient | None = None,
@@ -44,6 +49,21 @@ def __init__(
self.llm = llm
self._chroma_client = chroma_client
self._embed_model = embed_model
+ self._count_cache: dict[int, tuple[int, float]] = {}
+
+ async def _get_count(self, project_id: int) -> int:
+ """Get collection count with caching and async wrapping."""
+ now = time.monotonic()
+ cached = self._count_cache.get(project_id)
+ if cached and now - cached[1] < self._COUNT_CACHE_TTL:
+ return cached[0]
+ collection = self._get_collection(project_id)
+ count = await asyncio.to_thread(collection.count)
+ self._count_cache[project_id] = (count, now)
+ return count
+
+ def _invalidate_count(self, project_id: int) -> None:
+ self._count_cache.pop(project_id, None)
def _get_chroma_client(self) -> chromadb.ClientAPI:
if self._chroma_client is None:
@@ -58,7 +78,12 @@ def _get_chroma_client(self) -> chromadb.ClientAPI:
def _get_collection(self, project_id: int) -> chromadb.Collection:
return self._get_chroma_client().get_or_create_collection(
name=f"project_{project_id}",
- metadata={"hnsw:space": "cosine"},
+ metadata={
+ "hnsw:space": "cosine",
+ "hnsw:construction_ef": 200,
+ "hnsw:search_ef": 100,
+ "hnsw:M": 32,
+ },
)
def _ensure_embed_model(self) -> BaseEmbedding:
@@ -69,6 +94,15 @@ def _ensure_embed_model(self) -> BaseEmbedding:
LlamaSettings.embed_model = self._embed_model
return self._embed_model
+ def _reload_embed_model(self) -> BaseEmbedding:
+ """Force-reload the embedding model onto the best available GPU."""
+ from app.services.embedding_service import _cleanup_gpu_memory, get_embedding_model
+
+ _cleanup_gpu_memory()
+ self._embed_model = get_embedding_model(force_reload=True)
+ LlamaSettings.embed_model = self._embed_model
+ return self._embed_model
+
def _get_vector_store(self, project_id: int):
from llama_index.vector_stores.chroma import ChromaVectorStore
@@ -104,6 +138,9 @@ async def index_chunks(
if on_progress:
on_progress("loading_model", 0)
+ from app.services.embedding_service import _cleanup_gpu_memory
+
+ _cleanup_gpu_memory()
index = self._get_index(project_id)
if on_progress:
@@ -132,8 +169,6 @@ async def index_chunks(
node.relationships[NodeRelationship.SOURCE] = RelatedNodeInfo(node_id=ref_doc_id)
nodes.append(node)
- import asyncio
-
total = len(nodes)
indexed = 0
for i in range(0, total, batch_size):
@@ -144,6 +179,7 @@ async def index_chunks(
pct = 10 + int(90 * indexed / total)
on_progress("indexing", min(pct, 99))
+ self._invalidate_count(project_id)
return {"indexed": total, "collection": f"project_{project_id}"}
async def index_documents(
@@ -161,10 +197,9 @@ async def index_documents(
splitter = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
nodes = splitter.get_nodes_from_documents(documents)
- import asyncio
-
index = self._get_index(project_id)
await asyncio.to_thread(index.insert_nodes, nodes)
+ self._invalidate_count(project_id)
return {"indexed": len(nodes), "collection": f"project_{project_id}"}
@staticmethod
@@ -184,7 +219,7 @@ def _get_adjacent_chunks(
chunk_index: int,
window: int = 1,
) -> tuple[str, str]:
- """Fetch adjacent chunks for context expansion.
+ """Fetch adjacent chunks for context expansion (single node).
Returns ``(prev_text, next_text)`` so the caller can assemble
``[prev] \\n [main] \\n [next]`` in the correct order.
@@ -204,9 +239,60 @@ def _get_adjacent_chunks(
docs = result.get("documents") or []
next_text = "\n".join(d for d in docs if d)
except Exception:
- pass
+ logger.debug("Adjacent chunk fetch failed for paper %d chunk %d", paper_id, chunk_index, exc_info=True)
return prev_text, next_text
+ async def _get_adjacent_chunks_batch(
+ self,
+ collection: chromadb.Collection,
+ nodes: list,
+ ) -> list[tuple[str, str]]:
+ """Batch-fetch adjacent chunks for all nodes in a single ChromaDB call.
+
+ Returns a list of (prev_text, next_text) tuples aligned with *nodes*.
+ """
+ all_ids: set[str] = set()
+ node_adj: list[tuple[str | None, str | None]] = []
+
+ for n in nodes:
+ node = n.node if hasattr(n, "node") else n
+ meta = node.metadata or {}
+ pid = meta.get("paper_id")
+ cidx = meta.get("chunk_index")
+ if pid is None or cidx is None:
+ node_adj.append((None, None))
+ continue
+ prev_id = f"paper_{pid}_chunk_{cidx - 1}"
+ next_id = f"paper_{pid}_chunk_{cidx + 1}"
+ all_ids.update([prev_id, next_id])
+ node_adj.append((prev_id, next_id))
+
+ id_to_doc: dict[str, str] = {}
+ if all_ids:
+ try:
+ result = await asyncio.to_thread(collection.get, ids=list(all_ids), include=["documents"])
+ for doc_id, doc in zip(result.get("ids") or [], result.get("documents") or []):
+ if doc:
+ id_to_doc[doc_id] = doc
+ except Exception:
+ logger.debug("Batch adjacent chunk fetch failed", exc_info=True)
+
+ return [
+ (id_to_doc.get(prev_id, ""), id_to_doc.get(next_id, "")) if prev_id is not None else ("", "")
+ for prev_id, next_id in node_adj
+ ]
+
+ def _build_retriever(self, index: VectorStoreIndex, fetch_k: int, count: int):
+ """Build a retriever, optionally with MMR mode."""
+ effective_k = min(fetch_k, count)
+ if settings.rag_mmr_threshold > 0:
+ return index.as_retriever(
+ similarity_top_k=effective_k,
+ vector_store_query_mode="mmr",
+ vector_store_kwargs={"mmr_threshold": settings.rag_mmr_threshold},
+ )
+ return index.as_retriever(similarity_top_k=effective_k)
+
async def query(
self,
project_id: int,
@@ -216,42 +302,40 @@ async def query(
include_sources: bool = True,
) -> dict:
"""Query the knowledge base and generate an answer with citations."""
- collection = self._get_collection(project_id)
- if collection.count() == 0:
+ count = await self._get_count(project_id)
+ if count == 0:
return {
"answer": "No documents have been indexed yet. Please process and index papers first.",
"sources": [],
"confidence": 0.0,
}
- import asyncio
-
+ collection = self._get_collection(project_id)
index = self._get_index(project_id)
- retriever = index.as_retriever(similarity_top_k=min(top_k, collection.count()))
+
+ oversample = settings.rag_oversample_factor if use_reranker else 1
+ fetch_k = top_k * oversample
+ retriever = self._build_retriever(index, fetch_k, count)
retrieved_nodes = await asyncio.to_thread(retriever.retrieve, question)
if not retrieved_nodes:
return {"answer": "No relevant documents found.", "sources": [], "confidence": 0.0}
+ if use_reranker and retrieved_nodes:
+ from app.services.reranker_service import rerank_nodes
+
+ retrieved_nodes = await rerank_nodes(retrieved_nodes, question, top_n=top_k)
+
+ adj_results = await self._get_adjacent_chunks_batch(collection, retrieved_nodes)
+
contexts = []
sources = []
- for node_with_score in retrieved_nodes:
+ for node_with_score, (prev_text, next_text) in zip(retrieved_nodes, adj_results):
node = node_with_score.node
meta = node.metadata or {}
score = node_with_score.score or 0.0
text = node.get_content()
- paper_id = meta.get("paper_id")
- chunk_idx = meta.get("chunk_index")
- prev_text, next_text = "", ""
- if paper_id is not None and chunk_idx is not None:
- prev_text, next_text = await asyncio.to_thread(
- self._get_adjacent_chunks,
- collection,
- paper_id,
- chunk_idx,
- )
-
parts = [p for p in [prev_text, text, next_text] if p]
full_context = "\n".join(parts)
@@ -260,7 +344,7 @@ async def query(
)
sources.append(
{
- "paper_id": paper_id,
+ "paper_id": meta.get("paper_id"),
"paper_title": meta.get("paper_title", ""),
"page_number": meta.get("page_number"),
"chunk_type": meta.get("chunk_type", "text"),
@@ -289,45 +373,45 @@ async def retrieve_only(
project_id: int,
question: str,
top_k: int = 10,
+ use_reranker: bool = False,
) -> list[dict]:
"""Retrieve relevant chunks without LLM generation.
Designed for the Chat Pipeline where the LLM call happens
downstream in the generate node, avoiding a redundant call here.
"""
- collection = self._get_collection(project_id)
- if collection.count() == 0:
+ count = await self._get_count(project_id)
+ if count == 0:
return []
- import asyncio
-
+ collection = self._get_collection(project_id)
index = self._get_index(project_id)
- retriever = index.as_retriever(similarity_top_k=min(top_k, collection.count()))
+
+ oversample = settings.rag_oversample_factor if use_reranker else 1
+ fetch_k = top_k * oversample
+ retriever = self._build_retriever(index, fetch_k, count)
retrieved_nodes = await asyncio.to_thread(retriever.retrieve, question)
+ if use_reranker and retrieved_nodes:
+ from app.services.reranker_service import rerank_nodes
+
+ retrieved_nodes = await rerank_nodes(retrieved_nodes, question, top_n=top_k)
+
+ adj_results = await self._get_adjacent_chunks_batch(collection, retrieved_nodes)
+
sources: list[dict] = []
- for node_with_score in retrieved_nodes:
+ for node_with_score, (prev_text, next_text) in zip(retrieved_nodes, adj_results):
node = node_with_score.node
meta = node.metadata or {}
score = node_with_score.score or 0.0
text = node.get_content()
- paper_id = meta.get("paper_id")
- chunk_idx = meta.get("chunk_index")
- prev_text, next_text = "", ""
- if paper_id is not None and chunk_idx is not None:
- prev_text, next_text = await asyncio.to_thread(
- self._get_adjacent_chunks,
- collection,
- paper_id,
- chunk_idx,
- )
parts = [p for p in [prev_text, text, next_text] if p]
full_context = "\n".join(parts)
sources.append(
{
- "paper_id": paper_id,
+ "paper_id": meta.get("paper_id"),
"paper_title": meta.get("paper_title", ""),
"page_number": meta.get("page_number"),
"chunk_type": meta.get("chunk_type", "text"),
@@ -349,14 +433,7 @@ async def _generate_answer(self, question: str, context: str) -> str:
)
return await self.llm.chat(
messages=[
- {
- "role": "system",
- "content": (
- "You are a scientific research assistant. "
- "Answer questions based strictly on the provided context. "
- "Cite sources accurately."
- ),
- },
+ {"role": "system", "content": RAG_ANSWER_SYSTEM},
{"role": "user", "content": prompt},
],
temperature=0.3,
@@ -368,7 +445,8 @@ async def delete_index(self, project_id: int) -> dict:
client = self._get_chroma_client()
name = f"project_{project_id}"
try:
- client.delete_collection(name)
+ await asyncio.to_thread(client.delete_collection, name)
+ self._invalidate_count(project_id)
return {"deleted": True, "collection": name}
except ValueError:
return {"deleted": False, "message": "Collection not found"}
@@ -377,7 +455,8 @@ async def delete_paper(self, project_id: int, paper_id: int) -> dict:
"""Delete all chunks for a single paper from the index."""
collection = self._get_collection(project_id)
try:
- collection.delete(where={"paper_id": paper_id})
+ await asyncio.to_thread(collection.delete, where={"paper_id": paper_id})
+ self._invalidate_count(project_id)
return {"deleted": True, "paper_id": paper_id}
except Exception as e:
logger.warning("Failed to delete paper %d from index: %s", paper_id, e)
@@ -386,10 +465,11 @@ async def delete_paper(self, project_id: int, paper_id: int) -> dict:
async def get_stats(self, project_id: int) -> dict:
"""Get index statistics for a project."""
try:
- collection = self._get_collection(project_id)
+ count = await self._get_count(project_id)
return {
- "total_chunks": collection.count(),
+ "total_chunks": count,
"collection_name": f"project_{project_id}",
}
except Exception:
+ logger.warning("Failed to get stats for project %d", project_id, exc_info=True)
return {"total_chunks": 0, "collection_name": f"project_{project_id}"}
diff --git a/backend/app/services/reranker_service.py b/backend/app/services/reranker_service.py
new file mode 100644
index 0000000..47abb60
--- /dev/null
+++ b/backend/app/services/reranker_service.py
@@ -0,0 +1,86 @@
+"""Reranker model loading, caching, and async-safe inference."""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+from typing import TYPE_CHECKING
+
+from app.config import settings
+
+if TYPE_CHECKING:
+ from llama_index.core.schema import NodeWithScore
+
+logger = logging.getLogger(__name__)
+
+_reranker_semaphore: asyncio.Semaphore | None = None
+
+
+def _get_semaphore() -> asyncio.Semaphore:
+ global _reranker_semaphore
+ if _reranker_semaphore is None:
+ _reranker_semaphore = asyncio.Semaphore(settings.reranker_concurrency_limit)
+ return _reranker_semaphore
+
+
+def _build_reranker(model_name: str):
+ """Build a SentenceTransformerRerank instance (heavy, runs on GPU)."""
+ from llama_index.postprocessor.sbert_rerank import SentenceTransformerRerank
+
+ from app.services.embedding_service import _inject_hf_env, detect_gpu
+
+ _inject_hf_env()
+
+ _has_gpu, _count, device = detect_gpu(pinned_gpu_id=settings.rerank_gpu_id)
+ batch_size = settings.rerank_batch_size
+ logger.info("Loading reranker model=%s device=%s top_n=%d", model_name, device, batch_size)
+ return SentenceTransformerRerank(
+ model=model_name,
+ top_n=batch_size, # Oversample before rerank, then return top batch_size; aligns with RAG oversample_factor
+ device=device,
+ keep_retrieval_score=True,
+ )
+
+
+def get_reranker(*, model_name: str | None = None):
+ """Return a cached reranker via GPUModelManager (TTL-managed)."""
+ from app.services.embedding_service import detect_gpu
+ from app.services.gpu_model_manager import gpu_model_manager
+
+ name = model_name or settings.reranker_model
+ _, _, device = detect_gpu(pinned_gpu_id=settings.rerank_gpu_id)
+ return gpu_model_manager.acquire(
+ "reranker",
+ lambda: _build_reranker(name),
+ model_name=name,
+ device=device,
+ )
+
+
+async def rerank_nodes(
+ nodes: list[NodeWithScore],
+ query: str,
+ top_n: int,
+) -> list[NodeWithScore]:
+ """Apply reranker with concurrency control and graceful fallback.
+
+ Uses a semaphore to serialize GPU inference and falls back to
+ the original node order on any failure.
+ """
+ if not nodes:
+ return []
+ try:
+ from llama_index.core.schema import QueryBundle
+
+ reranker = get_reranker()
+ query_bundle = QueryBundle(query_str=query)
+ async with _get_semaphore():
+ reranked = await asyncio.to_thread(
+ reranker.postprocess_nodes,
+ nodes,
+ query_bundle=query_bundle,
+ )
+ return reranked[:top_n]
+ except (ImportError, OSError, RuntimeError):
+ logger.warning("Reranking failed, returning original nodes", exc_info=True)
+ return nodes[:top_n]
diff --git a/backend/app/services/subscription_service.py b/backend/app/services/subscription_service.py
index 11ce2e5..54eb2da 100644
--- a/backend/app/services/subscription_service.py
+++ b/backend/app/services/subscription_service.py
@@ -1,5 +1,6 @@
"""Incremental subscription service — scheduled literature updates via API and RSS."""
+import asyncio
import logging
from datetime import datetime
@@ -20,12 +21,19 @@ def __init__(self):
async def check_rss_feed(self, feed_url: str, since: datetime | None = None) -> list[dict]:
"""Parse an RSS/Atom feed and return new entries since the given date."""
+ from app.services.url_validator import validate_url_safe
+
+ try:
+ await asyncio.to_thread(validate_url_safe, feed_url)
+ except ValueError as e:
+ raise ValueError(f"Feed URL blocked (SSRF): {e}") from e
+
proxy = settings.http_proxy or None
async with httpx.AsyncClient(proxy=proxy, timeout=30.0) as client:
resp = await client.get(feed_url)
resp.raise_for_status()
- feed = feedparser.parse(resp.text)
+ feed = await asyncio.to_thread(feedparser.parse, resp.text)
entries = []
for entry in feed.entries:
diff --git a/backend/app/services/url_validator.py b/backend/app/services/url_validator.py
new file mode 100644
index 0000000..ed73d44
--- /dev/null
+++ b/backend/app/services/url_validator.py
@@ -0,0 +1,58 @@
+"""URL and DOI validation utilities for SSRF prevention."""
+
+import ipaddress
+import re
+import socket
+from urllib.parse import urlparse
+
+BLOCKED_HOSTNAMES = frozenset(
+ {
+ "metadata.google.internal",
+ "metadata.amazonaws.com",
+ }
+)
+
+DOI_PATTERN = re.compile(r"^10\.\d{4,9}/[-._;()/:A-Za-z0-9]+$")
+
+
+def validate_url_safe(url: str) -> str:
+ """Validate URL is safe for server-side fetch.
+
+ Blocks private IPs, loopback, link-local, reserved, multicast,
+ and known cloud metadata hostnames.
+
+ Raises ValueError if the URL is unsafe.
+ """
+ parsed = urlparse(url)
+ if parsed.scheme not in ("http", "https"):
+ raise ValueError(f"Unsupported scheme: {parsed.scheme}")
+
+ hostname = parsed.hostname
+ if not hostname:
+ raise ValueError("Invalid URL: no hostname")
+
+ if hostname in BLOCKED_HOSTNAMES:
+ raise ValueError(f"Blocked hostname: {hostname}")
+
+ try:
+ addrinfos = socket.getaddrinfo(hostname, None)
+ except socket.gaierror as e:
+ raise ValueError(f"DNS resolution failed for {hostname}: {e}") from e
+
+ for info in addrinfos:
+ ip_str = info[4][0]
+ try:
+ ip = ipaddress.ip_address(ip_str)
+ except ValueError:
+ continue
+ if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved or ip.is_multicast:
+ raise ValueError(f"Blocked: {hostname} resolves to private/reserved address {ip_str}")
+
+ return url
+
+
+def validate_doi(doi: str) -> str:
+ """Validate DOI format. Raises ValueError if invalid."""
+ if not DOI_PATTERN.match(doi):
+ raise ValueError(f"Invalid DOI format: {doi}")
+ return doi
diff --git a/backend/app/services/writing_service.py b/backend/app/services/writing_service.py
index 4ec1ef8..962d9fe 100644
--- a/backend/app/services/writing_service.py
+++ b/backend/app/services/writing_service.py
@@ -1,5 +1,6 @@
"""Writing assistance service — summarize, cite, outline, gap analysis, literature review."""
+import asyncio
import json
import logging
import re
@@ -8,9 +9,17 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
+from app.config import settings as app_settings
from app.models import Paper
-from app.services.llm_client import LLMClient
+from app.prompts.writing import (
+ WRITING_GAP_SYSTEM,
+ WRITING_OUTLINE_SYSTEM,
+ WRITING_SECTION_SYSTEM,
+ WRITING_SUMMARIZE_SYSTEM,
+)
+from app.services.llm.client import LLMClient
from app.services.rag_service import RAGService
+from app.utils.sse import format_sse_error
logger = logging.getLogger(__name__)
@@ -20,14 +29,6 @@
"thematic": "主题性综述 (thematic review):按研究主题分组对比,突出异同",
}
-SECTION_SYSTEM_PROMPT = """\
-你是一位学术综述写作专家。请为以下章节撰写综述段落。
-要求:
-1. 使用学术语言,逻辑清晰
-2. 在适当位置使用 [1][2] 格式引用
-3. 每个引用必须对应提供的文献,不得捏造
-4. 段落长度 200-400 字"""
-
class WritingService:
def __init__(self, db: AsyncSession, llm: LLMClient, rag: RAGService | None = None):
@@ -35,18 +36,15 @@ def __init__(self, db: AsyncSession, llm: LLMClient, rag: RAGService | None = No
self.llm = llm
self.rag = rag
+ _summarize_semaphore = asyncio.Semaphore(app_settings.llm_parallel_limit)
+
async def summarize_papers(self, paper_ids: list[int], language: str = "en") -> list[dict]:
- """Generate summaries for selected papers."""
+ """Generate summaries for selected papers (parallelized with semaphore)."""
stmt = select(Paper).where(Paper.id.in_(paper_ids))
result = await self.db.execute(stmt)
papers = {p.id: p for p in result.scalars().all()}
- summaries = []
- for paper_id in paper_ids:
- paper = papers.get(paper_id)
- if not paper:
- continue
-
+ async def _summarize_one(paper: Paper) -> dict:
prompt = f"""Summarize this scientific paper in {language}:
Title: {paper.title}
Abstract: {paper.abstract}
@@ -59,27 +57,21 @@ async def summarize_papers(self, paper_ids: list[int], language: str = "en") ->
3. Innovation points
4. Limitations (if apparent from abstract)"""
- summary = await self.llm.chat(
- messages=[
- {
- "role": "system",
- "content": "You are a scientific paper analyst. Provide concise, accurate summaries.",
- },
- {"role": "user", "content": prompt},
- ],
- temperature=0.3,
- task_type="summarize",
- )
+ async with self._summarize_semaphore:
+ summary = await self.llm.chat(
+ messages=[
+ {"role": "system", "content": WRITING_SUMMARIZE_SYSTEM},
+ {"role": "user", "content": prompt},
+ ],
+ temperature=0.3,
+ task_type="summarize",
+ )
- summaries.append(
- {
- "paper_id": paper.id,
- "title": paper.title,
- "summary": summary,
- }
- )
+ return {"paper_id": paper.id, "title": paper.title, "summary": summary}
- return summaries
+ tasks = [_summarize_one(papers[pid]) for pid in paper_ids if pid in papers]
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+ return [r for r in results if isinstance(r, dict)]
async def generate_citations(self, paper_ids: list[int], style: str = "gb_t_7714") -> list[dict]:
"""Generate formatted citations for papers."""
@@ -150,10 +142,7 @@ async def generate_review_outline(self, project_id: int, topic: str, language: s
outline = await self.llm.chat(
messages=[
- {
- "role": "system",
- "content": "You are a scientific writing expert. Generate well-structured review outlines.",
- },
+ {"role": "system", "content": WRITING_OUTLINE_SYSTEM},
{"role": "user", "content": prompt},
],
temperature=0.5,
@@ -188,10 +177,7 @@ async def analyze_gaps(self, project_id: int, research_topic: str) -> dict:
analysis = await self.llm.chat(
messages=[
- {
- "role": "system",
- "content": "You are a research gap analyst. Identify unexplored areas and innovation opportunities.",
- },
+ {"role": "system", "content": WRITING_GAP_SYSTEM},
{"role": "user", "content": prompt},
],
temperature=0.5,
@@ -222,7 +208,7 @@ async def generate_literature_review(
papers = result.scalars().all()
if not papers:
- yield _sse("error", {"message": "知识库中暂无文献,请先添加文献后再生成综述"})
+ yield format_sse_error("知识库中暂无文献,请先添加文献后再生成综述", code=400)
return
yield _sse("progress", {"step": "outline", "message": "正在生成综述提纲..."})
@@ -270,7 +256,7 @@ async def generate_literature_review(
async for chunk in self.llm.chat_stream(
messages=[
- {"role": "system", "content": SECTION_SYSTEM_PROMPT},
+ {"role": "system", "content": WRITING_SECTION_SYSTEM},
{"role": "user", "content": prompt},
],
temperature=0.5,
@@ -314,10 +300,7 @@ async def _generate_review_outline_for_draft(
return await self.llm.chat(
messages=[
- {
- "role": "system",
- "content": "You are a scientific writing expert. Generate well-structured review outlines.",
- },
+ {"role": "system", "content": WRITING_OUTLINE_SYSTEM},
{"role": "user", "content": prompt},
],
temperature=0.5,
diff --git a/backend/app/utils/sse.py b/backend/app/utils/sse.py
new file mode 100644
index 0000000..6b12b51
--- /dev/null
+++ b/backend/app/utils/sse.py
@@ -0,0 +1,11 @@
+"""SSE (Server-Sent Events) formatting utilities."""
+
+import json
+
+
+def format_sse_error(message: str, code: int = 500) -> str:
+ """Format a standardized SSE error event.
+
+ Unified format: event: error\\ndata: {"code": status_code, "message": error_msg}\\n\\n
+ """
+ return f"event: error\ndata: {json.dumps({'code': code, 'message': message})}\n\n"
diff --git a/backend/app/websocket/__init__.py b/backend/app/websocket/__init__.py
new file mode 100644
index 0000000..b997b06
--- /dev/null
+++ b/backend/app/websocket/__init__.py
@@ -0,0 +1,5 @@
+"""WebSocket connection management for real-time pipeline status updates."""
+
+from app.websocket.manager import PipelineConnectionManager, pipeline_manager
+
+__all__ = ["PipelineConnectionManager", "pipeline_manager"]
diff --git a/backend/app/websocket/manager.py b/backend/app/websocket/manager.py
new file mode 100644
index 0000000..69d3ae2
--- /dev/null
+++ b/backend/app/websocket/manager.py
@@ -0,0 +1,43 @@
+"""Room-based WebSocket connection manager for pipeline status broadcasts."""
+
+import logging
+from collections import defaultdict
+
+from fastapi import WebSocket
+
+logger = logging.getLogger(__name__)
+
+
+class PipelineConnectionManager:
+ """Manages WebSocket connections grouped by pipeline thread_id (room)."""
+
+ def __init__(self) -> None:
+ self.rooms: dict[str, set[WebSocket]] = defaultdict(set)
+
+ async def connect(self, websocket: WebSocket, thread_id: str) -> None:
+ await websocket.accept()
+ self.rooms[thread_id].add(websocket)
+ logger.debug("WS connected to room %s (%d clients)", thread_id, len(self.rooms[thread_id]))
+
+ def disconnect(self, websocket: WebSocket, thread_id: str) -> None:
+ if thread_id in self.rooms:
+ self.rooms[thread_id].discard(websocket)
+ if not self.rooms[thread_id]:
+ del self.rooms[thread_id]
+
+ async def broadcast_to_room(self, thread_id: str, message: dict) -> None:
+ if thread_id not in self.rooms:
+ return
+ dead: list[WebSocket] = []
+ for conn in list(self.rooms[thread_id]):
+ try:
+ await conn.send_json(message)
+ except Exception:
+ dead.append(conn)
+ for conn in dead:
+ self.rooms[thread_id].discard(conn)
+ if thread_id in self.rooms and not self.rooms[thread_id]:
+ del self.rooms[thread_id]
+
+
+pipeline_manager = PipelineConnectionManager()
diff --git a/backend/conftest.py b/backend/conftest.py
index 93c815f..1c09e03 100644
--- a/backend/conftest.py
+++ b/backend/conftest.py
@@ -3,6 +3,9 @@
import os
import tempfile
+import pytest
+from sqlalchemy import UniqueConstraint
+
_test_data_dir = tempfile.mkdtemp(prefix="omelette_test_")
_test_db_path = os.path.join(_test_data_dir, "test_omelette.db")
@@ -10,3 +13,61 @@
os.environ.setdefault("LLM_PROVIDER", "mock")
os.environ.setdefault("DATABASE_URL", f"sqlite:///{_test_db_path}")
os.environ.setdefault("DATA_DIR", _test_data_dir)
+
+REAL_LLM_AVAILABLE = os.environ.get("LLM_PROVIDER", "mock") != "mock"
+
+real_llm = pytest.mark.skipif(
+ not REAL_LLM_AVAILABLE,
+ reason="Real LLM not configured (set LLM_PROVIDER=volcengine)",
+)
+
+
+def remove_paper_doi_unique_constraint():
+ """Remove (project_id, doi) unique constraint so tests can insert duplicate DOIs for dedup."""
+ from app.database import Base
+
+ table = Base.metadata.tables.get("papers")
+ if table is not None:
+ for c in list(table.constraints):
+ if isinstance(c, UniqueConstraint) and getattr(c, "name", None) == "uq_paper_project_doi":
+ table.constraints.discard(c)
+ break
+
+
+# ---------------------------------------------------------------------------
+# Shared fixtures (for tests that need DB + HTTP client)
+# Tests with local fixtures of the same name will use their own (no override).
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture
+async def setup_db():
+ """Create tables before each test, drop after. Request explicitly or use local override."""
+ from app.database import Base, engine
+
+ remove_paper_doi_unique_constraint()
+ async with engine.begin() as conn:
+ await conn.run_sync(Base.metadata.create_all)
+ yield
+ async with engine.begin() as conn:
+ await conn.run_sync(Base.metadata.drop_all)
+
+
+@pytest.fixture
+async def client():
+ """Async HTTP client for in-process testing."""
+ from httpx import ASGITransport, AsyncClient
+
+ from app.main import app
+
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as ac:
+ yield ac
+
+
+@pytest.fixture
+async def project_id(client):
+ """Create a project and return its ID. Depends on client."""
+ resp = await client.post("/api/v1/projects", json={"name": "Test Project", "domain": "optics"})
+ assert resp.status_code == 201
+ return resp.json()["data"]["id"]
diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index 45d2c95..73551bb 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -19,7 +19,6 @@ dependencies = [
"pydantic-settings>=2.7.0",
"python-dotenv>=1.0.0",
"httpx>=0.28.0",
- "aiohttp>=3.11.0",
"chromadb>=0.6.0",
"openai>=1.60.0",
"pdfplumber>=0.11.0",
@@ -38,6 +37,7 @@ dependencies = [
"langchain-ollama>=0.3",
"llama-index-core>=0.12",
"llama-index-vector-stores-chroma>=0.4",
+ "huggingface-hub>=0.28",
"llama-index-embeddings-huggingface>=0.5",
"llama-index-embeddings-openai>=0.4",
"mcp>=1.26",
@@ -62,6 +62,8 @@ ocr = [
ml = [
"sentence-transformers>=4.0.0",
"torch>=2.6.0",
+ "transformers>=4.51.0",
+ "llama-index-postprocessor-sbert-rerank>=0.4.0",
]
[build-system]
@@ -97,6 +99,10 @@ indent-style = "space"
testpaths = ["tests"]
asyncio_mode = "auto"
addopts = "-v --tb=short"
+markers = [
+ "real_llm: marks tests requiring real LLM (deselect with -m 'not real_llm')",
+ "e2e: marks end-to-end tests requiring a live server (deselect with -m 'not e2e')",
+]
[tool.mypy]
python_version = "3.12"
diff --git a/backend/scripts/gpu_watchdog.py b/backend/scripts/gpu_watchdog.py
new file mode 100755
index 0000000..b8fc999
--- /dev/null
+++ b/backend/scripts/gpu_watchdog.py
@@ -0,0 +1,178 @@
+#!/usr/bin/env python3
+"""GPU resource watchdog — monitors Omelette and cleans up GPU resources on exit.
+
+Runs as an independent process. When the monitored Omelette process dies
+(including kill -9, OOM, crash), this script kills MinerU and clears GPU caches.
+
+Usage:
+ python scripts/gpu_watchdog.py # foreground
+ python scripts/gpu_watchdog.py --daemon # background (detach)
+ python scripts/gpu_watchdog.py --pid-file /path.pid # custom PID file
+"""
+
+from __future__ import annotations
+
+import argparse
+import logging
+import os
+import signal
+import subprocess
+import sys
+import time
+from pathlib import Path
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s | %(levelname)-8s | gpu_watchdog | %(message)s",
+)
+logger = logging.getLogger("gpu_watchdog")
+
+
+def pid_alive(pid: int) -> bool:
+ try:
+ os.kill(pid, 0)
+ return True
+ except (OSError, ProcessLookupError):
+ return False
+
+
+def find_pid_by_port(port: int) -> int | None:
+ """Find PID listening on a TCP port."""
+ try:
+ with open("/proc/net/tcp") as f:
+ hex_port = f":{port:04X}"
+ for line in f:
+ parts = line.strip().split()
+ if len(parts) >= 10 and hex_port in parts[1] and parts[3] == "0A":
+ inode = parts[9]
+ for pid_dir in os.listdir("/proc"):
+ if not pid_dir.isdigit():
+ continue
+ try:
+ fd_dir = f"/proc/{pid_dir}/fd"
+ for fd in os.listdir(fd_dir):
+ link = os.readlink(f"{fd_dir}/{fd}")
+ if f"socket:[{inode}]" in link:
+ return int(pid_dir)
+ except (OSError, PermissionError):
+ continue
+ except (OSError, PermissionError):
+ pass
+
+ try:
+ result = subprocess.run(
+ ["lsof", "-ti", f":{port}"],
+ capture_output=True,
+ text=True,
+ timeout=5,
+ )
+ if result.returncode == 0 and result.stdout.strip():
+ return int(result.stdout.strip().split("\n")[0])
+ except (subprocess.TimeoutExpired, ValueError, OSError, FileNotFoundError):
+ pass
+ return None
+
+
+def is_mineru_process(pid: int) -> bool:
+ try:
+ with open(f"/proc/{pid}/cmdline", "rb") as f:
+ return "mineru" in f.read().decode(errors="replace").lower()
+ except (OSError, PermissionError):
+ return False
+
+
+def kill_mineru(port: int) -> None:
+ pid = find_pid_by_port(port)
+ if pid and is_mineru_process(pid):
+ logger.info("Killing MinerU pid=%d on port %d", pid, port)
+ try:
+ os.kill(pid, signal.SIGTERM)
+ for _ in range(10):
+ time.sleep(1)
+ if not pid_alive(pid):
+ logger.info("MinerU pid=%d terminated", pid)
+ return
+ os.kill(pid, signal.SIGKILL)
+ logger.info("Force-killed MinerU pid=%d", pid)
+ except (OSError, ProcessLookupError):
+ pass
+ else:
+ logger.info("No MinerU found on port %d", port)
+
+
+def cleanup(pid_file: Path, mineru_port: int) -> None:
+ logger.info("Omelette process gone — running cleanup")
+ kill_mineru(mineru_port)
+ if pid_file.exists():
+ try:
+ pid_file.unlink()
+ logger.info("Removed PID file: %s", pid_file)
+ except OSError:
+ pass
+ logger.info("Cleanup complete")
+
+
+def wait_for_pid_file(pid_file: Path, timeout: int = 60) -> int | None:
+ """Wait for PID file to appear and return the PID."""
+ deadline = time.monotonic() + timeout
+ while time.monotonic() < deadline:
+ if pid_file.exists():
+ try:
+ pid = int(pid_file.read_text().strip())
+ if pid_alive(pid):
+ return pid
+ except (ValueError, OSError):
+ pass
+ time.sleep(2)
+ return None
+
+
+def daemonize() -> None:
+ """Double-fork to detach from terminal."""
+ if os.fork() > 0:
+ sys.exit(0)
+ os.setsid()
+ if os.fork() > 0:
+ sys.exit(0)
+ devnull_r = open(os.devnull) # noqa: SIM115
+ devnull_w = open(os.devnull, "w") # noqa: SIM115
+ sys.stdin = devnull_r
+ sys.stdout = devnull_w
+ sys.stderr = devnull_w
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="GPU watchdog for Omelette")
+ parser.add_argument("--pid-file", default="./data/omelette.pid", help="Path to PID file")
+ parser.add_argument("--interval", type=int, default=5, help="Check interval in seconds")
+ parser.add_argument("--mineru-port", type=int, default=8010, help="MinerU port")
+ parser.add_argument("--daemon", action="store_true", help="Run as daemon")
+ args = parser.parse_args()
+
+ pid_file = Path(args.pid_file)
+
+ if args.daemon:
+ daemonize()
+
+ logger.info(
+ "GPU watchdog started (pid_file=%s, interval=%ds, mineru_port=%d)", pid_file, args.interval, args.mineru_port
+ )
+
+ target_pid = wait_for_pid_file(pid_file, timeout=120)
+ if target_pid is None:
+ logger.warning("No Omelette process found within timeout, exiting")
+ return
+ logger.info("Monitoring Omelette pid=%d", target_pid)
+
+ try:
+ while True:
+ time.sleep(args.interval)
+ if not pid_alive(target_pid):
+ cleanup(pid_file, args.mineru_port)
+ return
+ except KeyboardInterrupt:
+ logger.info("Watchdog stopped by user")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/backend/tests/test_api_chat_rag_writing.py b/backend/tests/test_api_chat_rag_writing.py
new file mode 100644
index 0000000..8492c6a
--- /dev/null
+++ b/backend/tests/test_api_chat_rag_writing.py
@@ -0,0 +1,579 @@
+"""Comprehensive API tests for Chat, RAG, Writing, Completion, and Rewrite modules."""
+
+from __future__ import annotations
+
+import json
+
+import chromadb
+import pytest
+from httpx import ASGITransport, AsyncClient
+from llama_index.core.embeddings import MockEmbedding
+
+from app.api.v1.rag import get_rag_service
+from app.database import Base, async_session_factory, engine
+from app.main import app
+from app.models import Paper, PaperChunk, PaperStatus, Project
+from app.services.rag_service import RAGService
+
+MOCK_EMBED = MockEmbedding(embed_dim=128)
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture(autouse=True)
+async def setup_db():
+ async with engine.begin() as conn:
+ await conn.run_sync(Base.metadata.create_all)
+ yield
+ async with engine.begin() as conn:
+ await conn.run_sync(Base.metadata.drop_all)
+
+
+@pytest.fixture
+async def client():
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as ac:
+ yield ac
+
+
+@pytest.fixture
+def rag_service():
+ """RAGService with ephemeral ChromaDB and mock embedding for fast tests."""
+ return RAGService(
+ chroma_client=chromadb.EphemeralClient(),
+ embed_model=MOCK_EMBED,
+ )
+
+
+@pytest.fixture(autouse=True)
+def override_rag_dependency(rag_service):
+ """Override RAG dependency to use ephemeral ChromaDB."""
+ app.dependency_overrides[get_rag_service] = lambda: rag_service
+ yield
+ app.dependency_overrides.pop(get_rag_service, None)
+
+
+@pytest.fixture(autouse=True)
+def mock_chat_services(monkeypatch):
+ """Mock _init_services so Chat stream uses mock LLM/RAG without DB lookups."""
+ import app.api.v1.chat as chat_module
+ from app.services.llm.client import LLMClient
+
+ async def _mock_init_services(db):
+ from app.services.rag_service import RAGService
+
+ llm = LLMClient(provider="mock")
+ rag = RAGService(llm=llm, embed_model=MockEmbedding(embed_dim=128))
+ return {"llm": llm, "rag": rag}
+
+ monkeypatch.setattr(chat_module, "_init_services", _mock_init_services)
+
+
+@pytest.fixture
+async def project_with_chunks():
+ """Create a project with OCR-complete papers and chunks for RAG tests."""
+ async with async_session_factory() as session:
+ project = Project(name="RAG Test Project", domain="optics")
+ session.add(project)
+ await session.flush()
+
+ paper = Paper(
+ project_id=project.id,
+ title="Super-Resolution Microscopy Review",
+ abstract="A review of super-resolution techniques.",
+ journal="Nature",
+ year=2023,
+ status=PaperStatus.OCR_COMPLETE,
+ )
+ session.add(paper)
+ await session.flush()
+
+ chunk1 = PaperChunk(
+ paper_id=paper.id,
+ content="Super-resolution microscopy enables imaging beyond the diffraction limit.",
+ chunk_type="text",
+ page_number=1,
+ chunk_index=0,
+ )
+ chunk2 = PaperChunk(
+ paper_id=paper.id,
+ content="STED and STORM are two major techniques for nanoscale imaging.",
+ chunk_type="text",
+ page_number=2,
+ chunk_index=1,
+ )
+ session.add(chunk1)
+ session.add(chunk2)
+ await session.commit()
+ return project.id
+
+
+@pytest.fixture
+async def project_with_papers():
+ """Create a project with papers for writing tests."""
+ async with async_session_factory() as session:
+ project = Project(name="Writing Test Project", domain="optics")
+ session.add(project)
+ await session.flush()
+
+ paper1 = Paper(
+ project_id=project.id,
+ title="Super-Resolution Microscopy",
+ abstract="A comprehensive review of super-resolution techniques.",
+ journal="Nature",
+ year=2023,
+ authors=[{"name": "Alice Smith"}, {"name": "Bob Jones"}],
+ citation_count=100,
+ status=PaperStatus.INDEXED,
+ )
+ paper2 = Paper(
+ project_id=project.id,
+ title="STED Imaging Methods",
+ abstract="Stimulated emission depletion microscopy for nanoscale imaging.",
+ journal="Science",
+ year=2022,
+ authors=[{"name": "Carol Lee"}],
+ doi="10.1234/test",
+ citation_count=50,
+ status=PaperStatus.INDEXED,
+ )
+ session.add(paper1)
+ session.add(paper2)
+ await session.flush()
+ paper_ids = [paper1.id, paper2.id]
+ await session.commit()
+ return project.id, paper_ids
+
+
+# ---------------------------------------------------------------------------
+# Chat API tests
+# ---------------------------------------------------------------------------
+
+
+class TestChatStream:
+ """Tests for POST /api/v1/chat/stream (SSE)."""
+
+ @pytest.mark.asyncio
+ async def test_stream_returns_sse(self, client: AsyncClient):
+ resp = await client.post(
+ "/api/v1/chat/stream",
+ json={"message": "Hello", "knowledge_base_ids": []},
+ )
+ assert resp.status_code == 200
+ assert resp.headers["content-type"].startswith("text/event-stream")
+
+ text = resp.text
+ lines = [line for line in text.split("\n") if line.startswith("data: ")]
+
+ event_types = []
+ for line in lines:
+ payload = line.removeprefix("data: ").strip()
+ if payload == "[DONE]":
+ event_types.append("[DONE]")
+ continue
+ try:
+ parsed = json.loads(payload)
+ event_types.append(parsed.get("type", "unknown"))
+ except json.JSONDecodeError:
+ pass
+
+ assert "start" in event_types
+ assert "text-delta" in event_types
+ assert "finish" in event_types
+ assert "[DONE]" in event_types
+
+ @pytest.mark.asyncio
+ async def test_stream_with_rag_top_k_and_use_reranker(self, client: AsyncClient):
+ """Chat stream accepts rag_top_k (1-50) and use_reranker."""
+ resp = await client.post(
+ "/api/v1/chat/stream",
+ json={
+ "message": "What is super-resolution?",
+ "knowledge_base_ids": [1],
+ "rag_top_k": 15,
+ "use_reranker": True,
+ },
+ )
+ assert resp.status_code == 200
+ assert resp.headers["content-type"].startswith("text/event-stream")
+ assert "data:" in resp.text
+
+ @pytest.mark.asyncio
+ async def test_stream_rag_top_k_validation_min_fails(self, client: AsyncClient):
+ """rag_top_k=0 should fail validation."""
+ resp = await client.post(
+ "/api/v1/chat/stream",
+ json={"message": "Hello", "knowledge_base_ids": [], "rag_top_k": 0},
+ )
+ assert resp.status_code == 422
+
+ @pytest.mark.asyncio
+ async def test_stream_rag_top_k_validation_max_fails(self, client: AsyncClient):
+ """rag_top_k=51 should fail validation."""
+ resp = await client.post(
+ "/api/v1/chat/stream",
+ json={"message": "Hello", "knowledge_base_ids": [], "rag_top_k": 51},
+ )
+ assert resp.status_code == 422
+
+ @pytest.mark.asyncio
+ async def test_stream_message_required(self, client: AsyncClient):
+ resp = await client.post(
+ "/api/v1/chat/stream",
+ json={"message": "", "knowledge_base_ids": []},
+ )
+ assert resp.status_code == 422
+
+
+class TestChatComplete:
+ """Tests for POST /api/v1/chat/complete (Completion)."""
+
+ @pytest.mark.asyncio
+ async def test_complete_success(self, client: AsyncClient):
+ resp = await client.post(
+ "/api/v1/chat/complete",
+ json={
+ "prefix": "深度学习在自然语言处理领域",
+ "knowledge_base_ids": [],
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()["data"]
+ assert "completion" in data
+ assert "confidence" in data
+
+ @pytest.mark.asyncio
+ async def test_complete_prefix_too_short_fails(self, client: AsyncClient):
+ resp = await client.post(
+ "/api/v1/chat/complete",
+ json={"prefix": "short"},
+ )
+ assert resp.status_code == 422
+
+
+# ---------------------------------------------------------------------------
+# RAG API tests
+# ---------------------------------------------------------------------------
+
+
+class TestRAGQuery:
+ """Tests for POST /api/v1/projects/{project_id}/rag/query."""
+
+ @pytest.mark.asyncio
+ async def test_query_empty_index(self, client: AsyncClient, project_with_chunks: int):
+ resp = await client.post(
+ f"/api/v1/projects/{project_with_chunks}/rag/query",
+ json={"question": "What is super-resolution?", "top_k": 5},
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["code"] == 200
+ assert "answer" in body["data"]
+ assert "sources" in body["data"]
+ assert "confidence" in body["data"]
+
+ @pytest.mark.asyncio
+ async def test_query_with_use_reranker(self, client: AsyncClient, project_with_chunks: int):
+ resp = await client.post(
+ f"/api/v1/projects/{project_with_chunks}/rag/query",
+ json={
+ "question": "What is super-resolution?",
+ "top_k": 5,
+ "use_reranker": True,
+ },
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ assert "answer" in body["data"]
+
+ @pytest.mark.asyncio
+ async def test_query_without_reranker(self, client: AsyncClient, project_with_chunks: int):
+ resp = await client.post(
+ f"/api/v1/projects/{project_with_chunks}/rag/query",
+ json={
+ "question": "What is super-resolution?",
+ "top_k": 5,
+ "use_reranker": False,
+ },
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ assert "answer" in body["data"]
+
+ @pytest.mark.asyncio
+ async def test_query_top_k_validation_min_fails(self, client: AsyncClient, project_with_chunks: int):
+ """top_k=0 should fail validation."""
+ resp = await client.post(
+ f"/api/v1/projects/{project_with_chunks}/rag/query",
+ json={"question": "test", "top_k": 0},
+ )
+ assert resp.status_code == 422
+
+ @pytest.mark.asyncio
+ async def test_query_top_k_validation_max_fails(self, client: AsyncClient, project_with_chunks: int):
+ """top_k=51 should fail validation."""
+ resp = await client.post(
+ f"/api/v1/projects/{project_with_chunks}/rag/query",
+ json={"question": "test", "top_k": 51},
+ )
+ assert resp.status_code == 422
+
+ @pytest.mark.asyncio
+ async def test_query_after_index(self, client: AsyncClient, project_with_chunks: int):
+ await client.post(f"/api/v1/projects/{project_with_chunks}/rag/index")
+
+ resp = await client.post(
+ f"/api/v1/projects/{project_with_chunks}/rag/query",
+ json={"question": "What is super-resolution microscopy?", "top_k": 5},
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["code"] == 200
+ assert "answer" in body["data"]
+ assert "sources" in body["data"]
+ assert "confidence" in body["data"]
+
+
+class TestRAGIndex:
+ """Tests for POST /api/v1/projects/{project_id}/rag/index."""
+
+ @pytest.mark.asyncio
+ async def test_build_index(self, client: AsyncClient, project_with_chunks: int):
+ resp = await client.post(f"/api/v1/projects/{project_with_chunks}/rag/index")
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["code"] == 200
+ assert "indexed" in body["data"]
+ assert body["data"]["indexed"] >= 0
+
+
+class TestRAGIndexStream:
+ """Tests for POST /api/v1/projects/{project_id}/rag/index/stream (SSE)."""
+
+ @pytest.mark.asyncio
+ async def test_index_stream_returns_sse(self, client: AsyncClient, project_with_chunks: int):
+ resp = await client.post(f"/api/v1/projects/{project_with_chunks}/rag/index/stream")
+ assert resp.status_code == 200
+ assert resp.headers.get("content-type", "").startswith("text/event-stream")
+
+ text = resp.text
+ assert "event:" in text
+ assert "data:" in text
+
+ # Should have progress and complete events
+ lines = text.split("\n")
+ event_lines = [line for line in lines if line.startswith("event:")]
+ assert len(event_lines) >= 1
+
+
+class TestRAGStats:
+ """Tests for GET /api/v1/projects/{project_id}/rag/stats."""
+
+ @pytest.mark.asyncio
+ async def test_stats(self, client: AsyncClient, project_with_chunks: int):
+ resp = await client.get(f"/api/v1/projects/{project_with_chunks}/rag/stats")
+ assert resp.status_code == 200
+ body = resp.json()
+ assert "total_chunks" in body["data"]
+ assert "collection_name" in body["data"]
+
+
+class TestRAGDeleteIndex:
+ """Tests for DELETE /api/v1/projects/{project_id}/rag/index."""
+
+ @pytest.mark.asyncio
+ async def test_delete_index(self, client: AsyncClient, project_with_chunks: int):
+ resp = await client.delete(f"/api/v1/projects/{project_with_chunks}/rag/index")
+ assert resp.status_code == 200
+ body = resp.json()
+ assert "deleted" in body["data"]
+
+
+# ---------------------------------------------------------------------------
+# Writing API tests
+# ---------------------------------------------------------------------------
+
+
+class TestWritingSummarize:
+ """Tests for POST /api/v1/projects/{project_id}/writing/summarize."""
+
+ @pytest.mark.asyncio
+ async def test_summarize(self, client: AsyncClient, project_with_papers):
+ project_id, paper_ids = project_with_papers
+ resp = await client.post(
+ f"/api/v1/projects/{project_id}/writing/summarize",
+ json={"paper_ids": paper_ids, "language": "en"},
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["code"] == 200
+ assert "summaries" in body["data"]
+ assert len(body["data"]["summaries"]) == 2
+
+
+class TestWritingCitations:
+ """Tests for POST /api/v1/projects/{project_id}/writing/citations."""
+
+ @pytest.mark.asyncio
+ async def test_citations(self, client: AsyncClient, project_with_papers):
+ project_id, paper_ids = project_with_papers
+ resp = await client.post(
+ f"/api/v1/projects/{project_id}/writing/citations",
+ json={"paper_ids": paper_ids, "style": "gb_t_7714"},
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["code"] == 200
+ assert "citations" in body["data"]
+ assert body["data"]["style"] == "gb_t_7714"
+ assert len(body["data"]["citations"]) == 2
+
+
+class TestWritingReviewOutline:
+ """Tests for POST /api/v1/projects/{project_id}/writing/review-outline."""
+
+ @pytest.mark.asyncio
+ async def test_review_outline(self, client: AsyncClient, project_with_papers):
+ project_id, _ = project_with_papers
+ resp = await client.post(
+ f"/api/v1/projects/{project_id}/writing/review-outline",
+ json={"topic": "Super-resolution imaging", "language": "en"},
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["code"] == 200
+ assert "outline" in body["data"]
+ assert "paper_count" in body["data"]
+
+
+class TestWritingGapAnalysis:
+ """Tests for POST /api/v1/projects/{project_id}/writing/gap-analysis."""
+
+ @pytest.mark.asyncio
+ async def test_gap_analysis(self, client: AsyncClient, project_with_papers):
+ project_id, _ = project_with_papers
+ resp = await client.post(
+ f"/api/v1/projects/{project_id}/writing/gap-analysis",
+ json={"research_topic": "Nanoscale microscopy"},
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["code"] == 200
+ assert "analysis" in body["data"]
+ assert "papers_analyzed" in body["data"]
+
+
+class TestWritingReviewDraftStream:
+ """Tests for POST /api/v1/projects/{project_id}/writing/review-draft/stream (SSE)."""
+
+ @pytest.mark.asyncio
+ async def test_review_draft_stream_returns_sse(self, client: AsyncClient, project_with_papers):
+ project_id, _ = project_with_papers
+ resp = await client.post(
+ f"/api/v1/projects/{project_id}/writing/review-draft/stream",
+ json={
+ "topic": "Super-resolution microscopy",
+ "style": "narrative",
+ "citation_format": "numbered",
+ "language": "en",
+ },
+ )
+ assert resp.status_code == 200
+ assert resp.headers.get("content-type", "").startswith("text/event-stream")
+
+ text = resp.text
+ assert "event:" in text
+ assert "data:" in text
+
+ @pytest.mark.asyncio
+ async def test_review_draft_stream_invalid_style_fails(self, client: AsyncClient, project_with_papers):
+ project_id, _ = project_with_papers
+ resp = await client.post(
+ f"/api/v1/projects/{project_id}/writing/review-draft/stream",
+ json={"topic": "test", "style": "invalid_style"},
+ )
+ assert resp.status_code == 422
+
+
+# ---------------------------------------------------------------------------
+# Rewrite API tests (POST /api/v1/chat/rewrite)
+# ---------------------------------------------------------------------------
+
+
+class TestRewrite:
+ """Tests for POST /api/v1/chat/rewrite (SSE)."""
+
+ @pytest.mark.asyncio
+ async def test_rewrite_stream_returns_sse(self, client: AsyncClient):
+ resp = await client.post(
+ "/api/v1/chat/rewrite",
+ json={
+ "excerpt": "This is a sample excerpt to simplify for testing.",
+ "style": "simplify",
+ },
+ )
+ assert resp.status_code == 200
+ assert resp.headers.get("content-type", "").startswith("text/event-stream")
+
+ text = resp.text
+ assert "event:" in text
+ assert "data:" in text
+
+ # Parse SSE events
+ lines = text.split("\n")
+ event_types = []
+ for line in lines:
+ if line.startswith("event:"):
+ event_types.append(line.replace("event:", "").strip())
+
+ assert "rewrite_delta" in event_types or "rewrite_end" in event_types or "error" in event_types
+
+ @pytest.mark.asyncio
+ async def test_rewrite_academic_style(self, client: AsyncClient):
+ resp = await client.post(
+ "/api/v1/chat/rewrite",
+ json={
+ "excerpt": "This is a simple sentence.",
+ "style": "academic",
+ },
+ )
+ assert resp.status_code == 200
+ assert "data:" in resp.text
+
+ @pytest.mark.asyncio
+ async def test_rewrite_excerpt_too_long_fails(self, client: AsyncClient):
+ resp = await client.post(
+ "/api/v1/chat/rewrite",
+ json={
+ "excerpt": "x" * 2001,
+ "style": "simplify",
+ },
+ )
+ assert resp.status_code == 422
+
+ @pytest.mark.asyncio
+ async def test_rewrite_custom_requires_prompt(self, client: AsyncClient):
+ resp = await client.post(
+ "/api/v1/chat/rewrite",
+ json={
+ "excerpt": "Sample text",
+ "style": "custom",
+ },
+ )
+ assert resp.status_code == 422
+
+ @pytest.mark.asyncio
+ async def test_rewrite_custom_with_prompt(self, client: AsyncClient):
+ resp = await client.post(
+ "/api/v1/chat/rewrite",
+ json={
+ "excerpt": "Sample text to rewrite.",
+ "style": "custom",
+ "custom_prompt": "Rewrite this in a formal tone.",
+ },
+ )
+ assert resp.status_code == 200
+ assert "data:" in resp.text
diff --git a/backend/tests/test_api_convos_subs_tasks_settings.py b/backend/tests/test_api_convos_subs_tasks_settings.py
new file mode 100644
index 0000000..03153f3
--- /dev/null
+++ b/backend/tests/test_api_convos_subs_tasks_settings.py
@@ -0,0 +1,669 @@
+"""Comprehensive API tests for Conversations, Subscriptions, Tasks, Settings, and Pipelines."""
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from httpx import ASGITransport, AsyncClient
+
+from app.database import Base, async_session_factory, engine
+from app.main import app
+from app.models import Message, Project, Task
+
+
+@pytest.fixture(autouse=True)
+async def setup_db():
+ async with engine.begin() as conn:
+ await conn.run_sync(Base.metadata.create_all)
+ yield
+ async with engine.begin() as conn:
+ await conn.run_sync(Base.metadata.drop_all)
+
+
+@pytest.fixture
+async def client():
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as ac:
+ yield ac
+
+
+@pytest.fixture
+async def project(setup_db):
+ async with async_session_factory() as db:
+ p = Project(name="Test Project", description="For API tests")
+ db.add(p)
+ await db.commit()
+ await db.refresh(p)
+ return p
+
+
+# ── Conversations ──
+
+
+class TestConversationsAPI:
+ """Tests for /api/v1/conversations."""
+
+ @pytest.mark.asyncio
+ async def test_list_conversations_empty(self, client):
+ resp = await client.get("/api/v1/conversations")
+ assert resp.status_code == 200
+ data = resp.json()["data"]
+ assert data["items"] == []
+ assert data["total"] == 0
+ assert "page" in data
+ assert "page_size" in data
+
+ @pytest.mark.asyncio
+ async def test_list_conversations_paginated(self, client):
+ for i in range(5):
+ await client.post("/api/v1/conversations", json={"title": f"Conv {i}"})
+ resp = await client.get("/api/v1/conversations", params={"page": 1, "page_size": 2})
+ assert resp.status_code == 200
+ data = resp.json()["data"]
+ assert len(data["items"]) == 2
+ assert data["total"] == 5
+ assert data["page"] == 1
+ assert data["page_size"] == 2
+ assert data["total_pages"] == 3
+
+ @pytest.mark.asyncio
+ async def test_list_conversations_filter_by_knowledge_base_id(self, client):
+ await client.post(
+ "/api/v1/conversations",
+ json={"title": "KB1", "knowledge_base_ids": [1, 2]},
+ )
+ await client.post(
+ "/api/v1/conversations",
+ json={"title": "KB2", "knowledge_base_ids": [3, 4]},
+ )
+ resp = await client.get("/api/v1/conversations", params={"knowledge_base_id": 1})
+ assert resp.status_code == 200
+ data = resp.json()["data"]
+ assert data["total"] == 1
+ assert data["items"][0]["knowledge_base_ids"] == [1, 2]
+
+ @pytest.mark.asyncio
+ async def test_create_conversation(self, client):
+ resp = await client.post(
+ "/api/v1/conversations",
+ json={
+ "title": "New Chat",
+ "knowledge_base_ids": [1, 2],
+ "model": "gpt-4o",
+ "tool_mode": "citation_lookup",
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()["data"]
+ assert data["title"] == "New Chat"
+ assert data["knowledge_base_ids"] == [1, 2]
+ assert data["model"] == "gpt-4o"
+ assert data["tool_mode"] == "citation_lookup"
+ assert data["messages"] == []
+ assert "id" in data
+ assert "created_at" in data
+
+ @pytest.mark.asyncio
+ async def test_create_conversation_default_title(self, client):
+ resp = await client.post("/api/v1/conversations", json={})
+ assert resp.status_code == 200
+ assert resp.json()["data"]["title"] == "新对话"
+
+ @pytest.mark.asyncio
+ async def test_get_conversation_with_messages(self, client):
+ create_resp = await client.post(
+ "/api/v1/conversations",
+ json={"title": "With Messages", "knowledge_base_ids": [1]},
+ )
+ conv_id = create_resp.json()["data"]["id"]
+ async with async_session_factory() as db:
+ msg = Message(
+ conversation_id=conv_id,
+ role="user",
+ content="Hello",
+ )
+ db.add(msg)
+ await db.commit()
+
+ resp = await client.get(f"/api/v1/conversations/{conv_id}")
+ assert resp.status_code == 200
+ data = resp.json()["data"]
+ assert data["title"] == "With Messages"
+ assert len(data["messages"]) == 1
+ assert data["messages"][0]["role"] == "user"
+ assert data["messages"][0]["content"] == "Hello"
+
+ @pytest.mark.asyncio
+ async def test_get_conversation_not_found(self, client):
+ resp = await client.get("/api/v1/conversations/99999")
+ assert resp.status_code == 404
+
+ @pytest.mark.asyncio
+ async def test_update_conversation(self, client):
+ create_resp = await client.post(
+ "/api/v1/conversations",
+ json={"title": "Old", "tool_mode": "qa"},
+ )
+ conv_id = create_resp.json()["data"]["id"]
+ resp = await client.put(
+ f"/api/v1/conversations/{conv_id}",
+ json={"title": "Updated", "tool_mode": "review_outline"},
+ )
+ assert resp.status_code == 200
+ data = resp.json()["data"]
+ assert data["title"] == "Updated"
+ assert data["tool_mode"] == "review_outline"
+
+ @pytest.mark.asyncio
+ async def test_update_conversation_not_found(self, client):
+ resp = await client.put(
+ "/api/v1/conversations/99999",
+ json={"title": "X"},
+ )
+ assert resp.status_code == 404
+
+ @pytest.mark.asyncio
+ async def test_delete_conversation(self, client):
+ create_resp = await client.post(
+ "/api/v1/conversations",
+ json={"title": "To Delete"},
+ )
+ conv_id = create_resp.json()["data"]["id"]
+ resp = await client.delete(f"/api/v1/conversations/{conv_id}")
+ assert resp.status_code == 200
+ assert resp.json()["data"]["deleted"] is True
+ assert resp.json()["data"]["id"] == conv_id
+
+ resp2 = await client.get(f"/api/v1/conversations/{conv_id}")
+ assert resp2.status_code == 404
+
+ @pytest.mark.asyncio
+ async def test_delete_conversation_not_found(self, client):
+ resp = await client.delete("/api/v1/conversations/99999")
+ assert resp.status_code == 404
+
+
+# ── Subscriptions ──
+
+
+class TestSubscriptionsAPI:
+ """Tests for /api/v1/projects/{project_id}/subscriptions."""
+
+ @pytest.mark.asyncio
+ async def test_list_subscriptions_empty(self, client, project):
+ resp = await client.get(f"/api/v1/projects/{project.id}/subscriptions")
+ assert resp.status_code == 200
+ data = resp.json()["data"]
+ assert data["items"] == []
+ assert data["total"] == 0
+
+ @pytest.mark.asyncio
+ async def test_create_subscription_api_type(self, client, project):
+ resp = await client.post(
+ f"/api/v1/projects/{project.id}/subscriptions",
+ json={
+ "name": "API Sub",
+ "query": "machine learning",
+ "sources": ["semantic_scholar", "arxiv"],
+ "frequency": "weekly",
+ "max_results": 50,
+ },
+ )
+ assert resp.status_code == 201
+ data = resp.json()["data"]
+ assert data["name"] == "API Sub"
+ assert data["query"] == "machine learning"
+ assert data["sources"] == ["semantic_scholar", "arxiv"]
+ assert data["frequency"] == "weekly"
+ assert data["max_results"] == 50
+ assert data["project_id"] == project.id
+ assert data["is_active"] is True
+
+ @pytest.mark.asyncio
+ async def test_create_subscription_minimal(self, client, project):
+ resp = await client.post(
+ f"/api/v1/projects/{project.id}/subscriptions",
+ json={"name": "Minimal Sub"},
+ )
+ assert resp.status_code == 201
+ data = resp.json()["data"]
+ assert data["name"] == "Minimal Sub"
+ assert data["query"] == ""
+ assert data["sources"] == []
+ assert data["frequency"] == "weekly"
+ assert data["max_results"] == 50
+
+ @pytest.mark.asyncio
+ async def test_get_subscription(self, client, project):
+ create_resp = await client.post(
+ f"/api/v1/projects/{project.id}/subscriptions",
+ json={"name": "Get Me"},
+ )
+ sub_id = create_resp.json()["data"]["id"]
+ resp = await client.get(f"/api/v1/projects/{project.id}/subscriptions/{sub_id}")
+ assert resp.status_code == 200
+ assert resp.json()["data"]["name"] == "Get Me"
+
+ @pytest.mark.asyncio
+ async def test_get_subscription_not_found(self, client, project):
+ resp = await client.get(f"/api/v1/projects/{project.id}/subscriptions/99999")
+ assert resp.status_code == 404
+
+ @pytest.mark.asyncio
+ async def test_get_subscription_wrong_project(self, client, project):
+ create_resp = await client.post(
+ f"/api/v1/projects/{project.id}/subscriptions",
+ json={"name": "Sub"},
+ )
+ sub_id = create_resp.json()["data"]["id"]
+ resp = await client.get(f"/api/v1/projects/99999/subscriptions/{sub_id}")
+ assert resp.status_code == 404
+
+ @pytest.mark.asyncio
+ async def test_update_subscription(self, client, project):
+ create_resp = await client.post(
+ f"/api/v1/projects/{project.id}/subscriptions",
+ json={"name": "Old", "query": "old query"},
+ )
+ sub_id = create_resp.json()["data"]["id"]
+ resp = await client.put(
+ f"/api/v1/projects/{project.id}/subscriptions/{sub_id}",
+ json={"name": "New Name", "query": "new query", "is_active": False},
+ )
+ assert resp.status_code == 200
+ data = resp.json()["data"]
+ assert data["name"] == "New Name"
+ assert data["query"] == "new query"
+ assert data["is_active"] is False
+
+ @pytest.mark.asyncio
+ async def test_delete_subscription(self, client, project):
+ create_resp = await client.post(
+ f"/api/v1/projects/{project.id}/subscriptions",
+ json={"name": "To Delete"},
+ )
+ sub_id = create_resp.json()["data"]["id"]
+ resp = await client.delete(f"/api/v1/projects/{project.id}/subscriptions/{sub_id}")
+ assert resp.status_code == 200
+
+ resp2 = await client.get(f"/api/v1/projects/{project.id}/subscriptions/{sub_id}")
+ assert resp2.status_code == 404
+
+ @pytest.mark.asyncio
+ async def test_trigger_subscription(self, client, project):
+ create_resp = await client.post(
+ f"/api/v1/projects/{project.id}/subscriptions",
+ json={"name": "Trigger Sub", "query": "test", "max_results": 10},
+ )
+ sub_id = create_resp.json()["data"]["id"]
+ resp = await client.post(
+ f"/api/v1/projects/{project.id}/subscriptions/{sub_id}/trigger",
+ params={"since_days": 7},
+ )
+ assert resp.status_code == 200
+ data = resp.json()["data"]
+ assert "new_papers" in data
+ assert "total_checked" in data
+ assert "sources_searched" in data
+
+ @pytest.mark.asyncio
+ async def test_check_rss(self, client, project):
+ mock_rss = """
+
Crossref abstract
", + } + } + + async def mock_get(*args, **kwargs): + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = crossref_response + return resp + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.get = AsyncMock(side_effect=mock_get) + + with ( + patch("app.services.pdf_metadata.fitz.open", return_value=mock_doc), + patch("app.services.pdf_metadata.httpx.AsyncClient", return_value=mock_client), + ): + result = await pdf_metadata.extract_metadata(pdf_path, fallback_title="Untitled") + + assert result.title == "Crossref Title" + assert result.authors == [{"name": "Crossref Author"}] + assert result.journal == "Crossref Journal" + assert result.year == 2023 + assert result.abstract == "Crossref abstract" + assert result.pdf_path == str(pdf_path) + + +# --------------------------------------------------------------------------- +# test_lookup_crossref_failure +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_lookup_crossref_failure(tmp_path): + """Mock httpx raising; should fallback to local metadata.""" + pdf_path = tmp_path / "crossref_fail.pdf" + pdf_path.write_bytes(b"%PDF-1.4 minimal") + + mock_doc = MagicMock() + mock_doc.metadata = { + "title": "Local Only Title", + "author": "Local Author", + "subject": "10.9999/crossref-fail", + "creationDate": "", + } + mock_doc.page_count = 1 + mock_doc.__iter__ = lambda self: iter([]) + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.get = AsyncMock(side_effect=Exception("Network error")) + + with ( + patch("app.services.pdf_metadata.fitz.open", return_value=mock_doc), + patch("app.services.pdf_metadata.httpx.AsyncClient", return_value=mock_client), + ): + result = await pdf_metadata.extract_metadata(pdf_path, fallback_title="Untitled") + + assert result.title == "Local Only Title" + assert result.authors == [{"name": "Local Author"}] + assert result.doi == "10.9999/crossref-fail" + assert result.pdf_path == str(pdf_path) diff --git a/backend/tests/test_pipeline_real_pdf.py b/backend/tests/test_pipeline_real_pdf.py new file mode 100644 index 0000000..bd4387d --- /dev/null +++ b/backend/tests/test_pipeline_real_pdf.py @@ -0,0 +1,218 @@ +"""Pipeline integration tests with real PDF files. + +Requires test PDFs at /data0/djx/omelette_pdf_test/ (skipped otherwise). +These tests exercise the upload pipeline with real metadata extraction. +""" + +import os +from pathlib import Path + +import pytest +from httpx import ASGITransport, AsyncClient +from langgraph.checkpoint.memory import MemorySaver +from langgraph.types import Command + +from app.database import Base, async_session_factory, engine +from app.models import Paper, Project +from app.pipelines.graphs import create_upload_pipeline +from app.pipelines.state import PipelineState + +PDF_TEST_DIR = os.environ.get("E2E_PDF_DIR", "/data0/djx/omelette_pdf_test") +PDF_DIR_EXISTS = os.path.isdir(PDF_TEST_DIR) + +pytestmark = pytest.mark.skipif(not PDF_DIR_EXISTS, reason=f"Test PDF directory not available: {PDF_TEST_DIR}") + + +def _smallest_pdf() -> str: + """Find the smallest PDF in the test directory.""" + pdfs = sorted(Path(PDF_TEST_DIR).glob("*.pdf"), key=lambda p: p.stat().st_size) + if not pdfs: + pytest.skip("No PDFs found in test directory") + return str(pdfs[0]) + + +@pytest.fixture(autouse=True) +async def setup_db(): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +@pytest.fixture +async def project(): + async with async_session_factory() as db: + p = Project(name="pdf-test-kb", description="for real PDF testing") + db.add(p) + await db.commit() + await db.refresh(p) + return p + + +@pytest.fixture +def test_client(): + from app.main import app + + return AsyncClient(transport=ASGITransport(app=app), base_url="http://test") + + +# ── Upload pipeline with real PDF ── + + +async def test_upload_pipeline_real_pdf(project): + """Upload pipeline with a real PDF should extract metadata and import the paper.""" + pdf_path = _smallest_pdf() + + saver = MemorySaver() + graph = create_upload_pipeline(checkpointer=saver) + + initial: PipelineState = { + "project_id": project.id, + "thread_id": "test_upload_real", + "pipeline_type": "upload", + "params": {"pdf_paths": [pdf_path]}, + "papers": [], + "conflicts": [], + "resolved_conflicts": [], + "progress": 0, + "total": 100, + "stage": "starting", + "error": None, + "cancelled": False, + "result": {}, + } + + config = {"configurable": {"thread_id": "test_upload_real"}} + result = await graph.ainvoke(initial, config=config) + + assert result["progress"] == 100 + assert result.get("error") is None + + papers = result.get("papers", []) + assert len(papers) >= 1 + first = papers[0] + assert first.get("title"), "Extracted title should not be empty" + + +# ── HITL interrupt → resume flow ── + + +async def test_upload_hitl_interrupt_resume(project): + """When an uploaded PDF has the same title as an existing paper, + the dedup node should trigger HITL. Resuming with 'skip' should complete.""" + pdf_path = _smallest_pdf() + + from app.services.pdf_metadata import extract_metadata + + meta = await extract_metadata(Path(pdf_path), fallback_title="fallback") + + async with async_session_factory() as db: + existing = Paper( + project_id=project.id, + title=meta.title, + doi=meta.doi or "", + source="manual", + ) + db.add(existing) + await db.commit() + + saver = MemorySaver() + graph = create_upload_pipeline(checkpointer=saver) + + initial: PipelineState = { + "project_id": project.id, + "thread_id": "test_hitl_resume", + "pipeline_type": "upload", + "params": {"pdf_paths": [pdf_path]}, + "papers": [], + "conflicts": [], + "resolved_conflicts": [], + "progress": 0, + "total": 100, + "stage": "starting", + "error": None, + "cancelled": False, + "result": {}, + } + + config = {"configurable": {"thread_id": "test_hitl_resume"}} + await graph.ainvoke(initial, config=config) + + snapshot = graph.get_state(config) + assert "hitl_dedup" in snapshot.next, f"Expected HITL interrupt, got {snapshot.next}" + conflicts = snapshot.values.get("conflicts", []) + assert len(conflicts) >= 1 + + result = await graph.ainvoke( + Command(resume=[{"action": "skip", "new_paper": {}}]), + config=config, + ) + assert result["progress"] == 100 + assert result["stage"] in ("index", "import") + + +# ── Pipeline API path safety ── + + +async def test_pipeline_path_traversal_rejected(test_client, project): + """Paths outside pdf_dir should be rejected with 400.""" + resp = await test_client.post( + "/api/v1/pipelines/upload", + json={ + "project_id": project.id, + "pdf_paths": ["/etc/passwd"], + }, + ) + assert resp.status_code == 400 + assert "not within allowed directory" in resp.json().get("message", "") + + +async def test_pipeline_path_dot_dot_rejected(test_client, project): + """Paths with '..' that resolve outside pdf_dir should be rejected.""" + resp = await test_client.post( + "/api/v1/pipelines/upload", + json={ + "project_id": project.id, + "pdf_paths": [f"{PDF_TEST_DIR}/../../etc/passwd"], + }, + ) + assert resp.status_code == 400 + + +# ── Pipeline list endpoint ── + + +async def test_pipeline_list_includes_started(test_client, project, monkeypatch): + """After starting a pipeline, GET /pipelines should list it.""" + from app.api.v1 import pipelines + + pipelines._running_tasks.clear() + + async def mock_search(self, query="", sources=None, max_results=100): + return {"papers": [], "total": 0} + + from app.services import search_service + + monkeypatch.setattr(search_service.SearchService, "search", mock_search) + + resp = await test_client.post( + "/api/v1/pipelines/search", + json={ + "project_id": project.id, + "query": "test", + "max_results": 5, + }, + ) + assert resp.status_code == 200 + + import asyncio + + await asyncio.sleep(0.5) + + list_resp = await test_client.get("/api/v1/pipelines") + assert list_resp.status_code == 200 + data = list_resp.json()["data"] + assert len(data) >= 1 + + pipelines._running_tasks.clear() diff --git a/backend/tests/test_pipelines.py b/backend/tests/test_pipelines.py index 19fcfad..4c29ff2 100644 --- a/backend/tests/test_pipelines.py +++ b/backend/tests/test_pipelines.py @@ -416,10 +416,7 @@ async def mock_search(self, query="", sources=None, max_results=100): thread_id = data["thread_id"] - import asyncio - - await asyncio.sleep(1) - + # Get status immediately before pipeline completes and removes itself resp2 = await client.get(f"/api/v1/pipelines/{thread_id}/status") assert resp2.status_code == 200 diff --git a/backend/tests/test_reranker_service.py b/backend/tests/test_reranker_service.py new file mode 100644 index 0000000..41d89f2 --- /dev/null +++ b/backend/tests/test_reranker_service.py @@ -0,0 +1,93 @@ +"""Tests for reranker_service — model loading, caching, and async inference.""" + +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture(autouse=True) +def _reset_reranker_cache(): + from app.services.gpu_model_manager import gpu_model_manager + + gpu_model_manager.unload("reranker") + yield + gpu_model_manager.unload("reranker") + + +class TestGetReranker: + @patch("app.services.reranker_service._build_reranker") + def test_returns_cached_instance(self, mock_build): + from app.services.reranker_service import get_reranker + + sentinel = MagicMock() + mock_build.return_value = sentinel + result = get_reranker() + assert result is sentinel + mock_build.assert_called_once() + + @patch("app.services.reranker_service._build_reranker") + def test_caching_returns_same_instance(self, mock_build): + from app.services.reranker_service import get_reranker + + sentinel = MagicMock() + mock_build.return_value = sentinel + r1 = get_reranker() + r2 = get_reranker() + assert r1 is r2 + mock_build.assert_called_once() + + @patch("app.services.reranker_service._build_reranker") + def test_custom_model_name(self, mock_build): + from app.services.reranker_service import get_reranker + + mock_build.return_value = MagicMock() + get_reranker(model_name="custom/reranker") + mock_build.assert_called_with("custom/reranker") + + +class TestRerankNodes: + @pytest.mark.asyncio + async def test_empty_nodes_returns_empty(self): + from app.services.reranker_service import rerank_nodes + + result = await rerank_nodes([], "test query", top_n=5) + assert result == [] + + @pytest.mark.asyncio + @patch("app.services.reranker_service.get_reranker") + async def test_rerank_returns_top_n(self, mock_get_reranker): + from app.services.reranker_service import rerank_nodes + + mock_node_1 = MagicMock() + mock_node_1.score = 0.9 + mock_node_2 = MagicMock() + mock_node_2.score = 0.5 + mock_node_3 = MagicMock() + mock_node_3.score = 0.7 + + mock_reranker = MagicMock() + mock_reranker.postprocess_nodes.return_value = [mock_node_1, mock_node_3, mock_node_2] + mock_get_reranker.return_value = mock_reranker + + result = await rerank_nodes([mock_node_1, mock_node_2, mock_node_3], "query", top_n=2) + assert len(result) == 2 + assert result[0] is mock_node_1 + + @pytest.mark.asyncio + @patch("app.services.reranker_service.get_reranker", side_effect=ImportError("no model")) + async def test_fallback_on_import_error(self, _mock): + from app.services.reranker_service import rerank_nodes + + nodes = [MagicMock() for _ in range(5)] + result = await rerank_nodes(nodes, "query", top_n=3) + assert len(result) == 3 + assert result == nodes[:3] + + @pytest.mark.asyncio + @patch("app.services.reranker_service.get_reranker", side_effect=RuntimeError("GPU error")) + async def test_fallback_on_runtime_error(self, _mock): + from app.services.reranker_service import rerank_nodes + + nodes = [MagicMock() for _ in range(4)] + result = await rerank_nodes(nodes, "query", top_n=2) + assert len(result) == 2 diff --git a/backend/tests/test_search.py b/backend/tests/test_search.py index 011a519..b6a2bce 100644 --- a/backend/tests/test_search.py +++ b/backend/tests/test_search.py @@ -414,7 +414,7 @@ async def mock_search(*args, **kwargs): resp = await client.post( f"/api/v1/projects/{project_id}/search/execute", - params={"query": "machine learning"}, + json={"query": "machine learning"}, ) assert resp.status_code == 200 body = resp.json() @@ -431,16 +431,16 @@ async def test_execute_search_no_query_no_keywords(client: AsyncClient): resp = await client.post( f"/api/v1/projects/{project_id}/search/execute", - params={"query": ""}, + json={"query": ""}, ) assert resp.status_code == 400 - assert "no keywords" in resp.json()["detail"].lower() + assert "no keywords" in resp.json()["message"].lower() @pytest.mark.asyncio async def test_execute_search_nonexistent_project(client: AsyncClient): resp = await client.post( "/api/v1/projects/99999/search/execute", - params={"query": "test"}, + json={"query": "test"}, ) assert resp.status_code == 404 diff --git a/backend/tests/test_subscription.py b/backend/tests/test_subscription.py index d5396ca..ef94531 100644 --- a/backend/tests/test_subscription.py +++ b/backend/tests/test_subscription.py @@ -6,8 +6,9 @@ import pytest from httpx import ASGITransport, AsyncClient -from app.database import Base, engine +from app.database import Base, async_session_factory, engine from app.main import app +from app.models import Project from app.services.subscription_service import SubscriptionService @@ -20,6 +21,16 @@ async def setup_db(): await conn.run_sync(Base.metadata.drop_all) +@pytest.fixture +async def project(setup_db): + async with async_session_factory() as db: + p = Project(name="Test Project", description="For subscription tests") + db.add(p) + await db.commit() + await db.refresh(p) + return p + + @pytest.fixture async def client(): transport = ASGITransport(app=app) @@ -88,7 +99,10 @@ async def test_check_rss_feed(self, mock_rss_xml): mock_resp.text = mock_rss_xml mock_resp.raise_for_status = MagicMock() - with patch("httpx.AsyncClient") as mock_client_cls: + with ( + patch("app.services.url_validator.validate_url_safe", return_value="https://example.com/feed.xml"), + patch("httpx.AsyncClient") as mock_client_cls, + ): mock_client = AsyncMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) @@ -108,7 +122,10 @@ async def test_check_rss_feed_with_since_filter(self, mock_rss_xml): mock_resp.text = mock_rss_xml mock_resp.raise_for_status = MagicMock() - with patch("httpx.AsyncClient") as mock_client_cls: + with ( + patch("app.services.url_validator.validate_url_safe", return_value="https://example.com/feed.xml"), + patch("httpx.AsyncClient") as mock_client_cls, + ): mock_client = AsyncMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) @@ -135,12 +152,15 @@ async def test_list_common_feeds(self, client): assert len(body["data"]) >= 4 @pytest.mark.asyncio - async def test_check_rss_mock(self, client, mock_rss_xml): + async def test_check_rss_mock(self, client, project, mock_rss_xml): mock_resp = MagicMock() mock_resp.text = mock_rss_xml mock_resp.raise_for_status = MagicMock() - with patch("httpx.AsyncClient") as mock_client_cls: + with ( + patch("app.services.url_validator.validate_url_safe", return_value="https://example.com/feed.xml"), + patch("httpx.AsyncClient") as mock_client_cls, + ): mock_client = AsyncMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) @@ -148,7 +168,7 @@ async def test_check_rss_mock(self, client, mock_rss_xml): mock_client_cls.return_value = mock_client resp = await client.post( - "/api/v1/projects/1/subscriptions/check-rss", + f"/api/v1/projects/{project.id}/subscriptions/check-rss", params={"feed_url": "https://example.com/feed.xml", "since_days": 7}, ) assert resp.status_code == 200 diff --git a/backend/tests/test_url_validator.py b/backend/tests/test_url_validator.py new file mode 100644 index 0000000..de4e559 --- /dev/null +++ b/backend/tests/test_url_validator.py @@ -0,0 +1,65 @@ +"""Tests for app.services.url_validator.""" + +import pytest + +from app.services.url_validator import validate_doi, validate_url_safe + + +def test_valid_https_url(): + result = validate_url_safe("https://8.8.8.8/") + assert result == "https://8.8.8.8/" + + +def test_valid_http_url(): + result = validate_url_safe("http://1.1.1.1/") + assert result == "http://1.1.1.1/" + + +def test_ftp_scheme_rejected(): + with pytest.raises(ValueError, match="Unsupported scheme: ftp"): + validate_url_safe("ftp://example.com/file") + + +def test_no_scheme_rejected(): + with pytest.raises(ValueError, match="Unsupported scheme"): + validate_url_safe("example.com/path") + + +def test_private_ip_rejected(): + with pytest.raises(ValueError, match="Blocked.*private"): + validate_url_safe("http://192.168.1.1/") + + +def test_loopback_rejected(): + with pytest.raises(ValueError, match="Blocked.*private"): + validate_url_safe("http://127.0.0.1/") + + +def test_metadata_google_rejected(): + with pytest.raises(ValueError, match="Blocked hostname"): + validate_url_safe("http://metadata.google.internal/") + + +def test_metadata_aws_rejected(): + with pytest.raises(ValueError, match="Blocked hostname"): + validate_url_safe("http://metadata.amazonaws.com/") + + +def test_valid_doi(): + result = validate_doi("10.1038/nature12373") + assert result == "10.1038/nature12373" + + +def test_valid_doi_with_special_chars(): + result = validate_doi("10.1000/xyz123") + assert result == "10.1000/xyz123" + + +def test_invalid_doi_no_prefix(): + with pytest.raises(ValueError, match="Invalid DOI format"): + validate_doi("not-a-doi") + + +def test_invalid_doi_wrong_format(): + with pytest.raises(ValueError, match="Invalid DOI format"): + validate_doi("11.1234/abc") diff --git a/backend/tests/test_writing.py b/backend/tests/test_writing.py index 9a7561a..26796af 100644 --- a/backend/tests/test_writing.py +++ b/backend/tests/test_writing.py @@ -68,7 +68,7 @@ async def project_with_papers(): @pytest.mark.asyncio async def test_summarize_papers(project_with_papers): - from app.services.llm_client import LLMClient + from app.services.llm.client import LLMClient from app.services.writing_service import WritingService project_id, paper_ids = project_with_papers @@ -84,7 +84,7 @@ async def test_summarize_papers(project_with_papers): @pytest.mark.asyncio async def test_generate_citations_gb_t_7714(project_with_papers): - from app.services.llm_client import LLMClient + from app.services.llm.client import LLMClient from app.services.writing_service import WritingService project_id, paper_ids = project_with_papers @@ -103,7 +103,7 @@ async def test_generate_citations_gb_t_7714(project_with_papers): @pytest.mark.asyncio async def test_generate_citations_apa(project_with_papers): - from app.services.llm_client import LLMClient + from app.services.llm.client import LLMClient from app.services.writing_service import WritingService project_id, paper_ids = project_with_papers @@ -119,7 +119,7 @@ async def test_generate_citations_apa(project_with_papers): @pytest.mark.asyncio async def test_generate_citations_mla(project_with_papers): - from app.services.llm_client import LLMClient + from app.services.llm.client import LLMClient from app.services.writing_service import WritingService project_id, paper_ids = project_with_papers @@ -135,7 +135,7 @@ async def test_generate_citations_mla(project_with_papers): @pytest.mark.asyncio async def test_generate_review_outline(project_with_papers): - from app.services.llm_client import LLMClient + from app.services.llm.client import LLMClient from app.services.writing_service import WritingService project_id, paper_ids = project_with_papers @@ -156,7 +156,7 @@ async def test_generate_review_outline(project_with_papers): @pytest.mark.asyncio async def test_analyze_gaps(project_with_papers): - from app.services.llm_client import LLMClient + from app.services.llm.client import LLMClient from app.services.writing_service import WritingService project_id, paper_ids = project_with_papers @@ -300,7 +300,7 @@ async def test_assist_unknown_task(client: AsyncClient, project_with_papers): f"/api/v1/projects/{project_id}/writing/assist", json={"task": "unknown_task"}, ) - assert resp.status_code == 200 # We return 400 in data + assert resp.status_code == 400 body = resp.json() assert body["code"] == 400 assert "Unknown task" in body["message"] diff --git a/docs/api-endpoints.md b/docs/api-endpoints.md new file mode 100644 index 0000000..999dce2 --- /dev/null +++ b/docs/api-endpoints.md @@ -0,0 +1,249 @@ +# Omelette API Endpoints Reference + +This document lists all API v1 endpoints exposed by the Omelette backend. Endpoints are grouped by module. Base URL: `/api/v1`. + +**Legend:** +- 🤖 Involves LLM calls +- 📄 Involves file I/O (upload, download, PDF processing, vector store) +- 🔄 SSE streaming response + +--- + +## Summary by Module + +| Module | Endpoints | 🤖 LLM | 📄 File I/O | 🔄 SSE | +|--------|-----------|--------|-------------|--------| +| Projects | 6 | 0 | 2 | 0 | +| Papers | 9 | 0 | 2 | 0 | +| Upload | 2 | 0 | 2 | 0 | +| Keywords | 7 | 2 | 0 | 0 | +| Search | 2 | 0 | 0 | 0 | +| Dedup | 5 | 4 | 2 | 0 | +| Crawler | 2 | 0 | 1 | 0 | +| OCR | 2 | 0 | 1 | 0 | +| Subscriptions | 9 | 0 | 0 | 0 | +| RAG | 5 | 1 | 4 | 1 | +| Writing | 6 | 5 | 0 | 1 | +| Tasks | 3 | 0 | 0 | 0 | +| Settings | 5 | 1 | 0 | 0 | +| Conversations | 5 | 0 | 0 | 0 | +| Chat | 2 | 2 | 0 | 1 | +| Rewrite | 1 | 1 | 0 | 1 | +| Pipelines | 5 | 0 | 2 | 0 | +| **Total** | **76** | **16** | **14** | **4** | + +--- + +## Projects + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| GET | `/api/v1/projects` | List projects with pagination | `page`, `page_size` | | +| POST | `/api/v1/projects` | Create a new project | Body: `ProjectCreate` (name, description, domain, settings) | | +| GET | `/api/v1/projects/{project_id}` | Get project by ID | `project_id` | | +| PUT | `/api/v1/projects/{project_id}` | Update project | `project_id`, Body: `ProjectUpdate` | | +| DELETE | `/api/v1/projects/{project_id}` | Delete project | `project_id` | | +| POST | `/api/v1/projects/{project_id}/pipeline/run` | Trigger crawl → OCR → index pipeline for all pending papers | `project_id` | 📄 | +| POST | `/api/v1/projects/{project_id}/pipeline/paper/{paper_id}` | Trigger pipeline for a single paper | `project_id`, `paper_id` | 📄 | + +--- + +## Papers + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| GET | `/api/v1/projects/{project_id}/papers` | List papers with filters and pagination | `project_id`, `page`, `page_size`, `status`, `year`, `q`, `sort_by`, `order` | | +| POST | `/api/v1/projects/{project_id}/papers` | Create a paper | `project_id`, Body: `PaperCreate` | | +| POST | `/api/v1/projects/{project_id}/papers/bulk` | Bulk import papers | `project_id`, Body: `PaperBulkImport` (papers[]) | | +| GET | `/api/v1/projects/{project_id}/papers/{paper_id}` | Get paper by ID | `project_id`, `paper_id` | | +| PUT | `/api/v1/projects/{project_id}/papers/{paper_id}` | Update paper | `project_id`, `paper_id`, Body: `PaperUpdate` | | +| DELETE | `/api/v1/projects/{project_id}/papers/{paper_id}` | Delete paper | `project_id`, `paper_id` | | +| GET | `/api/v1/projects/{project_id}/papers/{paper_id}/pdf` | Serve PDF file | `project_id`, `paper_id` | 📄 | +| GET | `/api/v1/projects/{project_id}/papers/{paper_id}/citation-graph` | Get citation relationship graph via Semantic Scholar | `project_id`, `paper_id`, `depth`, `max_nodes` | | + +--- + +## Upload (Papers) + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| POST | `/api/v1/projects/{project_id}/papers/upload` | Upload PDFs, extract metadata, run dedup check | `project_id`, `files` (multipart) | 📄 | +| POST | `/api/v1/projects/{project_id}/papers/process` | Trigger OCR + RAG indexing for papers | `project_id`, `paper_ids` (optional) | 📄 | + +--- + +## Keywords + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| GET | `/api/v1/projects/{project_id}/keywords` | List keywords with pagination | `project_id`, `page`, `page_size`, `level` | | +| POST | `/api/v1/projects/{project_id}/keywords` | Create keyword | `project_id`, Body: `KeywordCreate` | | +| POST | `/api/v1/projects/{project_id}/keywords/bulk` | Bulk create keywords | `project_id`, Body: `KeywordCreate[]` | | +| GET | `/api/v1/projects/{project_id}/keywords/search-formula` | Generate boolean search formula from project keywords | `project_id`, `database` | 🤖 | +| PUT | `/api/v1/projects/{project_id}/keywords/{keyword_id}` | Update keyword | `project_id`, `keyword_id`, Body: `KeywordUpdate` | | +| DELETE | `/api/v1/projects/{project_id}/keywords/{keyword_id}` | Delete keyword | `project_id`, `keyword_id` | | +| POST | `/api/v1/projects/{project_id}/keywords/expand` | Expand seed keywords with synonyms via LLM | `project_id`, Body: `KeywordExpandRequest` | 🤖 | + +--- + +## Search + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| POST | `/api/v1/projects/{project_id}/search/execute` | Execute federated search (Semantic Scholar, OpenAlex, arXiv, Crossref) | `project_id`, `query`, `sources`, `max_results`, `auto_import` | | +| GET | `/api/v1/projects/{project_id}/search/sources` | List available search sources and status | `project_id` | | + +--- + +## Dedup + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| POST | `/api/v1/projects/{project_id}/dedup/run` | Run deduplication pipeline | `project_id`, `strategy` (full, doi_only, title_only) | 🤖 | +| GET | `/api/v1/projects/{project_id}/dedup/candidates` | List potential duplicate pairs for manual review | `project_id` | 🤖 | +| POST | `/api/v1/projects/{project_id}/dedup/verify` | Use LLM to verify if two papers are duplicates | `project_id`, `paper_a_id`, `paper_b_id` | 🤖 | +| POST | `/api/v1/projects/{project_id}/dedup/resolve` | Resolve upload conflict (keep_old, keep_new, merge, skip) | `project_id`, Body: `ResolveConflictRequest` | 📄 | +| POST | `/api/v1/projects/{project_id}/dedup/auto-resolve` | Use LLM to suggest resolution for conflict pairs | `project_id`, Body: `AutoResolveRequest` | 🤖 📄 | + +--- + +## Crawler + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| POST | `/api/v1/projects/{project_id}/crawl/start` | Start PDF download for papers needing PDFs | `project_id`, `priority`, `max_papers` | 📄 | +| GET | `/api/v1/projects/{project_id}/crawl/stats` | Return download statistics for project | `project_id` | | + +--- + +## OCR + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| POST | `/api/v1/projects/{project_id}/ocr/process` | Run OCR/text extraction on downloaded PDFs | `project_id`, `paper_ids`, `force_ocr`, `use_gpu` | 📄 | +| GET | `/api/v1/projects/{project_id}/ocr/stats` | Return OCR processing statistics | `project_id` | | + +--- + +## Subscriptions + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| GET | `/api/v1/projects/{project_id}/subscriptions/feeds` | List common academic RSS feed templates | `project_id` | | +| POST | `/api/v1/projects/{project_id}/subscriptions/check-rss` | Check RSS feed for new entries | `project_id`, `feed_url`, `since_days` | | +| POST | `/api/v1/projects/{project_id}/subscriptions/check-updates` | Check for new papers via API search | `project_id`, `query`, `sources`, `since_days`, `max_results` | | +| GET | `/api/v1/projects/{project_id}/subscriptions` | List subscriptions for project | `project_id` | | +| POST | `/api/v1/projects/{project_id}/subscriptions` | Create subscription | `project_id`, Body: `SubscriptionCreate` | | +| GET | `/api/v1/projects/{project_id}/subscriptions/{sub_id}` | Get subscription by ID | `project_id`, `sub_id` | | +| PUT | `/api/v1/projects/{project_id}/subscriptions/{sub_id}` | Update subscription | `project_id`, `sub_id`, Body: `SubscriptionUpdate` | | +| DELETE | `/api/v1/projects/{project_id}/subscriptions/{sub_id}` | Delete subscription | `project_id`, `sub_id` | | +| POST | `/api/v1/projects/{project_id}/subscriptions/{sub_id}/trigger` | Manually trigger subscription update | `project_id`, `sub_id`, `since_days` | | + +--- + +## RAG + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| POST | `/api/v1/projects/{project_id}/rag/query` | Answer question using RAG over indexed literature | `project_id`, Body: `RAGQueryRequest` (question, top_k, use_reranker, include_sources) | 🤖 | +| POST | `/api/v1/projects/{project_id}/rag/index` | Build or rebuild vector index for processed papers | `project_id` | 📄 | +| POST | `/api/v1/projects/{project_id}/rag/index/stream` | SSE streaming index rebuild with progress events | `project_id` | 📄 🔄 | +| GET | `/api/v1/projects/{project_id}/rag/stats` | Return indexing statistics | `project_id` | | +| DELETE | `/api/v1/projects/{project_id}/rag/index` | Delete vector index for project | `project_id` | 📄 | + +--- + +## Writing + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| POST | `/api/v1/projects/{project_id}/writing/assist` | AI writing assistance (summarize, cite, outline, gap analysis) | `project_id`, Body: `WritingAssistRequest` | 🤖 | +| POST | `/api/v1/projects/{project_id}/writing/summarize` | Generate summaries for selected papers | `project_id`, Body: `SummarizeRequest` | 🤖 | +| POST | `/api/v1/projects/{project_id}/writing/citations` | Generate formatted citations | `project_id`, Body: `CitationsRequest` | | +| POST | `/api/v1/projects/{project_id}/writing/review-outline` | Generate literature review outline | `project_id`, Body: `ReviewOutlineRequest` | 🤖 | +| POST | `/api/v1/projects/{project_id}/writing/gap-analysis` | Analyze research gaps | `project_id`, Body: `GapAnalysisRequest` | 🤖 | +| POST | `/api/v1/projects/{project_id}/writing/review-draft/stream` | Stream literature review draft via SSE | `project_id`, Body: `ReviewDraftRequest` | 🤖 🔄 | + +--- + +## Tasks + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| GET | `/api/v1/tasks/{task_id}` | Get task status and details | `task_id` | | +| GET | `/api/v1/tasks` | List tasks with pagination | `project_id`, `status`, `page`, `page_size` | | +| POST | `/api/v1/tasks/{task_id}/cancel` | Cancel a running task | `task_id` | | + +--- + +## Settings + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| GET | `/api/v1/settings` | Get merged settings (DB overrides .env); API keys masked | | | +| PUT | `/api/v1/settings` | Update user settings and persist to DB | Body: `SettingsUpdateSchema` | | +| GET | `/api/v1/settings/models` | List available LLM providers and models | | | +| POST | `/api/v1/settings/test-connection` | Test LLM configuration with simple prompt | | 🤖 | +| GET | `/api/v1/settings/health` | Simple health check | | | + +--- + +## Conversations + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| GET | `/api/v1/conversations` | List conversations, newest first | `page`, `page_size`, `knowledge_base_id` | | +| POST | `/api/v1/conversations` | Create new conversation | Body: `ConversationCreateSchema` | | +| GET | `/api/v1/conversations/{conversation_id}` | Get conversation with all messages | `conversation_id` | | +| PUT | `/api/v1/conversations/{conversation_id}` | Update conversation title or settings | `conversation_id`, Body: `ConversationUpdateSchema` | | +| DELETE | `/api/v1/conversations/{conversation_id}` | Delete conversation and messages | `conversation_id` | | + +--- + +## Chat + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| POST | `/api/v1/chat/stream` | Data Stream Protocol (Vercel AI SDK 5.0) chat endpoint | Body: `ChatStreamRequest` | 🤖 🔄 | +| POST | `/api/v1/chat/complete` | Short text completion for autocomplete | Body: `CompletionRequest` | 🤖 | + +--- + +## Rewrite + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| POST | `/api/v1/chat/rewrite` | SSE streaming excerpt rewrite (simplify, academic, translate, custom) | Body: `RewriteRequest` | 🤖 🔄 | + +--- + +## Pipelines + +| Method | Path | Description | Params | Flags | +|--------|------|-------------|--------|-------| +| POST | `/api/v1/pipelines/search` | Start keyword-search pipeline (search → dedup → crawl → OCR → index) | Body: `SearchPipelineRequest` | 📄 | +| POST | `/api/v1/pipelines/upload` | Start PDF-upload pipeline (extract → dedup → OCR → index) | Body: `UploadPipelineRequest` | 📄 | +| GET | `/api/v1/pipelines/{thread_id}/status` | Get pipeline execution status | `thread_id` | | +| POST | `/api/v1/pipelines/{thread_id}/resume` | Resume interrupted pipeline with resolved conflicts | `thread_id`, Body: `ResumeRequest` | | +| POST | `/api/v1/pipelines/{thread_id}/cancel` | Cancel running pipeline | `thread_id` | | + +--- + +## Authentication + +The API uses **optional API key authentication** via `API_SECRET_KEY` (configured in `.env`). + +- **When `API_SECRET_KEY` is set:** All requests must include the key via: + - Header: `X-API-Key: