diff --git a/.env b/.env index d945293..35bc98c 100644 --- a/.env +++ b/.env @@ -1,4 +1,6 @@ DEFAULT_MODEL="qwen3:0.6b" +DEFAULT_IMAGE_MODEL="qwen2.5vl:3b" + OLLAMA_URL="http://ollama:11434" DB_PATH="/data/enmemoryalpha_db" COLLECTION_NAME="memoryalpha" \ No newline at end of file diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 4dd3aec..f06e5e5 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -52,22 +52,6 @@ jobs: exit 1 fi - - name: Test streaming endpoint - run: | - # Test the streaming endpoint - response=$(timeout 30 curl -s -N -H "Accept: text/event-stream" \ - "http://localhost:8000/memoryalpha/rag/stream?question=What%20is%20a%20Transporter?&thinkingmode=DISABLED&max_tokens=50&top_k=3" \ - | head -10) - - # Check if streaming response contains data events - if echo "$response" | grep -q "data:"; then - echo "✅ Streaming endpoint test passed" - else - echo "❌ Streaming endpoint test failed - no streaming data found" - echo "Response: $response" - exit 1 - fi - - name: Generate OpenAPI spec run: | # Download OpenAPI spec diff --git a/Dockerfile b/Dockerfile index 67f4ed3..f9353b9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,14 +3,16 @@ FROM python:3.12-slim-bullseye AS devcontainer ARG DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ - apt-get -y install curl wget git + apt-get -y install jq curl wget git COPY ./requirements.txt /tmp/pip-tmp/ RUN pip install --no-cache-dir -r /tmp/pip-tmp/requirements.txt \ && rm -rf /tmp/pip-tmp WORKDIR /data -RUN wget https://github.com/aniongithub/memoryalpha-vectordb/releases/latest/download/enmemoryalpha_db.tar.gz &&\ + +ARG MEMORYALPHA_DB_RELEASE=v0.5.0 +RUN wget https://github.com/aniongithub/memoryalpha-vectordb/releases/download/${MEMORYALPHA_DB_RELEASE}/enmemoryalpha_db.tar.gz &&\ tar -xzf enmemoryalpha_db.tar.gz &&\ rm enmemoryalpha_db.tar.gz &&\ chmod -R 0777 /data diff --git a/README.md b/README.md index deddcd5..59b1793 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ A REST API for Retrieval-Augmented Generation (RAG) over Star Trek's MemoryAlpha ## Overview -This project provides a streaming REST API that enables natural language queries over the comprehensive Star Trek MemoryAlpha database. It uses the vectorized database from [memoryalpha-vectordb](https://github.com/aniongithub/memoryalpha-vectordb) and combines it with local LLMs via Ollama to provide accurate, context-aware responses about Star Trek lore. +This project provides a REST API that enables natural language queries over the comprehensive Star Trek MemoryAlpha database. It uses the vectorized database from [memoryalpha-vectordb](https://github.com/aniongithub/memoryalpha-vectordb) and combines it with local LLMs via Ollama to provide accurate, context-aware responses about Star Trek lore. The system implements: @@ -169,7 +169,7 @@ graph TD ### Components -- **FastAPI**: REST API framework with streaming support +- **FastAPI**: REST API framework and OpenAPI spec generation - **ChromaDB**: Vector database for document storage and retrieval - **Ollama**: Local LLM inference server - **Cross-Encoder**: Document reranking for improved relevance diff --git a/api/main.py b/api/main.py index 54695ba..2b1c7b5 100644 --- a/api/main.py +++ b/api/main.py @@ -2,8 +2,8 @@ from contextlib import asynccontextmanager from fastapi import FastAPI from .memoryalpha.health import router as health_router -from .memoryalpha.stream import router as stream_router from .memoryalpha.ask import router as ask_router +from .memoryalpha.identify import router as identify_router # Configure logging logging.basicConfig(level=logging.INFO) @@ -21,5 +21,5 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) app.include_router(health_router) -app.include_router(stream_router) -app.include_router(ask_router) \ No newline at end of file +app.include_router(ask_router) +app.include_router(identify_router) \ No newline at end of file diff --git a/api/memoryalpha/identify.py b/api/memoryalpha/identify.py new file mode 100644 index 0000000..ee8368b --- /dev/null +++ b/api/memoryalpha/identify.py @@ -0,0 +1,31 @@ +from fastapi import APIRouter, File, UploadFile, Query +from fastapi.responses import JSONResponse +import tempfile +import os +from .rag import MemoryAlphaRAG + +router = APIRouter() + +# Singleton or global instance for demo; in production, manage lifecycle properly +rag_instance = MemoryAlphaRAG() + +@router.post("/memoryalpha/rag/identify", summary="Multimodal Image Search") +def identify_endpoint( + file: UploadFile = File(...), + top_k: int = Query(5, description="Number of results to return") +): + """ + Accepts an image file upload, performs multimodal image search, and returns results. + """ + try: + # Save uploaded file to a temp location + with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[-1]) as tmp: + tmp.write(file.file.read()) + image_path = tmp.name + # Perform image search + results = rag_instance.search_image(image_path, top_k=top_k) + # Clean up temp file + os.remove(image_path) + return JSONResponse(content=results) + except Exception as e: + return JSONResponse(status_code=500, content={"error": str(e)}) diff --git a/api/memoryalpha/rag.py b/api/memoryalpha/rag.py index b8a8396..46e5315 100644 --- a/api/memoryalpha/rag.py +++ b/api/memoryalpha/rag.py @@ -2,9 +2,7 @@ import os import sys -import json import re -import requests import logging import warnings import numpy as np @@ -14,12 +12,6 @@ from sentence_transformers import CrossEncoder, SentenceTransformer import ollama -# Optional prompt UI -from prompt_toolkit import prompt -from prompt_toolkit.history import InMemoryHistory -from prompt_toolkit.auto_suggest import AutoSuggestFromHistory -from prompt_toolkit.completion import WordCompleter - # RAG components import pysqlite3 sys.modules["sqlite3"] = pysqlite3 @@ -63,11 +55,9 @@ class MemoryAlphaRAG: def __init__(self, chroma_db_path: str = os.getenv("DB_PATH"), ollama_url: str = os.getenv("OLLAMA_URL"), - model: str = os.getenv("DEFAULT_MODEL"), collection_name: str = os.getenv("COLLECTION_NAME", "memoryalpha"), rerank_method: str = "cross-encoder", thinking_mode: ThinkingMode = ThinkingMode.DISABLED, - enable_streaming: bool = True, max_history_turns: int = 5, thinking_text: str = "Processing..."): @@ -75,17 +65,13 @@ def __init__(self, raise ValueError("chroma_db_path must be provided or set in CHROMA_DB_PATH environment variable.") if not ollama_url: raise ValueError("ollama_url must be provided or set in OLLAMA_URL environment variable.") - if not model: - raise ValueError("model must be provided or set in DEFAULT_MODEL environment variable.") if not collection_name: raise ValueError("collection_name must be provided or set in COLLECTION_NAME environment variable.") self.chroma_db_path = chroma_db_path self.ollama_url = ollama_url - self.model = model self.collection_name = collection_name self.thinking_mode = thinking_mode - self.enable_streaming = enable_streaming self.max_history_turns = max_history_turns self.rerank_method = rerank_method self.thinking_text = thinking_text @@ -115,64 +101,59 @@ def __init__(self, path=self.chroma_db_path, settings=Settings(allow_reset=False) ) - - # Initialize CLIP model for consistent embeddings with the database - logger.info("Loading CLIP model for embedding compatibility...") + + # Initialize text embedding model and collection + logger.info("Loading text embedding model all-MiniLM-L6-v2 for text collection...") + self.text_model = SentenceTransformer('all-MiniLM-L6-v2') + logger.info("Text model loaded successfully") + from chromadb.utils import embedding_functions + class TextEmbeddingFunction(embedding_functions.EmbeddingFunction): + def __init__(self, text_model): + self.text_model = text_model + def __call__(self, input): + embeddings = [] + for text in input: + embedding = self.text_model.encode(text) + embeddings.append(embedding.tolist()) + return embeddings + self.text_ef = TextEmbeddingFunction(self.text_model) + self.text_collection = self.client.get_or_create_collection("memoryalpha_text", embedding_function=self.text_ef) + + # Initialize CLIP model and image collection + logger.info("Loading CLIP model for image collection...") self.clip_model = SentenceTransformer('clip-ViT-B-32') logger.info("CLIP model loaded successfully") - - # Create CLIP embedding function to match the one used during data creation - from chromadb.utils import embedding_functions - class CLIPEmbeddingFunction(embedding_functions.EmbeddingFunction): def __init__(self, clip_model): self.clip_model = clip_model - def __call__(self, input): - """Generate embeddings using CLIP model""" embeddings = [] - for text in input: - embedding = self.clip_model.encode(text) + for img in input: + embedding = self.clip_model.encode(img) embeddings.append(embedding.tolist()) return embeddings - self.clip_ef = CLIPEmbeddingFunction(self.clip_model) - self.collection = self.client.get_collection(self.collection_name, embedding_function=self.clip_ef) - + self.image_collection = self.client.get_or_create_collection("memoryalpha_images", embedding_function=self.clip_ef) + # Initialize Ollama client self.ollama_client = ollama.Client(host=self.ollama_url) - self._warm_up_model() - - def _warm_up_model(self): - try: - self.ollama_client.generate( - model=self.model, - prompt="System ready.", - stream=False, - keep_alive=-1 - ) - except Exception as e: - logger.warning(f"Model warm-up failed: {e}") - def _cosine_similarity(self, query_embedding: np.ndarray, doc_embeddings: np.ndarray) -> np.ndarray: query_norm = query_embedding / np.linalg.norm(query_embedding) doc_norms = doc_embeddings / np.linalg.norm(doc_embeddings, axis=1, keepdims=True) return np.dot(doc_norms, query_norm) def search(self, query: str, top_k: int = 10) -> List[Dict[str, Any]]: - # Search only text documents (filter out image documents for now) - results = self.collection.query( - query_texts=[query], - n_results=top_k, - where={"content_type": "text"} # Only search text documents + # Search only text documents using the text collection and text embedding model + results = self.text_collection.query( + query_texts=[query], + n_results=top_k ) docs = [ { "content": doc, "title": meta["title"], - "distance": dist, - "content_type": meta.get("content_type", "text") + "distance": dist } for doc, meta, dist in zip(results["documents"][0], results["metadatas"][0], results["distances"][0]) ] @@ -196,49 +177,44 @@ def search(self, query: str, top_k: int = 10) -> List[Dict[str, Any]]: def build_prompt(self, query: str, docs: List[Dict[str, Any]]) -> tuple[str, str]: system_prompt = get_system_prompt(self.thinking_mode) - char_limit = 800 context_text = "\n\n".join( - f"=== {doc['title']} ===\n{doc['content'][:char_limit]}" for doc in docs + f"=== {doc['title']} ===\n{doc['content']}" for doc in docs ) user_prompt = get_user_prompt(context_text, query) return system_prompt, user_prompt - def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float = 0.8, temperature: float = 0.3) -> str: + def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float = 0.8, temperature: float = 0.3, + model: str = os.getenv("DEFAULT_MODEL")) -> str: + """ + Ask a question using the specified model (defaults to $DEFAULT_MODEL if not provided). + """ + + if not model: + raise ValueError("model must be provided or set in DEFAULT_MODEL environment variable.") + docs = self.search(query, top_k=top_k) system_prompt, user_prompt = self.build_prompt(query, docs) - + # Build messages for chat messages = [ {"role": "system", "content": system_prompt} ] - + # Add conversation history for exchange in self.conversation_history[-3:]: # Last 3 exchanges messages.append({"role": "user", "content": exchange["question"]}) messages.append({"role": "assistant", "content": exchange["answer"]}) - + # Add current query messages.append({"role": "user", "content": user_prompt}) - full_response = "" - - if self.enable_streaming: - for chunk in self.ollama_client.chat( - model=self.model, - messages=messages, - stream=True, - options={"temperature": temperature, "top_p": top_p, "num_predict": max_tokens} - ): - if 'message' in chunk and 'content' in chunk['message']: - full_response += chunk['message']['content'] - else: - result = self.ollama_client.chat( - model=self.model, - messages=messages, - stream=False, - options={"temperature": temperature, "top_p": top_p, "num_predict": max_tokens} - ) - full_response = result['message']['content'] + result = self.ollama_client.chat( + model=model, + messages=messages, + stream=False, + options={"temperature": temperature, "top_p": top_p, "num_predict": max_tokens} + ) + full_response = result['message']['content'] # Handle thinking mode response processing if self.thinking_mode == ThinkingMode.DISABLED: @@ -269,4 +245,121 @@ def _replace_thinking_tags(self, answer: str) -> str: def _update_history(self, question: str, answer: str): self.conversation_history.append({"question": question, "answer": answer}) - self.conversation_history = self.conversation_history[-self.max_history_turns:] \ No newline at end of file + self.conversation_history = self.conversation_history[-self.max_history_turns:] + + def search_image(self, image_path: str, top_k: int = 5, + model: str = os.getenv("DEFAULT_IMAGE_MODEL")) -> Dict[str, Any]: + """ + 1. Generates CLIP embedding for the provided image + 2. Searches image records, retrieves top_k + 3. Downloads actual images for image results + 4. Uses source page titles to fetch text context from text collection + 5. Passes all info to the model to guess the theme and image + """ + from PIL import Image + import requests + import tempfile + import os + + if not model: + raise ValueError("model must be provided or set in DEFAULT_IMAGE_MODEL environment variable.") + + # 1. Generate CLIP embedding for the image + image = Image.open(image_path).convert('RGB') + image_embedding = self.clip_model.encode(image) + image_embedding = image_embedding.tolist() + + # 2. Search image records only + image_results = self.image_collection.query( + query_embeddings=[image_embedding], + n_results=top_k + ) + + # 3. Download actual images for image results and prepare for attachment + downloaded_images = [] + image_binaries = [] + image_docs = image_results['documents'][0] + image_metas = image_results['metadatas'][0] + image_urls = [meta.get('image_url') for meta in image_metas] + for idx, url in enumerate(image_urls): + if url: + try: + resp = requests.get(url, timeout=30) + if resp.status_code == 200: + with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp: + tmp.write(resp.content) + downloaded_images.append(tmp.name) + image_binaries.append(resp.content) + else: + downloaded_images.append(None) + image_binaries.append(None) + except Exception: + downloaded_images.append(None) + image_binaries.append(None) + else: + downloaded_images.append(None) + image_binaries.append(None) + + # 4. Use source page titles to fetch text context from text collection + source_titles = [meta.get('source_page') for meta in image_metas if meta.get('source_page')] + text_contexts = [] + if source_titles: + # Query text collection for each source page title + for title in source_titles: + text_results = self.text_collection.query( + query_texts=[title], + n_results=1 + ) + if text_results['documents'][0]: + doc = text_results['documents'][0][0] + meta = text_results['metadatas'][0][0] + dist = text_results['distances'][0][0] + text_contexts.append(f"Text Context for '{title}':\nTitle: {meta.get('title', 'Unknown')}\nSimilarity: {1-dist:.4f}\nContent: {doc[:300]}\n") + + # 5. Number and format results, reference images as Image 1, Image 2, etc. + formatted_images = [] + image_indices = [] + for i, (doc, meta, dist, img_path, img_bin) in enumerate(zip(image_docs, image_metas, image_results['distances'][0], downloaded_images, image_binaries), 1): + if img_bin: + formatted_images.append(f"Image {i}:\nImage Name: {meta.get('image_name', 'Unknown')}\nSource Page: {meta.get('source_page', 'Unknown')}\nSimilarity: {1-dist:.4f}\nDescription: {doc}\n(Refer to attached Image {i})\n") + image_indices.append(i-1) # index in image_binaries + else: + formatted_images.append(f"Image {i}:\nImage Name: {meta.get('image_name', 'Unknown')}\nSource Page: {meta.get('source_page', 'Unknown')}\nSimilarity: {1-dist:.4f}\nDescription: {doc}\nImage download failed.\n") + + # 6. Pass all info to the model, attach images + prompt = ( + "You are an expert Star Trek analyst. Your task is to identify the user-provided image (attached as Image 0) as specifically as possible. " + "Among the retrieved images and their metadata below, determine which image best matches Image 0. " + "Use the metadata (image name, source page, description, similarity score, and text context) of the closest match to identify the user image. " + "Do NOT mention the match number, just provide the identification. " + "If no close match is found, say so clearly.\n\n" + ) + prompt += "\n".join(formatted_images) + if text_contexts: + prompt += "\n".join(text_contexts) + prompt += "\nRespond with one or two lines identifying the user-provided image, based on the closest match and its metadata." + + messages = [ + {"role": "system", "content": "You are an expert Star Trek analyst."}, + {"role": "user", "content": prompt, "images": [image_binaries[i] for i in image_indices]} + ] + + # Only attach images that were successfully downloaded + response = self.ollama_client.chat( + model=model, + messages=messages, + stream=False + ) + answer = response['message']['content'] + + # Clean up temp images + for img_path in downloaded_images: + if img_path and os.path.exists(img_path): + try: + os.remove(img_path) + except Exception: + pass + + return { + "model_answer": answer + } \ No newline at end of file diff --git a/api/memoryalpha/stream.py b/api/memoryalpha/stream.py deleted file mode 100644 index 10339cc..0000000 --- a/api/memoryalpha/stream.py +++ /dev/null @@ -1,187 +0,0 @@ -from fastapi import APIRouter, Query -from fastapi.responses import StreamingResponse -import json -import re -import time -import logging - -from .rag import MemoryAlphaRAG, ThinkingMode - -router = APIRouter() -logger = logging.getLogger(__name__) - -# Lazy-loaded singleton -rag_instance = None - -def get_rag_instance(): - global rag_instance - if rag_instance is None: - logger.info("Initializing MemoryAlpha RAG instance...") - rag_instance = MemoryAlphaRAG() - logger.info("MemoryAlpha RAG instance initialized successfully") - return rag_instance - -@router.get("/memoryalpha/rag/stream") -def stream_endpoint( - question: str = Query(..., description="The user question"), - thinkingmode: str = Query("DISABLED", description="Thinking mode: DISABLED, QUIET, or VERBOSE"), - max_tokens: int = Query(2048, description="Maximum tokens to generate"), - top_k: int = Query(10, description="Number of documents to retrieve"), - top_p: float = Query(0.8, description="Sampling parameter"), - temperature: float = Query(0.3, description="Randomness/creativity of output") -): - """ - Query the RAG pipeline and return streaming response chunks. - """ - - def generate_stream(): - try: - start_time = time.time() - - # Get RAG instance (lazy-loaded) - rag = get_rag_instance() - - # Set the thinking mode for this request - rag.thinking_mode = ThinkingMode[thinkingmode.upper()] - - # Phase 1: Document retrieval - search_start = time.time() - docs = rag.search(question, top_k=top_k) - search_duration = time.time() - search_start - - # Phase 2: Prompt building - prompt_start = time.time() - system_prompt, user_prompt = rag.build_prompt(question, docs) - - # Build messages for chat - messages = [ - {"role": "system", "content": system_prompt} - ] - - # Add conversation history - for exchange in rag.conversation_history[-3:]: # Last 3 exchanges - messages.append({"role": "user", "content": exchange["question"]}) - messages.append({"role": "assistant", "content": exchange["answer"]}) - - # Add current query - messages.append({"role": "user", "content": user_prompt}) - - # Estimate input tokens (rough approximation) - full_prompt = system_prompt + "\n\n" + user_prompt - for msg in messages[1:]: # Skip system message already included - full_prompt += "\n" + msg["content"] - input_tokens = len(full_prompt.split()) * 1.3 # Rough token estimate - prompt_duration = time.time() - prompt_start - - full_response = "" - current_buffer = "" - in_thinking_block = False - - # Phase 3: LLM generation - generation_start = time.time() - first_token_time = None - - # Stream the response - for chunk in rag.ollama_client.chat( - model=rag.model, - messages=messages, - stream=True, - options={"temperature": temperature, "top_p": top_p, "num_predict": max_tokens} - ): - if 'message' in chunk and 'content' in chunk['message']: - content = chunk['message']['content'] - full_response += content - current_buffer += content - - # Track time to first token - if first_token_time is None and content: - first_token_time = time.time() - generation_start - - # Process content based on thinking mode - if rag.thinking_mode == ThinkingMode.DISABLED: - # Filter out thinking blocks in real-time - output_content = "" - i = 0 - while i < len(current_buffer): - if current_buffer[i:].startswith("") and not in_thinking_block: - in_thinking_block = True - i += 7 # Skip "" - elif current_buffer[i:].startswith("") and in_thinking_block: - in_thinking_block = False - i += 8 # Skip "" - current_buffer = current_buffer[i:] - i = 0 - elif not in_thinking_block: - output_content += current_buffer[i] - i += 1 - else: - i += 1 - - # Only update buffer with remaining unprocessed content - if not in_thinking_block: - current_buffer = "" - - # Send filtered content - if output_content: - chunk_data = {"chunk": output_content} - yield f"data: {json.dumps(chunk_data)}\n\n" - else: - # For other modes, send content as-is - chunk_data = {"chunk": content} - yield f"data: {json.dumps(chunk_data)}\n\n" - - generation_duration = time.time() - generation_start - - # Phase 4: Post-processing - processing_start = time.time() - output_tokens = len(full_response.split()) * 1.3 # Rough token estimate - - # Process final response based on thinking mode - if rag.thinking_mode == ThinkingMode.DISABLED: - final_response = rag._clean_response(full_response) - elif rag.thinking_mode == ThinkingMode.QUIET: - final_response = rag._replace_thinking_tags(full_response) - else: # VERBOSE - final_response = full_response.strip() - - # Update history with final processed response - rag._update_history(question, final_response) - processing_duration = time.time() - processing_start - - # Calculate total duration - total_duration = time.time() - start_time - - # Send completion signal with comprehensive metrics - metrics = { - "done": True, - "metrics": { - "duration_seconds": round(total_duration, 3), - "phase_timings": { - "search_seconds": round(search_duration, 3), - "prompt_building_seconds": round(prompt_duration, 3), - "generation_seconds": round(generation_duration, 3), - "post_processing_seconds": round(processing_duration, 3), - "time_to_first_token_seconds": round(first_token_time, 3) if first_token_time else None - }, - "input_tokens_estimated": int(input_tokens), - "output_tokens_estimated": int(output_tokens), - "total_tokens_estimated": int(input_tokens + output_tokens), - "documents_retrieved": len(docs), - "model": rag.model - } - } - yield f"data: {json.dumps(metrics)}\n\n" - - except Exception as e: - # Send error - error_data = {"error": str(e)} - yield f"data: {json.dumps(error_data)}\n\n" - - return StreamingResponse( - generate_stream(), - media_type="text/plain", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - } - ) diff --git a/chat.sh b/chat.sh index 7344aea..e54bdc9 100755 --- a/chat.sh +++ b/chat.sh @@ -12,39 +12,105 @@ echo "🖖 Welcome to MemoryAlpha RAG Chat" echo "Type 'quit' or 'exit' to end the session" echo "----------------------------------------" -while true; do - # Prompt for user input - echo -n "❓ Ask about Star Trek: " - read -r question - - # Check for exit commands - if [[ "$question" == "quit" || "$question" == "exit" || "$question" == "q" ]]; then - echo "🖖 Live long and prosper!" - break - fi - - # Skip empty questions - if [[ -z "$question" ]]; then - continue - fi - - # URL encode the question +# Function to handle text question +ask_question() { + local question="$1" + local encoded_question encoded_question=$(printf '%s' "$question" | jq -sRr @uri) - echo "🤖 LCARS Response:" echo "----------------------------------------" - - # Make the streaming request - curl -s -N -H "Accept: text/event-stream" \ - "${BASE_URL}/memoryalpha/rag/stream?question=${encoded_question}&thinkingmode=${THINKING_MODE}&max_tokens=${MAX_TOKENS}&top_k=${TOP_K}&top_p=${TOP_P}&temperature=${TEMPERATURE}" \ - | while IFS= read -r line; do - if [[ $line == data:* ]]; then - chunk=$(echo "${line#data: }" | jq -r '.chunk // empty') - if [[ -n "$chunk" ]]; then - printf "%s" "$chunk" - fi + local response + response=$(curl -s \ + "${BASE_URL}/memoryalpha/rag/ask?question=${encoded_question}&thinkingmode=${THINKING_MODE}&max_tokens=${MAX_TOKENS}&top_k=${TOP_K}&top_p=${TOP_P}&temperature=${TEMPERATURE}") + local answer + answer=$(echo "$response" | jq -r '.response // empty') + if [[ -n "$answer" ]]; then + printf "%s\n" "$answer" + else + local error + error=$(echo "$response" | jq -r '.error // empty') + if [[ -n "$error" ]]; then + printf "Error: %s\n" "$error" + else + printf "No response received.\n" + fi + fi + echo "----------------------------------------" +} + +# Function to handle image identification +identify_image() { + local image_path="$1" + local tmpfile="" + # Check if local file exists + if [[ -f "$image_path" ]]; then + tmpfile="$image_path" + else + # Try to download + echo "Attempting to download image from URL: $image_path" + tmpfile="/tmp/maimg_$$.img" + if ! curl -sSL "$image_path" -o "$tmpfile"; then + echo "Failed to download image. Returning to menu." + [[ -f "$tmpfile" ]] && rm -f "$tmpfile" + return + fi + fi + echo "🤖 LCARS Image Identification:" + echo "----------------------------------------" + local response + response=$(curl -s -X POST \ + -F "file=@${tmpfile}" \ + "${BASE_URL}/memoryalpha/rag/identify?top_k=${TOP_K}") + local answer + answer=$(echo "$response" | jq -r '.model_answer // empty') + if [[ -n "$answer" ]]; then + printf "%s\n" "$answer" + else + local error + error=$(echo "$response" | jq -r '.error // empty') + if [[ -n "$error" ]]; then + printf "Error: %s\n" "$error" + else + printf "No response received.\n" + fi + fi + echo "----------------------------------------" + # Clean up temp file if downloaded + if [[ "$tmpfile" != "$image_path" ]]; then + rm -f "$tmpfile" + fi +} + +while true; do + echo "Choose an option:" + echo " 1) Ask a Star Trek question" + echo " 2) Identify an image" + echo " q) Quit" + echo -n "Enter choice [1/2/q]: " + read -r choice + case "$choice" in + 1) + echo -n "❓ Enter your question: " + read -r question + if [[ -z "$question" ]]; then + continue + fi + ask_question "$question" + ;; + 2) + echo -n "🖼️ Enter local image path or image URL: " + read -r image_path + if [[ -z "$image_path" ]]; then + continue fi - done - - echo -e "\n----------------------------------------" + identify_image "$image_path" + ;; + q|quit|exit) + echo "🖖 Live long and prosper!" + break + ;; + *) + echo "Invalid choice. Please enter 1, 2, or q." + ;; + esac done diff --git a/requirements.txt b/requirements.txt index 1a13ac4..a8a8c23 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ prompt_toolkit==3.0.51 sentence-transformers==5.0.0 ollama==0.5.2 fastapi==0.116.1 -uvicorn==0.35.0 \ No newline at end of file +uvicorn==0.35.0 +python-multipart==0.0.20 \ No newline at end of file diff --git a/test_images/image01.jpg b/test_images/image01.jpg new file mode 100644 index 0000000..d8a16aa Binary files /dev/null and b/test_images/image01.jpg differ diff --git a/test_images/image02.jpg b/test_images/image02.jpg new file mode 100644 index 0000000..d3fd9df Binary files /dev/null and b/test_images/image02.jpg differ diff --git a/test_images/image03.jpg b/test_images/image03.jpg new file mode 100644 index 0000000..715781c Binary files /dev/null and b/test_images/image03.jpg differ diff --git a/test_images/image04.jpg b/test_images/image04.jpg new file mode 100644 index 0000000..cf72389 Binary files /dev/null and b/test_images/image04.jpg differ diff --git a/test_images/image05.jpg b/test_images/image05.jpg new file mode 100644 index 0000000..09744e9 Binary files /dev/null and b/test_images/image05.jpg differ diff --git a/test_images/image06.jpg b/test_images/image06.jpg new file mode 100644 index 0000000..1518e63 Binary files /dev/null and b/test_images/image06.jpg differ diff --git a/wait-for-ollama.sh b/wait-for-ollama.sh index 241fc27..b561af7 100755 --- a/wait-for-ollama.sh +++ b/wait-for-ollama.sh @@ -7,16 +7,22 @@ done echo "✅ Ollama is ready." -# Check if DEFAULT_MODEL is available, if not, pull it -echo "🔍 Checking if model '$DEFAULT_MODEL' is available..." -if curl -s "$OLLAMA_URL/api/tags" | grep -q "\"name\":\"$DEFAULT_MODEL\""; then - echo "✅ Model '$DEFAULT_MODEL' is already available." -else - echo "📥 Model '$DEFAULT_MODEL' not found. Pulling it now..." - curl -X POST "$OLLAMA_URL/api/pull" -H "Content-Type: application/json" -d "{\"name\":\"$DEFAULT_MODEL\"}" - echo "" - echo "✅ Model '$DEFAULT_MODEL' has been pulled successfully." -fi +pull_model() { + local model_name="$1" + echo "🔍 Checking if model '$model_name' is available..." + if curl -s "$OLLAMA_URL/api/tags" | grep -q "\"name\":\"$model_name\""; then + echo "✅ Model '$model_name' is already available." + else + echo "📥 Model '$model_name' not found. Pulling it now..." + curl -X POST "$OLLAMA_URL/api/pull" -H "Content-Type: application/json" -d "{\"name\":\"$model_name\"}" + echo "" + echo "✅ Model '$model_name' has been pulled successfully." + fi +} + +# Pull the default models +pull_model "$DEFAULT_MODEL" +pull_model "$DEFAULT_IMAGE_MODEL" # Warm up ollama with the default model echo "🤖 Warming up Ollama with $DEFAULT_MODEL..."