Skip to content

Commit f714027

Browse files
whipser030黑布林CaralHsifridayL
authored
feat: feedback interface (#541)
* update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code * re pr * format repair * fix time issue * develop feedback process * feedback handler configuration * upgrade feedback using * add threshold * update prompt * update prompt * fix handler * add feedback scheduler * add handler change node update * add handler change node update * add handler change node update * add handler change node update --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi <caralhsi@gmail.com> Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com>
1 parent 0c0a402 commit f714027

File tree

20 files changed

+1661
-9
lines changed

20 files changed

+1661
-9
lines changed

examples/api/product_api.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,23 @@ def chat_stream(query: str, session_id: str, history: list | None = None):
119119
print(payload)
120120

121121

122+
def feedback_memory(feedback_content: str, history: list | None = None):
123+
url = f"{BASE_URL}/feedback"
124+
data = {
125+
"user_id": USER_ID,
126+
"writable_cube_ids": [MEM_CUBE_ID],
127+
"history": history,
128+
"feedback_content": feedback_content,
129+
"async_mode": "sync",
130+
"corrected_answer": "false",
131+
}
132+
133+
print("[*] Feedbacking memory ...")
134+
resp = requests.post(url, headers=HEADERS, data=json.dumps(data), timeout=30)
135+
print(resp.status_code, resp.text)
136+
return resp.json()
137+
138+
122139
if __name__ == "__main__":
123140
print("===== STEP 1: Register User =====")
124141
register_user()
@@ -140,5 +157,14 @@ def chat_stream(query: str, session_id: str, history: list | None = None):
140157
],
141158
)
142159

143-
print("\n===== STEP 4: Stream Chat =====")
160+
print("\n===== STEP 5: Stream Chat =====")
144161
chat_stream("我刚和你说什么了呢", SESSION_ID2, history=[])
162+
163+
print("\n===== STEP 6: Feedback Memory =====")
164+
feedback_memory(
165+
feedback_content="错啦,我今天没有吃拉面",
166+
history=[
167+
{"role": "user", "content": "我刚和你说什么了呢"},
168+
{"role": "assistant", "content": "你今天吃了好吃的拉面"},
169+
],
170+
)

src/memos/api/handlers/add_handler.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
9-
from memos.api.product_models import APIADDRequest, MemoryResponse
9+
from memos.api.product_models import APIADDRequest, APIFeedbackRequest, MemoryResponse
1010
from memos.memories.textual.item import (
1111
list_all_fields,
1212
)
@@ -30,7 +30,9 @@ def __init__(self, dependencies: HandlerDependencies):
3030
dependencies: HandlerDependencies instance
3131
"""
3232
super().__init__(dependencies)
33-
self._validate_dependencies("naive_mem_cube", "mem_reader", "mem_scheduler")
33+
self._validate_dependencies(
34+
"naive_mem_cube", "mem_reader", "mem_scheduler", "feedback_server"
35+
)
3436

3537
def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse:
3638
"""
@@ -56,6 +58,39 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse:
5658

5759
cube_view = self._build_cube_view(add_req)
5860

61+
if add_req.is_feedback:
62+
chat_history = add_req.chat_history
63+
messages = add_req.messages
64+
if chat_history is None:
65+
chat_history = []
66+
if messages is None:
67+
messages = []
68+
concatenate_chat = chat_history + messages
69+
70+
last_user_index = max(i for i, d in enumerate(concatenate_chat) if d["role"] == "user")
71+
feedback_content = concatenate_chat[last_user_index]["content"]
72+
feedback_history = concatenate_chat[:last_user_index]
73+
74+
feedback_req = APIFeedbackRequest(
75+
user_id=add_req.user_id,
76+
session_id=add_req.session_id,
77+
task_id=add_req.task_id,
78+
history=feedback_history,
79+
feedback_content=feedback_content,
80+
writable_cube_ids=add_req.writable_cube_ids,
81+
async_mode=add_req.async_mode,
82+
)
83+
process_record = cube_view.feedback_memories(feedback_req)
84+
85+
self.logger.info(
86+
f"[FeedbackHandler] Final feedback results count={len(process_record)}"
87+
)
88+
89+
return MemoryResponse(
90+
message="Memory feedback successfully",
91+
data=[process_record],
92+
)
93+
5994
results = cube_view.add_memories(add_req)
6095

6196
self.logger.info(f"[AddHandler] Final add results count={len(results)}")
@@ -88,6 +123,7 @@ def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView:
88123
mem_reader=self.mem_reader,
89124
mem_scheduler=self.mem_scheduler,
90125
logger=self.logger,
126+
feedback_server=self.feedback_server,
91127
searcher=None,
92128
)
93129
else:
@@ -98,6 +134,7 @@ def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView:
98134
mem_reader=self.mem_reader,
99135
mem_scheduler=self.mem_scheduler,
100136
logger=self.logger,
137+
feedback_server=self.feedback_server,
101138
searcher=None,
102139
)
103140
for cube_id in cube_ids

src/memos/api/handlers/base_handler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
internet_retriever: Any | None = None,
3838
memory_manager: Any | None = None,
3939
mos_server: Any | None = None,
40+
feedback_server: Any | None = None,
4041
**kwargs,
4142
):
4243
"""
@@ -68,6 +69,7 @@ def __init__(
6869
self.internet_retriever = internet_retriever
6970
self.memory_manager = memory_manager
7071
self.mos_server = mos_server
72+
self.feedback_server = feedback_server
7173

7274
# Store any additional dependencies
7375
for key, value in kwargs.items():
@@ -166,6 +168,11 @@ def deepsearch_agent(self):
166168
"""Get deepsearch agent instance."""
167169
return self.deps.deepsearch_agent
168170

171+
@property
172+
def feedback_server(self):
173+
"""Get feedback server instance."""
174+
return self.deps.feedback_server
175+
169176
def _validate_dependencies(self, *required_deps: str) -> None:
170177
"""
171178
Validate that required dependencies are available.

src/memos/api/handlers/component_init.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from memos.llms.factory import LLMFactory
3030
from memos.log import get_logger
3131
from memos.mem_cube.navie import NaiveMemCube
32+
from memos.mem_feedback.simple_feedback import SimpleMemFeedback
3233
from memos.mem_os.product_server import MOSServer
3334
from memos.mem_reader.factory import MemReaderFactory
3435
from memos.mem_scheduler.orm_modules.base_model import BaseDBManager
@@ -295,6 +296,16 @@ def init_server() -> dict[str, Any]:
295296
)
296297
logger.debug("Searcher created")
297298

299+
# Initialize feedback server
300+
feedback_server = SimpleMemFeedback(
301+
llm=llm,
302+
embedder=embedder,
303+
graph_store=graph_db,
304+
memory_manager=memory_manager,
305+
mem_reader=mem_reader,
306+
searcher=searcher,
307+
)
308+
298309
# Initialize Scheduler
299310
scheduler_config_dict = APIConfig.get_scheduler_config()
300311
scheduler_config = SchedulerConfigFactory(
@@ -308,7 +319,9 @@ def init_server() -> dict[str, Any]:
308319
mem_reader=mem_reader,
309320
redis_client=redis_client,
310321
)
311-
mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube, searcher=searcher)
322+
mem_scheduler.init_mem_cube(
323+
mem_cube=naive_mem_cube, searcher=searcher, feedback_server=feedback_server
324+
)
312325
logger.debug("Scheduler initialized")
313326

314327
# Initialize SchedulerAPIModule
@@ -356,6 +369,7 @@ def init_server() -> dict[str, Any]:
356369
"text_mem": text_mem,
357370
"pref_mem": pref_mem,
358371
"online_bot": online_bot,
372+
"feedback_server": feedback_server,
359373
"redis_client": redis_client,
360374
"deepsearch_agent": deepsearch_agent,
361375
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
Feeback handler for memory add/update functionality.
3+
"""
4+
5+
from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
6+
from memos.api.product_models import APIFeedbackRequest, MemoryResponse
7+
from memos.log import get_logger
8+
from memos.multi_mem_cube.composite_cube import CompositeCubeView
9+
from memos.multi_mem_cube.single_cube import SingleCubeView
10+
from memos.multi_mem_cube.views import MemCubeView
11+
12+
13+
logger = get_logger(__name__)
14+
15+
16+
class FeedbackHandler(BaseHandler):
17+
"""
18+
Handler for memory feedback operations.
19+
20+
Provides fast, fine-grained, and mixture-based feedback modes.
21+
"""
22+
23+
def __init__(self, dependencies: HandlerDependencies):
24+
"""
25+
Initialize feedback handler.
26+
27+
Args:
28+
dependencies: HandlerDependencies instance
29+
"""
30+
super().__init__(dependencies)
31+
self._validate_dependencies("mem_reader", "mem_scheduler", "searcher")
32+
33+
def handle_feedback_memories(self, feedback_req: APIFeedbackRequest) -> MemoryResponse:
34+
"""
35+
Main handler for feedback memories endpoint.
36+
37+
Args:
38+
feedback_req: feedback request containing content and parameters
39+
40+
Returns:
41+
MemoryResponse with formatted results
42+
"""
43+
cube_view = self._build_cube_view(feedback_req)
44+
45+
process_record = cube_view.feedback_memories(feedback_req)
46+
47+
self.logger.info(f"[FeedbackHandler] Final feedback results count={len(process_record)}")
48+
49+
return MemoryResponse(
50+
message="Memory feedback successfully",
51+
data=[process_record],
52+
)
53+
54+
def _resolve_cube_ids(self, feedback_req: APIFeedbackRequest) -> list[str]:
55+
"""
56+
Normalize target cube ids from feedback_req.
57+
"""
58+
if feedback_req.writable_cube_ids:
59+
return list(dict.fromkeys(feedback_req.writable_cube_ids))
60+
61+
return [feedback_req.user_id]
62+
63+
def _build_cube_view(self, feedback_req: APIFeedbackRequest) -> MemCubeView:
64+
cube_ids = self._resolve_cube_ids(feedback_req)
65+
66+
if len(cube_ids) == 1:
67+
cube_id = cube_ids[0]
68+
return SingleCubeView(
69+
cube_id=cube_id,
70+
naive_mem_cube=None,
71+
mem_reader=None,
72+
mem_scheduler=self.mem_scheduler,
73+
logger=self.logger,
74+
searcher=None,
75+
feedback_server=self.feedback_server,
76+
)
77+
else:
78+
single_views = [
79+
SingleCubeView(
80+
cube_id=cube_id,
81+
naive_mem_cube=None,
82+
mem_reader=None,
83+
mem_scheduler=self.mem_scheduler,
84+
logger=self.logger,
85+
searcher=None,
86+
feedback_server=self.feedback_server,
87+
)
88+
for cube_id in cube_ids
89+
]
90+
return CompositeCubeView(
91+
cube_views=single_views,
92+
logger=self.logger,
93+
)

src/memos/api/product_models.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# Import message types from core types module
88
from memos.log import get_logger
9-
from memos.types import PermissionDict, SearchMode
9+
from memos.types import MessageDict, MessageList, MessagesType, PermissionDict, SearchMode
1010

1111

1212
logger = get_logger(__name__)
@@ -628,6 +628,38 @@ def _convert_deprecated_fields(self) -> "APIADDRequest":
628628
return self
629629

630630

631+
class APIFeedbackRequest(BaseRequest):
632+
"""Request model for processing feedback info."""
633+
634+
user_id: str = Field(..., description="User ID")
635+
session_id: str | None = Field(
636+
"default_session", description="Session ID for soft-filtering memories"
637+
)
638+
task_id: str | None = Field(None, description="Task ID for monitering async tasks")
639+
history: list[MessageDict] | None = Field(..., description="Chat history")
640+
retrieved_memory_ids: list[str] | None = Field(
641+
None, description="Retrieved memory ids at last turn"
642+
)
643+
feedback_content: str | None = Field(..., description="Feedback content to process")
644+
feedback_time: str | None = Field(None, description="Feedback time")
645+
# ==== Multi-cube writing ====
646+
writable_cube_ids: list[str] | None = Field(
647+
None, description="List of cube IDs user can write for multi-cube add"
648+
)
649+
async_mode: Literal["sync", "async"] = Field(
650+
"async", description="feedback mode: sync or async"
651+
)
652+
corrected_answer: bool = Field(False, description="Whether need return corrected answer")
653+
# ==== Backward compatibility ====
654+
mem_cube_id: str | None = Field(
655+
None,
656+
description=(
657+
"(Deprecated) Single cube ID to search in. "
658+
"Prefer `readable_cube_ids` for multi-cube search."
659+
),
660+
)
661+
662+
631663
class APIChatCompleteRequest(BaseRequest):
632664
"""Request model for chat operations."""
633665

src/memos/api/routers/server_router.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
from memos.api.handlers.add_handler import AddHandler
2222
from memos.api.handlers.base_handler import HandlerDependencies
2323
from memos.api.handlers.chat_handler import ChatHandler
24+
from memos.api.handlers.feedback_handler import FeedbackHandler
2425
from memos.api.handlers.search_handler import SearchHandler
2526
from memos.api.product_models import (
2627
APIADDRequest,
2728
APIChatCompleteRequest,
29+
APIFeedbackRequest,
2830
APISearchRequest,
2931
ChatRequest,
3032
DeleteMemoryRequest,
@@ -66,7 +68,7 @@
6668
add_handler,
6769
online_bot=components.get("online_bot"),
6870
)
69-
71+
feedback_handler = FeedbackHandler(dependencies)
7072
# Extract commonly used components for function-based handlers
7173
# (These can be accessed from the components dict without unpacking all of them)
7274
mem_scheduler: BaseScheduler = components["mem_scheduler"]
@@ -265,3 +267,18 @@ def delete_memories(memory_req: DeleteMemoryRequest):
265267
return handlers.memory_handler.handle_delete_memories(
266268
delete_mem_req=memory_req, naive_mem_cube=naive_mem_cube
267269
)
270+
271+
272+
# =============================================================================
273+
# Feedback API Endpoints
274+
# =============================================================================
275+
276+
277+
@router.post("/feedback", summary="Feedback memories", response_model=MemoryResponse)
278+
def feedback_memories(feedback_req: APIFeedbackRequest):
279+
"""
280+
Feedback memories for a specific user.
281+
282+
This endpoint uses the class-based FeedbackHandler for better code organization.
283+
"""
284+
return feedback_handler.handle_feedback_memories(feedback_req)

0 commit comments

Comments
 (0)