-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrag.py
More file actions
132 lines (101 loc) · 5.04 KB
/
rag.py
File metadata and controls
132 lines (101 loc) · 5.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from dotenv import load_dotenv
load_dotenv(dotenv_path=".env", override=True)
from langsmith import traceable
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_openai import ChatOpenAI
from langchain.schema import Document
from typing import List, Optional
from typing_extensions import TypedDict
from pydantic import BaseModel, Field
from datastore import retriever
llm = ChatOpenAI(model_name="gpt-4o", temperature=0)
# ========================================================================================================================
# CORE RAG FUNCTIONS
# ========================================================================================================================
@traceable(run_type="retriever")
def retrieve_documents(question: str) -> list:
"""Retrieve documents from vector datastore"""
print("Retrieving documents...\n")
# Retrieval
documents = retriever.invoke(question)
return documents
@traceable
def generate_response(question: str, documents: list):
"""Generate response using retrieved documents"""
print("Reviewing documents...\n")
rag_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
Question: {question}
Context: {context}
Answer:"""
formatted_docs = "\n\n".join(doc.page_content for doc in documents)
rag_prompt_formatted = rag_prompt.format(context=formatted_docs, question=question)
generation = llm.invoke([SystemMessage(content=rag_prompt_formatted), HumanMessage(content=question)])
return generation
class GradeDocuments(BaseModel):
"""Binary score for relevance check on retrieved documents."""
binary_score: str = Field(
description="Documents are relevant to the question, 'yes' or 'no'"
)
# ========================================================================================================================
# GUARDRAILS AND REFLECTIONS
# ========================================================================================================================
@traceable
def grade_documents(question: str, documents: list):
"""
Determines whether the retrieved documents are relevant to the question. Filters documents down to relevant docs
"""
grade_documents_llm = llm.with_structured_output(GradeDocuments)
grade_documents_system_prompt = """You are a grader assessing relevance of a retrieved document to a user question. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_documents_prompt = "Here is the retrieved document: \n\n {document} \n\n Here is the user question: \n\n {question}"
# Score each doc
filtered_docs = []
for d in documents:
grade_documents_prompt_formatted = grade_documents_prompt.format(document=d.page_content, question=question)
score = grade_documents_llm.invoke(
[SystemMessage(content=grade_documents_system_prompt)] + [HumanMessage(content=grade_documents_prompt_formatted)]
)
grade = score.binary_score
if grade == "yes":
filtered_docs.append(d)
else:
continue
return filtered_docs
@traceable
def decide_to_generate(filtered_documents: list):
"""
Determines whether to generate an answer, or to terminate execution if output does not pass guardrails
"""
if not filtered_documents:
return False # All documents have been filtered, so we will re-generate a new query
else:
return True
# ========================================================================================================================
# COMPILED RAG APPLICATION
# ========================================================================================================================
@traceable
def rag(question: str):
documents = retrieve_documents(question)
filtered_docs = grade_documents(question, documents)
approved = decide_to_generate(filtered_docs)
answer = "No relevant documents found. Try a different query."
if approved:
answer = generate_response(question, filtered_docs)
answer = answer.content
return {"answer": answer}
def run():
print("\nHi! I'm a basic RAG chatbot.\nAsk a question to receive a RAG result. Note: I do not remember previous messages, so ask one question at a time!\n")
while True:
user = input('User (q to quit): ')
if user in {'q', 'Q'}:
print('Goodbye!')
break
question = user
answer = rag(question)
print(answer["answer"])
if __name__ == "__main__":
run()