Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ jobs:
- uses: actions/checkout@v4
- uses: ./.github/actions/setup
- run: poetry install --with checks
- run: poetry run invoke checks.format
- run: poetry run invoke checks.code
- run: poetry run invoke checks.type
- run: poetry run invoke checks.security

- name: Run Ruff
run: poetry run ruff check .

- name: Run Mypy
run: poetry run mypy src/

- name: Run Tests
run: poetry run pytest
98 changes: 89 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,33 @@ packages = [
{ include = "cp_genie", from = "src" }
]

# Ruff
[tool.ruff]
fix = true
indent-width = 4
line-length = 100
target-version = "py312"

[tool.ruff.format]
docstring-code-format = true

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.lint.per-file-ignores]
"tests/*.py" = ["D100", "D103"]


# Checks
[tool.poetry.group.checks]
optional = true

[tool.poetry.group.checks.dependencies]
ruff = "^0.4.0"
mypy = "^1.8.0"
pytest = "^8.1.1"


[tool.poetry.scripts]
start = "run:main"

Expand Down
3 changes: 2 additions & 1 deletion src/cp_genie/api/v1/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from fastapi import Request

from cp_genie.domain.rag.normal import NormalRAG
from cp_genie.domain.rag.agentic import AgenticRAG
from cp_genie.api.v1.schema import ChatRequest, ChatResponse
from cp_genie.infrastructure.chat_memory import get_by_session_id

router = APIRouter(tags=["chat"])

RAG_CLASSES = {
"normal": NormalRAG,
# "agentic": AgenticRAG,
"agentic": AgenticRAG,
}


Expand Down
122 changes: 78 additions & 44 deletions src/cp_genie/domain/rag/agentic.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,109 @@
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableLambda
from langchain.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from cp_genie.domain.rag.base import State
from langchain.agents import Tool, AgentExecutor, create_tool_calling_agent
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode, tools_condition
from typing import List
from langchain_core.documents import Document


class AgenticRAG:
def __init__(self, llm, tools, memory):
def __init__(self, llm, retriever, memory):
self.llm = llm
self.tools = tools # list of Tool objects
self.retriever = retriever
self.memory = memory
self.chain = self._build_graph()

def _build_graph(self) -> StateGraph:
def query(state) -> State:
self.memory.add_user_message(state["question"])
llm_with_tools = self.llm.bind_tools(self.tools)
response = llm_with_tools.invoke(state["messages"])
return {"messages": state["messages"] + [response]}

def retrieve(state) -> State:
tool_inputs = {"input": state["question"]}
tool_result = agent_executor.invoke(tool_inputs)
return {**state, "context": tool_result.get("output", [])}

def generate(state) -> State:
result = combine_chain.invoke(
{
"messages": state["messages"],
"context": state["context"],
"question": state["question"],
}
)
self.memory.add_user_message(state["question"])
self.memory.add_ai_message(result)
return {**state, "output": result}
@tool
def retrieve(query: str) -> List[Document]:
"""
Retrieves information related to the input query
from a vector database containing information
on Computer Engineering at Chulalongkorn University.
Use this tool ONLY when the user asks a question that requires specific
knowledge about the university or department. Do not use for general conversation.
"""
docs = self.retriever.invoke(query)
return docs

tools = [retrieve]
tool_node = ToolNode(tools)
llm_with_tools = self.llm.bind_tools(tools)

prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"Use the available tools to retrieve useful context and answer as concisely as possible.",
),
("system", "Answer as concisely as possible."),
(
"human",
"chat history: {messages}\nretrieved context: {context}\nquestion: {question}",
"chat history: {messages}\nretrieved context: {context}\n",
),
]
)
combine_chain = create_stuff_documents_chain(self.llm, prompt)

# Set up an agent to choose tools
agent = create_tool_calling_agent(self.llm, self.tools)
agent_executor = AgentExecutor(agent=agent, tools=self.tools, verbose=True)
def query_or_respond(state: State) -> dict:
print("--- Agent Node: Query or Respond ---")
messages = state["messages"]
response = llm_with_tools.invoke(messages)
return {"messages": [response]}

def generate(state: State) -> dict:
print("--- Generate Node ---")
messages = state["messages"]
last_message = messages[-1]

if not isinstance(last_message, ToolMessage):
raise ValueError(
"Last message is not a ToolMessage. Generation node expects tool output."
)

retrieved_docs = last_message.content
if not isinstance(retrieved_docs, list) or not all(
isinstance(doc, Document) for doc in retrieved_docs
):
if isinstance(retrieved_docs, str):
retrieved_docs = [Document(page_content=retrieved_docs)]
else:
retrieved_docs = [Document(page_content=str(retrieved_docs))]

generation = combine_chain.invoke(
{
"messages": messages,
"context": retrieved_docs,
}
)

self.memory.add_ai_message(generation)
return {"messages": [generation]}

graph = StateGraph(State)
graph.add_node("retrieve", RunnableLambda(retrieve))
graph.add_node("generate", RunnableLambda(generate))
graph.add_node("agent", query_or_respond)
graph.add_node("tools", tool_node)
graph.add_node("generate", generate)

graph.set_entry_point("agent")

graph.add_conditional_edges(
"agent",
tools_condition,
{
"tools": "tools",
END: END,
},
)

graph.set_entry_point("retrieve")
graph.add_edge("retrieve", "generate")
graph.add_edge("tools", "generate")
graph.add_edge("generate", END)

return graph.compile()

def invoke(self, input: dict) -> State:
state: State = {
"messages": self.memory.messages,
"question": input["input"],
def invoke(self, input) -> State:
self.memory.add_user_message(HumanMessage(content=input))
initial_state: State = {
"messages": self.memory.get_messages(),
"context": [],
"output": "",
}
return self.chain.invoke(state)
return self.chain.invoke(initial_state)
Loading
Loading