Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
355 changes: 53 additions & 302 deletions backend/app/api/v1/chat.py

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions backend/app/pipelines/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Chat pipeline — LangGraph StateGraph for streaming chat with RAG."""

from app.pipelines.chat.graph import create_chat_pipeline

__all__ = ["create_chat_pipeline"]
49 changes: 49 additions & 0 deletions backend/app/pipelines/chat/config_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Helpers for accessing request-scoped services from LangGraph config.

LangGraph shallow-copies the config dict for each node, but nested mutable
objects remain shared. We store a ``_services`` dict in ``configurable``
so that ``understand_node`` can create ``llm``/``rag`` once and downstream
nodes read them through the same reference.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
from sqlalchemy.ext.asyncio import AsyncSession

from app.services.llm.client import LLMClient
from app.services.rag_service import RAGService


def _services(config: RunnableConfig) -> dict[str, Any]:
"""Return the shared mutable services dict, creating it if needed."""
cfg = config["configurable"]
if "_services" not in cfg:
cfg["_services"] = {}
return cfg["_services"]


def get_chat_db(config: RunnableConfig) -> AsyncSession:
return config["configurable"]["db"]


def get_chat_llm(config: RunnableConfig) -> LLMClient:
return _services(config)["llm"]


def get_chat_rag(config: RunnableConfig) -> RAGService:
return _services(config)["rag"]


def set_chat_services(config: RunnableConfig, *, llm: Any, rag: Any) -> None:
"""Called by understand_node to share LLM and RAG with downstream nodes."""
svc = _services(config)
svc["llm"] = llm
svc["rag"] = rag


def get_configurable(config: RunnableConfig) -> dict[str, Any]:
return config["configurable"]
57 changes: 57 additions & 0 deletions backend/app/pipelines/chat/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Chat pipeline graph definition.

Flow:
understand → [has KB?] → retrieve → rank → clean → generate → persist
└─ no KB ──────────────────────→ generate → persist
"""

from __future__ import annotations

from langgraph.graph import END, StateGraph

from app.pipelines.chat.nodes import (
clean_node,
generate_node,
persist_node,
rank_node,
retrieve_node,
understand_node,
)
from app.pipelines.chat.state import ChatState


def _route_after_understand(state: ChatState) -> str:
"""Skip RAG nodes when no knowledge bases are selected."""
kb_ids = state.get("knowledge_base_ids", [])
if kb_ids:
return "retrieve"
return "generate"


def create_chat_pipeline():
"""Compile the chat StateGraph.

No checkpointer — chat streams are stateless one-shot invocations.
"""
graph = StateGraph(ChatState)

graph.add_node("understand", understand_node)
graph.add_node("retrieve", retrieve_node)
graph.add_node("rank", rank_node)
graph.add_node("clean", clean_node)
graph.add_node("generate", generate_node)
graph.add_node("persist", persist_node)

graph.set_entry_point("understand")
graph.add_conditional_edges(
"understand",
_route_after_understand,
{"retrieve": "retrieve", "generate": "generate"},
)
graph.add_edge("retrieve", "rank")
graph.add_edge("rank", "clean")
graph.add_edge("clean", "generate")
graph.add_edge("generate", "persist")
graph.add_edge("persist", END)

return graph.compile()
Loading
Loading