diff --git a/dump.rdb b/dump.rdb new file mode 100644 index 000000000..9199ccdf3 Binary files /dev/null and b/dump.rdb differ diff --git a/examples/data/config/mem_scheduler/general_scheduler_config.yaml b/examples/data/config/mem_scheduler/general_scheduler_config.yaml index 2360bb14b..cc3de38a8 100644 --- a/examples/data/config/mem_scheduler/general_scheduler_config.yaml +++ b/examples/data/config/mem_scheduler/general_scheduler_config.yaml @@ -4,7 +4,7 @@ config: act_mem_update_interval: 30 context_window_size: 10 thread_pool_max_workers: 5 - consume_interval_seconds: 1 + consume_interval_seconds: 0.01 working_mem_monitor_capacity: 20 activation_mem_monitor_capacity: 5 enable_parallel_dispatch: true diff --git a/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml b/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml index 2d3958e60..cfb2a050c 100644 --- a/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml +++ b/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml @@ -38,7 +38,7 @@ mem_scheduler: act_mem_update_interval: 30 context_window_size: 10 thread_pool_max_workers: 10 - consume_interval_seconds: 1 + consume_interval_seconds: 0.01 working_mem_monitor_capacity: 20 activation_mem_monitor_capacity: 5 enable_parallel_dispatch: true diff --git a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml index cdfa49a76..bd9910300 100644 --- a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml +++ b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml @@ -38,7 +38,7 @@ mem_scheduler: act_mem_update_interval: 30 context_window_size: 10 thread_pool_max_workers: 10 - consume_interval_seconds: 1 + consume_interval_seconds: 0.01 working_mem_monitor_capacity: 20 activation_mem_monitor_capacity: 5 enable_parallel_dispatch: true diff --git a/examples/mem_scheduler/task_fair_schedule.py b/examples/mem_scheduler/task_fair_schedule.py new file mode 100644 index 000000000..86f996162 --- /dev/null +++ b/examples/mem_scheduler/task_fair_schedule.py @@ -0,0 +1,88 @@ +import sys + +from collections import defaultdict +from pathlib import Path + +from memos.api.routers.server_router import mem_scheduler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent +sys.path.insert(0, str(BASE_DIR)) + + +def make_message(user_id: str, mem_cube_id: str, label: str, idx: int | str) -> ScheduleMessageItem: + return ScheduleMessageItem( + item_id=f"{user_id}:{mem_cube_id}:{label}:{idx}", + user_id=user_id, + mem_cube_id=mem_cube_id, + label=label, + content=f"msg-{idx} for {user_id}/{mem_cube_id}/{label}", + ) + + +def seed_messages_for_test_fairness(queue, combos, per_stream): + # send overwhelm message by one user + (u, c, label) = combos[0] + task_target = 100 + print(f"{u}:{c}:{label} submit {task_target} messages") + for i in range(task_target): + msg = make_message(u, c, label, f"overwhelm_{i}") + queue.submit_messages(msg) + + for u, c, label in combos: + print(f"{u}:{c}:{label} submit {per_stream} messages") + for i in range(per_stream): + msg = make_message(u, c, label, i) + queue.submit_messages(msg) + print("======= seed_messages Done ===========") + + +def count_by_stream(messages): + counts = defaultdict(int) + for m in messages: + key = f"{m.user_id}:{m.mem_cube_id}:{m.label}" + counts[key] += 1 + return counts + + +def run_fair_redis_schedule(batch_size: int = 3): + print("=== Redis Fairness Demo ===") + print(f"use_redis_queue: {mem_scheduler.use_redis_queue}") + mem_scheduler.consume_batch = batch_size + queue = mem_scheduler.memos_message_queue + + # Isolate and clear queue + queue.debug_mode_on(debug_stream_prefix="fair_redis_schedule") + queue.clear() + + # Define multiple streams: (user_id, mem_cube_id, task_label) + combos = [ + ("u1", "u1", "labelX"), + ("u1", "u1", "labelY"), + ("u2", "u2", "labelX"), + ("u2", "u2", "labelY"), + ] + per_stream = 5 + + # Seed messages evenly across streams + seed_messages_for_test_fairness(queue, combos, per_stream) + + # Compute target batch size (fair split across streams) + print(f"Request batch_size={batch_size} for {len(combos)} streams") + + for _ in range(len(combos)): + # Fetch one brokered pack + msgs = queue.get_messages(batch_size=batch_size) + print(f"Fetched {len(msgs)} messages in first pack") + + # Check fairness: counts per stream + counts = count_by_stream(msgs) + for k in sorted(counts): + print(f"{k}: {counts[k]}") + + +if __name__ == "__main__": + # task 1 fair redis schedule + run_fair_redis_schedule() diff --git a/examples/mem_scheduler/task_stop_rerun.py b/examples/mem_scheduler/task_stop_rerun.py new file mode 100644 index 000000000..c421cbeab --- /dev/null +++ b/examples/mem_scheduler/task_stop_rerun.py @@ -0,0 +1,86 @@ +from pathlib import Path +from time import sleep + +# Note: we skip API handler status/wait utilities in this demo +from memos.api.routers.server_router import mem_scheduler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +# Debug: Print scheduler configuration +print("=== Scheduler Configuration Debug ===") +print(f"Scheduler type: {type(mem_scheduler).__name__}") +print(f"Config: {mem_scheduler.config}") +print(f"use_redis_queue: {mem_scheduler.use_redis_queue}") +print(f"Queue type: {type(mem_scheduler.memos_message_queue).__name__}") +print(f"Queue maxsize: {getattr(mem_scheduler.memos_message_queue, 'maxsize', 'N/A')}") +print("=====================================\n") + +queue = mem_scheduler.memos_message_queue +queue.debug_mode_on(debug_stream_prefix="task_stop_rerun") + + +# Define a handler function +def my_test_handler(messages: list[ScheduleMessageItem]): + print(f"My test handler received {len(messages)} messages: {[one.item_id for one in messages]}") + for msg in messages: + # Create a file named by task_id (use item_id as numeric id 0..99) + task_id = str(msg.item_id) + file_path = tmp_dir / f"{task_id}.txt" + try: + print(f"writing {file_path}...") + file_path.write_text(f"Task {task_id} processed.\n") + sleep(5) + except Exception as e: + print(f"Failed to write {file_path}: {e}") + + +def submit_tasks(): + mem_scheduler.memos_message_queue.clear() + + # Create 100 messages (task_id 0..99) + users = ["user_A", "user_B"] + messages_to_send = [ + ScheduleMessageItem( + item_id=str(i), + user_id=users[i % 2], + mem_cube_id="test_mem_cube", + label=TEST_HANDLER_LABEL, + content=f"Create file for task {i}", + ) + for i in range(100) + ] + # Submit messages in batch and print completion + print(f"Submitting {len(messages_to_send)} messages to the scheduler...") + mem_scheduler.memos_message_queue.submit_messages(messages_to_send) + print(f"Task submission done! tasks in queue: {mem_scheduler.get_tasks_status()}") + + +# Register the handler +TEST_HANDLER_LABEL = "test_handler" +mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler}) + + +tmp_dir = Path("./tmp") +tmp_dir.mkdir(exist_ok=True) + +# Test stop-and-restart: if tmp already has >1 files, skip submission and print info +existing_count = len(list(Path("tmp").glob("*.txt"))) if Path("tmp").exists() else 0 +if existing_count > 1: + print(f"Skip submission: found {existing_count} files in tmp (>1), continue processing") +else: + submit_tasks() + +# 6. Wait until tmp has 100 files or timeout +poll_interval = 1 +expected = 100 +tmp_dir = Path("tmp") +while mem_scheduler.get_tasks_status()["remaining"] != 0: + count = len(list(tmp_dir.glob("*.txt"))) if tmp_dir.exists() else 0 + user_status_running = mem_scheduler.get_tasks_status() + print(f"[Monitor] user_status_running: {user_status_running}; Files in tmp: {count}/{expected}") + sleep(poll_interval) +print(f"[Result] Final files in tmp: {len(list(tmp_dir.glob('*.txt')))})") + +# 7. Stop the scheduler +print("Stopping the scheduler...") +mem_scheduler.stop() diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 1bd83eae7..33febf5f0 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -45,7 +45,9 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: Returns: MemoryResponse with added memory information """ - self.logger.info(f"[AddHandler] Add Req is: {add_req}") + self.logger.info( + f"[DIAGNOSTIC] server_router -> add_handler.handle_add_memories called (Modified at 2025-11-29 18:46). Full request: {add_req.model_dump_json(indent=2)}" + ) if add_req.info: exclude_fields = list_all_fields() diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index 9df3310ec..fcfbac989 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -8,7 +8,7 @@ from typing import Any from memos.log import get_logger -from memos.mem_scheduler.base_scheduler import BaseScheduler +from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler from memos.memories.textual.tree_text_memory.retrieve.advanced_searcher import AdvancedSearcher @@ -127,7 +127,7 @@ def mem_reader(self): return self.deps.mem_reader @property - def mem_scheduler(self) -> BaseScheduler: + def mem_scheduler(self) -> OptimizedScheduler: """Get scheduler instance.""" return self.deps.mem_scheduler diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py index 71e384014..609d61124 100644 --- a/src/memos/api/routers/product_router.py +++ b/src/memos/api/routers/product_router.py @@ -188,6 +188,7 @@ def get_all_memories(memory_req: GetMemoryPlaygroundRequest): @router.post("/add", summary="add a new memory", response_model=SimpleResponse) def create_memory(memory_req: MemoryCreateRequest): """Create a new memory for a specific user.""" + logger.info("DIAGNOSTIC: /product/add endpoint called. This confirms the new code is deployed.") # Initialize status_tracker outside try block to avoid NameError in except blocks status_tracker = None diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index edf50feb1..75d0976a1 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -788,6 +788,9 @@ def process_textual_memory(): timestamp=datetime.utcnow(), task_id=task_id, ) + logger.info( + f"[DIAGNOSTIC] core.add: Submitting message to scheduler: {message_item.model_dump_json(indent=2)}" + ) self.mem_scheduler.memos_message_queue.submit_messages( messages=[message_item] ) diff --git a/src/memos/mem_os/utils/default_config.py b/src/memos/mem_os/utils/default_config.py index 967654d84..bf9f847d0 100644 --- a/src/memos/mem_os/utils/default_config.py +++ b/src/memos/mem_os/utils/default_config.py @@ -110,7 +110,7 @@ def get_default_config( "act_mem_update_interval": kwargs.get("scheduler_act_mem_update_interval", 300), "context_window_size": kwargs.get("scheduler_context_window_size", 5), "thread_pool_max_workers": kwargs.get("scheduler_thread_pool_max_workers", 10), - "consume_interval_seconds": kwargs.get("scheduler_consume_interval_seconds", 3), + "consume_interval_seconds": kwargs.get("scheduler_consume_interval_seconds", 0.01), "enable_parallel_dispatch": kwargs.get("scheduler_enable_parallel_dispatch", True), "enable_activation_memory": True, }, diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 6f4bf1b88..44967a999 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -137,7 +137,6 @@ def __init__(self, config: BaseSchedulerConfig): self.dispatcher = SchedulerDispatcher( config=self.config, memos_message_queue=self.memos_message_queue, - use_redis_queue=self.use_redis_queue, max_workers=self.thread_pool_max_workers, enable_parallel_dispatch=self.enable_parallel_dispatch, status_tracker=self.status_tracker, @@ -232,8 +231,8 @@ def initialize_modules( # start queue monitor if enabled and a bot is set later - def debug_mode_on(self): - self.memos_message_queue.debug_mode_on() + def debug_mode_on(self, debug_stream_prefix="debug_mode"): + self.memos_message_queue.debug_mode_on(debug_stream_prefix=debug_stream_prefix) def _cleanup_on_init_failure(self): """Clean up resources if initialization fails.""" @@ -594,6 +593,11 @@ def _submit_web_logs( Args: messages: Single log message or list of log messages """ + messages_list = [messages] if isinstance(messages, ScheduleLogForWebItem) else messages + for message in messages_list: + logger.info( + f"[DIAGNOSTIC] base_scheduler._submit_web_logs called. Message to publish: {message.model_dump_json(indent=2)}" + ) if self.rabbitmq_config is None: return diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py index 89cd9b7ba..62dd0ef69 100644 --- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py @@ -113,9 +113,10 @@ def create_event_log( metadata: list[dict], memory_len: int, memcube_name: str | None = None, + log_content: str | None = None, ) -> ScheduleLogForWebItem: item = self.create_autofilled_log_item( - log_content="", + log_content=log_content or "", label=label, from_memory_type=from_memory_type, to_memory_type=to_memory_type, diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 3e3298b10..6a910e884 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -367,16 +367,19 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: if kb_log_content: event = self.create_event_log( label="knowledgeBaseUpdate", + # 1. 移除 log_content 参数 + # 2. 补充 memory_type from_memory_type=USER_INPUT_TYPE, to_memory_type=LONG_TERM_MEMORY_TYPE, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, mem_cube=self.current_mem_cube, memcube_log_content=kb_log_content, - metadata=None, # Per design doc for KB logs + metadata=None, memory_len=len(kb_log_content), memcube_name=self._map_memcube_name(msg.mem_cube_id), ) + # 3. 后置赋值 log_content event.log_content = ( f"Knowledge Base Memory Update: {len(kb_log_content)} changes." ) @@ -474,6 +477,9 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.error(f"Error: {e}", exc_info=True) def _mem_read_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: + logger.info( + f"[DIAGNOSTIC] general_scheduler._mem_read_message_consumer called. Received messages: {[msg.model_dump_json(indent=2) for msg in messages]}" + ) logger.info(f"Messages {messages} assigned to {MEM_READ_LABEL} handler.") def process_message(message: ScheduleMessageItem): @@ -538,6 +544,9 @@ def _process_memories_with_reader( task_id: str | None = None, info: dict | None = None, ) -> None: + logger.info( + f"[DIAGNOSTIC] general_scheduler._process_memories_with_reader called. mem_ids: {mem_ids}, user_id: {user_id}, mem_cube_id: {mem_cube_id}, task_id: {task_id}" + ) """ Process memories using mem_reader for enhanced memory processing. @@ -635,6 +644,9 @@ def _process_memories_with_reader( } ) if kb_log_content: + logger.info( + f"[DIAGNOSTIC] general_scheduler._process_memories_with_reader: Creating event log for KB update. Label: knowledgeBaseUpdate, user_id: {user_id}, mem_cube_id: {mem_cube_id}, task_id: {task_id}. KB content: {json.dumps(kb_log_content, indent=2)}" + ) event = self.create_event_log( label="knowledgeBaseUpdate", from_memory_type=USER_INPUT_TYPE, diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py index 6cf3a9e58..2278abc2a 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py +++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py @@ -209,10 +209,9 @@ def _split_batches( def recall_for_missing_memories( self, query: str, - memories: list[TextualMemoryItem], + memories: list[str], ) -> tuple[str, bool]: - text_memories = [one.memory for one in memories] if memories else [] - text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(text_memories)]) + text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(memories)]) prompt = self.build_prompt( template_name="enlarge_recall", diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index 0e64ea9a0..6b6cf0e78 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -148,12 +148,13 @@ def mix_search_memories( "chat_history": search_req.chat_history, } - fast_retrieved_memories = self.searcher.retrieve( + raw_retrieved_memories = self.searcher.retrieve( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, - mode=SearchMode.FAST, + mode=SearchMode.FINE, manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, search_filter=search_filter, info=info, ) @@ -166,12 +167,21 @@ def mix_search_memories( ) logger.info(f"Found {len(history_memories)} history memories.") if not history_memories: - memories = self.searcher.post_retrieve( - retrieved_results=fast_retrieved_memories, + # Post retrieve + raw_memories = self.searcher.post_retrieve( + retrieved_results=raw_retrieved_memories, top_k=search_req.top_k, user_name=user_context.mem_cube_id, info=info, ) + + # Enhance with query + enhanced_memories, _ = self.retriever.enhance_memories_with_query( + query_history=[search_req.query], + memories=raw_memories, + ) + formatted_memories = [format_textual_memory_item(item) for item in enhanced_memories] + return formatted_memories else: # if history memories can directly answer sorted_history_memories = self.reranker.rerank( @@ -181,83 +191,26 @@ def mix_search_memories( search_filter=search_filter, ) logger.info(f"Reranked {len(sorted_history_memories)} history memories.") - processed_hist_mem = self.searcher.post_retrieve( - retrieved_results=sorted_history_memories, + merged_memories = self.searcher.post_retrieve( + retrieved_results=raw_retrieved_memories + sorted_history_memories, top_k=search_req.top_k, user_name=user_context.mem_cube_id, info=info, ) - - can_answer = self.retriever.evaluate_memory_answer_ability( - query=search_req.query, memory_texts=[one.memory for one in processed_hist_mem] + memories = merged_memories[: search_req.top_k] + + formatted_memories = [format_textual_memory_item(item) for item in memories] + logger.info("Submitted memory history async task.") + self.submit_memory_history_async_task( + search_req=search_req, + user_context=user_context, + memories_to_store={ + "memories": [one.to_dict() for one in memories], + "formatted_memories": formatted_memories, + }, ) - if can_answer: - logger.info("History memories can answer the query.") - sorted_results = fast_retrieved_memories + sorted_history_memories - combined_results = self.searcher.post_retrieve( - retrieved_results=sorted_results, - top_k=search_req.top_k, - user_name=user_context.mem_cube_id, - info=info, - ) - memories = combined_results[: search_req.top_k] - else: - logger.info("History memories cannot answer the query, enhancing memories.") - sorted_results = fast_retrieved_memories + sorted_history_memories - combined_results = self.searcher.post_retrieve( - retrieved_results=sorted_results, - top_k=search_req.top_k, - user_name=user_context.mem_cube_id, - info=info, - ) - enhanced_memories, _ = self.retriever.enhance_memories_with_query( - query_history=[search_req.query], - memories=combined_results, - ) - - if len(enhanced_memories) < search_req.top_k: - logger.info( - f"Enhanced memories ({len(enhanced_memories)}) are less than top_k ({search_req.top_k}). Recalling for more." - ) - missing_info_hint, trigger = self.retriever.recall_for_missing_memories( - query=search_req.query, - memories=combined_results, - ) - retrieval_size = search_req.top_k - len(enhanced_memories) - if trigger: - logger.info(f"Triggering additional search with hint: {missing_info_hint}") - additional_memories = self.searcher.search( - query=missing_info_hint, - user_name=user_context.mem_cube_id, - top_k=retrieval_size, - mode=SearchMode.FAST, - memory_type="All", - search_filter=search_filter, - info=info, - ) - else: - logger.info("Not triggering additional search, using combined results.") - additional_memories = combined_results[:retrieval_size] - logger.info( - f"Added {len(additional_memories)} more memories. Total enhanced memories: {len(enhanced_memories)}" - ) - enhanced_memories += additional_memories - - memories = enhanced_memories[: search_req.top_k] - - formatted_memories = [format_textual_memory_item(item) for item in memories] - logger.info("Submitted memory history async task.") - self.submit_memory_history_async_task( - search_req=search_req, - user_context=user_context, - memories_to_store={ - "memories": [one.to_dict() for one in memories], - "formatted_memories": formatted_memories, - }, - ) - - return formatted_memories + return formatted_memories def update_search_memories_to_redis( self, diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 91d442720..3e82eeb2a 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -24,7 +24,7 @@ DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache" DEFAULT_THREAD_POOL_MAX_WORKERS = 50 DEFAULT_CONSUME_INTERVAL_SECONDS = 0.01 -DEFAULT_CONSUME_BATCH = 1 +DEFAULT_CONSUME_BATCH = 3 DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300 DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2 DEFAULT_STUCK_THREAD_TOLERANCE = 10 diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 87738671c..65f81d3b6 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -84,6 +84,7 @@ def to_dict(self) -> dict: "content": self.content, "timestamp": self.timestamp.isoformat(), "user_name": self.user_name, + "task_id": self.task_id if self.task_id is not None else "", } @classmethod @@ -97,6 +98,7 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": content=data["content"], timestamp=datetime.fromisoformat(data["timestamp"]), user_name=data.get("user_name"), + task_id=data.get("task_id"), ) diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index c361a77a2..abbc4671b 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -16,6 +16,9 @@ ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem +from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue +from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue +from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker @@ -39,8 +42,7 @@ class SchedulerDispatcher(BaseSchedulerModule): def __init__( self, max_workers: int = 30, - memos_message_queue: Any | None = None, - use_redis_queue: bool | None = None, + memos_message_queue: ScheduleTaskQueue | None = None, enable_parallel_dispatch: bool = True, config=None, status_tracker: TaskStatusTracker | None = None, @@ -53,8 +55,12 @@ def __init__( # Main dispatcher thread pool self.max_workers = max_workers - self.memos_message_queue = memos_message_queue - self.use_redis_queue = use_redis_queue + # Accept either a ScheduleTaskQueue wrapper or a concrete queue instance + self.memos_message_queue = ( + memos_message_queue.memos_message_queue + if hasattr(memos_message_queue, "memos_message_queue") + else memos_message_queue + ) # Get multi-task timeout from config self.multi_task_running_timeout = ( @@ -159,24 +165,29 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): self.metrics.task_completed(user_id=m.user_id, task_type=m.label) # acknowledge redis messages - if self.use_redis_queue and self.memos_message_queue is not None: + if ( + isinstance(self.memos_message_queue, SchedulerRedisQueue) + and self.memos_message_queue is not None + ): for msg in messages: redis_message_id = msg.redis_message_id # Acknowledge message processing self.memos_message_queue.ack_message( user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, + task_label=msg.label, redis_message_id=redis_message_id, ) # Mark task as completed and remove from tracking - with self._task_lock: - if task_item.item_id in self._running_tasks: - task_item.mark_completed(result) - del self._running_tasks[task_item.item_id] - self._completed_tasks.append(task_item) - if len(self._completed_tasks) > self.completed_tasks_max_show_size: - self._completed_tasks.pop(0) + if isinstance(self.memos_message_queue, SchedulerLocalQueue): + with self._task_lock: + if task_item.item_id in self._running_tasks: + task_item.mark_completed(result) + del self._running_tasks[task_item.item_id] + self._completed_tasks.append(task_item) + if len(self._completed_tasks) > self.completed_tasks_max_show_size: + self._completed_tasks.pop(0) logger.info(f"Task completed: {task_item.get_execution_info()}") return result @@ -188,12 +199,13 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): task_id=task_item.item_id, user_id=task_item.user_id, error_message=str(e) ) # Mark task as failed and remove from tracking - with self._task_lock: - if task_item.item_id in self._running_tasks: - task_item.mark_failed(str(e)) - del self._running_tasks[task_item.item_id] - if len(self._completed_tasks) > self.completed_tasks_max_show_size: - self._completed_tasks.pop(0) + if isinstance(self.memos_message_queue, SchedulerLocalQueue): + with self._task_lock: + if task_item.item_id in self._running_tasks: + task_item.mark_failed(str(e)) + del self._running_tasks[task_item.item_id] + if len(self._completed_tasks) > self.completed_tasks_max_show_size: + self._completed_tasks.pop(0) logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}") raise @@ -383,10 +395,6 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]): messages=msgs, ) - # Add to running tasks - with self._task_lock: - self._running_tasks[task_item.item_id] = task_item - # Create wrapped handler for task tracking wrapped_handler = self._create_task_wrapper(handler, task_item) diff --git a/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py b/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py new file mode 100644 index 000000000..d03648bba --- /dev/null +++ b/src/memos/mem_scheduler/task_schedule_modules/orchestrator.py @@ -0,0 +1,47 @@ +""" +Scheduler Orchestrator for Redis-backed task queues. + +This module provides an orchestrator class that works with `SchedulerRedisQueue` to: +- Broker tasks from Redis streams according to per-user priority weights. +- Maintain a cache of fetched messages and assemble balanced batches across + `(user_id, mem_cube_id, task_label)` groups. + +Stream format: +- Keys follow: `{prefix}:{user_id}:{mem_cube_id}:{task_label}` + +Default behavior: +- All users have priority 1, so fetch sizes are equal per user. +""" + +from __future__ import annotations + +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class SchedulerOrchestrator: + def __init__(self, queue): + """ + Args: + queue: An instance of `SchedulerRedisQueue`. + """ + self.queue = queue + # Cache of fetched messages grouped by (user_id, mem_cube_id, task_label) + self._cache = None + + def get_stream_priorities(self) -> None | dict: + return None + + def get_stream_quotas(self, stream_keys, consume_batch_size) -> dict: + stream_priorities = self.get_stream_priorities() + stream_quotas = {} + for stream_key in stream_keys: + if stream_priorities is None: + # Distribute per-stream evenly + stream_quotas[stream_key] = consume_batch_size + else: + # TODO: not implemented yet + stream_quotas[stream_key] = consume_batch_size + return stream_quotas diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index dc2b9af26..1ab5162b5 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -9,11 +9,13 @@ import re import time +from collections import deque from collections.abc import Callable from uuid import uuid4 from memos.log import get_logger from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule @@ -35,7 +37,8 @@ class SchedulerRedisQueue(RedisSchedulerModule): def __init__( self, stream_key_prefix: str = os.getenv( - "MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX", "scheduler:messages:stream" + "MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX", + "scheduler:messages:stream:v2", ), consumer_group: str = "scheduler_group", consumer_name: str | None = "scheduler_consumer", @@ -78,20 +81,62 @@ def __init__( # Task tracking for mem_scheduler_wait compatibility self._unfinished_tasks = 0 + logger.info( + f"[REDIS_QUEUE] Initialized with stream_prefix='{self.stream_key_prefix}', " + f"consumer_group='{self.consumer_group}', consumer_name='{self.consumer_name}'" + ) + # Auto-initialize Redis connection if self.auto_initialize_redis(): self._is_connected = True self.seen_streams = set() - # Task Broker - # Task Orchestrator + self.message_pack_cache = deque() + self.orchestrator = SchedulerOrchestrator(queue=self) - def get_stream_key(self, user_id: str, mem_cube_id: str) -> str: - stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}" + def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str: + stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}:{task_label}" return stream_key + def task_broker( + self, + consume_batch_size: int, + ) -> list[list[ScheduleMessageItem]]: + stream_keys = self.get_stream_keys(stream_key_prefix=self.stream_key_prefix) + if not stream_keys: + return [] + + stream_quotas = self.orchestrator.get_stream_quotas( + stream_keys=stream_keys, consume_batch_size=consume_batch_size + ) + cache: list[ScheduleMessageItem] = [] + for stream_key in stream_keys: + messages = self.get( + stream_key=stream_key, + block=False, + batch_size=stream_quotas[stream_key], + ) + cache.extend(messages) + + # pack messages + packed: list[list[ScheduleMessageItem]] = [] + for i in range(0, len(cache), consume_batch_size): + packed.append(cache[i : i + consume_batch_size]) + # reset cache using deque for efficient consumption + self.message_pack_cache = deque(packed) + # return list for compatibility with type hint + return list(self.message_pack_cache) + + def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: + if not self.message_pack_cache: + self.task_broker(consume_batch_size=batch_size) + if self.message_pack_cache: + return self.message_pack_cache.popleft() + # No messages available + return [] + def _ensure_consumer_group(self, stream_key) -> None: """Ensure the consumer group exists for the stream.""" if not self._redis_conn: @@ -135,7 +180,7 @@ def put( try: stream_key = self.get_stream_key( - user_id=message.user_id, mem_cube_id=message.mem_cube_id + user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.label ) if stream_key not in self.seen_streams: @@ -158,8 +203,12 @@ def put( logger.error(f"Failed to add message to Redis queue: {e}") raise - def ack_message(self, user_id, mem_cube_id, redis_message_id) -> None: - stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id) + def ack_message( + self, user_id: str, mem_cube_id: str, task_label: str, redis_message_id + ) -> None: + stream_key = self.get_stream_key( + user_id=user_id, mem_cube_id=mem_cube_id, task_label=task_label + ) self.redis.xack(stream_key, self.consumer_group, redis_message_id) @@ -195,7 +244,7 @@ def get( self.consumer_group, self.consumer_name, {stream_key: ">"}, - count=batch_size if not batch_size else 1, + count=batch_size if batch_size is not None else None, block=redis_timeout, ) except Exception as read_err: @@ -210,7 +259,7 @@ def get( self.consumer_group, self.consumer_name, {stream_key: ">"}, - count=batch_size if not batch_size else 1, + count=batch_size if batch_size is not None else None, block=redis_timeout, ) else: @@ -358,18 +407,22 @@ def join(self) -> None: which is complex. For now, this is a no-op. """ - def clear(self) -> None: + def clear(self, stream_key=None) -> None: """Clear all messages from the queue.""" if not self._is_connected or not self._redis_conn: return try: - stream_keys = self.get_stream_keys() - - for stream_key in stream_keys: - # Delete the entire stream + if stream_key is not None: self._redis_conn.delete(stream_key) logger.info(f"Cleared Redis stream: {stream_key}") + else: + stream_keys = self.get_stream_keys() + + for stream_key in stream_keys: + # Delete the entire stream + self._redis_conn.delete(stream_key) + logger.info(f"Cleared Redis stream: {stream_key}") except Exception as e: logger.error(f"Failed to clear Redis queue: {e}") diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index 6d824f4b1..b7559eaf4 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -35,8 +35,9 @@ def __init__( def ack_message( self, - user_id, - mem_cube_id, + user_id: str, + mem_cube_id: str, + task_label: str, redis_message_id, ) -> None: if not isinstance(self.memos_message_queue, SchedulerRedisQueue): @@ -46,12 +47,13 @@ def ack_message( self.memos_message_queue.ack_message( user_id=user_id, mem_cube_id=mem_cube_id, + task_label=task_label, redis_message_id=redis_message_id, ) - def debug_mode_on(self): + def debug_mode_on(self, debug_stream_prefix="debug_mode"): self.memos_message_queue.stream_key_prefix = ( - f"debug_mode:{self.memos_message_queue.stream_key_prefix}" + f"{debug_stream_prefix}:{self.memos_message_queue.stream_key_prefix}" ) def get_stream_keys(self) -> list[str]: @@ -97,6 +99,8 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt ) def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: + if isinstance(self.memos_message_queue, SchedulerRedisQueue): + return self.memos_message_queue.get_messages(batch_size=batch_size) stream_keys = self.get_stream_keys() if len(stream_keys) == 0: diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index 2762ddaca..1cc97961d 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -1,4 +1,5 @@ import json +import os import ssl import threading import time @@ -270,15 +271,36 @@ def rabbitmq_publish_message(self, message: dict): """ import pika + exchange_name = self.rabbitmq_exchange_name + routing_key = self.rabbit_queue_name + + if message.get("label") == "knowledgeBaseUpdate": + kb_specific_exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME") + + if kb_specific_exchange_name: + exchange_name = kb_specific_exchange_name + + routing_key = "" # User specified empty routing key for KB updates + + logger.info( + f"[DIAGNOSTIC] Publishing KB Update message. " + f"ENV_EXCHANGE_NAME_USED: {kb_specific_exchange_name is not None}. " + f"Current configured Exchange: {exchange_name}, Routing Key: '{routing_key}'." + ) + logger.info(f" - Message Content: {json.dumps(message, indent=2)}") + with self._rabbitmq_lock: if not self.is_rabbitmq_connected(): logger.error("Cannot publish - no active connection") return False + logger.info( + f"[DIAGNOSTIC] rabbitmq_service.rabbitmq_publish_message: Attempting to publish message. Exchange: {exchange_name}, Routing Key: {routing_key}, Message Content: {json.dumps(message, indent=2)}" + ) try: self.rabbitmq_channel.basic_publish( - exchange=self.rabbitmq_exchange_name, - routing_key=self.rabbit_queue_name, + exchange=exchange_name, + routing_key=routing_key, body=json.dumps(message), properties=pika.BasicProperties( delivery_mode=2, # Persistent diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py index 22cd44b8c..8d64b77cd 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py @@ -47,7 +47,7 @@ def __init__( self.stage_retrieve_top = 3 self.process_llm = process_llm - self.thinking_stages = 0 # TODO: to increase thinking depth when the algorithm is reliable + self.thinking_stages = 3 self.max_retry_times = 2 self.deep_search_top_k_bar = 2 @@ -69,8 +69,7 @@ def stage_retrieve( query: str, previous_retrieval_phrases: list[str], text_memories: str, - context: str | None = None, - ) -> tuple[bool, str, str, list[str]]: + ) -> tuple[bool, str, list[str]]: """Run a retrieval-expansion stage and parse structured LLM output. Returns a tuple of: @@ -91,8 +90,6 @@ def stage_retrieve( "previous_retrieval_phrases": prev_phrases_text, "memories": text_memories, } - if context is not None: - args["context"] = context prompt = self.build_prompt(**args) max_attempts = max(0, self.max_retry_times) + 1 @@ -109,8 +106,6 @@ def stage_retrieve( reason = result.get("reason", "") - context_out = str(result.get("context", "")) - phrases_val = result.get("retrieval_phrases", result.get("retrival_phrases", [])) if isinstance(phrases_val, list): retrieval_phrases = [str(p).strip() for p in phrases_val if str(p).strip()] @@ -119,7 +114,7 @@ def stage_retrieve( else: retrieval_phrases = [] - return can_answer, reason, context_out, retrieval_phrases + return can_answer, reason, retrieval_phrases except Exception as e: if attempt < max_attempts: @@ -132,39 +127,6 @@ def stage_retrieve( ) raise e - def summarize_memories(self, query: str, context: str, text_memories: str, top_k: int): - args = { - "template_name": "memory_summary", - "query": query, - "context": context, - "memories": text_memories, - "top_k": top_k, - } - - prompt = self.build_prompt(**args) - - max_attempts = max(0, self.max_retry_times) + 1 - for attempt in range(1, max_attempts + 1): - try: - llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) - result = parse_structured_output(content=llm_response) - context, mem_list = result["context"], result["memories"] - if not isinstance(mem_list, list): - logger.error(f"The result of summarize_memories is {result}") - return context, mem_list - except Exception as e: - if attempt < max_attempts: - logger.debug( - f"[summarize_memories]🔁 retry {attempt}/{max_attempts} failed: {e!s}" - ) - time.sleep(1) - else: - logger.error( - f"[summarize_memories]❌ all {max_attempts} attempts failed: {e!s}; \nprompt: {prompt}", - exc_info=True, - ) - raise e - def judge_memories(self, query: str, text_memories: str): args = { "template_name": "memory_judgement", @@ -223,22 +185,32 @@ def get_final_memories(self, user_id: str, top_k: int, mem_list: list[str]): result_memories = enhanced_memories[:top_k] return result_memories - def recreate_enhancement( + def memory_recreate_enhancement( self, query: str, + top_k: int, text_memories: list[str], retries: int, ) -> list: attempt = 0 text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(text_memories)]) prompt_name = "memory_recreate_enhancement" - prompt = self.build_prompt(template_name=prompt_name, query=query, memories=text_memories) + prompt = self.build_prompt( + template_name=prompt_name, query=query, top_k=top_k, memories=text_memories + ) llm_response = None while attempt <= max(0, retries) + 1: try: llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) processed_text_memories = parse_structured_output(content=llm_response) + logger.debug( + f"[memory_recreate_enhancement]\n " + f"- original memories: \n" + f"{text_memories}\n" + f"- final memories: \n" + f"{processed_text_memories['answer']}" + ) return processed_text_memories["answer"] except Exception as e: attempt += 1 @@ -278,16 +250,15 @@ def deep_search( user_name=user_name, info=info, ) - if top_k < self.deep_search_top_k_bar or len(memories) == 0: + if len(memories) == 0: logger.warning("Requirements not met; returning memories as-is.") return memories user_id = memories[0].metadata.user_id - context = None mem_list, _ = self.tree_memories_to_text_memories(memories=memories) retrieved_memories = copy.deepcopy(retrieved_memories) - retrieved_memories_from_deep_search = [] + rewritten_flag = False for current_stage_id in range(self.thinking_stages + 1): try: # at last @@ -303,179 +274,31 @@ def deep_search( f"previous retrieval phrases have been tried: {previous_retrieval_phrases}; " f"final can_answer: {can_answer}; reason: {reason}" ) - mem_list = self.recreate_enhancement( - query=query, text_memories=mem_list, retries=self.max_retry_times - ) - enhanced_memories = self.get_final_memories( - user_id=user_id, top_k=top_k, mem_list=mem_list - ) - return enhanced_memories - - can_answer, reason, context, retrieval_phrases = self.stage_retrieve( - stage_id=current_stage_id + 1, - query=query, - previous_retrieval_phrases=previous_retrieval_phrases, - context=context, - text_memories="- " + "\n- ".join(mem_list) + "\n", - ) - if can_answer: - logger.info( - f"Stage {current_stage_id}: determined answer can be provided, creating enhanced memories; reason: {reason}", - ) - - enhanced_memories = self.get_final_memories( - user_id=user_id, top_k=top_k, mem_list=mem_list - ) - return enhanced_memories - else: - previous_retrieval_phrases.extend(retrieval_phrases) - logger.info( - f"Start complementary retrieval for Stage {current_stage_id}; " - f"previous retrieval phrases have been tried: {previous_retrieval_phrases}; " - f"can_answer: {can_answer}; reason: {reason}" - ) - logger.info( - "Stage %d - Found %d new retrieval phrases", - current_stage_id, - len(retrieval_phrases), - ) - # Search for additional memories based on retrieval phrases - additional_retrieved_memories = [] - for phrase in retrieval_phrases: - _retrieved_memories = self.retrieve( - query=phrase, - user_name=user_name, - top_k=self.stage_retrieve_top, - mode=SearchMode.FAST, - memory_type=memory_type, - search_filter=search_filter, - info=info, - ) - logger.info( - "Found %d additional memories for phrase: '%s'", - len(_retrieved_memories), - phrase[:30] + "..." if len(phrase) > 30 else phrase, - ) - additional_retrieved_memories.extend(_retrieved_memories) - retrieved_memories_from_deep_search.extend(additional_retrieved_memories) - merged_memories = self.post_retrieve( - retrieved_results=retrieved_memories + additional_retrieved_memories, - top_k=top_k * 2, - user_name=user_name, - info=info, - ) - - _mem_list, _ = self.tree_memories_to_text_memories(memories=merged_memories) - mem_list = _mem_list - mem_list = list(set(mem_list)) - logger.info( - "After stage %d, total memories in list: %d", - current_stage_id, - len(mem_list), - ) - - # enhance memories - mem_list = self.recreate_enhancement( - query=query, text_memories=mem_list, retries=self.max_retry_times - ) - logger.info("After summarization, memory list contains %d items", len(mem_list)) - - except Exception as e: - logger.error("Error in stage %d: %s", current_stage_id, str(e), exc_info=True) - # Continue to next stage instead of failing completely - continue - logger.error("Deep search failed, returning original memories") - return memories - - def deep_search_backup( - self, - query: str, - top_k: int, - info=None, - memory_type="All", - search_filter: dict | None = None, - user_name: str | None = None, - **kwargs, - ): - previous_retrieval_phrases = [query] - retrieved_memories = self.retrieve( - query=query, - user_name=user_name, - top_k=top_k, - mode=SearchMode.FAST, - memory_type=memory_type, - search_filter=search_filter, - info=info, - ) - memories = self.post_retrieve( - retrieved_results=retrieved_memories, - top_k=top_k, - user_name=user_name, - info=info, - ) - if top_k < self.deep_search_top_k_bar or len(memories) == 0: - logger.warning("Requirements not met; returning memories as-is.") - return memories - - user_id = memories[0].metadata.user_id - context = None - - mem_list, _ = self.tree_memories_to_text_memories(memories=memories) - retrieved_memories = copy.deepcopy(retrieved_memories) - retrieved_memories_from_deep_search = [] - for current_stage_id in range(self.thinking_stages + 1): - try: - # at last - if current_stage_id == self.thinking_stages: - # eval to finish - reason, can_answer = self.judge_memories( - query=query, - text_memories="- " + "\n- ".join(mem_list) + "\n", - ) - - logger.info( - f"Final Stage: Stage {current_stage_id}; " - f"previous retrieval phrases have been tried: {previous_retrieval_phrases}; " - f"final can_answer: {can_answer}; reason: {reason}" - ) - if len(retrieved_memories_from_deep_search) == 0: - memories = self.post_retrieve( - retrieved_results=retrieved_memories, - top_k=top_k, - user_name=user_name, - info=info, - ) - return memories[:top_k] - else: + if rewritten_flag: enhanced_memories = self.get_final_memories( user_id=user_id, top_k=top_k, mem_list=mem_list ) - return enhanced_memories + else: + enhanced_memories = memories + return enhanced_memories[:top_k] - can_answer, reason, context, retrieval_phrases = self.stage_retrieve( + can_answer, reason, retrieval_phrases = self.stage_retrieve( stage_id=current_stage_id + 1, query=query, previous_retrieval_phrases=previous_retrieval_phrases, - context=context, text_memories="- " + "\n- ".join(mem_list) + "\n", ) if can_answer: logger.info( f"Stage {current_stage_id}: determined answer can be provided, creating enhanced memories; reason: {reason}", ) - if len(retrieved_memories_from_deep_search) == 0: - memories = self.post_retrieve( - retrieved_results=retrieved_memories, - top_k=top_k, - user_name=user_name, - info=info, - ) - return memories[:top_k] - else: + if rewritten_flag: enhanced_memories = self.get_final_memories( user_id=user_id, top_k=top_k, mem_list=mem_list ) - return enhanced_memories + else: + enhanced_memories = memories + return enhanced_memories[:top_k] else: previous_retrieval_phrases.extend(retrieval_phrases) logger.info( @@ -506,32 +329,28 @@ def deep_search_backup( phrase[:30] + "..." if len(phrase) > 30 else phrase, ) additional_retrieved_memories.extend(_retrieved_memories) - retrieved_memories_from_deep_search.extend(additional_retrieved_memories) merged_memories = self.post_retrieve( retrieved_results=retrieved_memories + additional_retrieved_memories, top_k=top_k * 2, user_name=user_name, info=info, ) - + rewritten_flag = True _mem_list, _ = self.tree_memories_to_text_memories(memories=merged_memories) mem_list = _mem_list mem_list = list(set(mem_list)) + mem_list = self.memory_recreate_enhancement( + query=query, + top_k=top_k, + text_memories=mem_list, + retries=self.max_retry_times, + ) logger.info( "After stage %d, total memories in list: %d", current_stage_id, len(mem_list), ) - # Summarize memories - context, mem_list = self.summarize_memories( - query=query, - context=context, - text_memories="- " + "\n- ".join(mem_list) + "\n", - top_k=top_k, - ) - logger.info("After summarization, memory list contains %d items", len(mem_list)) - except Exception as e: logger.error("Error in stage %d: %s", current_stage_id, str(e), exc_info=True) # Continue to next stage instead of failing completely diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 9c5be2fae..099b9aa13 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -55,6 +55,10 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: This is basically your current handle_add_memories logic, but scoped to a single cube_id. """ + sync_mode = add_req.async_mode or self._get_sync_mode() + self.logger.info( + f"[DIAGNOSTIC] single_cube.add_memories called for cube_id: {self.cube_id}. sync_mode: {sync_mode}. Request: {add_req.model_dump_json(indent=2)}" + ) user_context = UserContext( user_id=add_req.user_id, mem_cube_id=self.cube_id, @@ -132,6 +136,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: ) self.logger.info(f"Search memories result: {memories_result}") + self.logger.info(f"Search {len(memories_result)} memories.") return memories_result def _get_search_mode(self, mode: str) -> str: @@ -152,7 +157,7 @@ def _search_text( user_context: UserContext, search_mode: str, ) -> list[dict[str, Any]]: - """G + """ Search text memories based on mode. Args: @@ -277,7 +282,7 @@ def _fine_search( ) missing_info_hint, trigger = self.mem_scheduler.retriever.recall_for_missing_memories( query=search_req.query, - memories=raw_memories, + memories=[mem.memory for mem in enhanced_memories], ) retrieval_size = len(raw_memories) - len(enhanced_memories) logger.info(f"Retrieval size: {retrieval_size}") @@ -530,7 +535,7 @@ def _process_pref_mem( return [ { - "memory": memory.memory, + "memory": memory.metadata.preference, "memory_id": memory_id, "memory_type": memory.metadata.preference_type, } diff --git a/src/memos/templates/advanced_search_prompts.py b/src/memos/templates/advanced_search_prompts.py index 13e80a79a..baf2f7536 100644 --- a/src/memos/templates/advanced_search_prompts.py +++ b/src/memos/templates/advanced_search_prompts.py @@ -1,54 +1,4 @@ -MEMORY_SUMMARY_PROMPT = """ -# Memory Summary and Context Assembly - -## Role -You are a precise context assembler. Given a user query and a set of retrieved memories (each indexed), your task is to synthesize a factual, concise, and coherent context using only the information explicitly present in the memories. - -## Instructions - -### Core Principles -- Use ONLY facts from the provided memories. Do not invent, infer, guess, or hallucinate. -- Resolve all pronouns (e.g., "he", "it", "they") and vague terms (e.g., "this", "that", "some people") to explicit entities using memory content. -- Merge overlapping or redundant facts. Preserve temporal, spatial, and relational details. -- Each fact must be atomic, unambiguous, and verifiable. -- Preserve all key details: who, what, when, where, why — if present in memory. -- Created a summarized facts for answering query at the first item, and separate logically coherent separate memories. -- Begin the with a single, aggregated summary that directly answers the query using the most relevant facts. -- The total number of facts in must not exceed {top_k}. -- If additional context is relevant, try to weave it together logically—or chronologically—based on how the pieces connect. -- **Must preserve the full timeline of all memories**: if multiple events or states are mentioned with temporal markers (e.g., dates, sequences, phases), their chronological order must be retained in both and . - -### Processing Logic -- Aggregate logically connected memories (e.g., events involving the same person, cause-effect chains, repeated entities). -- Exclude any memory that does not directly support answering the query. -- Prioritize specificity: e.g., "Travis Tang moved to Singapore in 2021" > "He relocated abroad." - -## Input -- Query: {query} -- Current context: -{context} -- Current Memories: -{memories} - -## Output Format (STRICT TAG-BASED) -Respond ONLY with the following XML-style tags. Do NOT include any other text, explanations, or formatting. - - -A single, compact, fluent paragraph synthesizing the above facts into a coherent narrative directly relevant to the query. Use resolved entities and logical flow. No bullet points. No markdown. No commentary. - - -- Aggregated summary -- Fact 1 -- Fact 2 - - -Answer: -""" - -# Stage 1: determine answerability; if not answerable, produce concrete retrieval phrases for missing info STAGE1_EXPAND_RETRIEVE_PROMPT = """ -# Stage 1 — Answerability and Missing Retrieval Phrases - ## Goal Determine whether the current memories can answer the query using concrete, specific facts. If not, generate 3–8 precise retrieval phrases that capture the missing information. @@ -76,9 +26,6 @@ true or false - -summary of current memories - Brief, one-sentence explanation for why the query is or isn't answerable with current memories. @@ -94,27 +41,24 @@ # Stage 2: if Stage 1 phrases still fail, rewrite the retrieval query and phrases to maximize recall STAGE2_EXPAND_RETRIEVE_PROMPT = """ -# Stage 2 — Rewrite Retrieval Query and Phrases to Improve Recall - ## Goal -If Stage 1's retrieval phrases failed to yield an answer, rewrite the original query and expand the phrase list to maximize recall of relevant memories. Use canonicalization, synonym expansion, and constraint enrichment. +Rewrite the original query and generate an improved list of retrieval phrases to maximize recall of relevant memories. Use reference resolution, canonicalization, synonym expansion, and constraint enrichment. ## Rewrite Strategy -- Canonicalize entities: use full names, official titles, or known aliases. -- Normalize time formats: e.g., "last year" → "2024", "in 2021" → "2021". -- Add discriminative tokens: entity + attribute + time + location where applicable. -- Split complex queries into focused sub-queries targeting distinct facets. -- Never include pronouns, vague terms, or subjective language. +- **Resolve ambiguous references**: Replace pronouns (e.g., “she”, “they”, “it”) and vague terms (e.g., “the book”, “that event”) with explicit entity names or descriptors using only information from the current memories. +- **Canonicalize entities**: Use full names (e.g., “Melanie Smith”), known roles (e.g., “Caroline’s mentor”), or unambiguous identifiers when available. +- **Normalize temporal expressions**: Convert relative time references (e.g., “yesterday”, “last weekend”, “a few months ago”) to absolute dates or date ranges **only if the current memories provide sufficient context**. +- **Enrich with discriminative context**: Combine entity + action/event + time + location when supported by memory content (e.g., “Melanie pottery class July 2023”). +- **Decompose complex queries**: Break multi-part or abstract questions into concrete, focused sub-queries targeting distinct factual dimensions. +- **Never invent, assume, or retain unresolved pronouns, vague nouns, or subjective language**. ## Input - Query: {query} - Previous retrieval phrases: {previous_retrieval_phrases} -- Context: {context} - Current Memories: {memories} - ## Output (STRICT TAG-BASED FORMAT) Respond ONLY with the following structure. Do not add any other text, explanation, or formatting. @@ -122,13 +66,10 @@ true or false -Brief explanation (1–2 sentences) of how this rewrite improves recall over Stage 1 phrases. +Brief explanation (1–2 sentences) of how this rewrite improves recall—e.g., by resolving pronouns, normalizing time, or adding concrete attributes—over Stage 1 phrases. - -summary of current memories - -- new phrase 1 (Rewritten version of the original query. More precise, canonical, and retrieval-optimized.) +- new phrase 1 (Rewritten, canonical, fully grounded in memory content) - new phrase 2 ... @@ -139,22 +80,19 @@ # Stage 3: generate grounded hypotheses to guide retrieval when still not answerable STAGE3_EXPAND_RETRIEVE_PROMPT = """ -# Stage 3 — Hypothesis Generation for Retrieval - ## Goal -When the query remains unanswerable, generate grounded, plausible hypotheses based ONLY on provided context and memories. Each hypothesis must imply a concrete retrieval target and validation criteria. +As the query remains unanswerable, generate grounded, plausible hypotheses based ONLY on the provided memories. Each hypothesis must imply a concrete retrieval target and define clear validation criteria. ## Rules -- Base hypotheses strictly on facts from the memories. No new entities or assumptions. -- Frame each hypothesis as a testable statement: "If [X] is true, then the query is answered." -- For each hypothesis, define 1–3 specific evidence requirements that would confirm it. -- Do NOT guess. Do NOT invent. Only extrapolate from existing facts. +- Base hypotheses strictly on facts from the memories. Do NOT introduce new entities, events, or assumptions. +- Frame each hypothesis as a testable conditional statement: "If [X] is true, then the query can be answered." +- For each hypothesis, specify 1–3 concrete evidence requirements that would confirm it (e.g., a specific date, name, or event description). +- Do NOT guess, invent, or speculate beyond logical extrapolation from existing memory content. ## Input - Query: {query} - Previous retrieval phrases: {previous_retrieval_phrases} -- Context: {context} - Memories: {memories} @@ -164,24 +102,20 @@ true or false - -summary of current memories - -- statement: - retrieval_query: +- statement: + retrieval_query: validation_criteria: - - - - -- statement: + - + - +- statement: retrieval_query: validation_criteria: - - + - - -- hypothesis retrieval query 1 (searchable query derived from the hypothesis) -- hypothesis retrieval query 2: +- +- ... @@ -229,33 +163,36 @@ """ MEMORY_RECREATE_ENHANCEMENT_PROMPT = """ -You are a knowledgeable and precise AI assistant. +You are a precise and detail-oriented AI assistant specialized in temporal memory reconstruction, reference resolution, and relevance-aware memory fusion. # GOAL -Transform raw memories into clean, query-relevant facts — preserving timestamps and resolving ambiguities without inference. - -# RULES & THINKING STEPS -1. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely. -2. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”, “after injury”). -3. Resolve all ambiguities using only memory content: - - Pronouns → full name: “she” → “Melanie” - - Vague nouns → specific detail: “home” → “her childhood home in Guangzhou” - - “the user” → identity from context (e.g., “Melanie” if travel/running memories) -4. Never invent, assume, or extrapolate. -5. Each output line must be a standalone, clear, factual statement. -6. Output format: one line per fact, starting with "- ", no extra text. +Transform the original memories into a clean, unambiguous, and consolidated set of factual statements that: +1. **Resolve all vague or relative references** (e.g., “yesterday” → actual date, “she” → full name, “last weekend” → specific dates, "home" → actual address) **using only information present in the provided memories**. +2. **Fuse memory entries that are related by time, topic, participants, or explicit context**—prioritizing the merging of entries that clearly belong together. +3. **Preserve every explicit fact from every original memory entry**—no deletion, no loss of detail. Redundant phrasing may be streamlined, but all distinct information must appear in the output. +4. **Return at most {top_k} fused and disambiguated memory segments in , ordered by relevance to the user query** (most relevant first). + +# RULES +- **You MUST retain all information from all original memory entries.** Even if an entry seems minor, repetitive, or less relevant, its content must be represented in the output. +- **Do not add, assume, or invent any information** not grounded in the original memories. +- **Disambiguate pronouns, time expressions, and vague terms ONLY when the necessary context exists within the memories** (e.g., if “yesterday” appears in a message dated July 3, resolve it to July 2). +- **If you cannot resolve a vague reference (e.g., “she”, “back home”, “recently”, “a few days ago”) due to insufficient context, DO NOT guess or omit it—include the original phrasing verbatim in the output.** +- **Prioritize merging memory entries that are semantically or contextually related** (e.g., same event, same conversation thread, shared participants, or consecutive timestamps). Grouping should reflect natural coherence, not just proximity. +- **The total number of bullets in must not exceed {top_k}.** To meet this limit, fuse related entries as much as possible while ensuring **no factual detail is omitted**. +- **Never sacrifice factual completeness for brevity or conciseness.** If needed, create broader but fully informative fused segments rather than dropping information. +- **Each bullet in must be a self-contained, fluent sentence or clause** that includes all resolved details from the original entries it represents. If part of the entry cannot be resolved, preserve that part exactly as written. +- **Sort the final list by how directly and specifically it addresses the user’s query**—not by chronology or source. # OUTPUT FORMAT (STRICT) -Return ONLY the following block, with **one enhanced memory per line**. -Each line MUST start with "- " (dash + space). +Return ONLY the following structure: -Wrap the final output inside: -- enhanced memory 1 -- enhanced memory 2 -... +- [Fully resolved, fused memory segment most relevant to the query — containing all facts from the original entries it covers; unresolved parts kept verbatim] +- [Next most relevant resolved and fused segment — again, with no factual loss] +- [...] + ## User Query {query} @@ -265,9 +202,7 @@ Final Output: """ - PROMPT_MAPPING = { - "memory_summary": MEMORY_SUMMARY_PROMPT, "memory_judgement": MEMORY_JUDGMENT_PROMPT, "stage1_expand_retrieve": STAGE1_EXPAND_RETRIEVE_PROMPT, "stage2_expand_retrieve": STAGE2_EXPAND_RETRIEVE_PROMPT, diff --git a/src/memos/templates/mem_scheduler_prompts.py b/src/memos/templates/mem_scheduler_prompts.py index 7f7415e79..acbae2281 100644 --- a/src/memos/templates/mem_scheduler_prompts.py +++ b/src/memos/templates/mem_scheduler_prompts.py @@ -393,6 +393,79 @@ MEMORY_RECREATE_ENHANCEMENT_PROMPT = """ You are a knowledgeable and precise AI assistant. +# GOAL +Transform raw memories into clean, complete, and fully disambiguated statements that preserve original meaning and explicit details. + +# RULES & THINKING STEPS +1. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”). +2. Resolve all ambiguities using only memory content. If disambiguation cannot be performed using only the provided memories, retain the original phrasing exactly as written. Never guess, infer, or fabricate missing information: + - Pronouns → full name (e.g., “she” → “Caroline”) + - Relative time expressions → concrete dates or full context (e.g., “last night” → “on the evening of November 25, 2025”) + - Vague references → specific, grounded details (e.g., “the event” → “the LGBTQ+ art workshop in Malmö”) + - Incomplete descriptions → full version from memory (e.g., “the activity” → “the abstract painting session at the community center”) +3. Merge memories that are largely repetitive in content but contain complementary or distinct details. Combine them into a single, cohesive statement that preserves all unique information from each original memory. Do not merge memories that describe different events, even if they share a theme. +4. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely. + +# OUTPUT FORMAT (STRICT) +Return ONLY the following block, with **one enhanced memory per line**. +Each line MUST start with "- " (dash + space). + +Wrap the final output inside: + +- enhanced memory 1 +- enhanced memory 2 +... + + +## User Query +{query_history} + +## Original Memories +{memories} + +Final Output: +""" + +MEMORY_RECREATE_ENHANCEMENT_PROMPT_BACKUP_1 = """ +You are a knowledgeable and precise AI assistant. + +# GOAL +Transform raw memories into clean, complete, and fully disambiguated statements that preserve original meaning and explicit details. + +# RULES & THINKING STEPS +1. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”). +2. Resolve all ambiguities using only memory content. If disambiguation cannot be performed using only the provided memories, retain the original phrasing exactly as written. Never guess, infer, or fabricate missing information: + - Pronouns → full name (e.g., “she” → “Caroline”) + - Relative time expressions → concrete dates or full context (e.g., “last night” → “on the evening of November 25, 2025”) + - Vague references → specific, grounded details (e.g., “the event” → “the LGBTQ+ art workshop in Malmö”) + - Incomplete descriptions → full version from memory (e.g., “the activity” → “the abstract painting session at the community center”) +3. Merge memories that are largely repetitive in content but contain complementary or distinct details. Combine them into a single, cohesive statement that preserves all unique information from each original memory. Do not merge memories that describe different events, even if they share a theme. +4. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely. + +# OUTPUT FORMAT (STRICT) +Return ONLY the following block, with **one enhanced memory per line**. +Each line MUST start with "- " (dash + space). + +Wrap the final output inside: + +- enhanced memory 1 +- enhanced memory 2 +... + + +## User Query +{query_history} + +## Original Memories +{memories} + +Final Output: +""" + + +MEMORY_RECREATE_ENHANCEMENT_PROMPT_BACKUP_2 = """ +You are a knowledgeable and precise AI assistant. + # GOAL Transform raw memories into clean, query-relevant facts — preserving timestamps and resolving ambiguities without inference. @@ -427,7 +500,6 @@ Final Output: """ -# Rewrite version: return enhanced memories with original IDs MEMORY_REWRITE_ENHANCEMENT_PROMPT = """ You are a knowledgeable and precise AI assistant. @@ -470,10 +542,43 @@ Final Output: """ + # One-sentence prompt for recalling missing information to answer the query (English) ENLARGE_RECALL_PROMPT_ONE_SENTENCE = """ You are a precise AI assistant. Your job is to analyze the user's query and the available memories to identify what specific information is missing to fully answer the query. +# GOAL +Identify the specific missing facts needed to fully answer the user's query and generate a concise hint for recalling them. + +# RULES +- Analyze the user's query to understand what information is being asked. +- Review the available memories to see what information is already present. +- Identify the gap between the user's query and the available memories. +- Generate a single, concise hint that prompts the user to provide the missing information. +- The hint should be a direct question or a statement that clearly indicates what is needed. + +# OUTPUT FORMAT +A JSON object with: + +trigger_retrieval: true if information is missing, false if sufficient. +hint: A clear, specific prompt to retrieve the missing information (or an empty string if trigger_retrieval is false): +{{ + "trigger_recall": , + "hint": a paraphrase to retrieve support memories +}} + +## User Query +{query} + +## Available Memories +{memories_inline} + +Final Output: +""" + +ENLARGE_RECALL_PROMPT_ONE_SENTENCE_BACKUP = """ +You are a precise AI assistant. Your job is to analyze the user's query and the available memories to identify what specific information is missing to fully answer the query. + # GOAL Identify the specific missing facts needed to fully answer the user's query and generate a concise hint for recalling them. @@ -505,7 +610,6 @@ Final Output: """ - PROMPT_MAPPING = { "intent_recognizing": INTENT_RECOGNIZING_PROMPT, "memory_reranking": MEMORY_RERANKING_PROMPT, diff --git a/src/memos/types/general_types.py b/src/memos/types/general_types.py index f796e682a..3706b49da 100644 --- a/src/memos/types/general_types.py +++ b/src/memos/types/general_types.py @@ -36,7 +36,6 @@ "MessagesType", "Permission", "PermissionDict", - "RawMessageList", "SearchMode", "UserContext", "UserID", @@ -50,7 +49,7 @@ # Message structure class MessageDict(TypedDict, total=False): - """Typed dictionary for chat message dictionaries, will (Deprecate), use ChatCompletionMessageParam instead.""" + """Typed dictionary for chat message dictionaries.""" role: MessageRole content: str @@ -102,11 +101,10 @@ class FineStrategy(str, Enum): REWRITE = "rewrite" RECREATE = "recreate" DEEP_SEARCH = "deep_search" - AGENTIC_SEARCH = "agentic_search" # algorithm strategies -DEFAULT_FINE_STRATEGY = FineStrategy.DEEP_SEARCH +DEFAULT_FINE_STRATEGY = FineStrategy.RECREATE FINE_STRATEGY = DEFAULT_FINE_STRATEGY # Read fine strategy from environment variable `FINE_STRATEGY`. diff --git a/tests/mem_scheduler/test_dispatcher.py b/tests/mem_scheduler/test_dispatcher.py index fe889559c..ccc4d77a1 100644 --- a/tests/mem_scheduler/test_dispatcher.py +++ b/tests/mem_scheduler/test_dispatcher.py @@ -157,7 +157,10 @@ def test_dispatch_serial(self): """Test dispatching messages in serial mode.""" # Create a new dispatcher with parallel dispatch disabled serial_dispatcher = SchedulerDispatcher( - max_workers=2, enable_parallel_dispatch=False, metrics=MagicMock() + max_workers=2, + memos_message_queue=self.dispatcher.memos_message_queue, + enable_parallel_dispatch=False, + metrics=MagicMock(), ) # Create fresh mock handlers for this test