diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..d29023f --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +vectorstore/* filter=lfs diff=lfs merge=lfs -text +vectorstore/*.sqlite filter=lfs diff=lfs merge=lfs -text +vectorstore/**/*.bin filter=lfs diff=lfs merge=lfs -text diff --git a/evaluation/evaluation.py b/evaluation/evaluation.py new file mode 100644 index 0000000..ad424e9 --- /dev/null +++ b/evaluation/evaluation.py @@ -0,0 +1,125 @@ +# This script evaluates a set of question-answer pairs using the RAG chatbot. +import json + +from loguru import logger +from paper_query.chatbots import HybridQueryChatbot +from paper_query.constants import METRICS_JSON, assets_dir +from ragchecker import RAGChecker, RAGResults +from ragchecker.metrics import all_metrics + +QNA_BENCHMARKS = "evaluation/qna_benchmarks.json" +INTERMIDATE_RESULTS = "evaluation/qna_benchmarks_answered.json" + + +def answer_queries(qa: dict) -> dict: + """Answer test questions in the RAG benchmark using the Chatbot. + + Parameters + ---------- + qa : dict + A dictionary containing the benchmark questions and answers. + + { + "results": [ # A list of QA pairs + { + "query_id": "000", + "query": "This is the question for the first example", + "gt_answer": "This is the ground truth answer for the first example" + }, + { + "query_id": "001", + "query": "This is the question for the second example", + "gt_answer": "This is the ground truth answer for the second example" + }, + ... + ] + } + + Returns + ------- + dict + A dictionary containing the benchmark questions and answers with responses and retrieved + contexts. + + { + "results": [ # A list of QA pairs with responses + { + "query_id": "000", + "query": "This is the question for the first example", + "gt_answer": "This is the ground truth answer for the first example", + "response": "This is the response generated by the chatbot", + "retrieved_context": [ + {"doc_id": "doc1", "text": "Content from document 1"}, + {"doc_id": "doc2", "text": "Content from document 2"} + ] + }, + ... + ] + } + + """ + if not qa or "results" not in qa: + raise ValueError("Input must contain 'results' key") + + chatbot = HybridQueryChatbot( + model_name="gpt-4.1", + model_provider="openai", + paper_path=str(assets_dir / "strainrelief_preprint.pdf"), + references_dir=str(assets_dir / "references"), + ) + + for qa_pair in qa["results"]: + if "query" not in qa_pair.keys() or "gt_answer" not in qa_pair.keys(): + raise ValueError("Each QA pair must contain 'query' and 'gt_answer' keys") + + # Stream the response (consuming all chunks) + for _ in chatbot.stream_response(qa_pair["query"]): + pass + + # Extract response and context from chat history + last_message = chatbot.chat_history[-1] + qa_pair["response"] = last_message.content + qa_pair["retrieved_context"] = _extract_context(last_message) + + return qa + + +def _extract_context(message) -> list[dict]: + """Extract and format retrieved context from chat message.""" + context_data = message.response_metadata.get("context", []) + return [ + {"doc_id": context["Document Title"], "text": context["Content"]} + for context in context_data + if "Document Title" in context and "Content" in context + ] + + +if __name__ == "__main__": + # Load the RAG benchmark Q&As from a JSON file + with open(QNA_BENCHMARKS) as f: + qa = json.load(f) + + # Answer each query using the Chatbot + qa_answered = answer_queries(qa) + + # Save the updated Q&As with responses and retrieved documents + with open(INTERMIDATE_RESULTS, "w") as f: + json.dump(qa_answered, f, indent=4) + + # Initialise RAGResults from the answered Q&As + rag_results = RAGResults.from_dict(qa_answered) + + # Set up the evaluator + evaluator = RAGChecker( + extractor_name="openai/gpt-4.1", + checker_name="openai/gpt-4.1", + batch_size_extractor=32, + batch_size_checker=32, + ) + + # Evaluate results with selected metrics or certain groups + # e.g., retriever_metrics, generator_metrics, all_metrics + evaluator.evaluate(rag_results, all_metrics, save_path=METRICS_JSON) + + logger.info(rag_results) + logger.info(f"Evaluation complete. Metrics saved to {METRICS_JSON}.") diff --git a/evaluation/qna_benchmarks.json b/evaluation/qna_benchmarks.json new file mode 100644 index 0000000..d0aad58 --- /dev/null +++ b/evaluation/qna_benchmarks.json @@ -0,0 +1,14 @@ +{ + "results": [ + { + "query_id": "000", + "query": "This is the question for the first example", + "gt_answer": "This is the ground truth answer for the first example" + }, + { + "query_id": "001", + "query": "This is the question for the second example", + "gt_answer": "This is the ground truth answer for the second example" + } + ] + } diff --git a/requirements.in b/requirements.in index 2996245..4c600e5 100644 --- a/requirements.in +++ b/requirements.in @@ -9,3 +9,4 @@ chromadb sentence-transformers accelerate loguru +ragchecker diff --git a/src/paper_query/chatbots/_chatbots.py b/src/paper_query/chatbots/_chatbots.py index 6110f24..2efdf42 100644 --- a/src/paper_query/chatbots/_chatbots.py +++ b/src/paper_query/chatbots/_chatbots.py @@ -31,7 +31,9 @@ def __init__(self, model_name: str, model_provider: str): self.model = setup_model(model_name, model_provider) self.chain = setup_chain(self.model, prompt=base_prompt) - def stream_response(self, user_input: str, chain_args: dict = {}) -> Generator[str, None, None]: + def stream_response( + self, user_input: str, chain_args: dict = {}, metadata: dict | None = None + ) -> Generator[str, None, None]: """Process user input and stream AI response.""" # Add user message to history before streaming logger.debug(f'User input:\n"{user_input}"') @@ -45,7 +47,11 @@ def stream_response(self, user_input: str, chain_args: dict = {}) -> Generator[s yield chunk # After streaming is complete, add the full response to chat history - self.chat_history.append(AIMessage(content=full_response)) + self.chat_history.append( + AIMessage( + content=full_response, response_metadata={"context": metadata} if metadata else {} + ) + ) logger.debug(f'AI response:\n"{full_response}"') @@ -121,6 +127,13 @@ def stream_response(self, user_input: str) -> Generator[str, None, None]: relevant_references = "\n".join( [f"From {doc.metadata[RAG_DOC_ID]}:\n{doc.page_content}" for doc in relevant_docs] ) + relevant_metadata = [ + { + "Document Title": doc.metadata.get(RAG_DOC_ID, "N/A"), + "Content": doc.page_content, + } + for doc in relevant_docs + ] # Log the context documents logger.debug(f"Context: {len(relevant_docs)} documents returned.") @@ -133,7 +146,9 @@ def stream_response(self, user_input: str) -> Generator[str, None, None]: ) return super().stream_response( - user_input, {"paper_text": self.paper_text, "relevant_references": relevant_references} + user_input, + {"paper_text": self.paper_text, "relevant_references": relevant_references}, + relevant_metadata, ) @@ -189,6 +204,13 @@ def stream_response(self, user_input: str) -> Generator[str, None, None]: relevant_code = "\n".join( [f"From {doc.metadata[RAG_DOC_ID]}:\n{doc.page_content}" for doc in relevant_docs] ) + relevant_metadata = [ + { + "Document Title": doc.metadata.get(RAG_DOC_ID, "N/A"), + "Content": doc.page_content, + } + for doc in relevant_docs + ] # Log the context documents logger.debug(f"Context: {len(relevant_docs)} documents returned.") @@ -201,7 +223,9 @@ def stream_response(self, user_input: str) -> Generator[str, None, None]: ) return super().stream_response( - user_input, {"paper_text": self.paper_text, "relevant_code": relevant_code} + user_input, + {"paper_text": self.paper_text, "relevant_code": relevant_code}, + relevant_metadata, ) @@ -264,14 +288,22 @@ def stream_response(self, user_input: str) -> Generator[str, None, None]: relevant_references = "\n".join( [f"From {doc.metadata[RAG_DOC_ID]}:\n{doc.page_content}" for doc in relevant_docs] ) + relevant_metadata = [ + { + "Document Title": doc.metadata.get(RAG_DOC_ID, "N/A"), + "Content": doc.page_content, + } + for doc in relevant_docs + ] # Log the context documents logger.debug(f"Context: {len(relevant_docs)} documents returned.") - for i, doc in enumerate(relevant_docs, start=1): - contents = doc.page_content[:200].replace("\n", " ") + for i, doc in enumerate(relevant_metadata, start=1): + contents = doc["Content"][:200].replace("\n", " ") logger.debug( - f"""Context Document {i}:\nDocument Title: {doc.metadata.get(RAG_DOC_ID, "N/A")} - Page Content: {contents}... + f"""Context Document {i}: + Document Title: {doc["Document Title"]} + Page Content: {contents} """ ) @@ -281,4 +313,5 @@ def stream_response(self, user_input: str) -> Generator[str, None, None]: "paper_text": self.paper_text, "relevant_references": relevant_references, }, + relevant_metadata, ) diff --git a/src/paper_query/constants/__init__.py b/src/paper_query/constants/__init__.py index 02ab920..7f710b0 100644 --- a/src/paper_query/constants/__init__.py +++ b/src/paper_query/constants/__init__.py @@ -1,5 +1,13 @@ from ._api_keys import GROQ_API_KEY, HUGGINGFACE_API_KEY, OPENAI_API_KEY -from ._paths import PERSIST_DIRECTORY, assets_dir, data_dir, project_dir, src_dir, test_dir +from ._paths import ( + METRICS_JSON, + PERSIST_DIRECTORY, + assets_dir, + data_dir, + project_dir, + src_dir, + test_dir, +) from ._strings import RAG_DOC_ID, STREAMLIT_CHEAP_MODEL, STREAMLIT_EXPENSIVE_MODEL __all__ = [ @@ -7,6 +15,7 @@ "HUGGINGFACE_API_KEY", "GROQ_API_KEY", "PERSIST_DIRECTORY", + "METRICS_JSON", "project_dir", "src_dir", "test_dir", diff --git a/src/paper_query/constants/_paths.py b/src/paper_query/constants/_paths.py index fce8790..17aa888 100644 --- a/src/paper_query/constants/_paths.py +++ b/src/paper_query/constants/_paths.py @@ -8,3 +8,4 @@ assets_dir: Path = project_dir / "assets" PERSIST_DIRECTORY: str = str(project_dir / "vectorstore") +METRICS_JSON: str = str(project_dir / "evaluation" / "rag_evaluation_results.json") diff --git a/src/paper_query/data/loaders.py b/src/paper_query/data/loaders.py index c57e19a..8f3b3dc 100644 --- a/src/paper_query/data/loaders.py +++ b/src/paper_query/data/loaders.py @@ -10,13 +10,24 @@ from paper_query.llm import setup_model -def pypdf_loader(file_path: str) -> Document: +def pypdf_loader(file_path: str, interpret_images: bool = False, **image_kwargs) -> Document: + """Function to load a PDF file, optionally interpreting images.""" + if interpret_images and "model" not in image_kwargs: + raise ValueError("When interpret_images is True, 'model' must be provided in image_kwargs.") + + if interpret_images: + return _pypdf_loader_w_images(file_path, **image_kwargs) + else: + return _pypdf_loader(file_path) + + +def _pypdf_loader(file_path: str) -> Document: """Function to load text from a PDF file.""" logger.debug("Loading PDF file using PyPDFLoader") return PyPDFLoader(file_path, mode="single").load()[0] -def pypdf_loader_w_images( +def _pypdf_loader_w_images( file_path: str, model: str, provider: str, max_tokens: int = 1024 ) -> Document: """Function to load text from a PDF file with images.""" diff --git a/src/paper_query/llm/models.py b/src/paper_query/llm/models.py index f183ea1..b0610ba 100644 --- a/src/paper_query/llm/models.py +++ b/src/paper_query/llm/models.py @@ -1,10 +1,11 @@ import os from langchain.chat_models import init_chat_model +from langchain_core.language_models.chat_models import BaseChatModel from loguru import logger -def setup_model(model_name: str, model_provider: str, **kwargs): +def setup_model(model_name: str, model_provider: str, **kwargs) -> BaseChatModel: """Initialize the chat model.""" logger.info(f"Initializing {model_name} model from {model_provider}") if model_provider == "openai": diff --git a/src/paper_query/ui/strain_relief_app.py b/src/paper_query/ui/strain_relief_app.py index c4eead8..2cbe89d 100644 --- a/src/paper_query/ui/strain_relief_app.py +++ b/src/paper_query/ui/strain_relief_app.py @@ -30,7 +30,7 @@ def strain_relief_chatbot(): """Chatbot for the StrainRelief paper.""" initialize_session_state() - st.title("The StrainRelief Chatbot") + st.title("StrainReliefChat") chat_tab, about_tab = st.tabs(["Chat", "About"]) st.sidebar.title("API Configuration") @@ -46,12 +46,14 @@ def strain_relief_chatbot(): # Display current model st.sidebar.markdown(f"Using **{st.session_state.model_name}** model.") - st.session_state.chatbot = HybridQueryChatbot( - model_name=st.session_state.model_name.lower(), - model_provider="openai", - paper_path=str(assets_dir / "strainrelief_preprint.pdf"), - references_dir=str(assets_dir / "references"), - ) + # Only instantiate chatbot once and store in session state + if st.session_state.chatbot is None: + st.session_state.chatbot = HybridQueryChatbot( + model_name=st.session_state.model_name.lower(), + model_provider="openai", + paper_path=str(assets_dir / "strainrelief_preprint.pdf"), + references_dir=str(assets_dir / "references"), + ) with chat_tab: if "messages" not in st.session_state: diff --git a/test/data/test_loaders.py b/test/data/test_loaders.py index 32bbef6..f3c3f73 100644 --- a/test/data/test_loaders.py +++ b/test/data/test_loaders.py @@ -5,7 +5,6 @@ from paper_query.data.loaders import ( code_loader, pypdf_loader, - pypdf_loader_w_images, references_loader, ) @@ -13,7 +12,7 @@ def test_pypdf_loader(test_assets_dir): """Test the pypdf_loader function.""" path = test_assets_dir / "example_pdf.pdf" - doc = pypdf_loader(path) + doc = pypdf_loader(path, interpret_images=False) assert isinstance(doc, Document) @@ -22,7 +21,7 @@ def test_pypdf_loader_w_images(test_assets_dir): """Test the pypdf_loader_w_images function.""" path = test_assets_dir / "example_pdf.pdf" # TODO: change to free model - doc = pypdf_loader_w_images(path, "gpt-4.1-nano", "openai") + doc = pypdf_loader(path, interpret_images=True, model="gpt-4.1-nano", provider="openai") assert isinstance(doc, Document)