forked from SesameAILabs/csm
-
-
Notifications
You must be signed in to change notification settings - Fork 71
Expand file tree
/
Copy pathrag_system.py
More file actions
executable file
·330 lines (269 loc) · 11.9 KB
/
rag_system.py
File metadata and controls
executable file
·330 lines (269 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import sqlite3
import numpy as np
import json
from pathlib import Path
import time
from typing import List, Dict, Any, Tuple, Optional
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import torch
class RAGSystem:
def __init__(self, db_path: str, model_name: str = "all-MiniLM-L6-v2", cache_dir: str = "./embeddings_cache"):
"""
Initialize the enhanced RAG system with embeddings.
Args:
db_path: Path to the SQLite database
model_name: Name of the sentence-transformer model to use
cache_dir: Directory to cache embeddings
"""
self.db_path = db_path
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
# Load embedding model
print(f"Loading embedding model: {model_name}")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SentenceTransformer(model_name, device=self.device)
print(f"Embedding model loaded on {self.device}")
# Cache for embeddings
self.embedding_cache = self._load_embedding_cache()
# Initialize database tables if needed
self._initialize_db()
# Load existing conversations and cache embeddings
self._load_conversations()
def _initialize_db(self):
"""Create necessary tables if they don't exist."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Create conversations table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS conversations (
id INTEGER PRIMARY KEY,
user_message TEXT,
ai_message TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
# Create embeddings table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
id INTEGER PRIMARY KEY,
conversation_id INTEGER,
text TEXT,
embedding_file TEXT,
chunk_id TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversations(id)
)
""")
conn.commit()
conn.close()
def _load_embedding_cache(self) -> Dict[str, np.ndarray]:
"""Load cached embeddings from disk."""
cache = {}
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, "r") as f:
cache_data = json.load(f)
for chunk_id, embedding_data in cache_data.items():
cache[chunk_id] = np.array(embedding_data)
except Exception as e:
print(f"Error loading cache file {cache_file}: {e}")
print(f"Loaded {len(cache)} cached embeddings")
return cache
def _save_embedding_to_cache(self, chunk_id: str, embedding: np.ndarray):
"""Save an embedding to the cache."""
cache_file = self.cache_dir / f"{chunk_id[:2]}.json"
# Load existing cache file or create new one
if cache_file.exists():
try:
with open(cache_file, "r") as f:
cache_data = json.load(f)
except:
cache_data = {}
else:
cache_data = {}
# Add new embedding
cache_data[chunk_id] = embedding.tolist()
# Save cache file
with open(cache_file, "w") as f:
json.dump(cache_data, f)
def _load_conversations(self):
"""Load existing conversations from the database and cache their embeddings."""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# First check if the conversations table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='conversations'")
if not cursor.fetchone():
print("Conversations table does not exist yet")
conn.close()
return
# Get all conversations not yet in the embeddings table
cursor.execute("""
SELECT c.id, c.user_message, c.ai_message
FROM conversations c
LEFT JOIN embeddings e ON c.id = e.conversation_id
WHERE e.id IS NULL
""")
conversations = cursor.fetchall()
if not conversations:
conn.close()
return
print(f"Processing embeddings for {len(conversations)} new conversations")
for conv_id, user_message, ai_message in conversations:
# Create chunks for indexing
if user_message is not None and ai_message is not None: # Ensure neither is None
self._process_conversation(conv_id, user_message, ai_message, conn)
conn.close()
print("Finished processing conversation embeddings")
except Exception as e:
print(f"Error loading conversations: {e}")
def _process_conversation(self, conv_id: int, user_message: str, ai_message: str, conn: sqlite3.Connection):
"""Process a conversation and store its embeddings."""
try:
cursor = conn.cursor()
# Combine user and AI messages
full_text = f"User: {user_message}\nAI: {ai_message}"
# For simplicity, we're using the entire message as a chunk
# In a more sophisticated system, you might split long messages into smaller chunks
chunk_id = f"conv_{conv_id}"
# Check if we already have this embedding cached
if chunk_id not in self.embedding_cache:
# Generate embedding
embedding = self.model.encode(full_text)
self.embedding_cache[chunk_id] = embedding
# Save to cache
self._save_embedding_to_cache(chunk_id, embedding)
else:
embedding = self.embedding_cache[chunk_id]
# Store reference in database
embedding_file = f"{chunk_id[:2]}.json"
cursor.execute(
"INSERT INTO embeddings (conversation_id, text, embedding_file, chunk_id) VALUES (?, ?, ?, ?)",
(conv_id, full_text, embedding_file, chunk_id)
)
conn.commit()
except Exception as e:
print(f"Error processing conversation {conv_id}: {e}")
def add_conversation(self, user_message: str, ai_message: str) -> int:
"""
Add a new conversation to the RAG system.
Returns:
The id of the newly added conversation
"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Insert the conversation first
cursor.execute(
"INSERT INTO conversations (user_message, ai_message) VALUES (?, ?)",
(user_message, ai_message)
)
# Get the ID of the new conversation
conv_id = cursor.lastrowid
# Process the conversation for embeddings
self._process_conversation(conv_id, user_message, ai_message, conn)
conn.commit()
conn.close()
return conv_id
except Exception as e:
print(f"Error adding conversation: {e}")
return -1
def query(self, query_text: str, top_k: int = 3) -> List[Tuple[str, float]]:
"""
Query the RAG system for relevant context.
Args:
query_text: The query text
top_k: Number of top results to return
Returns:
List of tuples with (text, similarity_score)
"""
if query_text is None or query_text.strip() == "":
print("Error: Empty query text")
return []
try:
# Generate query embedding
query_embedding = self.model.encode(query_text)
# Find most similar conversations
results = self._find_similar(query_embedding, top_k)
return results
except Exception as e:
print(f"Error during query: {e}")
return []
def get_context(self, query_text: str, top_k: int = 3, threshold: float = 0.6) -> str:
"""
Get formatted context from the RAG system.
Args:
query_text: The query text
top_k: Number of top results to return
threshold: Minimum similarity score to include
Returns:
String with relevant context
"""
results = self.query(query_text, top_k)
if not results:
return ""
# Format results
context_parts = []
for text, score in results:
# Only include really relevant results
if score < threshold: # Threshold for relevance
continue
context_parts.append(f"Relevance: {score:.2f}\n{text}")
return "\n---\n".join(context_parts)
def _find_similar(self, query_embedding: np.ndarray, top_k: int) -> List[Tuple[str, float]]:
"""Find the most similar conversations to the query."""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Check if the embeddings table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='embeddings'")
if not cursor.fetchone():
print("Embeddings table does not exist yet")
conn.close()
return []
# Get all embeddings from the database
cursor.execute("SELECT id, text, embedding_file, chunk_id FROM embeddings")
results = cursor.fetchall()
if not results:
conn.close()
return []
# Calculate similarities
similarities = []
for db_id, text, embedding_file, chunk_id in results:
# Get embedding from cache
if chunk_id in self.embedding_cache:
embedding = self.embedding_cache[chunk_id]
else:
# This should not happen, but just in case
# We'll reload from the cache file
cache_file = self.cache_dir / embedding_file
if cache_file.exists():
with open(cache_file, "r") as f:
cache_data = json.load(f)
if chunk_id in cache_data:
embedding = np.array(cache_data[chunk_id])
self.embedding_cache[chunk_id] = embedding
else:
continue
else:
continue
# Calculate similarity
similarity = cosine_similarity(
query_embedding.reshape(1, -1),
embedding.reshape(1, -1)
)[0][0]
similarities.append((text, similarity))
conn.close()
# Sort by similarity and return top_k
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
except Exception as e:
print(f"Error finding similar documents: {e}")
return []
def refresh(self):
"""Refresh embeddings from the database."""
self._load_conversations()
# Example usage
if __name__ == "__main__":
# Initialize the RAG system
rag = RAGSystem("conversations.db")