Skip to content

Commit 7d34e65

Browse files
authored
feat: add filter for search_memories (#553)
* add filter for search_memories * fix: data type incorrect * fix * fix textual filter bug and resolve conversation
1 parent e631649 commit 7d34e65

File tree

11 files changed

+253
-44
lines changed

11 files changed

+253
-44
lines changed

src/memos/api/product_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ class APIADDRequest(BaseRequest):
469469
),
470470
)
471471

472-
info: dict[str, str] | None = Field(
472+
info: dict[str, Any] | None = Field(
473473
None,
474474
description=(
475475
"Additional metadata for the add request. "

src/memos/mem_scheduler/optimized_scheduler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def mix_search_memories(
138138
target_session_id = search_req.session_id
139139
if not target_session_id:
140140
target_session_id = "default_session"
141-
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
141+
search_priority = {"session_id": search_req.session_id} if search_req.session_id else None
142+
search_filter = search_req.filter
142143

143144
# Rerank Memories - reranker expects TextualMemoryItem objects
144145

@@ -155,6 +156,7 @@ def mix_search_memories(
155156
mode=SearchMode.FAST,
156157
manual_close_internet=not search_req.internet_search,
157158
search_filter=search_filter,
159+
search_priority=search_priority,
158160
info=info,
159161
)
160162

@@ -178,7 +180,7 @@ def mix_search_memories(
178180
query=search_req.query, # Use search_req.query instead of undefined query
179181
graph_results=history_memories, # Pass TextualMemoryItem objects directly
180182
top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k
181-
search_filter=search_filter,
183+
search_priority=search_priority,
182184
)
183185
logger.info(f"Reranked {len(sorted_history_memories)} history memories.")
184186
processed_hist_mem = self.searcher.post_retrieve(
@@ -234,6 +236,7 @@ def mix_search_memories(
234236
mode=SearchMode.FAST,
235237
memory_type="All",
236238
search_filter=search_filter,
239+
search_priority=search_priority,
237240
info=info,
238241
)
239242
else:

src/memos/memories/textual/prefer_text_memory/retrievers.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=No
1717

1818
@abstractmethod
1919
def retrieve(
20-
self, query: str, top_k: int, info: dict[str, Any] | None = None
20+
self,
21+
query: str,
22+
top_k: int,
23+
info: dict[str, Any] | None = None,
24+
search_filter: dict[str, Any] | None = None,
2125
) -> list[TextualMemoryItem]:
2226
"""Retrieve memories from the retriever."""
2327

@@ -76,14 +80,19 @@ def _original_text_reranker(
7680
return prefs_mem
7781

7882
def retrieve(
79-
self, query: str, top_k: int, info: dict[str, Any] | None = None
83+
self,
84+
query: str,
85+
top_k: int,
86+
info: dict[str, Any] | None = None,
87+
search_filter: dict[str, Any] | None = None,
8088
) -> list[TextualMemoryItem]:
8189
"""Retrieve memories from the naive retriever."""
8290
# TODO: un-support rewrite query and session filter now
8391
if info:
8492
info = info.copy() # Create a copy to avoid modifying the original
8593
info.pop("chat_history", None)
8694
info.pop("session_id", None)
95+
search_filter = {"and": [info, search_filter]}
8796
query_embeddings = self.embedder.embed([query]) # Pass as list to get list of embeddings
8897
query_embedding = query_embeddings[0] # Get the first (and only) embedding
8998

@@ -96,15 +105,15 @@ def retrieve(
96105
query,
97106
"explicit_preference",
98107
top_k * 2,
99-
info,
108+
search_filter,
100109
)
101110
future_implicit = executor.submit(
102111
self.vector_db.search,
103112
query_embedding,
104113
query,
105114
"implicit_preference",
106115
top_k * 2,
107-
info,
116+
search_filter,
108117
)
109118

110119
# Wait for all results

src/memos/memories/textual/preference.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def get_memory(
7676
"""
7777
return self.extractor.extract(messages, type, info)
7878

79-
def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]:
79+
def search(
80+
self, query: str, top_k: int, info=None, search_filter=None, **kwargs
81+
) -> list[TextualMemoryItem]:
8082
"""Search for memories based on a query.
8183
Args:
8284
query (str): The query to search for.
@@ -85,7 +87,8 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem
8587
Returns:
8688
list[TextualMemoryItem]: List of matching memories.
8789
"""
88-
return self.retriever.retrieve(query, top_k, info)
90+
logger.info(f"search_filter for preference memory: {search_filter}")
91+
return self.retriever.retrieve(query, top_k, info, search_filter)
8992

9093
def load(self, dir: str) -> None:
9194
"""Load memories from the specified directory.

src/memos/memories/textual/simple_preference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def get_memory(
5050
"""
5151
return self.extractor.extract(messages, type, info)
5252

53-
def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]:
53+
def search(
54+
self, query: str, top_k: int, info=None, search_filter=None, **kwargs
55+
) -> list[TextualMemoryItem]:
5456
"""Search for memories based on a query.
5557
Args:
5658
query (str): The query to search for.
@@ -59,7 +61,7 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem
5961
Returns:
6062
list[TextualMemoryItem]: List of matching memories.
6163
"""
62-
return self.retriever.retrieve(query, top_k, info)
64+
return self.retriever.retrieve(query, top_k, info, search_filter)
6365

6466
def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]:
6567
"""Add memories.

src/memos/memories/textual/tree.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def search(
162162
mode: str = "fast",
163163
memory_type: str = "All",
164164
manual_close_internet: bool = True,
165+
search_priority: dict | None = None,
165166
search_filter: dict | None = None,
166167
user_name: str | None = None,
167168
) -> list[TextualMemoryItem]:
@@ -209,7 +210,14 @@ def search(
209210
manual_close_internet=manual_close_internet,
210211
)
211212
return searcher.search(
212-
query, top_k, info, mode, memory_type, search_filter, user_name=user_name
213+
query,
214+
top_k,
215+
info,
216+
mode,
217+
memory_type,
218+
search_filter,
219+
search_priority,
220+
user_name=user_name,
213221
)
214222

215223
def get_relevant_subgraph(

src/memos/memories/textual/tree_text_memory/retrieve/recall.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def retrieve(
3838
memory_scope: str,
3939
query_embedding: list[list[float]] | None = None,
4040
search_filter: dict | None = None,
41+
search_priority: dict | None = None,
4142
user_name: str | None = None,
4243
id_filter: dict | None = None,
4344
use_fast_graph: bool = False,
@@ -62,9 +63,12 @@ def retrieve(
6263
raise ValueError(f"Unsupported memory scope: {memory_scope}")
6364

6465
if memory_scope == "WorkingMemory":
65-
# For working memory, retrieve all entries (no filtering)
66+
# For working memory, retrieve all entries (no session-oriented filtering)
6667
working_memories = self.graph_store.get_all_memory_items(
67-
scope="WorkingMemory", include_embedding=False, user_name=user_name
68+
scope="WorkingMemory",
69+
include_embedding=False,
70+
user_name=user_name,
71+
filter=search_filter,
6872
)
6973
return [TextualMemoryItem.from_dict(record) for record in working_memories[:top_k]]
7074

@@ -84,6 +88,7 @@ def retrieve(
8488
memory_scope,
8589
top_k,
8690
search_filter=search_filter,
91+
search_priority=search_priority,
8792
user_name=user_name,
8893
)
8994
if self.use_bm25:
@@ -274,6 +279,7 @@ def _vector_recall(
274279
status: str = "activated",
275280
cube_name: str | None = None,
276281
search_filter: dict | None = None,
282+
search_priority: dict | None = None,
277283
user_name: str | None = None,
278284
) -> list[TextualMemoryItem]:
279285
"""
@@ -283,39 +289,41 @@ def _vector_recall(
283289
if not query_embedding:
284290
return []
285291

286-
def search_single(vec, filt=None):
292+
def search_single(vec, search_priority=None, search_filter=None):
287293
return (
288294
self.graph_store.search_by_embedding(
289295
vector=vec,
290296
top_k=top_k,
291297
status=status,
292298
scope=memory_scope,
293299
cube_name=cube_name,
294-
search_filter=filt,
300+
search_filter=search_priority,
301+
filter=search_filter,
295302
user_name=user_name,
296303
)
297304
or []
298305
)
299306

300307
def search_path_a():
301-
"""Path A: search without filter"""
308+
"""Path A: search without priority"""
302309
path_a_hits = []
303310
with ContextThreadPoolExecutor() as executor:
304311
futures = [
305-
executor.submit(search_single, vec, None) for vec in query_embedding[:max_num]
312+
executor.submit(search_single, vec, None, search_filter)
313+
for vec in query_embedding[:max_num]
306314
]
307315
for f in concurrent.futures.as_completed(futures):
308316
path_a_hits.extend(f.result() or [])
309317
return path_a_hits
310318

311319
def search_path_b():
312-
"""Path B: search with filter"""
313-
if not search_filter:
320+
"""Path B: search with priority"""
321+
if not search_priority:
314322
return []
315323
path_b_hits = []
316324
with ContextThreadPoolExecutor() as executor:
317325
futures = [
318-
executor.submit(search_single, vec, search_filter)
326+
executor.submit(search_single, vec, search_priority, search_filter)
319327
for vec in query_embedding[:max_num]
320328
]
321329
for f in concurrent.futures.as_completed(futures):

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,20 @@ def retrieve(
6969
mode="fast",
7070
memory_type="All",
7171
search_filter: dict | None = None,
72+
search_priority: dict | None = None,
7273
user_name: str | None = None,
7374
**kwargs,
7475
) -> list[tuple[TextualMemoryItem, float]]:
7576
logger.info(
7677
f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}"
7778
)
7879
parsed_goal, query_embedding, context, query = self._parse_task(
79-
query, info, mode, search_filter=search_filter, user_name=user_name
80+
query,
81+
info,
82+
mode,
83+
search_filter=search_filter,
84+
search_priority=search_priority,
85+
user_name=user_name,
8086
)
8187
results = self._retrieve_paths(
8288
query,
@@ -87,6 +93,7 @@ def retrieve(
8793
mode,
8894
memory_type,
8995
search_filter,
96+
search_priority,
9097
user_name,
9198
)
9299
return results
@@ -112,6 +119,7 @@ def search(
112119
mode="fast",
113120
memory_type="All",
114121
search_filter: dict | None = None,
122+
search_priority: dict | None = None,
115123
user_name: str | None = None,
116124
) -> list[TextualMemoryItem]:
117125
"""
@@ -128,6 +136,7 @@ def search(
128136
memory_type (str): Type restriction for search.
129137
['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory']
130138
search_filter (dict, optional): Optional metadata filters for search results.
139+
search_priority (dict, optional): Optional metadata priority for search results.
131140
Returns:
132141
list[TextualMemoryItem]: List of matching memories.
133142
"""
@@ -147,6 +156,7 @@ def search(
147156
mode=mode,
148157
memory_type=memory_type,
149158
search_filter=search_filter,
159+
search_priority=search_priority,
150160
user_name=user_name,
151161
)
152162

@@ -174,6 +184,7 @@ def _parse_task(
174184
mode,
175185
top_k=5,
176186
search_filter: dict | None = None,
187+
search_priority: dict | None = None,
177188
user_name: str | None = None,
178189
):
179190
"""Parse user query, do embedding search and create context"""
@@ -192,7 +203,8 @@ def _parse_task(
192203
query_embedding,
193204
top_k=top_k,
194205
status="activated",
195-
search_filter=search_filter,
206+
search_filter=search_priority,
207+
filter=search_filter,
196208
user_name=user_name,
197209
)
198210
]
@@ -244,6 +256,7 @@ def _retrieve_paths(
244256
mode,
245257
memory_type,
246258
search_filter: dict | None = None,
259+
search_priority: dict | None = None,
247260
user_name: str | None = None,
248261
):
249262
"""Run A/B/C retrieval paths in parallel"""
@@ -264,6 +277,7 @@ def _retrieve_paths(
264277
top_k,
265278
memory_type,
266279
search_filter,
280+
search_priority,
267281
user_name,
268282
id_filter,
269283
)
@@ -277,6 +291,7 @@ def _retrieve_paths(
277291
top_k,
278292
memory_type,
279293
search_filter,
294+
search_priority,
280295
user_name,
281296
id_filter,
282297
mode=mode,
@@ -313,6 +328,7 @@ def _retrieve_from_working_memory(
313328
top_k,
314329
memory_type,
315330
search_filter: dict | None = None,
331+
search_priority: dict | None = None,
316332
user_name: str | None = None,
317333
id_filter: dict | None = None,
318334
):
@@ -326,6 +342,7 @@ def _retrieve_from_working_memory(
326342
top_k=top_k,
327343
memory_scope="WorkingMemory",
328344
search_filter=search_filter,
345+
search_priority=search_priority,
329346
user_name=user_name,
330347
id_filter=id_filter,
331348
use_fast_graph=self.use_fast_graph,
@@ -349,6 +366,7 @@ def _retrieve_from_long_term_and_user(
349366
top_k,
350367
memory_type,
351368
search_filter: dict | None = None,
369+
search_priority: dict | None = None,
352370
user_name: str | None = None,
353371
id_filter: dict | None = None,
354372
mode: str = "fast",
@@ -378,6 +396,7 @@ def _retrieve_from_long_term_and_user(
378396
top_k=top_k * 2,
379397
memory_scope="LongTermMemory",
380398
search_filter=search_filter,
399+
search_priority=search_priority,
381400
user_name=user_name,
382401
id_filter=id_filter,
383402
use_fast_graph=self.use_fast_graph,
@@ -393,6 +412,7 @@ def _retrieve_from_long_term_and_user(
393412
top_k=top_k * 2,
394413
memory_scope="UserMemory",
395414
search_filter=search_filter,
415+
search_priority=search_priority,
396416
user_name=user_name,
397417
id_filter=id_filter,
398418
use_fast_graph=self.use_fast_graph,

0 commit comments

Comments
 (0)