From 02ed7ea4086e5afcd0dfcbd28dbaabc5dc3fbc29 Mon Sep 17 00:00:00 2001 From: Mehdi Abdi Date: Thu, 20 Feb 2025 03:12:47 +0330 Subject: [PATCH] feat: added llama2:7b --- Dockerfile | 8 ++- README.md | 13 +++- app/agents/educator/agent.py | 83 ++++++++++++++++++++------ app/agents/educator/feed.py | 7 ++- app/config.py | 4 ++ app/routers/api.py | 2 +- app/services/classification_service.py | 15 +---- 7 files changed, 96 insertions(+), 36 deletions(-) diff --git a/Dockerfile b/Dockerfile index 0eaf989..d47a716 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,13 @@ ENV PYTHONUNBUFFERED=1 # Install system dependencies RUN apt-get update \ - && apt-get install -y --no-install-recommends gcc \ + && apt-get install -y --no-install-recommends \ + gcc \ + make \ + cmake \ + build-essential \ + libmupdf-dev \ + tesseract-ocr \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* diff --git a/README.md b/README.md index 5357942..a5e8522 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,19 @@ ## install dev dependencies -`pip install -r requirements-dev.txt` -`pre-commit install` +```shell + pip install -r requirements-dev.txt +``` + +```shell + pre-commit install +``` ## run project -`docker compose up [--build]` +```shell + docker compose up [--build] +``` ## Optional tools Install [Rest Client](https://marketplace.visualstudio.com/items?itemName=humao.rest-client) extension to be able to send requests in the `requests.http` file. diff --git a/app/agents/educator/agent.py b/app/agents/educator/agent.py index 1352ec8..ba71024 100644 --- a/app/agents/educator/agent.py +++ b/app/agents/educator/agent.py @@ -1,24 +1,78 @@ # app/agents/educator_agent.py +import json + +import requests from langchain.agents import AgentType, initialize_agent from langchain.chains import RetrievalQA +from langchain.llms.base import LLM from langchain.tools import Tool from langchain_community.utilities import DuckDuckGoSearchAPIWrapper from langchain_community.vectorstores import FAISS -from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from langchain_openai import OpenAIEmbeddings from app.config import Config +class OllamaLLM(LLM): + """ + A custom LangChain LLM wrapper for the Ollama API. + """ + + model_name: str = "llama2:7b" + temperature: float = 0.0 + max_tokens: int = 256 + + def __init__( + self, + model_name: str = None, + temperature: float = 0.0, + max_tokens: int = 256, + **kwargs, + ): + super().__init__(**kwargs) # Initialize Pydantic fields + if model_name is not None: + self.model_name = model_name + self.temperature = temperature + self.max_tokens = max_tokens + + @property + def _llm_type(self) -> str: + return "ollama" + + def _call(self, prompt: str, stop: list[str] = None) -> str: + payload = { + "model": self.model_name, + "prompt": prompt, + "parameters": { + "temperature": self.temperature, + "max_tokens": self.max_tokens, + }, + } + url = "http://localhost:11434/api/generate" + response = requests.post(url, json=payload) + response.raise_for_status() + + try: + result = response.json() + except json.decoder.JSONDecodeError: + # Fallback: Assume the first line is valid JSON. + text = response.text.strip() + first_line = text.splitlines()[0] + result = json.loads(first_line) + + # Adjust response parsing based on your Ollama API's output format. + if "message" in result and "content" in result["message"]: + return result["message"]["content"] + return result.get("output", "") + + class BlockchainEducatorAgent: def __init__(self): - # initialize qa tool - self.llm = ChatOpenAI( - openai_api_key=Config.OPENAI_API_KEY, - openai_api_base=Config.OPENAI_BASE_URL, - model="gpt-4", - temperature=0, - ) + # Replace ChatOpenAI with our OllamaLLM instance + self.llm = OllamaLLM(model_name="llama2:7b", temperature=0, max_tokens=256) + + # Load the vector store as before self.vector_store = FAISS.load_local( "faiss_index", OpenAIEmbeddings(), allow_dangerous_deserialization=True ) @@ -32,7 +86,7 @@ def __init__(self): description="Useful for answering questions about blockchain concepts", ) - # Initialize search tool + # Initialize the web search tool using DuckDuckGo search = DuckDuckGoSearchAPIWrapper() self.search_tool = Tool( name="Web Search", @@ -40,9 +94,8 @@ def __init__(self): description="Useful for searching the internet for recent or additional information", ) - # Initialize agent with both QA and search capabilities + # Combine both tools into the agent tools = [self.qa_tool, self.search_tool] - self.agent = initialize_agent( tools=tools, llm=self.llm, @@ -52,19 +105,15 @@ def __init__(self): def handle_query(self, query: str, chat_history: list = None) -> str: try: - # Format chat history into context if available + # Format chat history as context if available context = "" if chat_history and len(chat_history) > 0: context = "Previous conversation:\n" - for interaction in chat_history[ - -3: - ]: # Use last 3 interactions for context + for interaction in chat_history[-3:]: context += f"Human: {interaction['query']}\nAssistant: {interaction['response']}\n" context += "\nCurrent question: " - # Combine context and query full_query = f"{context}{query}" if context else query - response = self.agent.invoke(full_query) return response.get("output") except Exception as e: diff --git a/app/agents/educator/feed.py b/app/agents/educator/feed.py index b82de61..8a38b7b 100644 --- a/app/agents/educator/feed.py +++ b/app/agents/educator/feed.py @@ -3,9 +3,9 @@ import fitz from dotenv import load_dotenv +from langchain.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS -from langchain_openai import OpenAIEmbeddings load_dotenv() @@ -40,7 +40,10 @@ def load_and_chunk_pdfs(pdf_files): documents = load_and_chunk_pdfs(pdf_files) print(f"Loaded {len(documents)} documents") -embedding_model = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY")) +# Use a local Hugging Face embedding model for generating embeddings +embedding_model = HuggingFaceEmbeddings( + model_name="sentence-transformers/all-MiniLM-L6-v2" +) print("Embedding documents...") vector_db = FAISS.from_texts(documents, embedding_model) print("Vector database created") diff --git a/app/config.py b/app/config.py index f976906..da2a22c 100644 --- a/app/config.py +++ b/app/config.py @@ -1,5 +1,9 @@ import os +from dotenv import load_dotenv + +load_dotenv() + class Config: # OpenAI configuration diff --git a/app/routers/api.py b/app/routers/api.py index f724c47..0f3c734 100644 --- a/app/routers/api.py +++ b/app/routers/api.py @@ -25,7 +25,7 @@ async def process_query(query: Query, request: Request): classification = classifier.classify_query(query.query, chat_history).lower() agent = agents.get(classification) if not agent: - raise HTTPException(status_code=400, detail="Could not classify the query") + raise HTTPException(status_code=200, detail="Could not classify the query") # Pass chat history to agent response = agent.handle_query(query.query, chat_history) diff --git a/app/services/classification_service.py b/app/services/classification_service.py index a4269e5..56d4e6c 100644 --- a/app/services/classification_service.py +++ b/app/services/classification_service.py @@ -2,19 +2,13 @@ from typing import Dict, List -from langchain_openai import ChatOpenAI - +from app.agents.educator.agent import OllamaLLM from app.config import Config class ClassificationService: def __init__(self): - self.llm = ChatOpenAI( - openai_api_key=Config.OPENAI_API_KEY, - openai_api_base=Config.OPENAI_BASE_URL, - model="gpt-4", - temperature=0, - ) + self.llm = OllamaLLM(model_name="llama2:7b", temperature=0, max_tokens=256) def _format_chat_context(self, chat_history: List[Dict]) -> str: """Format chat history into a context string for classification.""" @@ -39,8 +33,5 @@ def classify_query(self, query: str, chat_history: List[Dict] = None) -> str: f"Query: {query}\nCategory:" f"Just return the category, no other text." ) - response = self.llm.invoke( - input=prompt, - ) - classification = response.content + classification = self.llm(prompt) return classification