diff --git a/backend/app/api/v1/chat.py b/backend/app/api/v1/chat.py index 5628dc7..d9fd987 100644 --- a/backend/app/api/v1/chat.py +++ b/backend/app/api/v1/chat.py @@ -1,339 +1,89 @@ -"""Chat streaming API — SSE endpoint for real-time AI responses with citations.""" +"""Chat streaming API — Data Stream Protocol endpoint (Vercel AI SDK 5.0).""" from __future__ import annotations -import asyncio import json import logging -import time import uuid +from collections.abc import Callable from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload from app.api.deps import get_db -from app.models.conversation import Conversation -from app.models.message import Message -from app.models.paper import Paper +from app.pipelines.chat.graph import create_chat_pipeline +from app.pipelines.chat.stream_writer import ( + format_done, + format_error, + format_finish, + format_start, +) from app.schemas.conversation import ChatStreamRequest -from app.services.llm.client import LLMClient, get_llm_client -from app.services.rag_service import RAGService -from app.services.user_settings_service import UserSettingsService logger = logging.getLogger(__name__) router = APIRouter(prefix="/chat", tags=["chat"]) -TOOL_MODE_PROMPTS = { - "qa": ( - "You are a scientific research assistant. Answer the question based on the provided context. " - "Use inline citations like [1], [2] to reference source papers. " - "If the context doesn't contain enough information, say so honestly." - ), - "citation_lookup": ( - "You are a citation finder. Given the user's text, identify and list the most relevant " - "references from the provided context. Format as a numbered list with paper titles, authors, " - "and brief explanations of relevance. Keep your own commentary minimal." - ), - "review_outline": ( - "You are a literature review expert. Based on the provided context, generate a structured " - "review outline with sections, subsections, and key points. Use citations like [1], [2] " - "to reference sources. Suggest a logical flow and highlight key themes." - ), - "gap_analysis": ( - "You are a research gap analyst. Based on the provided literature context, identify " - "research gaps, unexplored areas, and potential future directions. Cite existing work " - "using [1], [2] format. Be specific about what has been studied and what remains open." - ), -} - -EXCERPT_CLEAN_PROMPT = ( - "Clean up the following text extracted from an academic PDF. " - "Fix OCR errors, add missing spaces between words, restore formatting. " - "Keep the original meaning intact. Output only the cleaned text, nothing else." -) - -_clean_semaphore = asyncio.Semaphore(3) - -async def _clean_excerpt(llm: LLMClient, excerpt: str) -> str: - """Use LLM to clean OCR-extracted text.""" - if not excerpt or len(excerpt) < 20: - return excerpt - async with _clean_semaphore: - messages = [ - {"role": "system", "content": EXCERPT_CLEAN_PROMPT}, - {"role": "user", "content": excerpt}, - ] - result = "" - try: - async with asyncio.timeout(10.0): - async for token in llm.chat_stream(messages, temperature=0.1, task_type="clean"): - result += token - except TimeoutError: - logger.warning("Excerpt cleaning timed out, using original") - return excerpt - return result if result.strip() else excerpt +async def _init_services(db: AsyncSession) -> dict: + """Create LLM + RAG services from user settings.""" + from app.services.llm.client import get_llm_client + from app.services.rag_service import RAGService + from app.services.user_settings_service import UserSettingsService - -async def _get_rag_service_for_chat(db: AsyncSession) -> tuple[RAGService, LLMClient]: svc = UserSettingsService(db) - config = await svc.get_merged_llm_config() - llm = get_llm_client(config=config) - - from llama_index.core.embeddings import MockEmbedding + llm_config = await svc.get_merged_llm_config() + llm = get_llm_client(config=llm_config) - from app.services.embedding_service import get_embedding_model + if llm_config.provider == "mock": + from llama_index.core.embeddings import MockEmbedding - embed = MockEmbedding(embed_dim=128) if config.provider == "mock" else get_embedding_model() - rag = RAGService(llm=llm, embed_model=embed) - return rag, llm + embed = MockEmbedding(embed_dim=128) + else: + from app.services.embedding_service import get_embedding_model + embed = get_embedding_model() -def _thinking(step: str, label: str, status: str = "running", **kwargs) -> str: - data = {"step": step, "label": label, "status": status, **kwargs} - return _sse("thinking_step", data) + rag = RAGService(llm=llm, embed_model=embed) + return {"llm": llm, "rag": rag} async def _stream_chat( request: ChatStreamRequest, db: AsyncSession, + init_services: Callable = _init_services, ): - """Generator that yields SSE events for the chat stream.""" - message_id = str(uuid.uuid4())[:12] - t0 = time.monotonic() - - yield _sse("message_start", {"message_id": message_id}) + """Yield Data Stream Protocol SSE events from the LangGraph chat pipeline.""" + msg_id = f"msg_{uuid.uuid4().hex}" + yield format_start(msg_id) try: - yield _thinking("understand", "Understanding query", detail=f"Analyzing '{request.message[:40]}...'") - - rag, llm = await _get_rag_service_for_chat(db) - - yield _thinking( - "understand", - "Understanding query", - status="done", - duration_ms=int((time.monotonic() - t0) * 1000), - summary="Ready", - ) - - all_sources = [] - all_contexts = [] - citations = [] - - if request.knowledge_base_ids: - t_retrieve = time.monotonic() - yield _thinking( - "retrieve", - "Searching knowledge base", - detail=f"Searching in {len(request.knowledge_base_ids)} knowledge base(s)...", - ) - - rag_tasks = [ - rag.query( - project_id=kb_id, - question=request.message, - top_k=5, - include_sources=True, - ) - for kb_id in request.knowledge_base_ids - ] - results = await asyncio.gather(*rag_tasks, return_exceptions=True) - - for result in results: - if isinstance(result, Exception): - logger.warning("RAG query failed for a KB: %s", result) - continue - if result.get("sources"): - all_sources.extend(result["sources"]) - for src in result["sources"]: - all_contexts.append( - f"[Source: {src.get('paper_title', 'Unknown')}, " - f"p.{src.get('page_number', '?')}]\n{src.get('excerpt', '')}" - ) - - yield _thinking( - "retrieve", - "Searching knowledge base", - status="done", - duration_ms=int((time.monotonic() - t_retrieve) * 1000), - summary=f"Found {len(all_sources)} relevant sources", - ) - - t_rank = time.monotonic() - yield _thinking("rank", "Analyzing citations", detail="Evaluating citation relevance...") - - paper_ids = list({pid for pid in (src.get("paper_id") for src in all_sources) if pid is not None}) - papers_by_id: dict[int, Paper] = {} - if paper_ids: - result = await db.execute(select(Paper).where(Paper.id.in_(paper_ids))) - papers_by_id = {p.id: p for p in result.scalars().all()} - - for i, src in enumerate(all_sources, 1): - paper = papers_by_id.get(src.get("paper_id")) if src.get("paper_id") else None - citation = { - "index": i, - "paper_id": src.get("paper_id"), - "paper_title": src.get("paper_title", ""), - "page_number": src.get("page_number"), - "excerpt": src.get("excerpt", ""), - "relevance_score": src.get("relevance_score", 0), - "chunk_type": src.get("chunk_type", "text"), - "authors": paper.authors if paper else None, - "year": paper.year if paper else None, - "doi": paper.doi if paper else None, - } - citations.append(citation) - yield _sse("citation", citation) - - high_relevance = sum(1 for c in citations if c.get("relevance_score", 0) > 0.6) - yield _thinking( - "rank", - "Analyzing citations", - status="done", - duration_ms=int((time.monotonic() - t_rank) * 1000), - summary=f"Selected {high_relevance} high-relevance citations (>60%)", - ) - - excerpts_to_clean = [(i, c["excerpt"]) for i, c in enumerate(citations) if c.get("excerpt")] - if excerpts_to_clean: - t_clean = time.monotonic() - yield _thinking( - "clean", - "Cleaning citation text", - detail=f"Improving readability of {len(excerpts_to_clean)} citations in parallel...", - ) - - clean_tasks = [_clean_excerpt(llm, excerpt) for _, excerpt in excerpts_to_clean] - cleaned_results = await asyncio.gather(*clean_tasks, return_exceptions=True) - - enhanced_count = 0 - for (idx, _original), cleaned in zip(excerpts_to_clean, cleaned_results): - if isinstance(cleaned, str) and cleaned.strip() and cleaned != citations[idx]["excerpt"]: - citations[idx]["excerpt"] = cleaned - enhanced_count += 1 - yield _sse( - "citation_enhanced", - { - "index": citations[idx]["index"], - "cleaned_excerpt": cleaned, - }, - ) - - yield _thinking( - "clean", - "Cleaning citation text", - status="done", - duration_ms=int((time.monotonic() - t_clean) * 1000), - summary=f"Enhanced {enhanced_count} citations", - ) - - history_messages = [] - conversation_id = request.conversation_id - - if conversation_id: - result = await db.execute( - select(Conversation) - .where(Conversation.id == conversation_id) - .options(selectinload(Conversation.messages)) - ) - conv = result.scalar_one_or_none() - if conv: - for msg in conv.messages[-10:]: - history_messages.append({"role": msg.role, "content": msg.content}) - - if request.knowledge_base_ids: - system_prompt = TOOL_MODE_PROMPTS.get(request.tool_mode, TOOL_MODE_PROMPTS["qa"]) - context_text = "\n\n---\n\n".join(all_contexts) if all_contexts else "No relevant documents found." - user_content = f"Context:\n{context_text}\n\nQuestion: {request.message}" - else: - system_prompt = ( - "You are a helpful scientific research assistant. " - "Answer questions clearly and accurately. " - "If you don't know the answer, say so honestly." - ) - user_content = request.message - - messages = [ - {"role": "system", "content": system_prompt}, - *history_messages, - {"role": "user", "content": user_content}, - ] - - t_gen = time.monotonic() - yield _thinking( - "generate", "Generating answer", detail=f"Generating answer based on {len(citations)} citations..." - ) - - full_response = "" - async for token in llm.chat_stream(messages, temperature=0.3, task_type="chat"): - full_response += token - yield _sse("text_delta", {"delta": token}) - - yield _thinking( - "generate", - "Generating answer", - status="done", - duration_ms=int((time.monotonic() - t_gen) * 1000), - summary=f"Generated {len(full_response)} characters", - ) - - if not conversation_id: - title = request.message[:50] + ("..." if len(request.message) > 50 else "") - conv = Conversation( - title=title, - knowledge_base_ids=request.knowledge_base_ids, - model=request.model or "", - tool_mode=request.tool_mode, - ) - db.add(conv) - await db.flush() - conversation_id = conv.id - - user_msg = Message( - conversation_id=conversation_id, - role="user", - content=request.message, - ) - assistant_msg = Message( - conversation_id=conversation_id, - role="assistant", - content=full_response, - citations=citations if citations else None, - ) - db.add(user_msg) - db.add(assistant_msg) - await db.commit() - - total_ms = int((time.monotonic() - t0) * 1000) - yield _thinking( - "complete", - "Complete", - status="done", - duration_ms=total_ms, - summary=f"Total {total_ms / 1000:.1f}s, cited {len(citations)} sources", - ) - - yield _sse( - "message_end", - { - "message_id": message_id, - "conversation_id": conversation_id, - "finish_reason": "stop", - }, - ) - + services = await init_services(db) + config = {"configurable": {"db": db, "_services": services}} + + pipeline = create_chat_pipeline() + initial_state = { + "message": request.message, + "knowledge_base_ids": request.knowledge_base_ids, + "tool_mode": request.tool_mode, + "conversation_id": request.conversation_id, + "model": request.model or "", + } + + async for event in pipeline.astream( + initial_state, + config=config, + stream_mode="custom", + ): + yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n" + + yield format_finish() except Exception as e: logger.exception("Chat stream error") - yield _sse("error", {"code": "stream_error", "message": str(e)}) - - -def _sse(event: str, data: dict) -> str: - return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + yield format_error(str(e)) + finally: + yield format_done() @router.post("/stream") @@ -341,7 +91,7 @@ async def chat_stream( request: ChatStreamRequest, db: AsyncSession = Depends(get_db), ): - """SSE streaming chat endpoint — sends token-level events.""" + """Data Stream Protocol (Vercel AI SDK 5.0) chat endpoint.""" return StreamingResponse( _stream_chat(request, db), media_type="text/event-stream", @@ -349,5 +99,6 @@ async def chat_stream( "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", + "X-Vercel-AI-UI-Message-Stream": "v1", }, ) diff --git a/backend/app/pipelines/chat/__init__.py b/backend/app/pipelines/chat/__init__.py new file mode 100644 index 0000000..308eae5 --- /dev/null +++ b/backend/app/pipelines/chat/__init__.py @@ -0,0 +1,5 @@ +"""Chat pipeline — LangGraph StateGraph for streaming chat with RAG.""" + +from app.pipelines.chat.graph import create_chat_pipeline + +__all__ = ["create_chat_pipeline"] diff --git a/backend/app/pipelines/chat/config_helpers.py b/backend/app/pipelines/chat/config_helpers.py new file mode 100644 index 0000000..a2d45eb --- /dev/null +++ b/backend/app/pipelines/chat/config_helpers.py @@ -0,0 +1,49 @@ +"""Helpers for accessing request-scoped services from LangGraph config. + +LangGraph shallow-copies the config dict for each node, but nested mutable +objects remain shared. We store a ``_services`` dict in ``configurable`` +so that ``understand_node`` can create ``llm``/``rag`` once and downstream +nodes read them through the same reference. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from langchain_core.runnables import RunnableConfig + from sqlalchemy.ext.asyncio import AsyncSession + + from app.services.llm.client import LLMClient + from app.services.rag_service import RAGService + + +def _services(config: RunnableConfig) -> dict[str, Any]: + """Return the shared mutable services dict, creating it if needed.""" + cfg = config["configurable"] + if "_services" not in cfg: + cfg["_services"] = {} + return cfg["_services"] + + +def get_chat_db(config: RunnableConfig) -> AsyncSession: + return config["configurable"]["db"] + + +def get_chat_llm(config: RunnableConfig) -> LLMClient: + return _services(config)["llm"] + + +def get_chat_rag(config: RunnableConfig) -> RAGService: + return _services(config)["rag"] + + +def set_chat_services(config: RunnableConfig, *, llm: Any, rag: Any) -> None: + """Called by understand_node to share LLM and RAG with downstream nodes.""" + svc = _services(config) + svc["llm"] = llm + svc["rag"] = rag + + +def get_configurable(config: RunnableConfig) -> dict[str, Any]: + return config["configurable"] diff --git a/backend/app/pipelines/chat/graph.py b/backend/app/pipelines/chat/graph.py new file mode 100644 index 0000000..3338ee5 --- /dev/null +++ b/backend/app/pipelines/chat/graph.py @@ -0,0 +1,57 @@ +"""Chat pipeline graph definition. + +Flow: + understand → [has KB?] → retrieve → rank → clean → generate → persist + └─ no KB ──────────────────────→ generate → persist +""" + +from __future__ import annotations + +from langgraph.graph import END, StateGraph + +from app.pipelines.chat.nodes import ( + clean_node, + generate_node, + persist_node, + rank_node, + retrieve_node, + understand_node, +) +from app.pipelines.chat.state import ChatState + + +def _route_after_understand(state: ChatState) -> str: + """Skip RAG nodes when no knowledge bases are selected.""" + kb_ids = state.get("knowledge_base_ids", []) + if kb_ids: + return "retrieve" + return "generate" + + +def create_chat_pipeline(): + """Compile the chat StateGraph. + + No checkpointer — chat streams are stateless one-shot invocations. + """ + graph = StateGraph(ChatState) + + graph.add_node("understand", understand_node) + graph.add_node("retrieve", retrieve_node) + graph.add_node("rank", rank_node) + graph.add_node("clean", clean_node) + graph.add_node("generate", generate_node) + graph.add_node("persist", persist_node) + + graph.set_entry_point("understand") + graph.add_conditional_edges( + "understand", + _route_after_understand, + {"retrieve": "retrieve", "generate": "generate"}, + ) + graph.add_edge("retrieve", "rank") + graph.add_edge("rank", "clean") + graph.add_edge("clean", "generate") + graph.add_edge("generate", "persist") + graph.add_edge("persist", END) + + return graph.compile() diff --git a/backend/app/pipelines/chat/nodes.py b/backend/app/pipelines/chat/nodes.py new file mode 100644 index 0000000..7bfd42e --- /dev/null +++ b/backend/app/pipelines/chat/nodes.py @@ -0,0 +1,445 @@ +"""Chat pipeline node implementations. + +Each node receives ``ChatState`` + ``RunnableConfig`` and returns a partial +state update. Custom SSE events are emitted via ``get_stream_writer()``. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid +from typing import Any + +from langchain_core.runnables import RunnableConfig +from langgraph.config import get_stream_writer + +from app.pipelines.chat.config_helpers import ( + get_chat_db, + get_chat_llm, + get_chat_rag, +) +from app.pipelines.chat.state import ChatMessageDict, ChatState, CitationDict + +logger = logging.getLogger(__name__) + +TOOL_MODE_PROMPTS = { + "qa": ( + "You are a scientific research assistant. Answer the question based on the provided context. " + "Use inline citations like [1], [2] to reference source papers. " + "If the context doesn't contain enough information, say so honestly." + ), + "citation_lookup": ( + "You are a citation finder. Given the user's text, identify and list the most relevant " + "references from the provided context. Format as a numbered list with paper titles, authors, " + "and brief explanations of relevance. Keep your own commentary minimal." + ), + "review_outline": ( + "You are a literature review expert. Based on the provided context, generate a structured " + "review outline with sections, subsections, and key points. Use citations like [1], [2] " + "to reference sources. Suggest a logical flow and highlight key themes." + ), + "gap_analysis": ( + "You are a research gap analyst. Based on the provided literature context, identify " + "research gaps, unexplored areas, and potential future directions. Cite existing work " + "using [1], [2] format. Be specific about what has been studied and what remains open." + ), +} + +EXCERPT_CLEAN_PROMPT = ( + "Clean up the following text extracted from an academic PDF. " + "Fix OCR errors, add missing spaces between words, restore formatting. " + "Keep the original meaning intact. Output only the cleaned text, nothing else." +) + +_clean_semaphore = asyncio.Semaphore(3) + + +def _emit_thinking( + writer, + step: str, + label: str, + status: str = "running", + **kwargs: Any, +) -> None: + writer( + { + "type": "data-thinking", + "data": {"step": step, "label": label, "status": status, **kwargs}, + } + ) + + +# --------------------------------------------------------------------------- +# Node: understand +# --------------------------------------------------------------------------- + + +async def understand_node(state: ChatState, config: RunnableConfig) -> dict[str, Any]: + """Load conversation history and build system prompt. + + LLM and RAG services are already initialized in the endpoint layer + and available via ``get_chat_llm``/``get_chat_rag``. + """ + writer = get_stream_writer() + t0 = time.monotonic() + + _emit_thinking( + writer, + "understand", + "Understanding query", + detail=f"Analyzing '{state['message'][:40]}...'", + ) + + db = get_chat_db(config) + + # Load conversation history + history_messages: list[ChatMessageDict] = [] + conv_id = state.get("conversation_id") + if conv_id: + from sqlalchemy import select + from sqlalchemy.orm import selectinload + + from app.models.conversation import Conversation + + result = await db.execute( + select(Conversation).where(Conversation.id == conv_id).options(selectinload(Conversation.messages)) + ) + conv = result.scalar_one_or_none() + if conv: + for msg in conv.messages[-10:]: + history_messages.append({"role": msg.role, "content": msg.content}) + + # Build system prompt + kb_ids = state.get("knowledge_base_ids", []) + tool_mode = state.get("tool_mode", "qa") + if kb_ids: + system_prompt = TOOL_MODE_PROMPTS.get(tool_mode, TOOL_MODE_PROMPTS["qa"]) + else: + system_prompt = ( + "You are a helpful scientific research assistant. " + "Answer questions clearly and accurately. " + "If you don't know the answer, say so honestly." + ) + + _emit_thinking( + writer, + "understand", + "Understanding query", + status="done", + duration_ms=int((time.monotonic() - t0) * 1000), + summary="Ready", + ) + + return { + "history_messages": history_messages, + "system_prompt": system_prompt, + } + + +# --------------------------------------------------------------------------- +# Node: retrieve +# --------------------------------------------------------------------------- + + +async def retrieve_node(state: ChatState, config: RunnableConfig) -> dict[str, Any]: + """Run parallel RAG queries across knowledge bases.""" + writer = get_stream_writer() + t0 = time.monotonic() + rag = get_chat_rag(config) + kb_ids = state.get("knowledge_base_ids", []) + + _emit_thinking( + writer, + "retrieve", + "Searching knowledge base", + detail=f"Searching in {len(kb_ids)} knowledge base(s)...", + ) + + tasks = [rag.query(project_id=kb_id, question=state["message"], top_k=5, include_sources=True) for kb_id in kb_ids] + results = await asyncio.gather(*tasks, return_exceptions=True) + + all_sources: list[dict[str, Any]] = [] + all_contexts: list[str] = [] + for result in results: + if isinstance(result, Exception): + logger.warning("RAG query failed for a KB: %s", result) + continue + if result.get("sources"): + all_sources.extend(result["sources"]) + for src in result["sources"]: + all_contexts.append( + f"[Source: {src.get('paper_title', 'Unknown')}, " + f"p.{src.get('page_number', '?')}]\n{src.get('excerpt', '')}" + ) + + _emit_thinking( + writer, + "retrieve", + "Searching knowledge base", + status="done", + duration_ms=int((time.monotonic() - t0) * 1000), + summary=f"Found {len(all_sources)} relevant sources", + ) + + return {"rag_results": all_sources, "all_contexts": all_contexts} + + +# --------------------------------------------------------------------------- +# Node: rank +# --------------------------------------------------------------------------- + + +async def rank_node(state: ChatState, config: RunnableConfig) -> dict[str, Any]: + """Build citation list from RAG results, batch-loading Paper metadata.""" + writer = get_stream_writer() + t0 = time.monotonic() + db = get_chat_db(config) + + _emit_thinking(writer, "rank", "Analyzing citations", detail="Evaluating citation relevance...") + + all_sources = state.get("rag_results", []) + + from sqlalchemy import select + + from app.models.paper import Paper + + paper_ids = list({pid for pid in (src.get("paper_id") for src in all_sources) if pid is not None}) + papers_by_id: dict[int, Any] = {} + if paper_ids: + result = await db.execute(select(Paper).where(Paper.id.in_(paper_ids))) + papers_by_id = {p.id: p for p in result.scalars().all()} + + citations: list[CitationDict] = [] + for i, src in enumerate(all_sources, 1): + paper = papers_by_id.get(src.get("paper_id")) if src.get("paper_id") else None + cit: CitationDict = { + "index": i, + "paper_id": src.get("paper_id"), + "paper_title": src.get("paper_title", ""), + "page_number": src.get("page_number"), + "excerpt": src.get("excerpt", ""), + "relevance_score": src.get("relevance_score", 0), + "chunk_type": src.get("chunk_type", "text"), + "authors": paper.authors if paper else None, + "year": paper.year if paper else None, + "doi": paper.doi if paper else None, + } + citations.append(cit) + writer( + { + "type": "data-citation", + "id": f"cit-{i}", + "data": cit, + } + ) + + high_relevance = sum(1 for c in citations if (c.get("relevance_score") or 0) > 0.6) + _emit_thinking( + writer, + "rank", + "Analyzing citations", + status="done", + duration_ms=int((time.monotonic() - t0) * 1000), + summary=f"Selected {high_relevance} high-relevance citations (>60%)", + ) + + return {"citations": citations} + + +# --------------------------------------------------------------------------- +# Node: clean +# --------------------------------------------------------------------------- + + +async def _clean_single_excerpt(llm, excerpt: str) -> str: + """Use LLM to clean OCR-extracted text with timeout and semaphore.""" + if not excerpt or len(excerpt) < 20: + return excerpt + async with _clean_semaphore: + messages = [ + {"role": "system", "content": EXCERPT_CLEAN_PROMPT}, + {"role": "user", "content": excerpt}, + ] + result = "" + try: + async with asyncio.timeout(10.0): + async for token in llm.chat_stream(messages, temperature=0.1, task_type="clean"): + result += token + except TimeoutError: + logger.warning("Excerpt cleaning timed out, using original") + return excerpt + return result if result.strip() else excerpt + + +async def clean_node(state: ChatState, config: RunnableConfig) -> dict[str, Any]: + """Clean citation excerpts in parallel using LLM.""" + writer = get_stream_writer() + t0 = time.monotonic() + llm = get_chat_llm(config) + citations = list(state.get("citations", [])) + + excerpts_to_clean = [(i, c["excerpt"]) for i, c in enumerate(citations) if c.get("excerpt")] + if not excerpts_to_clean: + _emit_thinking(writer, "clean", "Cleaning citation text", status="done", duration_ms=0, summary="No excerpts") + return {"citations": citations} + + _emit_thinking( + writer, + "clean", + "Cleaning citation text", + detail=f"Improving readability of {len(excerpts_to_clean)} citations in parallel...", + ) + + clean_tasks = [_clean_single_excerpt(llm, excerpt) for _, excerpt in excerpts_to_clean] + cleaned_results = await asyncio.gather(*clean_tasks, return_exceptions=True) + + enhanced_count = 0 + for (idx, _original), cleaned in zip(excerpts_to_clean, cleaned_results): + if isinstance(cleaned, str) and cleaned.strip() and cleaned != citations[idx]["excerpt"]: + citations[idx]["excerpt"] = cleaned + enhanced_count += 1 + # Re-emit citation with same id → AI SDK reconciliation updates the Part + writer( + { + "type": "data-citation", + "id": f"cit-{citations[idx]['index']}", + "data": citations[idx], + } + ) + + _emit_thinking( + writer, + "clean", + "Cleaning citation text", + status="done", + duration_ms=int((time.monotonic() - t0) * 1000), + summary=f"Enhanced {enhanced_count} citations", + ) + + return {"citations": citations} + + +# --------------------------------------------------------------------------- +# Node: generate +# --------------------------------------------------------------------------- + + +async def generate_node(state: ChatState, config: RunnableConfig) -> dict[str, Any]: + """Stream LLM response token by token.""" + writer = get_stream_writer() + t0 = time.monotonic() + llm = get_chat_llm(config) + + # Build final messages + system_prompt = state.get("system_prompt", "") + history = state.get("history_messages", []) + kb_ids = state.get("knowledge_base_ids", []) + all_contexts = state.get("all_contexts", []) + + if kb_ids: + context_text = "\n\n---\n\n".join(all_contexts) if all_contexts else "No relevant documents found." + user_content = f"Context:\n{context_text}\n\nQuestion: {state['message']}" + else: + user_content = state["message"] + + messages: list[ChatMessageDict] = [ + {"role": "system", "content": system_prompt}, + *history, + {"role": "user", "content": user_content}, + ] + + citations = state.get("citations", []) + _emit_thinking( + writer, + "generate", + "Generating answer", + detail=f"Generating answer based on {len(citations)} citations...", + ) + + text_id = f"text_{uuid.uuid4().hex}" + writer({"type": "text-start", "id": text_id}) + + full_response = "" + try: + async for token in llm.chat_stream(messages, temperature=0.3, task_type="chat"): + full_response += token + writer({"type": "text-delta", "id": text_id, "delta": token}) + except Exception: + logger.exception("LLM streaming error during generate") + writer({"type": "text-end", "id": text_id}) + writer({"type": "error", "errorText": "LLM generation failed"}) + return {"assistant_content": full_response, "error": "LLM generation failed"} + + writer({"type": "text-end", "id": text_id}) + + _emit_thinking( + writer, + "generate", + "Generating answer", + status="done", + duration_ms=int((time.monotonic() - t0) * 1000), + summary=f"Generated {len(full_response)} characters", + ) + + return {"assistant_content": full_response, "full_messages": messages} + + +# --------------------------------------------------------------------------- +# Node: persist +# --------------------------------------------------------------------------- + + +async def persist_node(state: ChatState, config: RunnableConfig) -> dict[str, Any]: + """Save conversation and messages to the database.""" + writer = get_stream_writer() + db = get_chat_db(config) + + try: + from app.models.conversation import Conversation + from app.models.message import Message + + conversation_id = state.get("conversation_id") + + if not conversation_id: + title = state["message"][:50] + ("..." if len(state["message"]) > 50 else "") + conv = Conversation( + title=title, + knowledge_base_ids=state.get("knowledge_base_ids") or [], + model=state.get("model", ""), + tool_mode=state.get("tool_mode", "qa"), + ) + db.add(conv) + await db.flush() + conversation_id = conv.id + + citations = state.get("citations") + user_msg = Message( + conversation_id=conversation_id, + role="user", + content=state["message"], + ) + assistant_msg = Message( + conversation_id=conversation_id, + role="assistant", + content=state.get("assistant_content", ""), + citations=citations if citations else None, + ) + db.add(user_msg) + db.add(assistant_msg) + await db.commit() + + writer( + { + "type": "data-conversation", + "data": {"conversation_id": conversation_id}, + } + ) + + return {"new_conversation_id": conversation_id} + + except Exception as e: + logger.exception("Failed to persist conversation") + _emit_thinking(writer, "persist", "Saving", status="error", detail=str(e)) + return {"error": f"persist failed: {e}"} diff --git a/backend/app/pipelines/chat/state.py b/backend/app/pipelines/chat/state.py new file mode 100644 index 0000000..244ce10 --- /dev/null +++ b/backend/app/pipelines/chat/state.py @@ -0,0 +1,51 @@ +"""Chat pipeline state definition.""" + +from __future__ import annotations + +from typing import Any, Literal, TypedDict + + +class CitationDict(TypedDict, total=False): + """Typed citation matching the frontend Citation interface.""" + + index: int + paper_id: int | None + paper_title: str + chunk_type: str + page_number: int | None + relevance_score: float + excerpt: str + authors: list[str] | None + year: int | None + doi: str | None + + +class ChatMessageDict(TypedDict): + """LLM message format.""" + + role: Literal["system", "user", "assistant"] + content: str + + +class ChatState(TypedDict, total=False): + """LangGraph chat pipeline state.""" + + # --- Input (from request) --- + message: str + knowledge_base_ids: list[int] + tool_mode: str + conversation_id: int | None + model: str + + # --- Intermediate (between nodes) --- + rag_results: list[dict[str, Any]] + citations: list[CitationDict] + all_contexts: list[str] + history_messages: list[ChatMessageDict] + system_prompt: str + full_messages: list[ChatMessageDict] + + # --- Output --- + assistant_content: str + new_conversation_id: int | None + error: str | None diff --git a/backend/app/pipelines/chat/stream_writer.py b/backend/app/pipelines/chat/stream_writer.py new file mode 100644 index 0000000..deef0d1 --- /dev/null +++ b/backend/app/pipelines/chat/stream_writer.py @@ -0,0 +1,48 @@ +"""Data Stream Protocol SSE formatting helpers. + +Outputs events in the Vercel AI SDK 5.0 Data Stream Protocol format. +Each function returns a string ready to be yielded from a StreamingResponse. +""" + +from __future__ import annotations + +import json +import uuid + + +def format_start(message_id: str | None = None) -> str: + mid = message_id or f"msg_{uuid.uuid4().hex}" + return f"data: {json.dumps({'type': 'start', 'messageId': mid})}\n\n" + + +def format_text_start(text_id: str | None = None) -> str: + tid = text_id or f"text_{uuid.uuid4().hex}" + return f"data: {json.dumps({'type': 'text-start', 'id': tid})}\n\n" + + +def format_text_delta(text_id: str, delta: str) -> str: + return f"data: {json.dumps({'type': 'text-delta', 'id': text_id, 'delta': delta})}\n\n" + + +def format_text_end(text_id: str) -> str: + return f"data: {json.dumps({'type': 'text-end', 'id': text_id})}\n\n" + + +def format_data_part(data_type: str, data: dict, *, part_id: str | None = None) -> str: + """Format a custom data-* part. ``data_type`` should NOT include the ``data-`` prefix.""" + payload: dict = {"type": f"data-{data_type}", "data": data} + if part_id is not None: + payload["id"] = part_id + return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + + +def format_error(message: str) -> str: + return f"data: {json.dumps({'type': 'error', 'errorText': message})}\n\n" + + +def format_finish() -> str: + return f"data: {json.dumps({'type': 'finish'})}\n\n" + + +def format_done() -> str: + return "data: [DONE]\n\n" diff --git a/backend/tests/test_chat.py b/backend/tests/test_chat.py index b0fae03..1077ef4 100644 --- a/backend/tests/test_chat.py +++ b/backend/tests/test_chat.py @@ -1,18 +1,13 @@ -"""Tests for Chat and Conversation API endpoints.""" +"""Tests for Conversation CRUD API endpoints. + +Chat streaming tests are in test_chat_pipeline.py. +""" -import chromadb import pytest from httpx import ASGITransport, AsyncClient -from llama_index.core.embeddings import MockEmbedding -from app.api.v1.chat import _get_rag_service_for_chat -from app.database import Base, async_session_factory, engine +from app.database import Base, engine from app.main import app -from app.models import PaperStatus, Project -from app.models.chunk import PaperChunk -from app.models.paper import Paper -from app.services.llm.client import LLMClient -from app.services.rag_service import RAGService @pytest.fixture(autouse=True) @@ -31,63 +26,6 @@ async def client(): yield ac -@pytest.fixture -def mock_rag_llm(): - """Provide mock RAG service and LLM client.""" - llm = LLMClient(provider="mock") - rag = RAGService( - llm=llm, - chroma_client=chromadb.EphemeralClient(), - embed_model=MockEmbedding(embed_dim=128), - ) - return rag, llm - - -@pytest.fixture(autouse=True) -def override_chat_deps(mock_rag_llm): - rag, llm = mock_rag_llm - - async def _mock_get_rag_service(db): - return rag, llm - - app.dependency_overrides[_get_rag_service_for_chat] = _mock_get_rag_service - import app.api.v1.chat as chat_module - - _original = chat_module._get_rag_service_for_chat - chat_module._get_rag_service_for_chat = _mock_get_rag_service - yield - chat_module._get_rag_service_for_chat = _original - app.dependency_overrides.clear() - - -@pytest.fixture -async def project_id(): - async with async_session_factory() as session: - proj = Project(name="Chat Test", domain="test") - session.add(proj) - await session.flush() - - paper = Paper( - project_id=proj.id, - title="Test Paper", - abstract="Test abstract", - status=PaperStatus.INDEXED, - ) - session.add(paper) - await session.flush() - - chunk = PaperChunk( - paper_id=paper.id, - content="Super-resolution microscopy enables imaging beyond the diffraction limit.", - chunk_type="text", - page_number=1, - chunk_index=0, - ) - session.add(chunk) - await session.commit() - return proj.id - - # --- Conversation CRUD tests --- @@ -154,85 +92,3 @@ async def test_delete_conversation(client: AsyncClient): resp2 = await client.get(f"/api/v1/conversations/{conv_id}") assert resp2.status_code == 404 - - -# --- SSE Chat Stream tests --- - - -@pytest.mark.asyncio -async def test_chat_stream_creates_conversation(client: AsyncClient, project_id: int): - resp = await client.post( - "/api/v1/chat/stream", - json={ - "knowledge_base_ids": [project_id], - "message": "What is STED?", - }, - ) - assert resp.status_code == 200 - assert resp.headers["content-type"].startswith("text/event-stream") - - text = resp.text - assert "event: message_start" in text - assert "event: text_delta" in text - assert "event: message_end" in text - assert "conversation_id" in text - - -@pytest.mark.asyncio -async def test_chat_stream_continues_conversation(client: AsyncClient, project_id: int): - create_resp = await client.post( - "/api/v1/conversations", - json={"title": "Stream Test", "knowledge_base_ids": [project_id]}, - ) - conv_id = create_resp.json()["data"]["id"] - - resp = await client.post( - "/api/v1/chat/stream", - json={ - "conversation_id": conv_id, - "knowledge_base_ids": [project_id], - "message": "Tell me more", - }, - ) - assert resp.status_code == 200 - text = resp.text - assert "event: message_end" in text - - detail_resp = await client.get(f"/api/v1/conversations/{conv_id}") - messages = detail_resp.json()["data"]["messages"] - assert len(messages) == 2 - assert messages[0]["role"] == "user" - assert messages[1]["role"] == "assistant" - - -@pytest.mark.asyncio -async def test_chat_stream_tool_modes(client: AsyncClient, project_id: int): - for mode in ["qa", "citation_lookup", "review_outline", "gap_analysis"]: - resp = await client.post( - "/api/v1/chat/stream", - json={ - "knowledge_base_ids": [project_id], - "message": "Test", - "tool_mode": mode, - }, - ) - assert resp.status_code == 200 - assert "event: message_end" in resp.text - - -@pytest.mark.asyncio -async def test_chat_stream_no_kb_direct_llm(client: AsyncClient): - """When no knowledge_base_ids are provided, direct LLM chat should work.""" - resp = await client.post( - "/api/v1/chat/stream", - json={ - "knowledge_base_ids": [], - "message": "What is photosynthesis?", - }, - ) - assert resp.status_code == 200 - text = resp.text - assert "event: message_start" in text - assert "event: text_delta" in text - assert "event: message_end" in text - assert "event: citation" not in text diff --git a/backend/tests/test_chat_pipeline.py b/backend/tests/test_chat_pipeline.py new file mode 100644 index 0000000..c7f7f5c --- /dev/null +++ b/backend/tests/test_chat_pipeline.py @@ -0,0 +1,321 @@ +"""Tests for the LangGraph chat pipeline nodes and stream_writer.""" + +from __future__ import annotations + +import json + +import pytest + +from app.pipelines.chat.state import ChatMessageDict, ChatState, CitationDict +from app.pipelines.chat.stream_writer import ( + format_data_part, + format_done, + format_error, + format_finish, + format_start, + format_text_delta, + format_text_end, + format_text_start, +) + +# --------------------------------------------------------------------------- +# stream_writer format tests +# --------------------------------------------------------------------------- + + +class TestStreamWriter: + def test_format_start(self): + result = format_start("msg_abc") + parsed = json.loads(result.removeprefix("data: ").strip()) + assert parsed["type"] == "start" + assert parsed["messageId"] == "msg_abc" + + def test_format_start_auto_id(self): + result = format_start() + parsed = json.loads(result.removeprefix("data: ").strip()) + assert parsed["type"] == "start" + assert parsed["messageId"].startswith("msg_") + + def test_format_text_start(self): + result = format_text_start("text_1") + parsed = json.loads(result.removeprefix("data: ").strip()) + assert parsed == {"type": "text-start", "id": "text_1"} + + def test_format_text_delta(self): + result = format_text_delta("text_1", "hello") + parsed = json.loads(result.removeprefix("data: ").strip()) + assert parsed == {"type": "text-delta", "id": "text_1", "delta": "hello"} + + def test_format_text_end(self): + result = format_text_end("text_1") + parsed = json.loads(result.removeprefix("data: ").strip()) + assert parsed == {"type": "text-end", "id": "text_1"} + + def test_format_data_part(self): + result = format_data_part("citation", {"index": 1, "title": "Test"}, part_id="cit-1") + parsed = json.loads(result.removeprefix("data: ").strip()) + assert parsed["type"] == "data-citation" + assert parsed["id"] == "cit-1" + assert parsed["data"]["index"] == 1 + + def test_format_data_part_without_id(self): + result = format_data_part("thinking", {"step": "understand"}) + parsed = json.loads(result.removeprefix("data: ").strip()) + assert parsed["type"] == "data-thinking" + assert "id" not in parsed + assert parsed["data"]["step"] == "understand" + + def test_format_error(self): + result = format_error("something broke") + parsed = json.loads(result.removeprefix("data: ").strip()) + assert parsed == {"type": "error", "errorText": "something broke"} + + def test_format_finish(self): + result = format_finish() + parsed = json.loads(result.removeprefix("data: ").strip()) + assert parsed == {"type": "finish"} + + def test_format_done(self): + assert format_done() == "data: [DONE]\n\n" + + def test_unicode_in_data_part(self): + result = format_data_part("citation", {"title": "超分辨率显微镜"}) + assert "超分辨率显微镜" in result + + +# --------------------------------------------------------------------------- +# State type tests +# --------------------------------------------------------------------------- + + +class TestChatState: + def test_citation_dict(self): + cit: CitationDict = { + "index": 1, + "paper_title": "Test Paper", + "excerpt": "Some text", + "relevance_score": 0.9, + "chunk_type": "text", + } + assert cit["index"] == 1 + assert cit["paper_title"] == "Test Paper" + + def test_chat_message_dict(self): + msg: ChatMessageDict = {"role": "user", "content": "Hello"} + assert msg["role"] == "user" + + def test_chat_state_minimal(self): + state: ChatState = { + "message": "What is STED?", + "knowledge_base_ids": [1], + "tool_mode": "qa", + } + assert state["message"] == "What is STED?" + + def test_chat_state_full(self): + state: ChatState = { + "message": "Test", + "knowledge_base_ids": [], + "tool_mode": "qa", + "conversation_id": None, + "model": "mock", + "rag_results": [], + "citations": [], + "all_contexts": [], + "history_messages": [], + "system_prompt": "You are helpful", + "full_messages": [], + "assistant_content": "", + "new_conversation_id": None, + "error": None, + } + assert state["error"] is None + + +# --------------------------------------------------------------------------- +# Graph compilation tests +# --------------------------------------------------------------------------- + + +class TestChatGraph: + def test_graph_compiles(self): + from app.pipelines.chat.graph import create_chat_pipeline + + pipeline = create_chat_pipeline() + assert pipeline is not None + + def test_graph_has_expected_nodes(self): + from app.pipelines.chat.graph import create_chat_pipeline + + pipeline = create_chat_pipeline() + graph = pipeline.get_graph() + node_names = set(graph.nodes.keys()) if isinstance(graph.nodes, dict) else {n for n in graph.nodes} + expected = {"understand", "retrieve", "rank", "clean", "generate", "persist"} + assert expected.issubset(node_names) + + +# --------------------------------------------------------------------------- +# Config helpers tests +# --------------------------------------------------------------------------- + + +class TestConfigHelpers: + def test_get_chat_db(self): + from app.pipelines.chat.config_helpers import get_chat_db + + mock_db = object() + config = {"configurable": {"db": mock_db}} + assert get_chat_db(config) is mock_db + + def test_set_and_get_services(self): + from app.pipelines.chat.config_helpers import ( + get_chat_llm, + get_chat_rag, + set_chat_services, + ) + + config = {"configurable": {"db": "x"}} + mock_llm = object() + mock_rag = object() + set_chat_services(config, llm=mock_llm, rag=mock_rag) + assert get_chat_llm(config) is mock_llm + assert get_chat_rag(config) is mock_rag + + def test_services_shared_across_config_copies(self): + """Verify that shallow config copies share the _services dict.""" + from app.pipelines.chat.config_helpers import ( + get_chat_llm, + set_chat_services, + ) + + config = {"configurable": {"db": "x"}} + set_chat_services(config, llm="test_llm", rag="test_rag") + + config_copy = {**config, "configurable": config["configurable"]} + assert get_chat_llm(config_copy) == "test_llm" + + +# --------------------------------------------------------------------------- +# Integration test: full endpoint (requires DB) +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +async def setup_db(): + from app.database import Base, engine + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +@pytest.fixture(autouse=True) +def mock_services(monkeypatch): + """Mock _init_services so endpoint tests use mock LLM/RAG without DB lookups.""" + import app.api.v1.chat as chat_module + from app.services.llm.client import LLMClient + + async def _mock_init_services(db): + from llama_index.core.embeddings import MockEmbedding + + from app.services.rag_service import RAGService + + llm = LLMClient(provider="mock") + rag = RAGService(llm=llm, embed_model=MockEmbedding(embed_dim=128)) + return {"llm": llm, "rag": rag} + + monkeypatch.setattr(chat_module, "_init_services", _mock_init_services) + + +@pytest.fixture +async def client(): + from httpx import ASGITransport, AsyncClient + + from app.main import app + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + + +@pytest.mark.asyncio +async def test_stream_endpoint_data_stream_protocol(client): + """Verify the /stream endpoint emits Data Stream Protocol events.""" + resp = await client.post( + "/api/v1/chat/stream", + json={"message": "Hello", "knowledge_base_ids": []}, + ) + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("text/event-stream") + + text = resp.text + lines = [line for line in text.split("\n") if line.startswith("data: ")] + + event_types = [] + error_text = None + for line in lines: + payload = line.removeprefix("data: ").strip() + if payload == "[DONE]": + event_types.append("[DONE]") + continue + try: + parsed = json.loads(payload) + etype = parsed.get("type", "unknown") + event_types.append(etype) + if etype == "error": + error_text = parsed.get("errorText", "") + except json.JSONDecodeError: + pass + + assert error_text is None, f"Stream returned error: {error_text}" + assert "start" in event_types + assert "text-delta" in event_types + assert "finish" in event_types + assert "[DONE]" in event_types + + +@pytest.mark.asyncio +async def test_stream_endpoint_no_kb_skips_rag(client): + """Without knowledge_base_ids, the pipeline skips RAG nodes.""" + resp = await client.post( + "/api/v1/chat/stream", + json={"message": "What is 2+2?", "knowledge_base_ids": []}, + ) + assert resp.status_code == 200 + text = resp.text + assert "data-citation" not in text + assert "text-delta" in text + + +@pytest.mark.asyncio +async def test_stream_endpoint_persists_conversation(client): + """Verify that the stream creates a conversation and persists messages.""" + resp = await client.post( + "/api/v1/chat/stream", + json={"message": "Tell me about AI", "knowledge_base_ids": []}, + ) + assert resp.status_code == 200 + + lines = [line for line in resp.text.split("\n") if line.startswith("data: ")] + conv_id = None + for line in lines: + payload = line.removeprefix("data: ").strip() + if payload == "[DONE]": + continue + try: + parsed = json.loads(payload) + if parsed.get("type") == "data-conversation": + conv_id = parsed["data"]["conversation_id"] + except json.JSONDecodeError: + pass + + assert conv_id is not None + + detail_resp = await client.get(f"/api/v1/conversations/{conv_id}") + assert detail_resp.status_code == 200 + messages = detail_resp.json()["data"]["messages"] + assert len(messages) == 2 + assert messages[0]["role"] == "user" + assert messages[1]["role"] == "assistant" diff --git a/docs/brainstorms/2026-03-12-chat-message-routing-chain-brainstorm.md b/docs/brainstorms/2026-03-12-chat-message-routing-chain-brainstorm.md new file mode 100644 index 0000000..36cd016 --- /dev/null +++ b/docs/brainstorms/2026-03-12-chat-message-routing-chain-brainstorm.md @@ -0,0 +1,340 @@ +--- +date: 2026-03-12 +topic: chat-message-routing-chain +depends_on: + - 2026-03-12-frontend-ux-robustness-brainstorm.md + - 2026-03-11-ux-architecture-upgrade-brainstorm.md +--- + +# 聊天消息路由链全面重构 + +## 我们要构建什么 + +对 Omelette 的聊天消息处理链路进行**全面重构**,覆盖前端协议层、状态管理、后端处理管道三个维度。目标是从当前的"手工 SSE 解析 + 巨型组件状态管理 + 单体流函数"升级为"标准化协议 + 抽象状态层 + 可编排管道"。 + +核心变化: +1. **协议层**:自定义 SSE 解析 → Vercel AI SDK 5.0 Data Stream Protocol +2. **前端状态**:PlaygroundPage ~15 个 useState → `useChat` + Transport 抽象 +3. **后端管道**:`_stream_chat` 单体函数 (230+ 行) → LangGraph StateGraph 可编排节点 +4. **可靠性**:无错误处理/无断流重试 → 标准错误 Part + Resumable Streams + +## 现状分析:当前消息路由链 + +### 完整链路图 + +``` +用户输入 + ↓ +ChatInput.handleSubmit (trim → onSend → clear) + ↓ +PlaygroundPage.handleSend + ├─ 创建 user/assistant LocalMessage + ├─ pendingDeltaRef + assistantIdRef 初始化 + ├─ setMessages([...prev, userMsg, assistantMsg]) + ├─ setIsStreaming(true) + ├─ AbortController 创建 + ↓ +streamChat(fetch POST /api/v1/chat/stream) ←── chat-api.ts + ├─ fetch + ReadableStream.getReader() + ├─ TextDecoder + 手动 buffer split('\n') + ├─ 解析 `event:` + `data:` + 空行 → yield SSEEvent + ↓ +for await (const event of gen) ←── PlaygroundPage.tsx + ├─ text_delta → pendingDeltaRef + 80ms debounce flush + ├─ citation → isCitation() → normalizeCitation() → append + ├─ thinking_step → update/append thinkingSteps + ├─ citation_enhanced → update citations[index].excerpt + ├─ a2ui_surface → append to a2uiMessages + ├─ message_end → flushDelta + setConversationId + navigate + └─ error → ❌ 未处理! + ↓ +finally: cleanup (timer, flush, isStreaming=false) +``` + +### 后端 `_stream_chat` 链路 + +``` +POST /api/v1/chat/stream + ↓ +_stream_chat(request, db) → AsyncGenerator[str, None] + ├─ message_start + ├─ thinking_step(understand, running) + ├─ _get_rag_service_for_chat() → (rag, llm) + ├─ thinking_step(understand, done) + │ + ├─ [if knowledge_base_ids]: + │ ├─ thinking_step(retrieve, running) + │ ├─ asyncio.gather(*rag_tasks) ← RAGService.query() + │ ├─ thinking_step(retrieve, done) + │ ├─ thinking_step(rank, running) + │ ├─ Load papers, build citations → yield citation × N + │ ├─ thinking_step(rank, done) + │ ├─ thinking_step(clean, running) + │ ├─ _clean_excerpt × M (LLM, semaphore, timeout) + │ ├─ yield citation_enhanced × M + │ └─ thinking_step(clean, done) + │ + ├─ Build history_messages from DB + ├─ Build messages (system + history + user) + ├─ thinking_step(generate, running) + ├─ llm.chat_stream() → yield text_delta × K + ├─ thinking_step(generate, done) + ├─ Create conversation + messages in DB + ├─ thinking_step(complete, done) + └─ message_end +``` + +### 现存问题 + +| # | 问题 | 严重度 | 影响 | +|---|------|--------|------| +| R-1 | 手写 SSE 解析器,无标准化 | 高 | 无法利用生态工具、难以测试、易出 bug | +| R-2 | 前端不处理 `error` SSE 事件 | 严重 | 后端发 error 前端静默忽略,用户无感知 | +| R-3 | SSE 断流无重试 | 高 | 网络闪断直接丢失响应 | +| R-4 | `response.body!` 非空断言 | 中 | 已改为 `?.` 但仍无优雅降级 | +| R-5 | PlaygroundPage ~15 个 useState/useRef | 高 | 状态逻辑与 UI 高度耦合,难测试难复用 | +| R-6 | text_delta 80ms 手动防抖 | 中 | 非标准实现,性能特征不可控 | +| R-7 | `_stream_chat` 单体函数 230+ 行 | 高 | 步骤间耦合,难以独立测试/复用/扩展 | +| R-8 | thinking_step 硬编码在流函数中 | 中 | 无法灵活添加/移除/重排步骤 | +| R-9 | citation 处理 (清洗/增强) 内联在流中 | 中 | 清洗超时影响整条流 | +| R-10 | fetch/axios 混用 | 中 | 流用 fetch、其他用 axios,错误处理不一致 | +| R-11 | 消息模型 (LocalMessage) 非标准 | 中 | 自定义接口,与 AI SDK UIMessage 不兼容 | + +## 为什么选择这个方案 + +### 考虑过的方案 + +| 方案 | 描述 | 取舍 | +|------|------|------| +| **A: Vercel AI SDK 5.0 + LangGraph(选中)** | 前端 useChat + Data Stream Protocol,后端 LangGraph StateGraph 编排 | 最彻底,但改动量大;获得标准化协议+可编排管道+类型安全 | +| B: 自建标准化 SSE + 中间件链 | 保留自定义 SSE 但规范化格式,后端用 Express 风格中间件链 | 轻量灵活,但丢失 AI SDK 生态优势(自动重试、断流恢复、工具调用等) | +| C: 增量修补 | 仅修 error 处理、抽取状态、拆分函数 | 改动最小,但不解决根本架构问题,长期技术债累积 | + +**选择方案 A 的理由:** +- Vercel AI SDK 5.0 的 transport 架构天然支持自定义后端(包括 FastAPI Python 后端) +- Data Stream Protocol 的 `data-*` 自定义 Part 可以覆盖所有 Omelette 的自定义事件(citation、thinking、a2ui) +- LangGraph 项目已经引入,且 StateGraph 天然适合"有条件分支+检查点+HITL"的聊天管道 +- 一次性解决协议标准化、状态管理、后端可编排三个问题,避免分三次重构 + +## 关键决策 + +### 1. 协议层:Vercel AI SDK 5.0 Data Stream Protocol + +- **决策**:前端迁移到 `@ai-sdk/react` 的 `useChat` + `DefaultChatTransport` +- **SSE 格式**:后端输出标准 Data Stream Protocol SSE(`data: {"type":"..."}` 格式) +- **自定义事件映射**: + + | 当前事件 | AI SDK 映射 | 说明 | + |---------|------------|------| + | `text_delta` | `text-start` + `text-delta` + `text-end` | 标准 text streaming,有 ID 追踪 | + | `citation` | `data-citation` | 自定义 data Part | + | `citation_enhanced` | `data-citation-enhanced` | 自定义 data Part | + | `thinking_step` | `data-thinking` | 自定义 data Part | + | `a2ui_surface` | `data-a2ui` | 自定义 data Part | + | `message_start` | `start` (messageId) | 标准 Part | + | `message_end` | `finish` | 标准 Part | + | `error` | `error` (errorText) | 标准 Part | + +- **理由**:Data Stream Protocol 的 `data-*` 类型是专门为自定义数据设计的扩展点,无需 hack + +### 2. 前端状态管理:useChat 替代手动 useState + +- **决策**:用 `useChat` hook 管理 messages、status、streaming 状态 +- **消息模型**:从 `LocalMessage` 迁移到 `UIMessage` + `parts` +- **自定义数据访问**:通过 `useChat` 的 `onMessage` 或 `parts` 过滤来处理 `data-citation` 等 +- **影响**:PlaygroundPage 从 ~15 个 useState 精简到核心交互状态(sidebarCollapsed、toolMode 等) +- **理由**:useChat 内部已处理好 streaming 状态、消息追加、abort、重试等逻辑 + +### 3. 后端管道:LangGraph StateGraph + +- **决策**:将 `_stream_chat` 拆分为 LangGraph 节点 +- **节点设计**: + + ``` + [understand] → [retrieve] → [rank] → [clean] → [generate] → [persist] → [complete] + ↓ (无 KB) + [generate] → [persist] → [complete] + ``` + + | 节点 | 职责 | 输入 | 输出 | + |------|------|------|------| + | understand | 解析请求,获取 LLM/RAG 服务 | request | llm, rag, parsed_query | + | retrieve | RAG 查询多个知识库 | rag, kb_ids, query | raw_results | + | rank | 构建引用,加载论文元数据 | raw_results | citations | + | clean | LLM 清洗引用摘要 | citations, llm | enhanced_citations | + | generate | LLM 流式生成回答 | messages, llm | text_stream | + | persist | 保存对话和消息到 DB | conversation, messages | conversation_id | + | complete | 生成结束信号 | all | message_end | + +- **SSE 发射**:每个节点通过 `StreamWriter` 发射标准 Data Stream Protocol 事件 +- **条件路由**:`retrieve` 节点根据 `knowledge_base_ids` 是否存在决定是否跳过到 `generate` +- **理由**: + - 每个节点可独立测试 + - 条件路由已内置(无需 if/else 嵌套) + - 未来可加入 HITL 检查点(如引用确认后再生成) + - 可复用已有的 LangGraph checkpointing 基础设施 + +### 4. 可靠性:错误处理(核心)+ 断流恢复(增强) + +- **核心(Phase 1 同步完成)**: + - 错误:标准 `error` Part + 前端 `useChat` 的 `error` 状态自动处理 + - abort:`useChat` 内置 `stop()` 方法 +- **增强(后续迭代)**: + - 前端:AI SDK 5.0 内置 `reconnect` 能力(`prepareReconnectToStreamRequest`) + - 后端:LangGraph 检查点 → 可从任意节点恢复 +- **范围界定**:Resumable Streams 是独立增强特性,不阻塞主体重构。当前先确保错误和中断不会导致 UI 崩溃,断流恢复留待第二迭代。 +- **理由**:先解决"断了会崩"(严重),再解决"断了能续"(增强) + +### 5. 后端 SSE 输出适配 + +- **决策**:创建 `StreamWriter` 工具类,封装 Data Stream Protocol 格式输出 +- **格式**:`data: {"type":"...", ...}\n\n`(标准 SSE 格式) +- **Header**:`x-vercel-ai-ui-message-stream: v1` +- **终止**:`data: [DONE]\n\n` +- **理由**:一处封装,所有节点统一调用;前端 useChat 自动解析 + +## 开源项目对比参考 + +### Vercel AI SDK 5.0 (标准参考) + +- **架构**:Provider → Core → Framework 三层 +- **协议**:Data Stream Protocol (SSE `data: {"type":"..."}`) +- **状态**:`useChat` hook 管理 messages, status, error +- **Transport**:可替换 `DefaultChatTransport`,支持自定义后端 +- **亮点**:`data-*` 自定义 Part、reconnect、abort、tool-call 标准化 + +### LibreChat (可靠性参考) + +- **Resumable Streams**:断线后客户端透明恢复,服务端从当前位置继续 +- **部署模式**:单实例用 Node EventEmitter pub/sub,多实例用 Redis Streams +- **文本动画**:动态速度调整(10→16 字符/ms),队列越长越快 +- **亮点**:断点续传对学术写作场景(长响应)极其重要 + +### Open WebUI (管道参考) + +- **Pipeline**:Inlet → Process → Outlet 三阶段 +- **Filter**:可插拔的 filter 链(监控、修改、阻断、翻译、限流) +- **异步模式**:耗时操作(web search 30-60s+)立即返回 task_id,WebSocket 推送进度 +- **亮点**:Pipeline 可扩展性强,适合学术场景(搜索→去重→全文获取→OCR 可能非常耗时) + + +## 技术可行性验证 + +### AI SDK 5.0 + FastAPI 后端 ✅ + +**结论**:可行,但需手动输出 Data Stream Protocol SSE 格式。 + +**验证来源**:[vercel/ai#7496](https://github.com/vercel/ai/issues/7496) + 社区 working example + +**后端输出格式**(已验证可工作): + +```python +import json, uuid + +async def event_stream(): + message_id = f"msg_{uuid.uuid4().hex}" + yield f'data: {json.dumps({"type": "start", "messageId": message_id})}\n\n' + + text_id = f"text_{uuid.uuid4().hex}" + yield f'data: {json.dumps({"type": "text-start", "id": text_id})}\n\n' + for chunk in chunks: + yield f'data: {json.dumps({"type": "text-delta", "id": text_id, "delta": chunk})}\n\n' + yield f'data: {json.dumps({"type": "text-end", "id": text_id})}\n\n' + + # 自定义 data Part(citation 等) + yield f'data: {json.dumps({"type": "data-citation", "data": citation_obj})}\n\n' + + yield f'data: {json.dumps({"type": "finish"})}\n\n' + yield 'data: [DONE]\n\n' +``` + +**关键要求**: +- Response Header 必须包含 `x-vercel-ai-ui-message-stream: v1` +- 每行格式 `data: {JSON}\n\n`(标准 SSE) +- 流结束 `data: [DONE]\n\n` +- 无官方 Python SDK helper,需自行封装 `StreamWriter` 工具类 + +**参考实现**:[Pydantic AI 的 Vercel AI SDK 协议实现](https://ai.pydantic.dev/ui/vercel-ai/) + +### LangGraph `get_stream_writer()` ✅ + +**结论**:完美支持从节点内部实时发射自定义事件。 + +**验证来源**:[LangGraph 官方文档](https://reference.langchain.com/python/langgraph/config/get_stream_writer) + 社区实践 + +```python +from langgraph.config import get_stream_writer + +def retrieve_node(state: ChatState) -> ChatState: + writer = get_stream_writer() + writer({"type": "data-thinking", "data": {"step": "retrieve", "status": "running"}}) + + results = rag_service.query(state["query"], state["kb_ids"]) + + writer({"type": "data-thinking", "data": {"step": "retrieve", "status": "done"}}) + return {**state, "rag_results": results} +``` + +**关键要求**: +- Python 3.11+(ContextVar 异步传播需要) +- `stream_mode=["updates", "custom"]` +- `get_stream_writer()` 在节点函数体内调用 +- 对于 `generate` 节点的 `text_delta` 流式输出,需要在 LLM streaming 循环内调用 `writer()` + +**LangGraph custom stream → Data Stream Protocol 桥接**: + +```python +async for chunk in graph.astream(input_state, stream_mode=["updates", "custom"]): + if chunk[0] == "custom": # custom event from get_stream_writer() + yield f'data: {json.dumps(chunk[1])}\n\n' +``` + +### 前端 `data-*` Part 消费方式 ✅ + +AI SDK 5.0 的 `UIMessage.parts` 数组包含所有 Part(含自定义 `data-*`),前端通过 `parts.filter()` 访问: + +```tsx +// 在 MessageBubble 中 +{message.parts.map((part, i) => { + switch (part.type) { + case 'text': + return ; + case 'data-citation': + return ; + case 'data-thinking': + return ; + } +})} +``` + +### Conversation 恢复对接 + +`useChat` 支持 `initialMessages` 和 `chatId` 参数: + +```tsx +const { messages, sendMessage, status } = useChat({ + chatId: conversationId, // 对应 /chat/:conversationId + initialMessages: restoredMessages, // 从 DB 恢复的历史消息 + transport: new DefaultChatTransport({ api: '/api/v1/chat/stream' }), +}); +``` + +路由恢复时,通过 `conversationApi.get(id)` 加载历史消息并转换为 `UIMessage[]` 格式传入 `initialMessages`。 + +## 未陈述的假设 + +1. AI SDK 5.0 已进入 stable(当前最新为 beta,但核心 API 已稳定且有大量生产使用) +2. Python 3.12 满足 LangGraph `get_stream_writer()` 的 ContextVar 要求(需 ≥ 3.11) +3. `data-*` Part 的 `data` 字段可以是任意 JSON 可序列化对象(已由协议规范确认) +4. Resumable Streams 作为**增强特性**在核心重构完成后实施,不阻塞主体工作 + +## Resolved Questions + +1. **AI SDK 5.0 + FastAPI 是否可行?** → ✅ 可行,需手动格式化 SSE,社区有 working example 和 Pydantic AI 参考实现 +2. **LangGraph 节点能否实时流式输出 SSE?** → ✅ 可以,`get_stream_writer()` + `stream_mode="custom"` +3. **`data-*` 自定义 Part 前端怎么消费?** → 通过 `message.parts` 数组的 `type` 字段过滤 +4. **对话恢复如何对接 useChat?** → `initialMessages` + `chatId` 参数 + +## 下一步 + +→ `/ce:plan` 生成详细实施计划,全面重写前后端消息处理链路 diff --git a/docs/brainstorms/2026-03-12-chat-message-routing-chain-spec-flow-analysis.md b/docs/brainstorms/2026-03-12-chat-message-routing-chain-spec-flow-analysis.md new file mode 100644 index 0000000..66f30b1 --- /dev/null +++ b/docs/brainstorms/2026-03-12-chat-message-routing-chain-spec-flow-analysis.md @@ -0,0 +1,276 @@ +# Chat Message Routing Chain — Spec Flow Analysis + +**Date**: 2026-03-12 +**Source**: `docs/brainstorms/2026-03-12-chat-message-routing-chain-brainstorm.md` +**Analyzer**: spec-flow-analyzer + +--- + +## User Flow Overview + +### Flow 1: Happy Path — New Chat with RAG (knowledge_base_ids present) + +```mermaid +flowchart TD + A[User enters message + selects KBs] --> B[handleSend / sendMessage] + B --> C[POST /api/v1/chat/stream] + C --> D[understand node] + D --> E[retrieve node] + E --> F[rank node] + F --> G[clean node] + G --> H[generate node] + H --> I[persist node] + I --> J[complete node] + J --> K[Stream ends with DONE] + K --> L[useChat updates messages] + L --> M[Navigate to /chat/:id] +``` + +User sends message → backend runs full pipeline (understand → retrieve → rank → clean → generate → persist → complete) → SSE stream emits `start`, `text-start`, `text-delta`, `text-end`, `data-citation`, `data-thinking`, `data-citation-enhanced`, `finish`, `[DONE]` → frontend `useChat` parses and updates `messages` → user sees streaming response with citations and thinking steps → URL updates to `/chat/:conversationId`. + +### Flow 2: Happy Path — New Chat without RAG (no knowledge_base_ids) + +```mermaid +flowchart TD + A[User enters message, no KBs] --> B[sendMessage] + B --> C[POST /api/v1/chat/stream] + C --> D[understand node] + D --> E[Conditional: skip retrieve/rank/clean] + E --> F[generate node] + F --> G[persist node] + G --> H[complete node] + H --> I[Stream ends] +``` + +Same as Flow 1 but `retrieve`, `rank`, `clean` nodes are skipped. No `data-citation` or `data-citation-enhanced` events. + +### Flow 3: Continue Existing Conversation + +User navigates to `/chat/:conversationId` or sends a message in an existing conversation. Backend loads `history_messages` from DB (last 10), prepends to prompt. Same pipeline as Flow 1 or 2 depending on `knowledge_base_ids`. `conversation_id` in request prevents creating new conversation. + +### Flow 4: Conversation Restore (Page Load / Direct URL) + +User opens `/chat/123` directly. Frontend: +1. `routeConvId` from URL → `conversationApi.get(123)` +2. Load conversation + messages +3. Convert `ChatMessage[]` → `UIMessage[]` (or `initialMessages` format) +4. Pass to `useChat({ chatId: '123', initialMessages: restored })` +5. Render restored messages; user can continue chatting + +### Flow 5: User Aborts Stream (Stop Button) + +User clicks Stop during streaming. `useChat.stop()` → AbortController aborts fetch → backend receives cancellation → stream terminates. Frontend shows partial response, no error. + +### Flow 6: Error During Stream (Backend Exception) + +Backend node throws (e.g., LLM timeout, RAG failure). Backend catches, emits `error` Part `{"type":"error","errorText":"..."}`. Frontend `useChat` sets `error` state. User sees error UI (toast or inline). Spec says "useChat error state automatically handles" — exact UX (retry button, inline message, toast) not specified. + +### Flow 7: Network Failure / Connection Drop + +Fetch fails or stream breaks mid-response. No `[DONE]` received. Spec defers "Resumable Streams" to Phase 2. Current behavior: stream hangs or throws; user gets generic error. No reconnect logic in Phase 1. + +### Flow 8: Empty / Invalid Input + +User submits empty message or whitespace-only. Backend `ChatStreamRequest.message` has `min_length=1`. Frontend: ChatInput may or may not prevent submit. If submitted, backend returns 422. `useChat` error handling for non-2xx not specified. + +--- + +## Flow Permutations Matrix + +| Flow | User State | Context | knowledge_base_ids | Expected Behavior | +|------|------------|---------|--------------------|-------------------| +| 1 | Any | New chat | Non-empty | Full RAG pipeline, citations, thinking | +| 2 | Any | New chat | Empty/undefined | Direct LLM, no citations | +| 3 | Any | Existing conv | From conv or override | History loaded, same pipeline | +| 4 | Any | Direct URL | From conv | Restore messages, ready to continue | +| 5 | Any | Streaming | Any | Abort, partial response | +| 6 | Any | Any | Any | Error Part, useChat error | +| 7 | Any | Any | Any | No reconnect (Phase 1) | +| 8 | Any | Submit | N/A | Validation error | + +**Additional dimensions not fully specified:** +- **Tool mode** (qa, citation_lookup, review_outline, gap_analysis): Affects system prompt; no flow-specific behavior change in spec +- **Model override** (`request.model`): Passed but not described in pipeline +- **First-time vs returning user**: Same flows; restore is Flow 4 +- **Concurrent actions**: User sends message while another streams — race condition; spec does not address + +--- + +## Missing Elements & Gaps + +### Category: Error Handling + +| Gap | Description | Impact | Current Ambiguity | +|-----|-------------|--------|-------------------| +| **E-1** | `error` Part schema not defined | Frontend cannot reliably parse error | Is it `{"type":"error","errorText":"..."}` or `{"type":"error","code":"...","message":"..."}`? Current backend uses `{"code":"stream_error","message":"..."}`. | +| **E-2** | Partial failure in RAG (one KB fails) | User gets incomplete citations or full failure? | Current `_stream_chat` uses `return_exceptions=True` and continues; spec does not say if LangGraph nodes should do same. | +| **E-3** | LLM timeout / rate limit | No specific handling | Current `_clean_excerpt` has 10s timeout; main `chat_stream` has none. What timeout for generate node? | +| **E-4** | DB failure in persist node | Conversation/messages not saved | Should stream still complete with `finish`? Or emit `error`? User may see response but refresh loses it. | +| **E-5** | Non-2xx HTTP (422, 500) before stream starts | `useChat` behavior | Does `useChat` surface `response.ok === false` as `error`? Or does fetch throw? | +| **E-6** | Malformed SSE (invalid JSON in data) | Parser behavior | Current `streamChat` yields `{ raw: currentData }` on parse error. Data Stream Protocol: does useChat handle malformed lines? | + +### Category: Protocol & Integration + +| Gap | Description | Impact | Current Ambiguity | +|-----|-------------|--------|-------------------| +| **P-1** | Current backend uses `event: X\ndata: Y`; Data Stream Protocol uses `data: {"type":"X",...}` only | Breaking change | No `event:` line in new format. All info in JSON. Frontend must not expect `event:` prefix. | +| **P-2** | `text_delta` → `text-start` + `text-delta` + `text-end` mapping | Delta format change | Current: `{"delta":"x"}`. New: `{"type":"text-delta","id":"text_xxx","delta":"x"}`. Need `id` for correlation. | +| **P-3** | `message_start` → `start` with `messageId` | Field name change | Current: `message_id`. New: `messageId` (camelCase). | +| **P-4** | `message_end` → `finish` with `conversation_id` | Where does conversation_id go? | Data Stream Protocol `finish` may not include custom fields. Spec says backend yields `finish` — need to confirm `conversation_id` is in same Part or separate `data-*` Part. | +| **P-5** | Header `x-vercel-ai-ui-message-stream: v1` | Required for useChat | Backend must add this. Current backend does not send it. | +| **P-6** | `[DONE]` vs `data: [DONE]` | Termination format | Spec says `data: [DONE]\n\n`. Confirm exact string. | + +### Category: Frontend State & UX + +| Gap | Description | Impact | Current Ambiguity | +|-----|-------------|--------|-------------------| +| **F-1** | `LocalMessage` → `UIMessage` + `parts` migration | MessageBubble, ChatInput, etc. | MessageBubble expects `content`, `citations`, `thinkingSteps`, `a2uiMessages` as props. UIMessage has `parts`. Need adapter: `parts.filter(p => p.type === 'text')` → content, `parts.filter(p => p.type === 'data-citation')` → citations. | +| **F-2** | `loadingStage` derivation | MessageBubble uses `loadingStage` for UI | Current: 'searching' | 'citations' | 'generating' | 'complete'. UIMessage has no such field. Derive from `data-thinking` parts? Or add custom Part? | +| **F-3** | `initialMessages` format | Conversation restore | `ChatMessage[]` from API has `id`, `role`, `content`, `citations`. UIMessage has `id`, `role`, `parts`. Conversion logic not specified. | +| **F-4** | `chatId` type | useChat expects string? | Route has `conversationId` as number. useChat `chatId` may need string. | +| **F-5** | Tool mode, selectedKBs in useChat | useChat sends body | Need to pass `knowledge_base_ids`, `tool_mode`, `conversation_id` in request body. useChat's `body` or `sendMessage` options? | +| **F-6** | Navigation after stream end | Current: `navigate(/chat/${cid})` on message_end | useChat does not know conversation_id. Need `onFinish` or similar to get `conversation_id` from stream and navigate. | +| **F-7** | 80ms debounce removal | Current: manual debounce for text_delta | useChat handles streaming internally. No debounce needed. But does useChat batch updates? May affect perceived performance. | + +### Category: Backend Pipeline + +| Gap | Description | Impact | Current Ambiguity | +|-----|-------------|--------|-------------------| +| **B-1** | ChatState TypedDict | New state for chat graph | PipelineState exists for search/upload. Chat pipeline needs different state: `request`, `llm`, `rag`, `citations`, `messages`, etc. Not defined. | +| **B-2** | DB session / request context | Nodes need db | Current `_stream_chat` receives `db: AsyncSession`. LangGraph nodes receive `state`. How does `persist` node get db? Inject via config or context? | +| **B-3** | `get_stream_writer()` in async generator | LangGraph streams to caller | `graph.astream(..., stream_mode=["updates","custom"])` yields chunks. Caller (FastAPI endpoint) must consume and re-emit as SSE. Bridge code shown in spec but not full request lifecycle. | +| **B-4** | LLM streaming inside generate node | `writer()` in loop | `async for token in llm.chat_stream(...)`: call `writer({"type":"text-delta",...})` each token. Confirmed in spec. | +| **B-5** | Conditional edge: skip retrieve when no KB | Graph structure | `add_conditional_edges("understand", _route, {"retrieve": "retrieve", "generate": "generate"})`. Need `_route(state)` returning next node. | +| **B-6** | Conversation creation timing | New conv vs existing | Current: create conv before persist if `not conversation_id`. persist node must create conv + messages. | +| **B-7** | RAG `return_exceptions` | One KB fails | Keep current behavior (continue with partial results) or fail entire pipeline? | + +### Category: Backward Compatibility & Migration + +| Gap | Description | Impact | Current Ambiguity | +|-----|-------------|--------|-------------------| +| **M-1** | Rewrite API (`/api/v1/chat/rewrite`) | Uses same SSE format | `rewrite-api.ts` uses `event:` + `data:` format. Not migrated in spec. Stays on old format or migrate too? | +| **M-2** | RAG streaming (`/api/v1/rag/...`) | Another SSE endpoint | `rag.py` has `_sse()`. Out of scope? | +| **M-3** | Index pipeline SSE | `api.ts` IndexSSEEvent | Different event types. Unaffected. | +| **M-4** | E2E / tests | test_chat.py asserts `event: message_start` | Backend format change breaks tests. Must update to Data Stream Protocol assertions. | +| **M-5** | Feature flags / gradual rollout | None | Big bang migration. No way to run old and new in parallel. | + +### Category: Testing + +| Gap | Description | Impact | Current Ambiguity | +|-----|-------------|--------|-------------------| +| **T-1** | Backend: LangGraph node unit tests | Each node testable | No plan for mocking `get_stream_writer()`, db, RAG. | +| **T-2** | Backend: Stream format tests | Assert correct SSE | test_chat.py checks `event: message_start`. Need tests for `data: {"type":"start",...}`, `[DONE]`, etc. | +| **T-3** | Frontend: useChat integration | Mock fetch, assert state | No tests for PlaygroundPage today. Adding useChat: how to mock transport? | +| **T-4** | E2E: Full flow | Playwright | Current e2e fixtures use mock SSE. Need real backend or new mock format. | +| **T-5** | Error path tests | 500, timeout, abort | No tests for error Part, abort behavior. | + +### Category: Security & Validation + +| Gap | Description | Impact | Current Ambiguity | +|-----|-------------|--------|-------------------| +| **S-1** | knowledge_base_ids authorization | User can only query own KBs | Current backend does not check. Spec does not mention. | +| **S-2** | conversation_id authorization | User can only continue own conv | Same. | +| **S-3** | Rate limiting | Stream endpoint | Long-running, no rate limit specified. | +| **S-4** | Input sanitization | Message content | Passed to LLM. XSS in citations? Rendered in Markdown. | + +### Category: Accessibility & i18n + +| Gap | Description | Impact | Current Ambiguity | +|-----|-------------|--------|-------------------| +| **A-1** | Error messages | useChat error | May be raw backend message. Need i18n? | +| **A-2** | Loading/streaming announcements | Screen readers | useChat status. Does it expose `status` for aria-live? | + +--- + +## Critical Questions Requiring Clarification + +### Critical (blocks implementation or creates risks) + +1. **Q1: `finish` Part and `conversation_id`** + - **Question**: Where does `conversation_id` go in the Data Stream Protocol? The standard `finish` Part may not include custom fields. Does useChat support a custom `data-conversation-id` Part, or do we extend the `finish` payload? + - **Why it matters**: Frontend needs `conversation_id` to navigate to `/chat/:id` and set `chatId` for subsequent messages. + - **Assumption if unanswered**: Emit a separate `data-conversation-id` Part immediately before `finish`, and handle it in a custom `onMessage` or stream callback. + - **Example**: `data: {"type":"data-conversation-id","conversationId":123}\n\n` then `data: {"type":"finish"}\n\n`. + +2. **Q2: `error` Part schema** + - **Question**: What is the exact JSON schema for the `error` Part that useChat expects? Is it `{"type":"error","errorText":"..."}` or does it support `code` and `message`? + - **Why it matters**: Backend currently sends `{"code":"stream_error","message":"..."}`. Mismatch may cause useChat to not display the error. + - **Assumption**: Use AI SDK's documented `error` Part format; adapt backend to match. + - **Example**: Check [AI SDK Data Stream Protocol](https://sdk.vercel.ai/docs/ai-sdk-ui/stream-protocol) for exact schema. + +3. **Q3: DB session injection into LangGraph** + - **Question**: How does the `persist` node (and any node needing DB) get the `AsyncSession`? Via `configurable` in `astream()`, or a context variable? + - **Why it matters**: LangGraph nodes are stateless; db is request-scoped. + - **Assumption**: Pass `{"db": db}` in `configurable` when calling `graph.astream()`, and have nodes read `state` or a separate context. + - **Example**: `await graph.astream(input_state, config={"configurable": {"db": db}})` + +4. **Q4: ChatState definition** + - **Question**: What fields does the Chat LangGraph state have? At minimum: `request`, `llm`, `rag`, `all_sources`, `all_contexts`, `citations`, `history_messages`, `messages`, `full_response`, `conversation_id`. + - **Why it matters**: Nodes read/write state. Undefined state blocks implementation. + - **Assumption**: Define `ChatState(TypedDict, total=False)` mirroring `_stream_chat` variables. + +### Important (significantly affects UX or maintainability) + +5. **Q5: MessageBubble migration strategy** + - **Question**: Should MessageBubble be refactored to accept `UIMessage` (or `parts`) directly, or should PlaygroundPage adapt `UIMessage` to the current props (`content`, `citations`, `thinkingSteps`, etc.)? + - **Why it matters**: Affects component reuse and testability. + - **Assumption**: Create adapter in PlaygroundPage: `messageToBubbleProps(message: UIMessage)` to avoid changing MessageBubble initially. + +6. **Q6: `initialMessages` conversion** + - **Question**: What is the exact mapping from `ChatMessage` (API) to `UIMessage`/`initialMessages`? `ChatMessage` has `content`, `citations`; UIMessage has `parts`. + - **Why it matters**: Conversation restore must show history correctly. + - **Assumption**: Build `parts`: `[{type:'text', text: m.content}, ...(m.citations?.map(c => ({type:'data-citation', data: c})) ?? [])]`. + +7. **Q7: RAG partial failure behavior** + - **Question**: When one of N knowledge bases fails (exception in `rag.query`), should the pipeline continue with results from the others, or fail entirely? + - **Why it matters**: Current behavior continues; spec does not state. + - **Assumption**: Continue (match current behavior). Emit `data-thinking` with status "partial" or "warning" if desired. + +8. **Q8: Persist node failure** + - **Question**: If DB commit fails in `persist`, should we emit `error` and not send `finish`, or send `finish` (user saw response) and log the failure? + - **Why it matters**: User may see response but lose it on refresh. + - **Assumption**: Emit `error` Part, do not send `finish`. User sees error; response not persisted. + +### Nice-to-have (improves clarity) + +9. **Q9: Rewrite API migration** + - **Question**: Is `/api/v1/chat/rewrite` in scope for this refactor, or does it stay on the old SSE format? + - **Assumption**: Out of scope; rewrite stays as-is. + +10. **Q10: Tool mode in useChat body** + - **Question**: How does useChat send `tool_mode`, `knowledge_base_ids`, `conversation_id`? Via `body` in transport options? + - **Assumption**: `DefaultChatTransport` or custom transport accepts `body` merge. Verify AI SDK 5.0 API. + +--- + +## Recommended Next Steps + +1. **Resolve Critical Questions (Q1–Q4)** + - Check AI SDK 5.0 Data Stream Protocol docs for `finish`, `error`, and custom `data-*` Parts. + - Define `ChatState` and db injection approach. + - Document in brainstorm or a short ADR. + +2. **Define Protocol Contract** + - Create a shared spec (or TypeScript + Python types) for: + - All Part types (`start`, `text-start`, `text-delta`, `text-end`, `finish`, `error`, `data-citation`, `data-thinking`, `data-citation-enhanced`, `data-a2ui`, `data-conversation-id`). + - Exact JSON schema for each. + - Use for backend StreamWriter and frontend parsing validation. + +3. **Draft Migration Plan** + - Phase 1a: Backend — implement StreamWriter, new endpoint (e.g. `/api/v1/chat/ai-stream`) that outputs Data Stream Protocol. Keep `/api/v1/chat/stream` for rollback. + - Phase 1b: Frontend — add useChat with transport pointing to new endpoint; feature-flag or route-based switch. + - Phase 1c: Migrate PlaygroundPage to useChat; adapter for MessageBubble props. + - Phase 1d: Remove old endpoint and streamChat; update tests. + +4. **Update Tests** + - Backend: `test_chat.py` — assert new SSE format (`data: {"type":"start"...}`, `[DONE]`), error Part, no `event:` lines. + - Backend: Add unit tests for each LangGraph node (mock writer, db, RAG). + - Frontend: Add integration test for useChat + mock transport. + +5. **Address Security Gaps (S-1, S-2)** + - Add auth checks: user can only access own `knowledge_base_ids` and `conversation_id`. + - Document in plan even if deferred. + +6. **Document Error Handling** + - Specify: error Part schema, partial RAG failure, persist failure, HTTP 4xx/5xx. + - Add to plan or brainstorm. diff --git a/docs/plans/2026-03-12-feat-chat-message-routing-chain-rewrite-plan.md b/docs/plans/2026-03-12-feat-chat-message-routing-chain-rewrite-plan.md new file mode 100644 index 0000000..0b68b74 --- /dev/null +++ b/docs/plans/2026-03-12-feat-chat-message-routing-chain-rewrite-plan.md @@ -0,0 +1,1068 @@ +--- +title: "feat: Rewrite chat message routing chain" +type: feat +status: completed +date: 2026-03-12 +origin: docs/brainstorms/2026-03-12-chat-message-routing-chain-brainstorm.md +--- + +# feat: 聊天消息路由链全面重写 + +## Enhancement Summary + +**Deepened on:** 2026-03-12 +**Sections enhanced:** 6 (ChatState, Citation 策略, 前端类型, 错误处理, 性能, 架构) +**Research agents used:** Python Reviewer, TypeScript Reviewer, Performance Oracle, Architecture Strategist, AI SDK Context7 Docs + +### Key Improvements + +1. **Citation 协调模式改进**:用 AI SDK `id`-based 协调替代分离的 `data-citation-enhanced` 事件——同一 `data-citation` 类型 + 相同 `id` 可直接更新已有 Part,更简洁 +2. **ChatState 类型加强**:引入 `CitationDict` 和 `ChatMessageDict` TypedDicts 替代 `dict[str, Any]` +3. **前端防抖**:`useChat` 每 token 触发 re-render,需添加 `useDeferredValue` 或自定义防抖(P0 性能问题) +4. **UIMessage 泛型**:使用 `useChat()` 获得类型安全的自定义 Part 访问 + +### New Considerations Discovered + +- AI SDK 5.0 `data-*` Part 支持 `id` 字段进行 reconciliation(同 id 的新 Part 更新旧 Part) +- AI SDK 5.0 支持 `transient: true` 标记(thinking steps 可设为 transient,不保存到消息历史) +- `useChat` 的 `onData` 回调可处理所有 data Part(含 transient 的) +- LangGraph overhead 极低(~0.5-2ms/token),不需要 checkpointer +- 当前 Paper 查询已经是批量的(不是 N+1),需保持该模式 + +## Overview + +对 Omelette 的聊天消息处理链路进行全面重写:前端从手工 SSE 解析 + 15 个 useState 迁移到 Vercel AI SDK 5.0 `useChat`;后端从 230+ 行的 `_stream_chat` 单体函数迁移到 LangGraph StateGraph 可编排节点管道。同时建立标准化的 Data Stream Protocol 通信协议,统一错误处理和消息模型。 + +## Problem Statement + +当前聊天消息处理链路存在 11 个已识别问题(见 brainstorm R-1 至 R-11),核心矛盾是: +- 前端手写 SSE 解析器无法利用生态工具、不处理 error 事件、断流无重试 +- PlaygroundPage ~15 个 useState/useRef 导致状态逻辑与 UI 高度耦合 +- 后端 `_stream_chat` 单体函数步骤间耦合,难以独立测试/复用/扩展 + +## Proposed Solution + +三层重构: +1. **协议层**:Vercel AI SDK 5.0 Data Stream Protocol(标准化 SSE 格式) +2. **前端层**:`useChat` hook 替代手动状态管理,`UIMessage` + `parts` 替代 `LocalMessage` +3. **后端层**:LangGraph StateGraph 编排 7 个节点,`get_stream_writer()` 发射标准 SSE 事件 + +(见 brainstorm: `docs/brainstorms/2026-03-12-chat-message-routing-chain-brainstorm.md`) + +## Technical Approach + +### Architecture + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ 前端 (React) │ +│ │ +│ useChat({ transport, chatId, initialMessages }) │ +│ ├─ messages: UIMessage[] (parts: text | data-citation | ...) │ +│ ├─ status: 'ready' | 'submitted' | 'streaming' | 'error' │ +│ ├─ sendMessage() │ +│ └─ stop() │ +│ │ +│ Components: │ +│ PlaygroundPage → MessageBubble → [TextPart, CitationCard, │ +│ ThinkingChain, A2UISurface] │ +└────────────────────────┬──────────────────────────────────────────┘ + │ POST /api/v1/chat/stream + │ Header: x-vercel-ai-ui-message-stream: v1 + │ SSE: data: {"type":"text-delta","id":"...","delta":"..."}\n\n + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ 后端 (FastAPI) │ +│ │ +│ POST /api/v1/chat/stream │ +│ → StreamingResponse(chat_graph_stream(request, db)) │ +│ │ +│ LangGraph StateGraph: │ +│ [understand] → [retrieve] → [rank] → [clean] → [generate] │ +│ ↓ (无KB) → [persist] → [end] │ +│ [generate] → [persist] → [end] │ +│ │ +│ 每个节点通过 get_stream_writer() 发射 Data Stream Protocol 事件 │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### Data Stream Protocol 事件映射 + +| 阶段 | 事件类型 | Payload | 发射节点 | +|------|---------|---------|---------| +| 开始 | `start` | `{"type":"start","messageId":"msg_xxx"}` | endpoint | +| 思考 | `data-thinking` | `{"type":"data-thinking","data":{"step":"understand","status":"running","detail":"..."}}` | 各节点 | +| 引用 | `data-citation` | `{"type":"data-citation","id":"cit-0","data":{"index":0,"title":"...","excerpt":"原始摘要"}}` | rank | +| 引用更新 | `data-citation` (同 id) | `{"type":"data-citation","id":"cit-0","data":{"index":0,"title":"...","excerpt":"清洗后摘要"}}` | clean | +| 文本开始 | `text-start` | `{"type":"text-start","id":"text_xxx"}` | generate | +| 文本增量 | `text-delta` | `{"type":"text-delta","id":"text_xxx","delta":"..."}` | generate | +| 文本结束 | `text-end` | `{"type":"text-end","id":"text_xxx"}` | generate | +| 会话ID | `data-conversation` | `{"type":"data-conversation","data":{"conversation_id":123}}` | persist | +| 结束 | `finish` | `{"type":"finish"}` | endpoint | +| 终止 | `[DONE]` | `data: [DONE]` | endpoint | +| 错误 | `error` | `{"type":"error","errorText":"..."}` | 任意节点 | + +### ChatState 定义 + +```python +# backend/app/pipelines/chat/state.py +from typing import Any, Literal, TypedDict + +class CitationDict(TypedDict, total=False): + """强类型 citation,与前端 Citation interface 对齐。""" + index: int + paper_id: int | None + paper_title: str + chunk_type: str + page_number: int | None + relevance_score: float + excerpt: str + authors: list[str] | None + year: int | None + doi: str | None + +class ChatMessageDict(TypedDict): + """LLM 消息格式。""" + role: Literal["system", "user", "assistant"] + content: str + +class ChatState(TypedDict, total=False): + """LangGraph chat pipeline state.""" + + # --- 输入 (from request) --- + message: str + knowledge_base_ids: list[int] + tool_mode: str + conversation_id: int | None + + # --- 服务注入 (通过 config["configurable"],不在 state 中) --- + # config["configurable"]["db"]: AsyncSession + # config["configurable"]["llm"]: LLMClient (由 understand_node 注入) + # config["configurable"]["rag"]: RAGService (由 understand_node 注入) + + # --- 中间结果 --- + rag_results: list[dict[str, Any]] # RAG 内部格式,保持松散 + citations: list[CitationDict] + history_messages: list[ChatMessageDict] + system_prompt: str + full_messages: list[ChatMessageDict] + + # --- 输出 --- + assistant_content: str + new_conversation_id: int | None + error: str | None +``` + +**服务注入约定**(`config["configurable"]` 契约): + +```python +# backend/app/pipelines/chat/config_helpers.py +from langchain_core.runnables import RunnableConfig + +def get_chat_db(config: RunnableConfig): + return config["configurable"]["db"] + +def get_chat_llm(config: RunnableConfig): + return config["configurable"]["llm"] + +def get_chat_rag(config: RunnableConfig): + return config["configurable"]["rag"] +``` + +- `understand_node` 创建 `llm` 和 `rag` 并注入 `config["configurable"]` +- 后续节点通过 `get_chat_llm(config)` 等辅助函数访问 +- `db` 在端点调用时注入,全生命周期有效 + +### 服务注入模式 + +LangGraph 节点通过 `config["configurable"]` 访问请求级服务(DB session、LLM、RAG): + +```python +# backend/app/pipelines/chat/nodes.py + +from langchain_core.runnables import RunnableConfig + +async def understand_node(state: ChatState, config: RunnableConfig) -> dict: + db = config["configurable"]["db"] + svc = UserSettingsService(db) + llm_config = await svc.get_merged_llm_config() + llm = get_llm_client(config=llm_config) + embed = get_embedding_model() if llm_config.provider != "mock" else MockEmbedding(embed_dim=128) + rag = RAGService(llm=llm, embed_model=embed) + + # 注入到 configurable 供后续节点使用 + config["configurable"]["llm"] = llm + config["configurable"]["rag"] = rag + + writer = get_stream_writer() + writer({"type": "data-thinking", "data": {"step": "understand", "status": "done", ...}}) + + return {"history_messages": [...], "system_prompt": "..."} +``` + +**调用侧**: + +```python +# backend/app/api/v1/chat.py + +async def chat_graph_stream(request: ChatStreamRequest, db: AsyncSession): + graph = get_chat_graph() + config = {"configurable": {"db": db, "thread_id": str(uuid4())}} + initial_state = { + "message": request.message, + "knowledge_base_ids": request.knowledge_base_ids or [], + "tool_mode": request.tool_mode or "qa", + "conversation_id": request.conversation_id, + } + + yield f'data: {json.dumps({"type": "start", "messageId": f"msg_{uuid4().hex}"})}\n\n' + + async for mode, chunk in graph.astream(initial_state, config=config, stream_mode=["updates", "custom"]): + if mode == "custom": + yield f'data: {json.dumps(chunk)}\n\n' + + yield f'data: {json.dumps({"type": "finish"})}\n\n' + yield 'data: [DONE]\n\n' +``` + +### 前端 UIMessage Parts 适配 + +```typescript +// frontend/src/types/chat.ts + +// AI SDK 5.0 UIMessage 扩展类型 +interface CitationData { + index: number; + title: string; + authors: { name: string }[]; + year?: number; + doi?: string; + excerpt: string; + paper_id?: number; + source?: string; + confidence?: number; +} + +interface ThinkingData { + step: string; + label?: string; + status: 'running' | 'done' | 'error' | 'skipped'; + detail?: string; + duration_ms?: number; + summary?: string; +} + +interface ConversationData { + conversation_id: number; +} + +// 从 UIMessage.parts 提取自定义数据 +function getCitations(message: UIMessage): CitationData[] { + return message.parts + .filter(p => p.type === 'data-citation') + .map(p => p.data as CitationData); +} + +function getThinkingSteps(message: UIMessage): ThinkingData[] { + return message.parts + .filter(p => p.type === 'data-thinking') + .map(p => p.data as ThinkingData); +} +``` + +### MessageBubble 适配 + +```tsx +// frontend/src/components/playground/MessageBubble.tsx + +function MessageBubble({ message }: { message: UIMessage }) { + const citations = useMemo(() => getCitations(message), [message.parts]); + const thinkingSteps = useMemo(() => getThinkingSteps(message), [message.parts]); + + return ( + + {message.parts.map((part, i) => { + switch (part.type) { + case 'text': + return ; + case 'data-citation': + return null; // 在 CitationCardList 中统一渲染 + case 'data-thinking': + return null; // 在 ThinkingChain 中统一渲染 + default: + return null; + } + })} + {thinkingSteps.length > 0 && } + {citations.length > 0 && } + + ); +} + +export default memo(MessageBubble); +``` + +### Conversation 恢复 + +```tsx +// frontend/src/pages/PlaygroundPage.tsx + +function PlaygroundPage() { + const { conversationId: routeConvId } = useParams(); + const [initialMessages, setInitialMessages] = useState([]); + + // 从 DB 恢复消息并转换为 UIMessage 格式 + useEffect(() => { + if (routeConvId) { + conversationApi.get(Number(routeConvId)).then(conv => { + setInitialMessages(convertToUIMessages(conv.messages)); + }); + } + }, [routeConvId]); + + const { messages, sendMessage, status, stop, error } = useChat({ + chatId: routeConvId, + initialMessages, + transport: new DefaultChatTransport({ + api: '/api/v1/chat/stream', + headers: { 'Content-Type': 'application/json' }, + }), + }); + + // 从 streaming message 中提取 conversation_id 用于 URL 更新 + const latestConvId = useMemo(() => { + const lastAssistant = [...messages].reverse().find(m => m.role === 'assistant'); + if (!lastAssistant) return null; + const convPart = lastAssistant.parts.find(p => p.type === 'data-conversation'); + return convPart?.data?.conversation_id ?? null; + }, [messages]); + // ... +} +``` + +### Implementation Phases + +#### Phase 1: 后端 LangGraph Chat Pipeline + Data Stream Protocol + +**目标**:创建新的 LangGraph chat graph 和 Data Stream Protocol 端点,与旧端点并行运行 + +**文件变更**: + +| 操作 | 文件 | 说明 | +|------|------|------| +| 新建 | `backend/app/pipelines/chat/__init__.py` | 包初始化 | +| 新建 | `backend/app/pipelines/chat/state.py` | ChatState TypedDict | +| 新建 | `backend/app/pipelines/chat/nodes.py` | 7 个节点函数 | +| 新建 | `backend/app/pipelines/chat/graph.py` | StateGraph 定义 | +| 新建 | `backend/app/pipelines/chat/stream_writer.py` | Data Stream Protocol SSE 格式化工具 | +| 修改 | `backend/app/api/v1/chat.py` | 新增 `/stream/v2` 端点 | +| 新建 | `backend/tests/test_chat_pipeline.py` | 节点单元测试 | + +**详细任务**: + +**1.1 ChatState 定义** (`backend/app/pipelines/chat/state.py`) + +```python +from typing import Any, TypedDict + +class ChatState(TypedDict, total=False): + message: str + knowledge_base_ids: list[int] + tool_mode: str + conversation_id: int | None + rag_results: list[dict[str, Any]] + citations: list[dict[str, Any]] + enhanced_citations: list[dict[str, Any]] + history_messages: list[dict[str, Any]] + system_prompt: str + full_messages: list[dict[str, Any]] + assistant_content: str + new_conversation_id: int | None + error: str | None +``` + +**1.2 StreamWriter 工具** (`backend/app/pipelines/chat/stream_writer.py`) + +```python +import json +import uuid + +class DataStreamWriter: + """Formats events in Vercel AI SDK 5.0 Data Stream Protocol.""" + + @staticmethod + def start(message_id: str | None = None) -> str: + mid = message_id or f"msg_{uuid.uuid4().hex}" + return f'data: {json.dumps({"type": "start", "messageId": mid})}\n\n' + + @staticmethod + def text_start(text_id: str | None = None) -> str: + tid = text_id or f"text_{uuid.uuid4().hex}" + return f'data: {json.dumps({"type": "text-start", "id": tid})}\n\n' + + @staticmethod + def text_delta(text_id: str, delta: str) -> str: + return f'data: {json.dumps({"type": "text-delta", "id": text_id, "delta": delta})}\n\n' + + @staticmethod + def text_end(text_id: str) -> str: + return f'data: {json.dumps({"type": "text-end", "id": text_id})}\n\n' + + @staticmethod + def data_part(data_type: str, data: dict) -> str: + return f'data: {json.dumps({"type": data_type, "data": data})}\n\n' + + @staticmethod + def error(message: str) -> str: + return f'data: {json.dumps({"type": "error", "errorText": message})}\n\n' + + @staticmethod + def finish() -> str: + return f'data: {json.dumps({"type": "finish"})}\n\n' + + @staticmethod + def done() -> str: + return 'data: [DONE]\n\n' +``` + +**1.3 节点实现** (`backend/app/pipelines/chat/nodes.py`) + +每个节点遵循已有模式(`backend/app/pipelines/nodes.py`):接收 state + config,返回 partial state update,通过 `get_stream_writer()` 发射自定义事件。 + +| 节点 | 职责 | `get_stream_writer()` 事件 | +|------|------|---------------------------| +| `understand_node` | 获取 LLM/RAG 服务,加载历史消息,构建 system prompt | `data-thinking(understand)` | +| `retrieve_node` | `asyncio.gather()` 并行 RAG 查询多个 KB | `data-thinking(retrieve)` | +| `rank_node` | 批量加载 Paper 元数据,构建 citation 列表 | `data-thinking(rank)` + `data-citation(id=cit-N)` × N | +| `clean_node` | LLM 并行清洗 citation excerpts | `data-thinking(clean)` + `data-citation(id=cit-N, 同 id 更新)` × M | +| `generate_node` | LLM 流式生成回答 | `data-thinking(generate)` + `text-start` + `text-delta` × K + `text-end` | +| `persist_node` | 创建/更新 conversation 和 messages | `data-conversation` | + +**关键:`generate_node` 的流式输出** + +```python +async def generate_node(state: ChatState, config: RunnableConfig) -> dict: + writer = get_stream_writer() + llm = config["configurable"]["llm"] + + writer({"type": "data-thinking", "data": {"step": "generate", "status": "running"}}) + + text_id = f"text_{uuid.uuid4().hex}" + writer({"type": "text-start", "id": text_id}) + + content = "" + async for token in llm.chat_stream(state["full_messages"]): + content += token + writer({"type": "text-delta", "id": text_id, "delta": token}) + + writer({"type": "text-end", "id": text_id}) + writer({"type": "data-thinking", "data": {"step": "generate", "status": "done"}}) + + return {"assistant_content": content} +``` + +**注意**:`generate_node` 发射的 `text-start/delta/end` 不是 `data-*` 类型,是标准 Part,直接作为 `custom` stream 事件输出。 + +**1.4 Graph 定义** (`backend/app/pipelines/chat/graph.py`) + +```python +from langgraph.graph import END, StateGraph + +def _route_after_understand(state: ChatState) -> str: + if state.get("knowledge_base_ids"): + return "retrieve" + return "generate" + +def build_chat_graph(): + graph = StateGraph(ChatState) + + graph.add_node("understand", understand_node) + graph.add_node("retrieve", retrieve_node) + graph.add_node("rank", rank_node) + graph.add_node("clean", clean_node) + graph.add_node("generate", generate_node) + graph.add_node("persist", persist_node) + + graph.set_entry_point("understand") + graph.add_conditional_edges("understand", _route_after_understand, { + "retrieve": "retrieve", + "generate": "generate", + }) + graph.add_edge("retrieve", "rank") + graph.add_edge("rank", "clean") + graph.add_edge("clean", "generate") + graph.add_edge("generate", "persist") + graph.add_edge("persist", END) + + return graph.compile() +``` + +**1.5 新端点** (`backend/app/api/v1/chat.py`) + +在现有 `/stream` 旁新增 `/stream/v2`: + +```python +@router.post("/stream/v2") +async def chat_stream_v2( + request: ChatStreamRequest, + db: AsyncSession = Depends(get_db), +): + return StreamingResponse( + _stream_chat_v2(request, db), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "x-vercel-ai-ui-message-stream": "v1", + }, + ) +``` + +**1.6 后端测试** (`backend/tests/test_chat_pipeline.py`) + +- 每个节点独立测试(mock state + config) +- `understand_node`:验证 LLM/RAG 初始化、history 加载 +- `retrieve_node`:验证 RAG 查询、部分失败处理 +- `rank_node`:验证批量 Paper 查询、citation 构建 +- `clean_node`:验证 LLM 清洗、超时降级 +- `generate_node`:验证流式输出事件序列 +- `persist_node`:验证 conversation 创建/更新 +- 集成测试:验证完整 graph 执行和事件序列 + +**1.7 端点级错误处理** + +```python +async def _stream_chat_v2(request: ChatStreamRequest, db: AsyncSession): + writer = DataStreamWriter + yield writer.start() + try: + graph = build_chat_graph() + config = {"configurable": {"db": db, "thread_id": str(uuid4())}} + initial_state = {...} + + async for mode, chunk in graph.astream(initial_state, config=config, stream_mode=["updates", "custom"]): + if mode == "custom": + yield f'data: {json.dumps(chunk)}\n\n' + + yield writer.finish() + except Exception as e: + logger.exception("Chat graph error") + yield writer.error(str(e)) + finally: + yield writer.done() +``` + +### Phase 1 Research Insights + +**错误处理(Python Reviewer)**: +- 端点级 `try/except` 捕获所有 graph 异常并发射 `error` Part +- 节点内错误通过 `writer({"type": "error", ...})` 发射 + `return {"error": str(e)}` +- `generate_node` 中 LLM 错误可能发生在 partial content 已发射后——保留已发射内容 + +**性能(Performance Oracle)**: +- LangGraph overhead ~0.5-2ms/token,远低于 LLM 50-200ms 延迟 +- **不要使用 checkpointer**——`graph.compile()` 不传 checkpointer,避免序列化开销 +- `json.dumps` 每 token ~2-5µs,500 token 总计 < 3ms,可忽略 +- 当前 Paper 查询已是批量的(`IN` 查询),`rank_node` 须保持此模式 + +**测试(Python Reviewer)**: +- `get_stream_writer()` 在 graph 外部调用时为 None——测试需通过可选参数注入或在最小 graph 中运行 +- 建议节点签名添加 `_writer` 可选参数用于测试注入 + +**架构(Architecture Strategist)**: +- LangGraph 是正确选择(与现有 pipeline 一致、支持 HITL、支持 checkpointing) +- 路由函数返回类型用 `Literal["retrieve", "generate"]` +- `DataStreamWriter` 改为模块级函数而非 class + @staticmethod + +**RAG 效率(Python Reviewer)**: +- 考虑添加 `RAGService.retrieve_only()` 避免生成未使用的 answer +- 或保持 `query()` 并忽略 answer(更简单但有微小浪费) + +**Phase 1 验证标准**: +- [ ] 所有节点有单元测试(mock writer + config) +- [ ] `/stream/v2` 输出标准 Data Stream Protocol SSE +- [ ] 端点级错误处理:graph 异常 → `error` Part +- [ ] 旧 `/stream` 端点仍然工作(无破坏性变更) +- [ ] `ruff lint` 通过 +- [ ] 不使用 checkpointer + +--- + +#### Phase 2: 前端 Vercel AI SDK 迁移 + +**目标**:安装 AI SDK 5.0,创建 `useChat` 集成,替换手动状态管理 + +**文件变更**: + +| 操作 | 文件 | 说明 | +|------|------|------| +| 修改 | `frontend/package.json` | 添加 `ai`, `@ai-sdk/react` | +| 新建 | `frontend/src/hooks/useChatStream.ts` | `useChat` 封装(含自定义 data-* 处理) | +| 新建 | `frontend/src/lib/chat-transport.ts` | 自定义 Transport(请求体格式适配) | +| 修改 | `frontend/src/types/chat.ts` | 新增 UIMessage 辅助类型 + 提取函数 | +| 修改 | `frontend/src/pages/PlaygroundPage.tsx` | 重写为 useChat 驱动 | +| 修改 | `frontend/src/components/playground/MessageBubble.tsx` | 适配 UIMessage parts | +| 修改 | `frontend/src/components/playground/ThinkingChain.tsx` | 适配 ThinkingData | +| 修改 | `frontend/src/components/playground/CitationCard.tsx` | 适配 CitationData | +| 修改 | `frontend/src/components/playground/CitationCardList.tsx` | 适配 CitationData[] | +| 删除 | `frontend/src/services/chat-api.ts` (streamChat 部分) | 不再需要手写 SSE 解析 | +| 修改 | `frontend/src/components/playground/ChatInput.tsx` | 适配 sendMessage API | + +**详细任务**: + +**2.1 安装依赖** + +```bash +cd frontend && npm install ai @ai-sdk/react +``` + +**2.2 自定义 Transport** (`frontend/src/lib/chat-transport.ts`) + +```typescript +import { DefaultChatTransport } from 'ai'; +import type { UIMessage } from 'ai'; + +function getMessageText(message: UIMessage): string { + return message.parts + .filter((p): p is { type: 'text'; text: string } => p.type === 'text') + .map(p => p.text) + .join(''); +} + +export function createChatTransport(options?: { + knowledgeBaseIds?: number[]; + toolMode?: string; +}) { + return new DefaultChatTransport({ + api: '/api/v1/chat/stream/v2', + headers: { 'Content-Type': 'application/json' }, + prepareSendMessagesRequest: ({ messages, id }) => ({ + body: { + message: getMessageText(messages[messages.length - 1]), + knowledge_base_ids: options?.knowledgeBaseIds ?? [], + tool_mode: options?.toolMode ?? 'qa', + conversation_id: id ? Number(id) : undefined, + }, + }), + }); +} +``` + +**关键细节**(TypeScript Reviewer 反馈): +- `messages[*].content` 在 AI SDK 5.0 已废弃 → 用 `getMessageText()` 从 parts 提取 +- `conversation_id` 通过 `id` 参数(即 `chatId`)传递到请求体 +- Transport 实例须通过 `useMemo` 稳定引用,数组 deps 需序列化 + +**2.3 UIMessage 类型定义** (`frontend/src/types/chat.ts`) + +```typescript +import type { UIMessage } from 'ai'; + +// 自定义 data-* Part 的 payload 类型(与后端 CitationDict 对齐) +export interface CitationData { + index: number; + paper_id?: number; + paper_title: string; + chunk_type?: string; + page_number?: number; + relevance_score?: number; + excerpt: string; + authors?: string[] | null; + year?: number | null; + doi?: string | null; +} + +export interface ThinkingData { + step: string; + label: string; // 必须,后端保证发射 + status: 'running' | 'done' | 'error' | 'skipped'; + detail?: string; + duration_ms?: number; + summary?: string; +} + +export interface ConversationData { + conversation_id: number; +} + +// AI SDK 5.0 泛型:强类型自定义 data parts +export type OmeletteDataParts = { + citation: CitationData; + thinking: ThinkingData; + conversation: ConversationData; +}; + +export type OmeletteUIMessage = UIMessage; + +// Citation 提取辅助(AI SDK id-based reconciliation 自动合并同 id Part) +export function getCitations(message: OmeletteUIMessage): CitationData[] { + return message.parts + .filter((p): p is { type: 'data-citation'; id?: string; data: CitationData } => + p.type === 'data-citation') + .map(p => p.data); +} + +export function getThinkingSteps(message: OmeletteUIMessage): ThinkingData[] { + return message.parts + .filter((p): p is { type: 'data-thinking'; data: ThinkingData } => + p.type === 'data-thinking') + .map(p => p.data); +} +``` + +**2.4 useChatStream Hook** (`frontend/src/hooks/useChatStream.ts`) + +```typescript +import { useChat } from '@ai-sdk/react'; +import { useDeferredValue, useMemo } from 'react'; +import { createChatTransport } from '@/lib/chat-transport'; +import type { + CitationData, ThinkingData, OmeletteUIMessage, + getCitations, getThinkingSteps, +} from '@/types/chat'; + +export function useChatStream(options: { + chatId?: string; + initialMessages?: OmeletteUIMessage[]; + knowledgeBaseIds?: number[]; + toolMode?: string; +}) { + // 稳定化 array deps(TypeScript Reviewer 反馈) + const kbIdsKey = useMemo( + () => JSON.stringify(options.knowledgeBaseIds ?? []), + [options.knowledgeBaseIds], + ); + + const transport = useMemo( + () => createChatTransport({ + knowledgeBaseIds: options.knowledgeBaseIds, + toolMode: options.toolMode, + }), + [kbIdsKey, options.toolMode], + ); + + const chat = useChat({ + id: options.chatId, // AI SDK 5.0 用 `id` 而非 `chatId` + initialMessages: options.initialMessages, + transport, + onData: (dataPart) => { + // 处理 transient parts(如通知) + if (dataPart.type === 'data-thinking' && dataPart.data.status === 'running') { + // thinking steps 可选择为 transient(不保存到消息历史) + } + }, + }); + + // 防抖流式内容(Performance Oracle P0 建议) + const deferredMessages = useDeferredValue(chat.messages); + + const lastAssistant = useMemo(() => { + return [...deferredMessages].reverse().find(m => m.role === 'assistant'); + }, [deferredMessages]); + + const citations = useMemo( + () => lastAssistant ? getCitations(lastAssistant) : [], + [lastAssistant], + ); + + const thinkingSteps = useMemo( + () => lastAssistant ? getThinkingSteps(lastAssistant) : [], + [lastAssistant], + ); + + const conversationId = useMemo(() => { + if (!lastAssistant) return null; + const part = lastAssistant.parts.find(p => p.type === 'data-conversation'); + return part ? (part.data as { conversation_id: number }).conversation_id : null; + }, [lastAssistant]); + + return { + ...chat, + messages: deferredMessages, + citations, + thinkingSteps, + conversationId, + }; +} +``` + +**关键改进(Research Insights)**: +- `useChat()` 泛型获得强类型 `data-*` Part 访问 +- `useDeferredValue(chat.messages)` 防止每 token re-render(Performance Oracle P0) +- `kbIdsKey = JSON.stringify(...)` 稳定化数组 deps(TypeScript Reviewer) +- `id` 替代 `chatId`(AI SDK 5.0 API 修正) +- `onData` 回调处理 transient parts +- Citation 用 type guard `(p): p is {...}` 替代 `as` 类型断言 + +**2.4 PlaygroundPage 重写** + +核心变化: +- 移除 `messages` useState → `useChatStream.messages` +- 移除 `isStreaming` useState → `status === 'streaming'` +- 移除 `pendingDeltaRef`、`flushTimerRef`、`assistantIdRef` → AI SDK 内部管理 +- 移除 `abortRef` → `stop()` +- 移除整个 `for await (event of gen)` 循环 → AI SDK 自动处理 +- 保留 `sidebarCollapsed`、`toolMode`、`selectedKBs` 等 UI 状态 + +**2.5 MessageBubble 适配** + +- `content` prop → 从 `message.parts` 中提取 text parts +- `citations` prop → `getCitations(message)` +- `thinkingSteps` prop → `getThinkingSteps(message)` +- `isStreaming` → 通过 `status` 判断 +- 保持 `memo()` 优化 + +**2.6 Conversation 恢复** + +```typescript +// 从后端 ChatMessage 转换为 UIMessage +function convertToUIMessages(messages: ChatMessage[]): UIMessage[] { + return messages.map(msg => ({ + id: String(msg.id), + role: msg.role as 'user' | 'assistant', + parts: [{ type: 'text' as const, text: msg.content }], + // 恢复的消息不包含 citation/thinking parts(已完成的对话) + })); +} +``` + +**2.7 删除旧代码** + +- `chat-api.ts` 中的 `streamChat` 函数和 `SSEEvent` 类型 +- `PlaygroundPage.tsx` 中的手动 SSE 处理逻辑 +- `LocalMessage` interface + +### Phase 2 Research Insights + +**类型安全(TypeScript Reviewer)**: +- `CitationData` 字段须与后端 `CitationDict` 对齐(`paper_title` 而非 `title`) +- 用 type guards 替代 `as CitationData` 类型断言 +- `convertToUIMessages` 须映射恢复消息的 citations 到 `data-citation` parts + +**性能(Performance Oracle P0)**: +- `useChat` 每 token 触发 re-render → 用 `useDeferredValue(chat.messages)` 防抖 +- `memo(MessageBubble)` 必须保留 +- `getCitations()` / `getThinkingSteps()` 放在 `useMemo` 中,依赖 `lastAssistant` + +**A2UISurface 迁移**: +- `LocalMessage.a2uiMessages` → `data-a2ui` Part(与 citation 同模式) +- 如果 A2UISurface 暂时不迁移,保留为独立 prop 从消息中提取 + +**错误/加载状态**: +- `status === 'submitted'`:已发送,等待首 token +- `status === 'streaming'`:流式中 +- `status === 'error'`:错误(`chat.error` 有详细信息) +- `chat.stop()`:中止 +- ChatInput disabled when `status !== 'ready'` + +**Phase 2 验证标准**: +- [ ] `useChat` 正常发送消息并接收流式响应 +- [ ] citations(含 id reconciliation 更新)正确渲染 +- [ ] thinking steps 正确渲染 +- [ ] 对话恢复(`/chat/:id`)正常工作(`initialMessages` + `id`) +- [ ] abort(`stop()`)正常工作 +- [ ] error 状态显示 toast 或内联错误 +- [ ] 流式文本无明显卡顿(`useDeferredValue` 生效) +- [ ] ESLint 无新增错误 +- [ ] TypeScript 编译通过 + +--- + +#### Phase 3: 清理与测试 + +**目标**:删除旧端点、补充测试、处理边缘情况 + +**文件变更**: + +| 操作 | 文件 | 说明 | +|------|------|------| +| 修改 | `backend/app/api/v1/chat.py` | `/stream/v2` → `/stream`(替换旧端点) | +| 修改 | `frontend/src/lib/chat-transport.ts` | URL 从 `/stream/v2` → `/stream` | +| 删除 | `backend/app/api/v1/chat.py` 旧代码 | 移除 `_stream_chat`、`_thinking`、`_clean_excerpt` | +| 修改 | `backend/tests/test_chat.py` | 更新为 Data Stream Protocol 格式断言 | +| 新建 | `frontend/src/hooks/__tests__/useChatStream.test.ts` | Hook 测试 | +| 新建 | `frontend/src/lib/__tests__/chat-transport.test.ts` | Transport 测试 | + +**详细任务**: + +**3.1 端点切换** + +- `/stream/v2` 重命名为 `/stream` +- 删除旧的 `_stream_chat` 函数和所有相关 helper +- 更新前端 Transport URL + +**3.2 边缘情况处理** + +| 场景 | 处理 | +|------|------| +| RAG 部分失败(某个 KB 查询失败) | `retrieve_node` 中 `asyncio.gather(return_exceptions=True)` + 过滤异常 | +| LLM 清洗超时 | `clean_node` 中 `asyncio.wait_for(timeout=10)` + 降级为原始 excerpt | +| DB 持久化失败 | `persist_node` catch → 发射 `data-thinking(persist, error)` + 继续完成流 | +| 用户 abort 中途 | `useChat.stop()` → 前端断开连接 → FastAPI `ClientDisconnect` | +| 无效 knowledge_base_ids | `retrieve_node` 跳过无效 ID + 发射 warning | +| conversation_id 不存在 | `persist_node` 创建新对话而非更新 | + +**3.3 测试补充** + +- 后端节点单元测试(Phase 1 已覆盖) +- 后端集成测试(完整 graph 执行 + SSE 事件验证) +- 前端 `useChatStream` hook 测试(MSW mock `/stream` 端点) +- 前端 `convertToUIMessages` 单元测试 +- 手动 E2E 验证(浏览器工具) + +**3.4 Rewrite API 评估** + +`backend/app/api/v1/rewrite.py` 也使用自定义 SSE 格式(`rewrite_delta`, `rewrite_end`)。本次不迁移——范围限定在 chat 端点。后续可复用 `DataStreamWriter` 统一所有 SSE 端点。 + +**Phase 3 验证标准**: +- [ ] 旧端点已删除,所有流量走新管道 +- [ ] 所有后端测试通过(`pytest`) +- [ ] 所有前端测试通过(`vitest`) +- [ ] 边缘情况有处理(部分失败、超时、abort) +- [ ] `ruff lint` + `eslint` 通过 + +## System-Wide Impact + +### Interaction Graph + +``` +用户点击发送 + → useChat.sendMessage() + → DefaultChatTransport.fetch(POST /api/v1/chat/stream) + → FastAPI chat_stream_v2() + → StreamingResponse(chat_graph_stream()) + → LangGraph graph.astream(stream_mode=["updates","custom"]) + → understand_node → retrieve_node → ... → persist_node + → get_stream_writer() → custom event → SSE data line + → yield SSE line + → Response body + → useChat internal SSE parser + → UIMessage.parts update + → React re-render + → MessageBubble → ThinkingChain, CitationCardList, MarkdownRenderer +``` + +### Error & Failure Propagation + +| 错误源 | 传播路径 | 处理 | +|--------|---------|------| +| LLM API 错误 | node → exception → graph → `_stream_chat_v2` catch → `error` Part | `useChat.error` 状态 | +| RAG 查询失败 | `retrieve_node` → `return_exceptions=True` → 过滤 | 跳过失败 KB,继续 | +| DB 错误 | `persist_node` → catch → `data-thinking(persist, error)` | 回答仍然显示,URL 不更新 | +| 网络断开 | fetch abort → `useChat` 检测 → `error` 状态 | 显示错误 toast | +| 用户 abort | `stop()` → abort signal → FastAPI `ClientDisconnect` | 清理状态 | + +### State Lifecycle Risks + +| 风险 | 缓解 | +|------|------| +| Graph 执行中 DB session 关闭 | 在 `chat_graph_stream` generator 中保持 `db` scope | +| 部分 citation 发射后 LLM 失败 | 前端已收到的 citation 保留显示,error 消息追加 | +| persist 失败导致对话丢失 | catch 异常 + log + 不影响已发射的回答内容 | + +### API Surface Parity + +| 接口 | 变更 | +|------|------| +| `POST /api/v1/chat/stream` | 输出格式从 `event:X\ndata:{}\n\n` → `data:{"type":"..."}\n\n` | +| `POST /api/v1/chat/rewrite` | **不变**(本次不迁移) | +| `GET /api/v1/rag/{id}/stream` | **不变**(本次不迁移) | +| `conversationApi.*` | **不变**(REST CRUD 不受影响) | + +### Integration Test Scenarios + +1. **完整 RAG 聊天流**:发送带 KB 的消息 → 验证 thinking steps 顺序 → 验证 citations 出现在 text 之前 → 验证 text streaming → 验证 conversation_id 更新 URL +2. **无 KB 直聊**:发送无 KB 的消息 → 验证跳过 retrieve/rank/clean → 直接 generate +3. **中途 abort**:发送消息 → 等 text_delta 出现 → stop() → 验证 UI 停止更新且不崩溃 +4. **对话恢复**:创建对话 → 导航到 `/chat/:id` → 验证 initialMessages 加载 → 继续发送消息 +5. **后端错误**:配置无效 LLM provider → 发送消息 → 验证 error Part 到达前端 → error 状态显示 + +## Acceptance Criteria + +### Functional Requirements + +- [ ] 用户发送消息后,收到标准 Data Stream Protocol SSE 响应 +- [ ] thinking steps 实时显示各节点状态(running → done) +- [ ] citations 在文本生成前出现 +- [ ] citation 更新(同 id reconciliation)正确替换摘要 +- [ ] 文本流式显示 +- [ ] conversation_id 正确传递并更新 URL +- [ ] 对话可通过 URL 恢复(`/chat/:id`) +- [ ] 错误正确显示(LLM 错误、网络错误) +- [ ] abort 正常工作 + +### Non-Functional Requirements + +- [ ] 首 token 延迟 ≤ 2s(与旧端点持平) +- [ ] 流式文本无明显卡顿 +- [ ] 旧端点的所有功能在新端点中可用 +- [ ] 新增代码有测试覆盖 + +### Quality Gates + +- [ ] `pytest` 通过,chat 节点测试 ≥ 7 个 +- [ ] `vitest` 通过 +- [ ] `ruff lint` + `eslint` 通过 +- [ ] 手动 E2E 验证聊天流程 + +## Success Metrics + +- PlaygroundPage 代码行数减少 ≥ 40%(手动状态管理被 useChat 替代) +- 后端 chat.py 拆分为 7 个可独立测试的节点 +- SSE 协议标准化(可被任何 AI SDK 兼容客户端消费) +- error 事件有处理(从 0 处理到 100% 处理) + +## Dependencies & Prerequisites + +| 依赖 | 版本要求 | 说明 | +|------|---------|------| +| `ai` (npm) | ^5.0.0 | Vercel AI SDK core | +| `@ai-sdk/react` (npm) | ^2.0.0 | React hooks | +| `langgraph` (pip) | >=0.4.0 | 已安装 | +| `langchain-core` (pip) | >=0.3 | 已安装 | +| Python | >=3.12 | 已满足(LangGraph `get_stream_writer()` 需 ≥3.11) | +| React | >=18 | 已满足(当前 v19) | + +## Risk Analysis & Mitigation + +| 风险 | 概率 | 影响 | 缓解 | +|------|------|------|------| +| AI SDK 5.0 beta 不稳定 | 中 | 高 | 先在 `/stream/v2` 并行运行,确认稳定后切换;Pydantic AI 已有生产参考 | +| `data-*` Part 前端消费 API 不明确 | 低 | 中 | 已验证 `message.parts.filter()` 可用(见 brainstorm 技术验证) | +| LangGraph `get_stream_writer()` + `astream` 组合行为未预期 | 中 | 中 | Phase 1 先写节点测试验证事件发射 | +| `useChat` 的 `prepareSendMessagesRequest` 不支持所需的请求体格式 | 低 | 高 | 备选:自定义 Transport 类继承 `ChatTransport` | +| 并行运行两个端点增加维护成本 | 低 | 低 | Phase 3 快速切换并删除旧代码 | + +## Future Considerations + +- **Resumable Streams**:AI SDK 5.0 支持 `prepareReconnectToStreamRequest`,结合 LangGraph checkpointing 可实现断流续传 +- **Rewrite API 统一**:复用 `DataStreamWriter` 迁移 `/rewrite` 端点 +- **工具调用**:AI SDK 5.0 原生支持 `tool-input-start/delta/available` + `tool-output-available` Part +- **多模型并行**:LangGraph 支持并行节点,可扩展为多模型对比回答 + +## Sources & References + +### Origin + +- **Brainstorm document**: [docs/brainstorms/2026-03-12-chat-message-routing-chain-brainstorm.md](docs/brainstorms/2026-03-12-chat-message-routing-chain-brainstorm.md) — 关键决策:AI SDK 5.0 Data Stream Protocol、LangGraph StateGraph、data-* 自定义 Part + +### Internal References + +- LangGraph 现有模式: `backend/app/pipelines/graphs.py:39-68` +- LangGraph 节点模式: `backend/app/pipelines/nodes.py:16-33` +- 当前 chat 端点: `backend/app/api/v1/chat.py:99-330` +- 当前 SSE 客户端: `frontend/src/services/chat-api.ts:27-78` +- 当前 Playground 状态: `frontend/src/pages/PlaygroundPage.tsx:31-40,114-269` +- HITL 模式: `docs/solutions/integration-issues/langgraph-hitl-interrupt-api-snapshot-next.md` +- Sync 调用模式: `docs/solutions/performance-issues/blocking-sync-calls-asyncio-to-thread.md` +- RAG 性能: `docs/solutions/performance-issues/2026-03-12-rag-rich-citation-performance-analysis.md` +- LangGraph 规则: `.cursor/rules/langgraph-pipelines.mdc` + +### External References + +- Vercel AI SDK 5.0 Stream Protocol: https://sdk.vercel.ai/docs/ai-sdk-ui/stream-protocol +- AI SDK 5.0 + FastAPI working example: https://github.com/vercel/ai/issues/7496#issuecomment-2379142 +- Pydantic AI 的 AI SDK 协议实现: https://ai.pydantic.dev/ui/vercel-ai/ +- LangGraph `get_stream_writer()`: https://reference.langchain.com/python/langgraph/config/get_stream_writer +- LangGraph + FastAPI SSE guide: https://dev.to/kasi_viswanath/streaming-ai-agent-with-fastapi-langgraph-2025-26-guide-1nkn diff --git a/docs/solutions/integration-issues/2026-03-12-chat-routing-chain-langgraph-aisdk-rewrite.md b/docs/solutions/integration-issues/2026-03-12-chat-routing-chain-langgraph-aisdk-rewrite.md new file mode 100644 index 0000000..13b10bc --- /dev/null +++ b/docs/solutions/integration-issues/2026-03-12-chat-routing-chain-langgraph-aisdk-rewrite.md @@ -0,0 +1,239 @@ +--- +title: "Chat Message Routing Chain Rewrite — Monolith to LangGraph + AI SDK 5.0" +date: 2026-03-12 +category: integration-issues +tags: + - chat + - langgraph + - streaming + - sse + - vercel-ai-sdk + - data-stream-protocol + - refactor + - useChat +severity: medium +components: + - backend/app/api/v1/chat.py + - backend/app/pipelines/chat/ + - frontend/src/hooks/use-chat-stream.ts + - frontend/src/lib/chat-transport.ts + - frontend/src/pages/PlaygroundPage.tsx + - frontend/src/components/playground/MessageBubbleV2.tsx + - frontend/src/types/chat.ts + - backend/tests/test_chat_pipeline.py +symptoms: | + Monolithic 250+ line _stream_chat function; manual SSE parsing on frontend + with 15+ useState/useRef; brittle state management; hard to extend or test; + custom non-standard SSE protocol. +root_cause: | + Chat endpoint and frontend evolved incrementally without a structured + streaming protocol or declarative state management. Backend lacked pipeline + abstraction; frontend lacked a standard SSE consumption layer. +resolution: | + Replaced with LangGraph StateGraph (6 nodes), Vercel AI SDK 5.0 Data Stream + Protocol, and useChat hook. 23 pipeline tests added. Old endpoint removed. +--- + +# Chat Message Routing Chain Rewrite + +## Problem + +### Backend: monolithic `_stream_chat` + +The chat endpoint was a single `_stream_chat` async generator (~250 lines) that handled every step inline: service initialization, conversation history loading, RAG retrieval, citation ranking, excerpt cleaning, LLM streaming, and persistence. Steps were tightly coupled, making unit testing and extension difficult. + +SSE events used a non-standard format: + +``` +event: text_delta +data: {"delta": "Hello"} +``` + +### Frontend: manual SSE parsing + state explosion + +PlaygroundPage used 15+ `useState`/`useRef` hooks for messages, streaming status, citations, thinking steps, pending deltas, flush timers, and abort refs. A custom `streamChat` async generator in `chat-api.ts` manually parsed SSE lines. + +### Protocol mismatch + +No standard protocol between backend and frontend. Each side maintained its own event format, making it brittle to extend with new event types. + +## Root Cause Analysis + +1. **No shared protocol** — backend and frontend each implemented their own SSE format +2. **Monolithic backend** — all chat logic in one function instead of a composable pipeline +3. **Frontend state explosion** — each streamed field managed with separate state +4. **No pipeline abstraction** — backend lacked graph-based orchestration +5. **Service wiring in wrong place** — LLM/RAG created inside stream logic instead of at endpoint layer + +## Solution + +### Architecture + +``` +Frontend (useChat) + → DefaultChatTransport → POST /api/v1/chat/stream + → FastAPI StreamingResponse + → LangGraph StateGraph.astream(stream_mode="custom") + → understand → [has KB?] → retrieve → rank → clean → generate → persist + └─ no KB ──────────────────────→ generate → persist + → get_stream_writer() emits Data Stream Protocol events +``` + +### Backend: LangGraph StateGraph (6 nodes) + +```python +# backend/app/pipelines/chat/graph.py +graph = StateGraph(ChatState) +graph.add_node("understand", understand_node) +graph.add_node("retrieve", retrieve_node) +graph.add_node("rank", rank_node) +graph.add_node("clean", clean_node) +graph.add_node("generate", generate_node) +graph.add_node("persist", persist_node) +graph.set_entry_point("understand") +graph.add_conditional_edges("understand", _route_after_understand, + {"retrieve": "retrieve", "generate": "generate"}) +# ... edges ... +return graph.compile() +``` + +Each node uses `get_stream_writer()` to emit Data Stream Protocol events: + +```python +# In generate_node +writer = get_stream_writer() +writer({"type": "text-start", "id": text_id}) +async for token in llm.chat_stream(messages): + writer({"type": "text-delta", "id": text_id, "delta": token}) +writer({"type": "text-end", "id": text_id}) +``` + +### Protocol: Vercel AI SDK 5.0 Data Stream Protocol + +``` +data: {"type": "start", "messageId": "msg_xxx"} +data: {"type": "data-thinking", "data": {"step": "understand", ...}} +data: {"type": "text-start", "id": "text_xxx"} +data: {"type": "text-delta", "id": "text_xxx", "delta": "Hello"} +data: {"type": "text-end", "id": "text_xxx"} +data: {"type": "data-citation", "id": "cit-1", "data": {...}} +data: {"type": "data-conversation", "data": {"conversation_id": 123}} +data: {"type": "finish"} +data: [DONE] +``` + +### Frontend: `useChat` + custom transport + +```typescript +// frontend/src/lib/chat-transport.ts +export function createChatTransport(options) { + return new DefaultChatTransport({ + api: '/api/v1/chat/stream', + prepareSendMessagesRequest({ messages, trigger }) { + return { + body: { + message: getMessageText(lastUserMsg), + knowledge_base_ids: options.knowledgeBaseIds ?? [], + tool_mode: options.toolMode ?? 'qa', + // ... + }, + }; + }, + }); +} +``` + +```typescript +// frontend/src/hooks/use-chat-stream.ts +const chat = useChat({ + transport, + experimental_throttle: 80, + // ... +}); +const deferredMessages = useDeferredValue(chat.messages); +``` + +### Service injection via `_services` dict + +Services (LLM, RAG) are created at the endpoint layer and passed via a shared mutable dict in `config["configurable"]["_services"]` so all graph nodes share the same instances. + +```python +# Endpoint layer +services = await _init_services(db) +config = {"configurable": {"db": db, "_services": services}} + +# Any node +llm = get_chat_llm(config) # reads from _services dict +``` + +## Key Design Decisions + +| Decision | Rationale | +|----------|-----------| +| LangGraph StateGraph | Clear pipeline, conditional routing, easy to test/extend | +| `get_stream_writer()` + `stream_mode="custom"` | Nodes emit events directly; LangGraph handles streaming | +| Data Stream Protocol | Standard format compatible with `useChat` | +| Services at endpoint layer | Shared instances via `_services` dict; avoids deep-copy isolation | +| No checkpointer | Chat is stateless per request | +| `useDeferredValue` + `experimental_throttle: 80` | Reduces re-renders during streaming | +| `id`-based citation reconciliation | Same `data-citation` + same `id` → AI SDK updates existing Part | + +## Issues Encountered During Implementation + +### 1. LangGraph config deep-copy isolation + +**Problem**: `config["configurable"]` is deep-copied between nodes. Services injected in `understand_node` weren't visible in downstream nodes. + +**Fix**: Initialize services at the endpoint layer and pass them via `config["configurable"]["_services"]`. The `_services` dict itself is a nested mutable object that survives shallow copies. + +### 2. FastAPI StreamingResponse + Depends(get_db) + +**Problem**: DB session from `Depends(get_db)` can close before the streaming generator finishes. + +**Fix**: For the endpoint, use `Depends(get_db)` but ensure the session lifecycle extends through the full streaming response. For tests, monkeypatch `_init_services` to avoid DB-dependent initialization. + +### 3. SQLite test isolation + +**Problem**: Streaming endpoint creates its own session that may not see tables created by test fixtures (different SQLite connections). + +**Fix**: Monkeypatch `_init_services` to return mock LLM/RAG directly without querying `user_settings`. + +### 4. LangGraph `stream_mode` tuple unpacking + +**Problem**: `stream_mode=["custom"]` returns 2-tuples, but code expected 3-tuples. + +**Fix**: Use `stream_mode="custom"` (string, not list) for single-mode streaming. + +### 5. React duplicate key warning + +**Problem**: `ThinkingChain` used `step.step` as key, but multiple thinking events can share the same step name. + +**Fix**: Use `${step.step}-${index}` as key. + +## Prevention Strategies + +1. **Service injection**: Always initialize services at the entry point and share via a mutable container in config +2. **Streaming + DB**: Keep DB session creation inside the streaming generator; don't rely on dependency injection scoping +3. **Test design**: Mock heavy service initialization for streaming endpoint tests +4. **SDK versions**: Use string values for single-mode options; pin SDK versions +5. **React keys**: Use composite keys when items can repeat + +## Common Pitfalls + +| Pitfall | How to avoid | +|---------|-------------| +| Services not available in downstream nodes | Use shared `_services` dict initialized at endpoint | +| DB session closed mid-stream | Create session inside generator | +| Tests fail with "table not found" | Monkeypatch service init | +| Inconsistent stream tuple lengths | Use `stream_mode='custom'` (string) | +| React duplicate key warnings | Use `${type}-${index}` keys | + +## Related Documents + +- [LangGraph HITL Interrupt Pattern](langgraph-hitl-interrupt-api-snapshot-next.md) +- [Blocking Sync Calls — asyncio.to_thread](../performance-issues/blocking-sync-calls-asyncio-to-thread.md) +- [RAG Rich Citation Performance Analysis](../performance-issues/2026-03-12-rag-rich-citation-performance-analysis.md) +- [Chat Routing Chain Performance Analysis](../performance-issues/2026-03-12-chat-routing-chain-rewrite-performance-analysis.md) +- [Brainstorm](../../brainstorms/2026-03-12-chat-message-routing-chain-brainstorm.md) +- [Plan](../../plans/2026-03-12-feat-chat-message-routing-chain-rewrite-plan.md) +- [LangGraph Pipelines Rule](../../.cursor/rules/langgraph-pipelines.mdc) diff --git a/docs/solutions/performance-issues/2026-03-12-chat-routing-chain-rewrite-performance-analysis.md b/docs/solutions/performance-issues/2026-03-12-chat-routing-chain-rewrite-performance-analysis.md new file mode 100644 index 0000000..82472ee --- /dev/null +++ b/docs/solutions/performance-issues/2026-03-12-chat-routing-chain-rewrite-performance-analysis.md @@ -0,0 +1,205 @@ +--- +title: "Chat Message Routing Chain Rewrite — Performance Analysis" +date: 2026-03-12 +category: performance-issues +tags: + - chat + - langgraph + - streaming + - sse + - useChat + - performance +components: + - backend/app/api/v1/chat.py + - backend/app/pipelines/chat/ + - frontend/src/pages/PlaygroundPage.tsx + - frontend/src/hooks/useChatStream.ts +severity: medium +origin: docs/plans/2026-03-12-feat-chat-message-routing-chain-rewrite-plan.md +--- + +# Chat Message Routing Chain Rewrite — Performance Analysis + +## 1. Performance Summary + +This analysis evaluates the performance implications of migrating from the current monolithic `_stream_chat` async generator to a LangGraph StateGraph with 7 nodes, Data Stream Protocol, and Vercel AI SDK `useChat`. The rewrite introduces additional abstraction layers; the key question is whether the overhead is acceptable for the target SLA (first token ≤ 2s, smooth streaming). + +**Verdict**: The proposed architecture is **viable** with minor mitigations. Most overhead is in the **1–5ms per token** range and will be dominated by LLM latency (50–200ms/token). Critical mitigations: ensure `rank_node` batches Paper queries (current code already does), add optional frontend debounce if `useChat` causes jank, and validate LangGraph graph instance concurrency. + +--- + +## 2. Concern-by-Concern Analysis + +### 2.1 LangGraph Overhead: `get_stream_writer()` vs Direct Yield + +| Dimension | Assessment | +|-----------|------------| +| **Mechanism** | Current: `yield _sse("text_delta", {"delta": token})` directly from generator. Proposed: `writer({"type":"text-delta", ...})` inside `generate_node` → LangGraph queues to `stream_mode="custom"` → caller `async for mode, chunk` yields to SSE. | +| **Overhead per token** | ~0.5–2ms (estimate) | +| **Breakdown** | 1) `writer()` call → ContextVar lookup + queue put. 2) LangGraph `astream` iteration → async queue get. 3) Caller `yield f'data: {json.dumps(chunk)}\n\n'` → same as current. | +| **Dominant cost** | LLM token latency (50–200ms typical) >> 1–2ms framework overhead. | + +**Recommendation**: +- **Benchmark**: Add `t0 = time.monotonic()` before first `writer({"type":"text-delta",...})` and log `(time.monotonic() - t0) * 1000 / token_count` after generate completes. Target: < 2ms/token average. +- **Mitigation**: None required if benchmark confirms < 5ms/token. If higher, profile LangGraph internals (queue operations, ContextVar). + +**Estimate**: Overhead adds **< 1%** to total stream time for typical 500-token responses. + +--- + +### 2.2 State Serialization Between Nodes + +| Dimension | Assessment | +|-----------|------------| +| **When serialization happens** | LangGraph merges node return dict into state and passes to next node. For TypedDict state, this is **in-process dict merge** — no pickle/JSON across process boundaries. | +| **Large fields** | `rag_results` (list of dicts, ~10–50 items × ~1KB each = 10–50KB), `citations` (similar), `full_messages` (system + history + user + context, 5–50KB). Total: ~50–150KB per state. | +| **Cost** | Dict merge + reference copy. Python dict operations are O(n) in field count; copying 50–150KB of nested dicts is ~0.1–0.5ms per node transition. | +| **Node transitions** | 5–7 transitions (understand → retrieve → rank → clean → generate → persist). Total: ~1–3ms. | + +**Recommendation**: +- **Reduce state size**: Store `rag_results` / `citations` as references; avoid duplicating large blobs. LangGraph's reducer (if any) may copy; verify TypedDict default is shallow merge. +- **Lazy loading**: Consider storing `paper_ids` in state and fetching Paper metadata only in `rank_node`, not carrying full `rag_results` through every node. +- **Benchmark**: Log `len(json.dumps(state))` at each node entry; alert if > 200KB. + +**Estimate**: 1–3ms total for full graph execution — negligible vs RAG (500ms–2s) and LLM (2–10s). + +--- + +### 2.3 First Token Latency (TTFT) + +| Dimension | Assessment | +|-----------|------------| +| **Current flow** | understand (DB + LLM init) → retrieve (RAG) → rank (Paper batch) → clean (LLM parallel) → generate (LLM stream). First token = after all pre-generate steps + first LLM chunk. | +| **Proposed flow** | Same DAG. LangGraph executes nodes sequentially; no parallelism change. | +| **Added latency** | 1) Graph setup: `get_chat_graph()` + `config` + `initial_state` — ~0.5ms. 2) Node dispatch overhead — ~0.2ms per node. 3) `stream_mode=["updates","custom"]` — first custom chunk is emitted as soon as `generate_node` yields first token to writer. | +| **Critical path** | TTFT = understand + retrieve + rank + clean + first LLM token. LangGraph adds ~2–5ms to this path. | + +**Recommendation**: +- **Preserve parallelism**: Ensure `retrieve_node` uses `asyncio.gather()` for multi-KB RAG; `clean_node` uses `asyncio.gather()` for excerpt cleaning. Plan already specifies this. +- **No checkpointing for chat**: Do **not** enable LangGraph checkpointer for chat — it would add serialization on every node. Use `graph.compile()` without checkpointer. +- **Benchmark**: Measure TTFT (user send → first `text-delta` received) for both old and new endpoints. Plan target: ≤ 2s. If new endpoint exceeds by > 100ms, profile node dispatch. + +**Estimate**: +2–5ms TTFT. Acceptable if baseline is 1.5–2s. + +--- + +### 2.4 Frontend Re-renders: useChat vs 80ms Debounce + +| Dimension | Assessment | +|-----------|------------| +| **Current** | 80ms debounce: `text_delta` events buffered, `setMessages` called at most every 80ms. ~12–15 state updates for 1s of streaming. | +| **useChat** | AI SDK parses SSE, updates `messages` state on each `text-delta` Part. **No built-in debounce** — each token can trigger a re-render. | +| **Impact** | 50 tokens/s → 50 React re-renders/s. Each re-render: MessageBubble, MarkdownRenderer (full re-parse), CitationCardList. Per RAG performance analysis: this is **high severity** — 500 tokens = 500 full Markdown parses. | +| **useChat internals** | SDK appends to `parts` array; React batches state updates within same tick, but each SSE message is a separate tick. Expect 1 re-render per `text-delta` unless SDK implements internal batching (unconfirmed). | + +**Recommendation**: +- **P0 — Add debounce layer**: Wrap the streaming message's text in a debounced update before passing to MessageBubble. Options: + 1. **Custom hook**: `useDebouncedStreamingContent(message, delayMs=80)` — returns content that updates at most every 80ms. + 2. **useDeferredValue**: `const deferredContent = useDeferredValue(streamingContent)` — React 18 defers non-urgent updates; may reduce render frequency. + 3. **AI SDK option**: Check if `useChat` has `throttle` or `debounce` in v5; if not, file feature request. +- **P1 — MessageBubble memo**: Ensure `memo(MessageBubble)` and stable `citations`/`thinkingSteps` via `useMemo` so sibling messages don't re-render. +- **Benchmark**: Chrome Performance — measure Long Tasks during streaming. Target: < 50ms. If useChat causes 50+ re-renders/s and Long Tasks > 100ms, debounce is mandatory. + +**Estimate**: Without debounce, **regression** vs current 80ms approach. With debounce, parity or better. + +--- + +### 2.5 Memory Footprint + +| Dimension | Assessment | +|-----------|------------| +| **State in memory** | ChatState holds: `rag_results`, `citations`, `enhanced_citations`, `full_messages`, `assistant_content`. Peak during generate: all of the above + streaming `assistant_content` (grows to ~2–10KB for 500 tokens). | +| **Per-request estimate** | ~100–200KB for typical RAG chat. 10 concurrent users = ~1–2MB. | +| **LangGraph overhead** | Graph definition is shared (compiled once). Per-invocation: state dict + node stack. Negligible. | +| **Risk** | Very long conversations (20+ turns) with large history_messages could bloat `full_messages`. Plan caps history at 10 (from current code). | + +**Recommendation**: +- **Bound history**: Keep `history_messages[-10:]` as in current implementation. +- **Streaming content**: In `generate_node`, accumulate `assistant_content` in a string; avoid storing full token list. Current plan does this. +- **Monitor**: Add optional memory logging in dev: `import tracemalloc` at request start, `tracemalloc.get_traced_memory()` at end. Alert if > 5MB per request. + +**Estimate**: 100–200KB/request. Safe for 50+ concurrent users on typical 4GB app process. + +--- + +### 2.6 Concurrent Streams: Graph Instance Safety + +| Dimension | Assessment | +|-----------|------------| +| **Graph lifecycle** | Plan: `graph = get_chat_graph()` — compiled once, reused. LangGraph compiled graphs are **stateless** — they don't hold request-specific data. | +| **Per-request data** | `config` (with `db`, `thread_id`) and `initial_state` are passed per `astream()` call. Each invocation gets its own execution context. | +| **ContextVar** | `get_stream_writer()` uses ContextVar — each async task has its own context. FastAPI spawns a new task per request. **Safe**. | +| **DB session** | `db` is request-scoped (`Depends(get_db)`). Each request has its own session. **Safe**. | + +**Recommendation**: +- **No shared mutable state**: Ensure no global variables in nodes (e.g. don't cache LLM client in module scope per user). Plan injects `llm`/`rag` via `config["configurable"]` — correct. +- **Load test**: Run 10 concurrent streaming requests; verify no cross-talk (wrong messages to wrong users) and no connection/session errors. + +**Estimate**: **Safe** for concurrent use. LangGraph and FastAPI patterns support this. + +--- + +### 2.7 SSE Formatting: json.dumps Per Token + +| Dimension | Assessment | +|-----------|------------| +| **Current** | `_sse("text_delta", {"delta": token})` → `json.dumps({"delta": "x"})` per token. ~20–30 bytes per token (single char) or ~50–100 bytes (word). | +| **Proposed** | Same: `json.dumps({"type":"text-delta","id":text_id,"delta":token})` — slightly larger (~60–80 bytes) due to `type` and `id`. | +| **Cost** | `json.dumps` for 80-byte dict: ~2–5 µs (microseconds) in CPython. 500 tokens = 1–2.5ms total. | +| **Comparison** | Network send of 80 bytes at 1Gbps = 0.0006ms. CPU cost dominates; still negligible. | + +**Recommendation**: +- **No change needed**: `json.dumps` is not a bottleneck. If profiling ever shows it (unlikely), consider `orjson.dumps` (2–3× faster) or pre-built template for `text-delta` (e.g. `f'{{"type":"text-delta","id":"{id}","delta":{escape(delta)}}}'`) — only if proven hot. + +**Estimate**: < 3ms total for 500-token response. **Negligible**. + +--- + +## 3. Paper DB Query Batching (Correction) + +The plan and user's "Current architecture" mention "No batching of Paper DB queries (N+1)". **The current `_stream_chat` already batches** (lines 169–173 in `chat.py`): + +```python +paper_ids = list({pid for pid in (src.get("paper_id") for src in all_sources) if pid is not None}) +if paper_ids: + result = await db.execute(select(Paper).where(Paper.id.in_(paper_ids))) + papers_by_id = {p.id: p for p in result.scalars().all()} +``` + +**Action**: Ensure `rank_node` in the new pipeline uses the same batch pattern. The plan's node table says "批量加载 Paper 元数据" — confirm implementation matches. + +--- + +## 4. Recommended Actions (Prioritized) + +| Priority | Action | Impact | Effort | +|----------|--------|--------|--------| +| P0 | Add frontend debounce (80ms) for streaming text if useChat updates per token | Prevents render jank, Long Tasks | Low | +| P0 | Verify `rank_node` batches Paper query (`Paper.id.in_(paper_ids)`) | Prevents N+1 regression | Low | +| P1 | Disable LangGraph checkpointer for chat graph | Avoids serialization overhead | Trivial | +| P1 | Benchmark TTFT: old vs new endpoint (target ≤ 2s) | Validates SLA | Low | +| P2 | Log state size at node boundaries in dev (`len(json.dumps(state))`) | Early warning for state bloat | Low | +| P2 | Optional: `useDeferredValue` for streaming content as secondary mitigation | May reduce render frequency | Low | +| P3 | Consider `orjson` for SSE if profiling shows json.dumps > 5% of stream time | Unlikely to be needed | Low | + +--- + +## 5. Benchmark Checklist + +Before/after comparison: + +| Metric | Current | Target (New) | How to Measure | +|--------|---------|--------------|----------------| +| TTFT (send → first text-delta) | Baseline | ≤ baseline + 100ms | Backend log + frontend Performance API | +| Tokens per second (throughput) | Baseline | ≥ 90% of baseline | Count text-delta events / elapsed time | +| Frontend Long Tasks during stream | < 50ms | < 50ms | Chrome Performance, Long Task observer | +| Memory per request | — | < 500KB | tracemalloc or process memory diff | +| 10 concurrent streams | — | No errors, no cross-talk | Load test script | + +--- + +## 6. Related Documents + +- **Plan**: `docs/plans/2026-03-12-feat-chat-message-routing-chain-rewrite-plan.md` +- **RAG performance**: `docs/solutions/performance-issues/2026-03-12-rag-rich-citation-performance-analysis.md` +- **LangGraph rules**: `.cursor/rules/langgraph-pipelines.mdc` diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 95ca27d..c85ad68 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -10,8 +10,10 @@ "dependencies": { "@a2ui-sdk/react": "^0.4.0", "@a2ui-sdk/types": "^0.4.0", + "@ai-sdk/react": "^3.0.118", "@radix-ui/react-hover-card": "^1.1.15", "@tanstack/react-query": "^5.90.21", + "ai": "^6.0.116", "axios": "^1.13.6", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", @@ -213,6 +215,70 @@ "dev": true, "license": "MIT" }, + "node_modules/@ai-sdk/gateway": { + "version": "3.0.66", + "resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-3.0.66.tgz", + "integrity": "sha512-SIQ0YY0iMuv+07HLsZ+bB990zUJ6S4ujORAh+Jv1V2KGNn73qQKnGO0JBk+w+Res8YqOFSycwDoWcFlQrVxS4A==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.19", + "@vercel/oidc": "3.1.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/provider": { + "version": "3.0.8", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.8.tgz", + "integrity": "sha512-oGMAgGoQdBXbZqNG0Ze56CHjDZ1IDYOwGYxYjO5KLSlz5HiNQ9udIXsPZ61VWaHGZ5XW/jyjmr6t2xz2jGVwbQ==", + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/provider-utils": { + "version": "4.0.19", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.19.tgz", + "integrity": "sha512-3eG55CrSWCu2SXlqq2QCsFjo3+E7+Gmg7i/oRVoSZzIodTuDSfLb3MRje67xE9RFea73Zao7Lm4mADIfUETKGg==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.6" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/react": { + "version": "3.0.118", + "resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-3.0.118.tgz", + "integrity": "sha512-fBAix8Jftxse6/2YJnOFkwW1/O6EQK4DK68M9DlFmZGAzBmsaHXEPVS77sVIlkaOWCy11bE7434NAVXRY+3OsQ==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider-utils": "4.0.19", + "ai": "6.0.116", + "swr": "^2.2.5", + "throttleit": "2.1.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "react": "^18 || ~19.0.1 || ~19.1.2 || ^19.2.1" + } + }, "node_modules/@antfu/ni": { "version": "25.0.0", "resolved": "https://registry.npmjs.org/@antfu/ni/-/ni-25.0.0.tgz", @@ -2271,6 +2337,15 @@ "dev": true, "license": "MIT" }, + "node_modules/@opentelemetry/api": { + "version": "1.9.0", + "resolved": "https://registry.npmjs.org/@opentelemetry/api/-/api-1.9.0.tgz", + "integrity": "sha512-3giAOQvZiH5F9bMlMiv8+GSPMeqg0dbaeo58/0SlA9sxSqZhnUtxzX9/2FzyhS9sWQf5S0GJE0AKBrFqjpeYcg==", + "license": "Apache-2.0", + "engines": { + "node": ">=8.0.0" + } + }, "node_modules/@radix-ui/number": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/@radix-ui/number/-/number-1.1.1.tgz", @@ -4150,7 +4225,6 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz", "integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==", - "dev": true, "license": "MIT" }, "node_modules/@tailwindcss/node": { @@ -5083,6 +5157,15 @@ "integrity": "sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==", "license": "ISC" }, + "node_modules/@vercel/oidc": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/@vercel/oidc/-/oidc-3.1.0.tgz", + "integrity": "sha512-Fw28YZpRnA3cAHHDlkt7xQHiJ0fcL+NRcIqsocZQUSmbzeIKRpwttJjik5ZGanXP+vlA4SbTg+AbA3bP363l+w==", + "license": "Apache-2.0", + "engines": { + "node": ">= 20" + } + }, "node_modules/@vitejs/plugin-react": { "version": "5.1.4", "resolved": "https://registry.npmjs.org/@vitejs/plugin-react/-/plugin-react-5.1.4.tgz", @@ -5290,6 +5373,24 @@ "node": ">= 14" } }, + "node_modules/ai": { + "version": "6.0.116", + "resolved": "https://registry.npmjs.org/ai/-/ai-6.0.116.tgz", + "integrity": "sha512-7yM+cTmyRLeNIXwt4Vj+mrrJgVQ9RMIW5WO0ydoLoYkewIvsMcvUmqS4j2RJTUXaF1HphwmSKUMQ/HypNRGOmA==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/gateway": "3.0.66", + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.19", + "@opentelemetry/api": "1.9.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, "node_modules/ajv": { "version": "6.14.0", "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz", @@ -6872,7 +6973,6 @@ "version": "3.0.6", "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.6.tgz", "integrity": "sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg==", - "dev": true, "license": "MIT", "engines": { "node": ">=18.0.0" @@ -8460,6 +8560,12 @@ "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==", "license": "MIT" }, + "node_modules/json-schema": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/json-schema/-/json-schema-0.4.0.tgz", + "integrity": "sha512-es94M3nTIfsEPisRafak+HDLfHXnKBhV3vU5eqPcS3flIWqcxJWgXHXiey3YrpaNsanY5ei1VoYEbOzijuq9BA==", + "license": "(AFL-2.1 OR BSD-3-Clause)" + }, "node_modules/json-schema-traverse": { "version": "0.4.1", "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", @@ -12034,6 +12140,19 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/swr": { + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/swr/-/swr-2.4.1.tgz", + "integrity": "sha512-2CC6CiKQtEwaEeNiqWTAw9PGykW8SR5zZX8MZk6TeAvEAnVS7Visz8WzphqgtQ8v2xz/4Q5K+j+SeMaKXeeQIA==", + "license": "MIT", + "dependencies": { + "dequal": "^2.0.3", + "use-sync-external-store": "^1.6.0" + }, + "peerDependencies": { + "react": "^16.11.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, "node_modules/symbol-tree": { "version": "3.2.4", "resolved": "https://registry.npmjs.org/symbol-tree/-/symbol-tree-3.2.4.tgz", @@ -12085,6 +12204,18 @@ "url": "https://opencollective.com/webpack" } }, + "node_modules/throttleit": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/throttleit/-/throttleit-2.1.0.tgz", + "integrity": "sha512-nt6AMGKW1p/70DF/hGBdJB57B8Tspmbp5gfJ8ilhLnt7kkr2ye7hzD6NVG8GGErk2HWF34igrL2CXmNIkzKqKw==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/tiny-invariant": { "version": "1.3.3", "resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.3.tgz", @@ -13236,7 +13367,6 @@ "version": "4.3.6", "resolved": "https://registry.npmjs.org/zod/-/zod-4.3.6.tgz", "integrity": "sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg==", - "dev": true, "license": "MIT", "peer": true, "funding": { diff --git a/frontend/package.json b/frontend/package.json index bac701b..77f401f 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -15,8 +15,10 @@ "dependencies": { "@a2ui-sdk/react": "^0.4.0", "@a2ui-sdk/types": "^0.4.0", + "@ai-sdk/react": "^3.0.118", "@radix-ui/react-hover-card": "^1.1.15", "@tanstack/react-query": "^5.90.21", + "ai": "^6.0.116", "axios": "^1.13.6", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", diff --git a/frontend/src/components/playground/MessageBubbleV2.tsx b/frontend/src/components/playground/MessageBubbleV2.tsx new file mode 100644 index 0000000..eccb62d --- /dev/null +++ b/frontend/src/components/playground/MessageBubbleV2.tsx @@ -0,0 +1,155 @@ +import { memo, useState, useCallback, useMemo } from "react"; +import ReactMarkdown from "react-markdown"; +import remarkGfm from "remark-gfm"; +import remarkMath from "remark-math"; +import rehypeKatex from "rehype-katex"; +import rehypeHighlight from "rehype-highlight"; +import { User, Bot } from "lucide-react"; +import { cn } from "@/lib/utils"; +import remarkCitation from "@/lib/remark-citation"; +import InlineCitationTag from "./InlineCitationTag"; +import CitationCardList from "./CitationCardList"; +import ThinkingChain from "./ThinkingChain"; +import type { ThinkingStep } from "./ThinkingChain"; +import type { OmeletteUIMessage, Citation, ThinkingData } from "@/types/chat"; +import { getCitations, getThinkingSteps, getMessageText } from "@/types/chat"; + +interface MessageBubbleV2Props { + message: OmeletteUIMessage; + isStreaming?: boolean; +} + +function thinkingDataToStep(data: ThinkingData): ThinkingStep { + return { + step: data.step, + label: data.label, + status: data.status === "error" ? "error" : data.status, + detail: data.detail, + duration_ms: data.duration_ms, + summary: data.summary, + }; +} + +function MessageBubbleV2({ message, isStreaming }: MessageBubbleV2Props) { + const isUser = message.role === "user"; + + const content = useMemo(() => getMessageText(message), [message]); + const citations = useMemo(() => getCitations(message), [message]); + const thinkingData = useMemo(() => getThinkingSteps(message), [message]); + const thinkingSteps = useMemo( + () => thinkingData.map(thinkingDataToStep), + [thinkingData], + ); + + const hasThinkingSteps = thinkingSteps.length > 0; + const showLoading = isStreaming && !content && !hasThinkingSteps; + + const [highlightedCitationIndex, setHighlightedCitationIndex] = useState< + number | null + >(null); + + const citationMap = useMemo(() => { + const map = new Map(); + for (const c of citations) { + map.set(c.index, c); + } + return map; + }, [citations]); + + const handleClickCitation = useCallback((index: number) => { + setHighlightedCitationIndex(index); + }, []); + + const remarkPlugins = useMemo( + () => [remarkGfm, remarkMath, remarkCitation], + [], + ); + + const rehypePlugins = useMemo(() => [rehypeKatex, rehypeHighlight], []); + + const markdownComponents = useMemo( + () => + ({ + "citation-ref": ({ + index: citationIndex, + }: { + index?: number; + children?: React.ReactNode; + }) => { + if (citationIndex == null) return null; + return ( + + ); + }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + }) as any, + [citationMap, handleClickCitation], + ); + + return ( + + + {isUser ? : } + + + + {isUser ? ( + {content} + ) : ( + <> + {hasThinkingSteps && } + + {showLoading && ( + + + ... + + + )} + + {content && ( + + + {content} + + {isStreaming && ( + + )} + + )} + + + > + )} + + + ); +} + +export default memo(MessageBubbleV2); diff --git a/frontend/src/components/playground/ThinkingChain.tsx b/frontend/src/components/playground/ThinkingChain.tsx index 95f9142..838c2a1 100644 --- a/frontend/src/components/playground/ThinkingChain.tsx +++ b/frontend/src/components/playground/ThinkingChain.tsx @@ -96,11 +96,11 @@ function ThinkingChain({ steps }: ThinkingChainProps) { className="overflow-hidden" > - {steps.filter((s) => s.step !== 'complete').map((step) => { + {steps.filter((s) => s.step !== 'complete').map((step, index) => { const Icon = STEP_ICONS[step.step] ?? Search; const StatusIcon = STATUS_ICON[step.status]; return ( - + void; + onError?: (error: Error) => void; +} + +interface UseChatStreamReturn { + messages: OmeletteUIMessage[]; + sendMessage: (text: string) => void; + stop: () => void; + status: 'ready' | 'submitted' | 'streaming' | 'error'; + error: Error | undefined; + isStreaming: boolean; + lastAssistantCitations: Citation[]; + lastAssistantThinking: ThinkingData[]; + setMessages: (msgs: OmeletteUIMessage[] | ((prev: OmeletteUIMessage[]) => OmeletteUIMessage[])) => void; +} + +export function useChatStream({ + conversationId, + knowledgeBaseIds, + toolMode, + model, + initialMessages, + onConversationId, + onError, +}: UseChatStreamOptions): UseChatStreamReturn { + const onConversationIdRef = useRef(onConversationId); + onConversationIdRef.current = onConversationId; + + const transport = useMemo( + () => + createChatTransport({ + conversationId, + knowledgeBaseIds, + toolMode, + model, + }), + // eslint-disable-next-line react-hooks/exhaustive-deps + [conversationId, JSON.stringify(knowledgeBaseIds), toolMode, model], + ); + + const chat = useChat({ + transport, + dataPartSchemas: {} as Record, + messages: initialMessages, + onError, + onFinish({ messages: finishedMessages }) { + const lastAssistant = [...finishedMessages].reverse().find((m) => m.role === 'assistant'); + if (lastAssistant) { + const cid = getConversationId(lastAssistant); + if (cid && onConversationIdRef.current) { + onConversationIdRef.current(cid); + } + } + }, + experimental_throttle: 80, + }); + + const deferredMessages = useDeferredValue(chat.messages); + + const isStreaming = chat.status === 'streaming' || chat.status === 'submitted'; + + const lastAssistant = useMemo(() => { + return [...deferredMessages].reverse().find((m) => m.role === 'assistant'); + }, [deferredMessages]); + + const lastAssistantCitations = useMemo(() => { + return lastAssistant ? getCitations(lastAssistant) : []; + }, [lastAssistant]); + + const lastAssistantThinking = useMemo(() => { + return lastAssistant ? getThinkingSteps(lastAssistant) : []; + }, [lastAssistant]); + + const sendMessage = useCallback( + (text: string) => { + chat.sendMessage({ text }); + }, + [chat], + ); + + return { + messages: deferredMessages, + sendMessage, + stop: chat.stop, + status: chat.status, + error: chat.error, + isStreaming, + lastAssistantCitations, + lastAssistantThinking, + setMessages: chat.setMessages, + }; +} diff --git a/frontend/src/lib/chat-transport.ts b/frontend/src/lib/chat-transport.ts new file mode 100644 index 0000000..cf00e5c --- /dev/null +++ b/frontend/src/lib/chat-transport.ts @@ -0,0 +1,34 @@ +import { DefaultChatTransport } from 'ai'; +import type { OmeletteUIMessage } from '@/types/chat'; +import { getMessageText } from '@/types/chat'; + +interface ChatTransportOptions { + conversationId?: number; + knowledgeBaseIds?: number[]; + toolMode?: string; + model?: string; +} + +export function createChatTransport(options: ChatTransportOptions) { + return new DefaultChatTransport({ + api: '/api/v1/chat/stream', + prepareSendMessagesRequest({ messages, trigger }) { + const lastUserMsg = [...messages].reverse().find((m) => m.role === 'user'); + const messageText = lastUserMsg ? getMessageText(lastUserMsg) : ''; + + return { + body: { + message: messageText, + conversation_id: options.conversationId ?? null, + knowledge_base_ids: options.knowledgeBaseIds ?? [], + tool_mode: options.toolMode ?? 'qa', + model: options.model ?? null, + trigger, + }, + headers: { + 'Content-Type': 'application/json', + }, + }; + }, + }); +} diff --git a/frontend/src/pages/PlaygroundPage.tsx b/frontend/src/pages/PlaygroundPage.tsx index e43a820..91d888e 100644 --- a/frontend/src/pages/PlaygroundPage.tsx +++ b/frontend/src/pages/PlaygroundPage.tsx @@ -1,4 +1,4 @@ -import { useState, useRef, useCallback, useEffect } from 'react'; +import { useState, useRef, useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useParams, useNavigate } from 'react-router-dom'; import { useQuery } from '@tanstack/react-query'; @@ -16,44 +16,25 @@ import { PopoverTrigger, } from '@/components/ui/popover'; import ChatInput from '@/components/playground/ChatInput'; -import MessageBubble from '@/components/playground/MessageBubble'; +import MessageBubbleV2 from '@/components/playground/MessageBubbleV2'; import ChatHistorySidebar from '@/components/playground/ChatHistorySidebar'; import { useSidebarCollapsed } from '@/components/playground/sidebar-utils'; import { SidebarToggleButton } from '@/components/playground/SidebarToggleButton'; -import { streamChat, conversationApi } from '@/services/chat-api'; +import { conversationApi } from '@/services/chat-api'; import { projectApi } from '@/services/api'; -import type { ToolMode, Citation } from '@/types/chat'; -import { isCitation, normalizeCitation } from '@/types/chat'; -import type { LoadingStage } from '@/components/playground/MessageLoadingStages'; -import type { A2UIMessage } from '@a2ui-sdk/types/0.8'; -import type { ThinkingStep } from '@/components/playground/ThinkingChain'; - -interface LocalMessage { - id: string; - role: 'user' | 'assistant'; - content: string; - citations?: Citation[]; - isStreaming?: boolean; - loadingStage?: LoadingStage; - a2uiMessages?: A2UIMessage[]; - thinkingSteps?: ThinkingStep[]; -} +import { useChatStream } from '@/hooks/use-chat-stream'; +import type { ToolMode, OmeletteUIMessage, Citation } from '@/types/chat'; export default function PlaygroundPage() { const { t } = useTranslation(); const navigate = useNavigate(); const { conversationId: routeConvId } = useParams<{ conversationId: string }>(); - const [messages, setMessages] = useState([]); - const [isStreaming, setIsStreaming] = useState(false); - const [toolMode, setToolMode] = useState('qa'); - const [selectedKBs, setSelectedKBs] = useState([]); - const [conversationId, setConversationId] = useState(); - const [isRestoringConversation, setIsRestoringConversation] = useState(false); + const [toolModeOverride, setToolModeOverride] = useState(null); + const [selectedKBsOverride, setSelectedKBsOverride] = useState(null); + const [newConversationId, setNewConversationId] = useState(); const [sidebarCollapsed, setSidebarCollapsed] = useSidebarCollapsed(); const bottomRef = useRef(null); - const abortRef = useRef(null); - const hasRestoredRef = useRef(undefined); const { data: projectsData, isLoading: isLoadingProjects } = useQuery({ queryKey: ['projects'], @@ -61,224 +42,88 @@ export default function PlaygroundPage() { }); const projects = projectsData?.items ?? []; - useEffect(() => { - if (!routeConvId || hasRestoredRef.current === routeConvId) return; - hasRestoredRef.current = routeConvId; - const convIdNum = Number(routeConvId); - if (Number.isNaN(convIdNum)) return; - - setIsRestoringConversation(true); - conversationApi.get(convIdNum) - .then((conv) => { - setConversationId(conv.id); - setToolMode((conv.tool_mode as ToolMode) || 'qa'); - if (conv.knowledge_base_ids?.length) { - setSelectedKBs(conv.knowledge_base_ids); - } - const restored: LocalMessage[] = (conv.messages ?? []).map((m) => ({ - id: `restored-${m.id}`, - role: m.role as 'user' | 'assistant', - content: m.content, - citations: (m.citations as Citation[]) ?? [], - })); - setMessages(restored); - }) - .catch(() => { - setMessages([]); - setConversationId(undefined); - }) - .finally(() => setIsRestoringConversation(false)); - }, [routeConvId]); - - useEffect(() => { - bottomRef.current?.scrollIntoView({ behavior: 'smooth' }); - }, [messages]); + const convIdNum = routeConvId ? Number(routeConvId) : undefined; + const { data: restoredConv, isLoading: isRestoringConversation, isError: restoreFailed } = useQuery({ + queryKey: ['conversation', convIdNum], + queryFn: () => conversationApi.get(convIdNum!), + enabled: convIdNum != null && !Number.isNaN(convIdNum), + }); - const pendingDeltaRef = useRef(''); - const flushTimerRef = useRef | undefined>(undefined); - const assistantIdRef = useRef(''); + const conversationId = restoredConv?.id ?? newConversationId; + const toolMode = toolModeOverride ?? (restoredConv?.tool_mode as ToolMode) ?? 'qa'; + const selectedKBs = selectedKBsOverride ?? restoredConv?.knowledge_base_ids ?? []; - const flushDelta = useCallback(() => { - if (!pendingDeltaRef.current || !assistantIdRef.current) return; - const delta = pendingDeltaRef.current; - const aid = assistantIdRef.current; - pendingDeltaRef.current = ''; - setMessages((prev) => - prev.map((m) => - m.id === aid - ? { ...m, content: m.content + delta, loadingStage: 'generating' as LoadingStage } - : m, - ), + const setToolMode = useCallback((mode: ToolMode) => setToolModeOverride(mode), []); + const setSelectedKBs = useCallback((fn: number[] | ((prev: number[]) => number[])) => { + setSelectedKBsOverride((prev) => + typeof fn === 'function' ? fn(prev ?? []) : fn, ); }, []); - const handleSend = useCallback( - async (message: string) => { - const userMsg: LocalMessage = { - id: `u-${Date.now()}`, - role: 'user', - content: message, - }; - const assistantMsg: LocalMessage = { - id: `a-${Date.now()}`, - role: 'assistant', - content: '', - citations: [], - isStreaming: true, - loadingStage: 'searching', + const restoredMessages = useMemo((): OmeletteUIMessage[] => { + if (!restoredConv) return []; + return (restoredConv.messages ?? []).map((m) => { + const parts: OmeletteUIMessage['parts'] = [{ type: 'text' as const, text: m.content }]; + if (m.role === 'assistant' && m.citations) { + for (const cit of m.citations as Citation[]) { + parts.push({ type: 'data-citation' as const, id: `cit-${cit.index}`, data: cit }); + } + } + return { + id: `restored-${m.id}`, + role: m.role as 'user' | 'assistant', + parts, }; + }); + }, [restoredConv]); - assistantIdRef.current = assistantMsg.id; - pendingDeltaRef.current = ''; - - setMessages((prev) => [...prev, userMsg, assistantMsg]); - setIsStreaming(true); - - const controller = new AbortController(); - abortRef.current = controller; - - try { - const gen = streamChat( - { - conversation_id: conversationId, - message, - knowledge_base_ids: selectedKBs.length > 0 ? selectedKBs : undefined, - tool_mode: toolMode, - }, - controller.signal, - ); + const handleConversationId = useCallback( + (cid: number) => { + setNewConversationId(cid); + if (!routeConvId) { + navigate(`/chat/${cid}`, { replace: true }); + } + }, + [routeConvId, navigate], + ); - for await (const event of gen) { - if (event.event === 'text_delta') { - const delta = (event.data as { delta: string }).delta; - pendingDeltaRef.current += delta; - if (!flushTimerRef.current) { - flushTimerRef.current = setTimeout(() => { - flushTimerRef.current = undefined; - flushDelta(); - }, 80); - } - } else if (event.event === 'citation') { - if (!isCitation(event.data)) { - console.warn('Invalid citation event', event.data); - continue; - } - const citation = normalizeCitation(event.data as Record); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsg.id - ? { - ...m, - citations: [...(m.citations ?? []), citation], - loadingStage: 'citations' as LoadingStage, - } - : m, - ), - ); - } else if (event.event === 'thinking_step') { - const step = event.data as ThinkingStep; - setMessages((prev) => - prev.map((m) => { - if (m.id !== assistantMsg.id) return m; - const existing = m.thinkingSteps ?? []; - const idx = existing.findIndex((s) => s.step === step.step); - const updated = idx >= 0 - ? existing.map((s, i) => (i === idx ? { ...s, ...step } : s)) - : [...existing, step]; - return { ...m, thinkingSteps: updated }; - }), - ); - } else if (event.event === 'citation_enhanced') { - const { index, cleaned_excerpt } = event.data as { - index: number; - cleaned_excerpt: string; - }; - setMessages((prev) => - prev.map((m) => - m.id === assistantMsg.id - ? { - ...m, - citations: (m.citations ?? []).map((c) => - c.index === index ? { ...c, excerpt: cleaned_excerpt } : c, - ), - } - : m, - ), - ); - } else if (event.event === 'a2ui_surface') { - const a2uiMsg = event.data as unknown as A2UIMessage; - if (a2uiMsg.beginRendering || a2uiMsg.surfaceUpdate || a2uiMsg.dataModelUpdate) { - setMessages((prev) => - prev.map((m) => - m.id === assistantMsg.id - ? { - ...m, - a2uiMessages: [...(m.a2uiMessages ?? []), a2uiMsg], - } - : m, - ), - ); - } - } else if (event.event === 'message_end') { - if (flushTimerRef.current) { - clearTimeout(flushTimerRef.current); - flushTimerRef.current = undefined; - } - flushDelta(); + const { + messages, + sendMessage, + stop, + isStreaming, + setMessages, + } = useChatStream({ + conversationId, + knowledgeBaseIds: selectedKBs, + toolMode, + initialMessages: restoredMessages.length > 0 ? restoredMessages : undefined, + onConversationId: handleConversationId, + onError: (err) => toast.error(err.message || t('playground.streamError')), + }); - const cid = (event.data as { conversation_id?: number }) - .conversation_id; - if (cid) { - setConversationId(cid); - if (!routeConvId) { - navigate(`/chat/${cid}`, { replace: true }); - } - } - } - } - } catch (err) { - if ((err as Error).name !== 'AbortError') { - toast.error(t('playground.streamError')); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsg.id - ? { ...m, content: m.content || t('playground.streamError') } - : m, - ), - ); - } - } finally { - if (flushTimerRef.current) { - clearTimeout(flushTimerRef.current); - flushTimerRef.current = undefined; - } - flushDelta(); + useEffect(() => { + bottomRef.current?.scrollIntoView({ behavior: 'smooth' }); + }, [messages]); - setMessages((prev) => - prev.map((m) => - m.id === assistantMsg.id - ? { ...m, isStreaming: false, loadingStage: 'complete' as LoadingStage } - : m, - ), - ); - setIsStreaming(false); - abortRef.current = null; - assistantIdRef.current = ''; - } + const handleSend = useCallback( + (message: string) => { + sendMessage(message); }, - [conversationId, selectedKBs, toolMode, t, routeConvId, navigate, flushDelta], + [sendMessage], ); const handleStop = useCallback(() => { - abortRef.current?.abort(); - }, []); + stop(); + }, [stop]); const handleNewChat = () => { - abortRef.current?.abort(); + stop(); + setMessages([]); setMessages([]); - setConversationId(undefined); - setIsStreaming(false); - hasRestoredRef.current = undefined; + setNewConversationId(undefined); + setToolModeOverride(null); + setSelectedKBsOverride(null); navigate('/', { replace: true }); }; @@ -294,7 +139,7 @@ export default function PlaygroundPage() { return ; } - if (routeConvId && !conversationId && !isRestoringConversation && hasRestoredRef.current === routeConvId) { + if (routeConvId && !conversationId && !isRestoringConversation && restoreFailed) { return ( - ))} diff --git a/frontend/src/services/chat-api.ts b/frontend/src/services/chat-api.ts index 2dcfedc..2888deb 100644 --- a/frontend/src/services/chat-api.ts +++ b/frontend/src/services/chat-api.ts @@ -3,8 +3,6 @@ import type { PaginatedData } from '@/lib/api'; import type { Conversation, ConversationCreate, - ChatStreamRequest, - SSEEvent, } from '@/types/chat'; export const conversationApi = { @@ -24,60 +22,6 @@ export const conversationApi = { api.delete(`/conversations/${id}`).then(r => r.data), }; -export async function* streamChat( - request: ChatStreamRequest, - signal?: AbortSignal, -): AsyncGenerator { - const response = await fetch('/api/v1/chat/stream', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(request), - signal, - }); - - if (!response.ok) { - throw new Error(`Chat stream error: ${response.status}`); - } - - const reader = response.body?.getReader(); - if (!reader) throw new Error('No response body'); - - const decoder = new TextDecoder(); - let buffer = ''; - - try { - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - buffer += decoder.decode(value, { stream: true }); - const lines = buffer.split('\n'); - buffer = lines.pop() ?? ''; - - let currentEvent = ''; - let currentData = ''; - - for (const line of lines) { - if (line.startsWith('event: ')) { - currentEvent = line.slice(7).trim(); - } else if (line.startsWith('data: ')) { - currentData = line.slice(6); - } else if (line === '' && currentEvent && currentData) { - try { - yield { event: currentEvent, data: JSON.parse(currentData) }; - } catch { - yield { event: currentEvent, data: { raw: currentData } }; - } - currentEvent = ''; - currentData = ''; - } - } - } - } finally { - reader.releaseLock(); - } -} - export const settingsApi = { get: () => api.get>('/settings').then(r => r.data), diff --git a/frontend/src/types/chat.ts b/frontend/src/types/chat.ts index fc1ea66..bf84acb 100644 --- a/frontend/src/types/chat.ts +++ b/frontend/src/types/chat.ts @@ -1,3 +1,9 @@ +import type { UIMessage, UIMessagePart } from 'ai'; + +// --------------------------------------------------------------------------- +// Domain types +// --------------------------------------------------------------------------- + export interface Conversation { id: number; title: string; @@ -80,3 +86,59 @@ export interface ChatStreamRequest { model?: string; tool_mode?: string; } + +// --------------------------------------------------------------------------- +// AI SDK 5.0 data part types (maps to backend data-* stream events) +// --------------------------------------------------------------------------- + +export interface ThinkingData { + step: string; + label: string; + status: 'running' | 'done' | 'error'; + detail?: string; + duration_ms?: number; + summary?: string; +} + +export interface ConversationData { + conversation_id: number; +} + +export type OmeletteDataParts = { + citation: Citation; + thinking: ThinkingData; + conversation: ConversationData; +}; + +export type OmeletteUIMessage = UIMessage; +export type OmelettePart = UIMessagePart>; + +// --------------------------------------------------------------------------- +// Part extraction helpers +// --------------------------------------------------------------------------- + +export function getCitations(message: OmeletteUIMessage): Citation[] { + return message.parts + .filter((p): p is { type: 'data-citation'; id?: string; data: Citation } => p.type === 'data-citation') + .map((p) => p.data); +} + +export function getThinkingSteps(message: OmeletteUIMessage): ThinkingData[] { + return message.parts + .filter((p): p is { type: 'data-thinking'; id?: string; data: ThinkingData } => p.type === 'data-thinking') + .map((p) => p.data); +} + +export function getConversationId(message: OmeletteUIMessage): number | undefined { + const part = message.parts.find( + (p): p is { type: 'data-conversation'; id?: string; data: ConversationData } => p.type === 'data-conversation', + ); + return part?.data.conversation_id; +} + +export function getMessageText(message: OmeletteUIMessage): string { + return message.parts + .filter((p): p is { type: 'text'; text: string } => p.type === 'text') + .map((p) => p.text) + .join(''); +}
{content}