66"""
77
88from memos .api .handlers .base_handler import BaseHandler , HandlerDependencies
9- from memos .api .product_models import APIADDRequest , MemoryResponse
9+ from memos .api .product_models import APIADDRequest , APIFeedbackRequest , MemoryResponse
1010from memos .memories .textual .item import (
1111 list_all_fields ,
1212)
@@ -30,7 +30,9 @@ def __init__(self, dependencies: HandlerDependencies):
3030 dependencies: HandlerDependencies instance
3131 """
3232 super ().__init__ (dependencies )
33- self ._validate_dependencies ("naive_mem_cube" , "mem_reader" , "mem_scheduler" )
33+ self ._validate_dependencies (
34+ "naive_mem_cube" , "mem_reader" , "mem_scheduler" , "feedback_server"
35+ )
3436
3537 def handle_add_memories (self , add_req : APIADDRequest ) -> MemoryResponse :
3638 """
@@ -56,6 +58,39 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse:
5658
5759 cube_view = self ._build_cube_view (add_req )
5860
61+ if add_req .is_feedback :
62+ chat_history = add_req .chat_history
63+ messages = add_req .messages
64+ if chat_history is None :
65+ chat_history = []
66+ if messages is None :
67+ messages = []
68+ concatenate_chat = chat_history + messages
69+
70+ last_user_index = max (i for i , d in enumerate (concatenate_chat ) if d ["role" ] == "user" )
71+ feedback_content = concatenate_chat [last_user_index ]["content" ]
72+ feedback_history = concatenate_chat [:last_user_index ]
73+
74+ feedback_req = APIFeedbackRequest (
75+ user_id = add_req .user_id ,
76+ session_id = add_req .session_id ,
77+ task_id = add_req .task_id ,
78+ history = feedback_history ,
79+ feedback_content = feedback_content ,
80+ writable_cube_ids = add_req .writable_cube_ids ,
81+ async_mode = add_req .async_mode ,
82+ )
83+ process_record = cube_view .feedback_memories (feedback_req )
84+
85+ self .logger .info (
86+ f"[FeedbackHandler] Final feedback results count={ len (process_record )} "
87+ )
88+
89+ return MemoryResponse (
90+ message = "Memory feedback successfully" ,
91+ data = [process_record ],
92+ )
93+
5994 results = cube_view .add_memories (add_req )
6095
6196 self .logger .info (f"[AddHandler] Final add results count={ len (results )} " )
@@ -88,6 +123,7 @@ def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView:
88123 mem_reader = self .mem_reader ,
89124 mem_scheduler = self .mem_scheduler ,
90125 logger = self .logger ,
126+ feedback_server = self .feedback_server ,
91127 searcher = None ,
92128 )
93129 else :
@@ -98,6 +134,7 @@ def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView:
98134 mem_reader = self .mem_reader ,
99135 mem_scheduler = self .mem_scheduler ,
100136 logger = self .logger ,
137+ feedback_server = self .feedback_server ,
101138 searcher = None ,
102139 )
103140 for cube_id in cube_ids
0 commit comments