diff --git a/api/memoryalpha/rag.py b/api/memoryalpha/rag.py index 6c0fdb4..b8a8396 100644 --- a/api/memoryalpha/rag.py +++ b/api/memoryalpha/rag.py @@ -115,7 +115,29 @@ def __init__(self, path=self.chroma_db_path, settings=Settings(allow_reset=False) ) - self.collection = self.client.get_collection(self.collection_name) + + # Initialize CLIP model for consistent embeddings with the database + logger.info("Loading CLIP model for embedding compatibility...") + self.clip_model = SentenceTransformer('clip-ViT-B-32') + logger.info("CLIP model loaded successfully") + + # Create CLIP embedding function to match the one used during data creation + from chromadb.utils import embedding_functions + + class CLIPEmbeddingFunction(embedding_functions.EmbeddingFunction): + def __init__(self, clip_model): + self.clip_model = clip_model + + def __call__(self, input): + """Generate embeddings using CLIP model""" + embeddings = [] + for text in input: + embedding = self.clip_model.encode(text) + embeddings.append(embedding.tolist()) + return embeddings + + self.clip_ef = CLIPEmbeddingFunction(self.clip_model) + self.collection = self.client.get_collection(self.collection_name, embedding_function=self.clip_ef) # Initialize Ollama client self.ollama_client = ollama.Client(host=self.ollama_url) @@ -139,12 +161,18 @@ def _cosine_similarity(self, query_embedding: np.ndarray, doc_embeddings: np.nda return np.dot(doc_norms, query_norm) def search(self, query: str, top_k: int = 10) -> List[Dict[str, Any]]: - results = self.collection.query(query_texts=[query], n_results=top_k) + # Search only text documents (filter out image documents for now) + results = self.collection.query( + query_texts=[query], + n_results=top_k, + where={"content_type": "text"} # Only search text documents + ) docs = [ { "content": doc, "title": meta["title"], - "distance": dist + "distance": dist, + "content_type": meta.get("content_type", "text") } for doc, meta, dist in zip(results["documents"][0], results["metadatas"][0], results["distances"][0]) ] diff --git a/api/memoryalpha/rag/query.py b/api/memoryalpha/rag/query.py deleted file mode 100644 index e69de29..0000000