Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions api/memoryalpha/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,29 @@ def __init__(self,
path=self.chroma_db_path,
settings=Settings(allow_reset=False)
)
self.collection = self.client.get_collection(self.collection_name)

# Initialize CLIP model for consistent embeddings with the database
logger.info("Loading CLIP model for embedding compatibility...")
self.clip_model = SentenceTransformer('clip-ViT-B-32')
logger.info("CLIP model loaded successfully")

# Create CLIP embedding function to match the one used during data creation
from chromadb.utils import embedding_functions

class CLIPEmbeddingFunction(embedding_functions.EmbeddingFunction):
def __init__(self, clip_model):
self.clip_model = clip_model

def __call__(self, input):
"""Generate embeddings using CLIP model"""
embeddings = []
for text in input:
embedding = self.clip_model.encode(text)
embeddings.append(embedding.tolist())
return embeddings

self.clip_ef = CLIPEmbeddingFunction(self.clip_model)
self.collection = self.client.get_collection(self.collection_name, embedding_function=self.clip_ef)

# Initialize Ollama client
self.ollama_client = ollama.Client(host=self.ollama_url)
Expand All @@ -139,12 +161,18 @@ def _cosine_similarity(self, query_embedding: np.ndarray, doc_embeddings: np.nda
return np.dot(doc_norms, query_norm)

def search(self, query: str, top_k: int = 10) -> List[Dict[str, Any]]:
results = self.collection.query(query_texts=[query], n_results=top_k)
# Search only text documents (filter out image documents for now)
results = self.collection.query(
query_texts=[query],
n_results=top_k,
where={"content_type": "text"} # Only search text documents
)
docs = [
{
"content": doc,
"title": meta["title"],
"distance": dist
"distance": dist,
"content_type": meta.get("content_type", "text")
}
for doc, meta, dist in zip(results["documents"][0], results["metadatas"][0], results["distances"][0])
]
Expand Down
Empty file removed api/memoryalpha/rag/query.py
Empty file.