diff --git a/src/memos/api/config.py b/src/memos/api/config.py index e7cc5d65..18327ef6 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -100,8 +100,10 @@ def get_reranker_config() -> dict[str, Any]: "backend": "http_bge", "config": { "url": os.getenv("MOS_RERANKER_URL"), - "model": "bge-reranker-v2-m3", + "model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"), "timeout": 10, + "headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"), + "rerank_source": os.getenv("MOS_RERANK_SOURCE"), }, } else: diff --git a/src/memos/reranker/concat.py b/src/memos/reranker/concat.py new file mode 100644 index 00000000..43f184ed --- /dev/null +++ b/src/memos/reranker/concat.py @@ -0,0 +1,107 @@ +import re + +from typing import Any, Literal +from .item import DialogueRankingTracker + +_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*") + + +def concat_single_turn( + graph_results: list, +) -> tuple[DialogueRankingTracker, dict[str, any]]: + """ + Concatenate dialogue pairs into single strings for ranking. + + Args: + graph_results: List of graph results + + Returns: + List of concatenated dialogue pairs + + Example: + >>> sources = ["user: hello", "assistant: hi there", "user: how are you?", "assistant: I'm good"] + >>> concat_single_turn(messages) + ["user: hello\nassistant: hi there", "user: how are you?\nassistant: I'm good"] + """ + + tracker = DialogueRankingTracker() + original_items = {} + + def extract_content(msg: dict[str, Any] | str) -> str: + """Extract content from message, handling both string and dict formats.""" + if isinstance(msg, dict): + return msg.get('content', str(msg)) + return str(msg) + + for item in graph_results: + memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m + sources = getattr(item.metadata, "sources", []) + original_items[item.id] = item + + # Group messages into pairs and concatenate + dialogue_pairs = [] + for i in range(0, len(sources), 2): + user_msg = sources[i] if i < len(sources) else "" + assistant_msg = sources[i + 1] if i + 1 < len(sources) else "" + + user_content = extract_content(user_msg) + assistant_content = extract_content(assistant_msg) + if user_content or assistant_content: # Only add non-empty pairs + pair_index = i // 2 + tracker.add_dialogue_pair(item.id, pair_index, user_msg, assistant_msg, memory) + return tracker, original_items + + +def process_source( + items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None, + concat_strategy: Literal["user", "assistant", "single_turn"] = "user", +) -> str: + """ + Args: + items: List of tuples where each tuple contains (memory, source). + source can be str, Dict, or List. + recent_num: Number of recent items to concatenate. + Returns: + str: Concatenated source. + """ + if items is None: + items = [] + concat_data = [] + memory = None + for item in items: + memory, source = item + for content in source: + if isinstance(content, str): + if "assistant:" in content: + continue + concat_data.append(content) + if memory is not None: + concat_data = [memory, *concat_data] + return "\n".join(concat_data) + + +def concat_original_source( + graph_results: list, + merge_field: list[str] | None = None, + concat_strategy: Literal["user", "assistant", "single_turn"] = "user", +) -> list[str]: + """ + Merge memory items with original dialogue. + Args: + graph_results (list[TextualMemoryItem]): List of memory items with embeddings. + merge_field (List[str]): List of fields to merge. + Returns: + list[str]: List of memory and concat orginal memory. + """ + if merge_field is None: + merge_field = ["sources"] + documents = [] + for item in graph_results: + memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m + sources = [] + for field in merge_field: + source = getattr(item.metadata, field, "") + sources.append((memory, source)) + concat_string = process_source(sources) + documents.append(concat_string) + return documents diff --git a/src/memos/reranker/factory.py b/src/memos/reranker/factory.py index 244b6928..c47ef158 100644 --- a/src/memos/reranker/factory.py +++ b/src/memos/reranker/factory.py @@ -6,6 +6,7 @@ from .cosine_local import CosineLocalReranker from .http_bge import HTTPBGEReranker from .noop import NoopReranker +from .http_bge_strategy import HTTPBGEWithStrategyReranker if TYPE_CHECKING: @@ -29,6 +30,7 @@ def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None: model=c.get("model", "bge-reranker-v2-m3"), timeout=int(c.get("timeout", 10)), headers_extra=c.get("headers_extra"), + rerank_source=c.get("rerank_source"), ) if backend in {"cosine_local", "cosine"}: @@ -40,4 +42,12 @@ def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None: if backend in {"noop", "none", "disabled"}: return NoopReranker() + if backend in {"http_bge_strategy", "bge_strategy"}: + return HTTPBGEWithStrategyReranker( + reranker_url=c.get("url") or c.get("endpoint") or c.get("reranker_url"), + model=c.get("model", "bge-reranker-v2-m3"), + timeout=int(c.get("timeout", 10)), + headers_extra=c.get("headers_extra") + ) + raise ValueError(f"Unknown reranker backend: {cfg.backend}") diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py index a852f325..c54a1ade 100644 --- a/src/memos/reranker/http_bge.py +++ b/src/memos/reranker/http_bge.py @@ -7,7 +7,13 @@ import requests +from memos.log import get_logger + from .base import BaseReranker +from .concat import concat_original_source + + +logger = get_logger(__name__) if TYPE_CHECKING: @@ -28,6 +34,7 @@ def __init__( model: str = "bge-reranker-v2-m3", timeout: int = 10, headers_extra: dict | None = None, + rerank_source: list[str] | None = None, ): if not reranker_url: raise ValueError("reranker_url must not be empty") @@ -36,6 +43,7 @@ def __init__( self.model = model self.timeout = timeout self.headers_extra = headers_extra or {} + self.concat_source = rerank_source def rerank( self, @@ -47,11 +55,18 @@ def rerank( if not graph_results: return [] - documents = [ - (_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m) - for item in graph_results - ] - documents = [d for d in documents if isinstance(d, str) and d] + documents = [] + if self.concat_source: + documents = concat_original_source(graph_results, self.concat_source) + else: + documents = [ + (_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m) + for item in graph_results + ] + documents = [d for d in documents if isinstance(d, str) and d] + + logger.info(f"[HTTPBGERerankerSample] query: {query} , documents: {documents[:5]}...") + if not documents: return [] @@ -95,5 +110,5 @@ def rerank( return [(item, 0.0) for item in graph_results[:top_k]] except Exception as e: - print(f"[HTTPBGEReranker] request failed: {e}") + logger.error(f"[HTTPBGEReranker] request failed: {e}") return [(item, 0.0) for item in graph_results[:top_k]] diff --git a/src/memos/reranker/http_bge_strategy.py b/src/memos/reranker/http_bge_strategy.py new file mode 100644 index 00000000..6d48c261 --- /dev/null +++ b/src/memos/reranker/http_bge_strategy.py @@ -0,0 +1,145 @@ +import re +import json +import requests +from typing import TYPE_CHECKING, Literal + +from memos.log import get_logger + +from .base import BaseReranker +from .item import DialogueRankingTracker +from .concat import concat_original_source, concat_single_turn + +logger = get_logger(__name__) + +from memos.memories.textual.item import TextualMemoryItem + +_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*") + + +class HTTPBGEWithStrategyReranker(BaseReranker): + """ + HTTP-based BGE reranker with enhanced source text processing. + Supports multiple text concatenation strategies including dialogue pairing. + """ + + def __init__( + self, + reranker_url: str, + token: str = "", + model: str = "bge-reranker-v2-m3", + timeout: int = 10, + headers_extra: dict | None = None, + rerank_source: list[str] | None = None, + concat_strategy: Literal["user", "assistant", "single_turn"] = "single_turn", + source_weight: float = 0.3, + ): + if not reranker_url: + raise ValueError("reranker_url must not be empty") + + self.reranker_url = reranker_url + self.token = token or "" + self.model = model + self.timeout = timeout + self.headers_extra = headers_extra or {} + self.concat_strategy = concat_strategy + self.source_weight = source_weight + + def _prepare_documents(self, graph_results: list) -> tuple[DialogueRankingTracker, dict[str, any], list[str]]: + """Prepare documents based on the concatenation strategy. + Args: + graph_results: List of graph results + Returns: + tuple[DialogueRankingTracker, dict[str, any], list[str]]: Tracker, original items, documents + """ + documents = [] + tracker = None + original_items = None + + if self.concat_strategy == "single_turn": + tracker, original_items = concat_single_turn(graph_results) + documents = tracker.get_documents_for_ranking() + + elif self.concat_strategy == "user": + raise NotImplementedError("User strategy is not implemented") + + elif self.concat_strategy == "assistant": + raise NotImplementedError("Assistant strategy is not implemented") + + else: + raise ValueError(f"Unknown concat_strategy: {self.concat_strategy}") + + return tracker, original_items, documents + + def rerank( + self, + query: str, + graph_results: list, + top_k: int, + **kwargs, + ) -> list[tuple[TextualMemoryItem, float]]: + if not graph_results: + return [] + + tracker, original_items, documents = self._prepare_documents(graph_results) + + logger.info( + f"[HTTPBGEWithSourceReranker] strategy: {self.concat_strategy}, " + f"query: {query}, documents count: {len(documents)}" + ) + logger.debug(f"[HTTPBGEWithSourceReranker] sample documents: {documents[:2]}...") + + if not documents: + return [] + + headers = {"Content-Type": "application/json", **self.headers_extra} + if self.token: + headers["Authorization"] = f"Bearer {self.token}" + + payload = {"model": self.model, "query": query, "documents": documents} + + try: + resp = requests.post( + self.reranker_url, headers=headers, json=payload, timeout=self.timeout + ) + resp.raise_for_status() + data = resp.json() + logger.info(f"[HTTPBGEWithStrategyReranker] response: {json.dumps(data, indent=4)}") + # Parse ranking results + ranked_indices = [] + scores = [] + + if "results" in data: + rows = data.get("results", []) + for r in rows: + idx = r.get("index") + if isinstance(idx, int) and 0 <= idx < len(documents): + score = float(r.get("relevance_score", r.get("score", 0.0))) + ranked_indices.append(idx) + scores.append(score) + + elif "data" in data: + rows = data.get("data", []) + score_list = [float(r.get("score", 0.0)) for r in rows] + + # Create ranked indices based on scores + indexed_scores = [(i, score) for i, score in enumerate(score_list)] + indexed_scores.sort(key=lambda x: x[1], reverse=True) + + ranked_indices = [idx for idx, _ in indexed_scores] + scores = [score for _, score in indexed_scores] + + else: + # Fallback: return original items with zero scores + return [(item, 0.0) for item in graph_results[:top_k]] + + # Reconstruct memory items from ranked dialogue pairs + reconstructed_items = tracker.reconstruct_memory_items( + ranked_indices, scores, original_items, top_k + ) + + logger.info(f"[HTTPBGEDialogueReranker] reconstructed {len(reconstructed_items)} memory items") + return reconstructed_items + + except Exception as e: + logger.error(f"[HTTPBGEWithSourceReranker] request failed: {e}") + return [(item, 0.0) for item in graph_results[:top_k]] \ No newline at end of file diff --git a/src/memos/reranker/item.py b/src/memos/reranker/item.py new file mode 100644 index 00000000..467b09db --- /dev/null +++ b/src/memos/reranker/item.py @@ -0,0 +1,147 @@ +from typing import Any +from pydantic import BaseModel + + +class DialoguePair(BaseModel): + """Represents a single dialogue pair extracted from sources.""" + + pair_id: str # Unique identifier for this dialogue pair + memory_id: str # ID of the source TextualMemoryItem + memory: str + pair_index: int # Index of this pair within the source memory's dialogue + user_msg: str | dict[str, Any] # User message content + assistant_msg: str | dict[str, Any] # Assistant message content + combined_text: str # The concatenated text used for ranking + + def extract_content(self, msg: str | dict[str, Any]) -> str: + """Extract content from message, handling both string and dict formats.""" + if isinstance(msg, dict): + return msg.get('content', str(msg)) + return str(msg) + + @property + def user_content(self) -> str: + """Get user message content as string.""" + return self.extract_content(self.user_msg) + + @property + def assistant_content(self) -> str: + """Get assistant message content as string.""" + return self.extract_content(self.assistant_msg) + + +class DialogueRankingTracker: + """Tracks dialogue pairs and their rankings for memory reconstruction.""" + + def __init__(self): + self.dialogue_pairs: list[DialoguePair] = [] + + def add_dialogue_pair( + self, + memory_id: str, + pair_index: int, + user_msg: str | dict[str, Any], + assistant_msg: str | dict[str, Any], + memory: str + ) -> str: + """Add a dialogue pair and return its unique ID.""" + pair_id = f"{memory_id}_{pair_index}" + + # Extract content for ranking + def extract_content(msg: str | dict[str, Any]) -> str: + if isinstance(msg, dict): + return msg.get('content', str(msg)) + return str(msg) + + user_content = extract_content(user_msg) + assistant_content = extract_content(assistant_msg) + combined_text = f"{user_content}\n{assistant_content}" + + dialogue_pair = DialoguePair( + pair_id=pair_id, + memory_id=memory_id, + pair_index=pair_index, + user_msg=user_msg, + assistant_msg=assistant_msg, + combined_text=combined_text, + memory=memory + ) + + self.dialogue_pairs.append(dialogue_pair) + + return pair_id + + def get_documents_for_ranking(self, concat_memory: bool = True) -> list[str]: + """Get the combined text documents for ranking.""" + return [(pair.memory + "\n\n" + pair.combined_text) for pair in self.dialogue_pairs] + + def get_dialogue_pair_by_index(self, index: int) -> DialoguePair | None: + """Get dialogue pair by its index in the ranking results.""" + if 0 <= index < len(self.dialogue_pairs): + return self.dialogue_pairs[index] + return None + + def reconstruct_memory_items( + self, + ranked_indices: list[int], + scores: list[float], + original_memory_items: dict[str, Any], + top_k: int + ) -> list[tuple[Any, float]]: + """ + Reconstruct TextualMemoryItem objects from ranked dialogue pairs. + + Args: + ranked_indices: List of dialogue pair indices sorted by relevance + scores: Corresponding relevance scores + original_memory_items: Dict mapping memory_id to original TextualMemoryItem + top_k: Maximum number of items to return + + Returns: + List of (reconstructed_memory_item, aggregated_score) tuples + """ + from collections import defaultdict + from copy import deepcopy + + # Group ranked pairs by memory_id + memory_groups = defaultdict(list) + memory_scores = defaultdict(list) + + for idx, score in zip(ranked_indices[:top_k * 3], scores[:top_k * 3]): # Take more pairs to ensure we have enough memories + dialogue_pair = self.get_dialogue_pair_by_index(idx) + if dialogue_pair: + memory_groups[dialogue_pair.memory_id].append(dialogue_pair) + memory_scores[dialogue_pair.memory_id].append(score) + + # Reconstruct memory items + reconstructed_items = [] + + for memory_id, pairs in memory_groups.items(): + if memory_id not in original_memory_items: + continue + + # Create a copy of the original memory item + original_item = original_memory_items[memory_id] + reconstructed_item = deepcopy(original_item) + + # Sort pairs by their original index to maintain order + pairs.sort(key=lambda p: p.pair_index) + + # Reconstruct sources from selected dialogue pairs + new_sources = [] + for pair in pairs[:1]: + new_sources.extend([pair.user_msg, pair.assistant_msg]) + + # Update the metadata sources + if hasattr(reconstructed_item.metadata, 'sources'): + reconstructed_item.metadata.sources = new_sources + + # Calculate aggregated score (e.g., max, mean, or weighted average) + pair_scores = memory_scores[memory_id] + aggregated_score = max(pair_scores) if pair_scores else 0.0 + + reconstructed_items.append((reconstructed_item, aggregated_score)) + + # Sort by aggregated score and return top_k + reconstructed_items.sort(key=lambda x: x[1], reverse=True) + return reconstructed_items[:top_k] \ No newline at end of file