forked from IammSwanand/Inscribe.AI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsearch.py
More file actions
107 lines (87 loc) · 3.57 KB
/
search.py
File metadata and controls
107 lines (87 loc) · 3.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# search.py
import os
import chromadb
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain.retrievers import MultiQueryRetriever
# 👈 NEW IMPORTS for Compression
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from dotenv import load_dotenv
load_dotenv()
CHROMA_DIR = os.getenv("CHROMA_PERSIST_DIR", "./chroma_db")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
COLLECTION_NAME = "legal_docs"
# Embeddings
hf = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
def get_retrieval_qa(model_name="llama-3.1-8b-instant"):
"""Return RetrievalQA chain using Groq LLM"""
client = chromadb.PersistentClient(path=CHROMA_DIR)
collection = client.get_or_create_collection(name=COLLECTION_NAME)
vectordb = Chroma(
persist_directory=CHROMA_DIR,
embedding_function=hf,
collection_name=COLLECTION_NAME,
client=client,
)
llm = ChatGroq(
model=model_name,
groq_api_key=GROQ_API_KEY,
temperature=0,
max_tokens=1024,
)
# 1. Base Retriever: Defines how to search the vector store
base_retriever = vectordb.as_retriever(search_kwargs={"k": 10}) # 👈 INCREASE K to retrieve more context for the filter
# 2. MultiQueryRetriever (Query Division): Generates sub-queries
mq_retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=llm
)
# --- NEW: Contextual Compression Layer ---
# 3. Compressor: Uses the LLM to extract only the highly relevant parts from the retrieved chunks
compressor = LLMChainExtractor.from_llm(llm)
# 4. Final Retriever: Combines MultiQuery and Compression
# This retriever executes MultiQuery first, then passes all chunks to the compressor.
final_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=mq_retriever
)
# --- END OF COMPRESSION LAYER ---
# 🔑 STRICT + STRUCTURED PROMPT
CUSTOM_PROMPT_TEMPLATE = """
You are highly efficient legal assistant. Use ONLY the context below to answer the question.
If the context does not contain the answer, reply: "Not found in the documents."
Cite the source inline by referencing the source_file and page number in square brackets,
right after the relevant sentence (example: [contract.docx, page 2]).
Format the response in a clear and structured way with headings and bullet points.
---------------------
Context:
{context}
---------------------
Question:
{question}
Answer (use only the context above):
"""
CUSTOM_PROMPT = PromptTemplate(
template=CUSTOM_PROMPT_TEMPLATE,
input_variables=["context", "question"]
)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=final_retriever, # 👈 Pass the final, filtered retriever
return_source_documents=True,
chain_type_kwargs={"prompt": CUSTOM_PROMPT},
)
return qa
def answer_query(query: str):
qa = get_retrieval_qa()
result = qa.invoke({"query": query})
# 🔑 Neatly format results
structured_answer = "### 📄 Answer\n" + result["result"] + "\n\n"
# NOTE: Since you commented out the source code formatting in your last version,
# I've kept it commented out here.
return {"result": structured_answer}