diff --git a/src/ai.py b/src/ai.py index 5b3e0e5..f98bf59 100644 --- a/src/ai.py +++ b/src/ai.py @@ -55,7 +55,6 @@ def __init__( self.SummaryCompleted = False self.SetTemplates(PromptTemplate, SummaryPromptTemplate) - self.ManageOllama() def SetTemplates(self, PromptTemplate, SummaryPromptTemplate): """ @@ -101,21 +100,35 @@ def SetRepoPath(self, Path): """ self.RepoPath = Path - @unittest.skip("Not needed for test.") - def ManageOllama(self): - """ - Manage Ollama server and model availability. + def CheckIfModelAvailability(self): + """Check if the specified model exists and pull if necessary.""" + try: + Result = subprocess.run( + ["ollama", "list"], + shell=True, + capture_output=True, + text=True, + encoding="utf-8" + ) # nosec - Ensures the Ollama server is running and the required model - is available. If the server is not running, it attempts to start it. - If the model is not available, it downloads the specified model. - """ - OllamaPath = shutil.which("ollama") - if not OllamaPath: - print("Ollama executable not found. Please install Ollama.") + if self.ModelName not in Result.stdout: + print(f"Model '{self.ModelName}' not found. Downloading...") + subprocess.run( + ["ollama", "pull", self.ModelName], + shell=True, + capture_output=True, + text=True, + encoding="utf-8" + ) # nosec + print(f"Model '{self.ModelName}' downloaded successfully.") + else: + print(f"Model '{self.ModelName}' already exists.") + except Exception as E: + print(f"Failed to check/download model '{self.ModelName}': {E}") exit(1) - # Check if Ollama server is running + def CheckModelStatus(self): + """Check if Ollama server is running.""" try: Response = requests.get("http://localhost:11434/health", timeout=5) if Response.status_code == 200: @@ -126,36 +139,37 @@ def ManageOllama(self): print("Ollama server not running. Attempting to start...") try: subprocess.Popen( - [OllamaPath, "serve"], + ["ollama", "stop"], stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL + stderr=subprocess.DEVNULL, + shell=True, + text=True, + encoding="utf-8" + ) # nosec + subprocess.Popen( + ["ollama", "serve"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + shell=True, + text=True, + encoding="utf-8" ) # nosec print("Ollama server started successfully.") except Exception as E: print(f"Failed to start Ollama server: {E}") exit(1) - # Check if the specified model exists and pull if necessary - try: - Result = subprocess.run( - [OllamaPath, "list"], - capture_output=True, - text=True - ) # nosec + @unittest.skip("Not needed for test.") + def ManageOllama(self): + """ + Manage Ollama server and model availability. - if self.ModelName not in Result.stdout: - print(f"Model '{self.ModelName}' not found. Downloading...") - subprocess.run( - [OllamaPath, "pull", self.ModelName], - capture_output=True, - text=True - ) # nosec - print(f"Model '{self.ModelName}' downloaded successfully.") - else: - print(f"Model '{self.ModelName}' already exists.") - except Exception as E: - print(f"Failed to check/download model '{self.ModelName}': {E}") - exit(1) + Ensures the Ollama server is running and the required model + is available. If the server is not running, it attempts to start it. + If the model is not available, it downloads the specified model. + """ + self.CheckIfModelAvailability() + self.CheckModelStatus() def LoadDocuments(self): """ @@ -200,9 +214,16 @@ def CreateVectorStore(self, Docs): Splits = TextSplitter.create_documents( [Doc["Content"] for Doc in Docs] ) + + # Falls Chroma beschädigt ist, löschen und neu erstellen + ChromaPath = "chromadb_store" + if os.path.exists(ChromaPath): + shutil.rmtree(ChromaPath) # Löscht die bestehende Datenbank + self.VectorStore = Chroma.from_documents( Splits, embedding=self.Embeddings, + persist_directory=ChromaPath # Persistenz aktivieren ) print("Vector store created successfully.") @@ -259,23 +280,30 @@ def AskQuestion(self, Query): "Please analyze the repository first." ) - if not self.VectorStore: + if self.VectorStore is None: return ( "No vector store available. " "Please analyze a repository first." ) - # Retrieve relevant documents using similarity search - RelevantDocs = self.VectorStore.similarity_search(Query, k=5) - Context = "\n\n".join(Doc.page_content for Doc in RelevantDocs) + try: + # Retrieve relevant documents using similarity search + RelevantDocs = self.VectorStore.similarity_search(Query, k=5) + if not RelevantDocs: + return "No relevant documents found for the query." - # Create prompt based on retrieved context - Prompt = self.PromptTemplate.format(Context=Context, Question=Query) - Response = self.Assistant.invoke(Prompt) + # Create prompt based on retrieved context + Context = "\n\n".join(Doc.page_content for Doc in RelevantDocs) + Prompt = self.PromptTemplate.format(Context=Context, + Question=Query) + Response = self.Assistant.invoke(Prompt) - if isinstance(Response, AIMessage): - return Response.content - return str(Response) + if isinstance(Response, AIMessage): + return Response.content + return str(Response) + + except Exception as E: + return f"Error during query processing: {E}" def WriteSummary(self, Content): """ diff --git a/src/main.py b/src/main.py index 86b7d19..98821cf 100644 --- a/src/main.py +++ b/src/main.py @@ -1,5 +1,6 @@ """This file runs the AI model and lets you interact with it.""" +import streamlit as st from ai import AIAssistant from ui import StreamlitUI @@ -15,19 +16,26 @@ def __init__( summary_prompt_template ): """Initialize the application.""" - self.ai_assistant = AIAssistant( - model_name, - creativity, - prompt_template, - summary_prompt_template - ) - self.ui = StreamlitUI(self.ai_assistant) - - def run(self): + if "AI_Assistant" not in st.session_state: + st.session_state.AI_Assistant = AIAssistant( + model_name, + creativity, + prompt_template, + summary_prompt_template + ) + st.session_state.AI_Assistant.ManageOllama() + print("Starting AI...") + + self.ui = StreamlitUI() + + def Run(self): """Run the application.""" self.ui.Run() + print("Starting UI...") if __name__ == "__main__": - app = MainApp("llama3.1:8b", 1.0, "", "") - app.run() + if "MainApp" not in st.session_state: + st.session_state.MainApp = MainApp("llama3.1:8b", 1.0, "", "") + + st.session_state.MainApp.Run() diff --git a/src/ui.py b/src/ui.py index 90bc8c9..dbe4624 100644 --- a/src/ui.py +++ b/src/ui.py @@ -1,20 +1,21 @@ """This file provides the UI for the AI model.""" -from ai import AIAssistant import streamlit as st class StreamlitUI: """Class for managing the Streamlit user interface.""" - def __init__(self, AIAssistant: AIAssistant): + def __init__(self): """Initialize Process.""" - self.AIAssistant = AIAssistant - self.ChatHistory = "" - self.RepoPath = "" + if "ChatHistory" not in st.session_state: + st.session_state.ChatHistory = "" + if "RepoPath" not in st.session_state: + st.session_state.RepoPath = "" def Run(self): """Run the Streamlit UI.""" + st.set_page_config("AI Assistant") st.title("AI Assistant") st.write( """Welcome to the AI Repo Summarizer!\n @@ -22,22 +23,25 @@ def Run(self): ) # Path Input - self.RepoPath = st.text_input("Set Repository Path:", self.RepoPath) + st.session_state.RepoPath = st.text_input("Set Repository Path:", + st.session_state.RepoPath) if st.button("Set Path"): - if self.RepoPath: - st.write(f"Repository path set to: {self.RepoPath}") - self.AIAssistant.SetRepoPath(self.ChatHistory) + if st.session_state.RepoPath: + st.write(f"""Repository path set to: + {st.session_state.RepoPath}""") + st.session_state.AI_Assistant.SetRepoPath( + st.session_state.RepoPath) st.write("Analyzing repository...") - result = self.AIAssistant.AnalyzeRepository() + result = st.session_state.AI_Assistant.AnalyzeRepository() st.write(result) # User Input UserInput = st.text_input("Your question:") if st.button("Send Question"): if UserInput: - Response = self.AIAssistant.AskQuestion(UserInput) - self.ChatHistory += f"User: {UserInput}\nAI: {Response}\n\n" - st.write(Response) + Response = st.session_state.AI_Assistant.AskQuestion(UserInput) + st.session_state.ChatHistory += f"""User: + {UserInput}\nAI: {Response}\n\n""" # Show Conversation History - st.text_area("Chat History", self.ChatHistory, height=300) + st.text_area("Chat History", st.session_state.ChatHistory, height=300) diff --git a/tests/test_ui.py b/tests/test_ui.py index 54b747c..33770a5 100644 --- a/tests/test_ui.py +++ b/tests/test_ui.py @@ -1,7 +1,6 @@ """This file runs the tests for the UI.""" import pytest -from src.ai import AIAssistant from src.ui import StreamlitUI @@ -11,34 +10,17 @@ class TestUI: """Value Section.""" @pytest.mark.parametrize( - """ModelName, Creativity, Prompt, SummaryPrompt, FileTypes, - expected_AIAssistant, expected_UI""", + """ expected_UI""", [ ( - "llama3.1:8b", - 1, - "This is a test prompt.", - "This is a summary test prompt.", - [".py", ".js", ".java", ".md", ".txt"], - True, - True, + True ), ], ) def test_UI( self, - ModelName, - Creativity, - Prompt, - SummaryPrompt, - FileTypes, - expected_AIAssistant, expected_UI, ): """Initialize the UI.""" - self.AIAssistant = AIAssistant( - ModelName, Creativity, Prompt, SummaryPrompt, FileTypes - ) - self.UI = StreamlitUI(self.AIAssistant) - assert (self.AIAssistant is not None) == expected_AIAssistant + self.UI = StreamlitUI() assert (self.UI is not None) == expected_UI