diff --git a/refactron/rag/indexer.py b/refactron/rag/indexer.py index 690cfe3..2a6587f 100644 --- a/refactron/rag/indexer.py +++ b/refactron/rag/indexer.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, cast +from functools import lru_cache try: import chromadb @@ -19,6 +20,16 @@ SentenceTransformer = None CHROMA_AVAILABLE = False +@lru_cache(maxsize=2) +def get_sentence_transformer(model_name: str) -> Any: + """Module-level LRU cache for SentenceTransformer to avoid duplicate memory allocation.""" + if not SentenceTransformer: + raise RuntimeError( + "sentence-transformers is not available. " + "Install with: pip install sentence-transformers" + ) + return SentenceTransformer(model_name) + from refactron.rag.chunker import CodeChunk from refactron.rag.parser import CodeParser @@ -46,7 +57,7 @@ class RAGIndexer: def __init__( self, workspace_path: Path, - embedding_model: str = "all-MiniLM-L6-v2", + embedding_model: Any = "all-MiniLM-L6-v2", collection_name: str = "code_chunks", llm_client: Optional[GroqClient] = None, ): @@ -74,8 +85,12 @@ def __init__( self.llm_client = llm_client # Initialize embedding model - self.embedding_model_name = embedding_model - self.embedding_model = SentenceTransformer(embedding_model) + if isinstance(embedding_model, str): + self.embedding_model_name = embedding_model + self.embedding_model = get_sentence_transformer(embedding_model) + else: + self.embedding_model_name = "custom_model" + self.embedding_model = embedding_model # Initialize ChromaDB self.client = chromadb.PersistentClient( diff --git a/refactron/rag/retriever.py b/refactron/rag/retriever.py index 9fcd16f..dc801bd 100644 --- a/refactron/rag/retriever.py +++ b/refactron/rag/retriever.py @@ -4,7 +4,9 @@ from dataclasses import dataclass from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Any + +from refactron.rag.indexer import get_sentence_transformer try: import chromadb @@ -38,7 +40,7 @@ class ContextRetriever: def __init__( self, workspace_path: Path, - embedding_model: str = "all-MiniLM-L6-v2", + embedding_model: Any = "all-MiniLM-L6-v2", collection_name: str = "code_chunks", ): """Initialize the context retriever. @@ -58,7 +60,10 @@ def __init__( self.index_path = self.workspace_path / ".rag" # Initialize embedding model - self.embedding_model = SentenceTransformer(embedding_model) + if isinstance(embedding_model, str): + self.embedding_model = get_sentence_transformer(embedding_model) + else: + self.embedding_model = embedding_model # Initialize ChromaDB client self.client = chromadb.PersistentClient(