Skip to content
Draft
4 changes: 3 additions & 1 deletion src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ def get_reranker_config() -> dict[str, Any]:
"backend": "http_bge",
"config": {
"url": os.getenv("MOS_RERANKER_URL"),
"model": "bge-reranker-v2-m3",
"model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"),
"timeout": 10,
"headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"),
"rerank_source": os.getenv("MOS_RERANK_SOURCE"),
},
}
else:
Expand Down
107 changes: 107 additions & 0 deletions src/memos/reranker/concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import re

from typing import Any, Literal
from .item import DialogueRankingTracker

_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*")


def concat_single_turn(
graph_results: list,
) -> tuple[DialogueRankingTracker, dict[str, any]]:
"""
Concatenate dialogue pairs into single strings for ranking.

Args:
graph_results: List of graph results

Returns:
List of concatenated dialogue pairs

Example:
>>> sources = ["user: hello", "assistant: hi there", "user: how are you?", "assistant: I'm good"]
>>> concat_single_turn(messages)
["user: hello\nassistant: hi there", "user: how are you?\nassistant: I'm good"]
"""

tracker = DialogueRankingTracker()
original_items = {}

def extract_content(msg: dict[str, Any] | str) -> str:
"""Extract content from message, handling both string and dict formats."""
if isinstance(msg, dict):
return msg.get('content', str(msg))
return str(msg)

for item in graph_results:
memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m
sources = getattr(item.metadata, "sources", [])
original_items[item.id] = item

# Group messages into pairs and concatenate
dialogue_pairs = []
for i in range(0, len(sources), 2):
user_msg = sources[i] if i < len(sources) else ""
assistant_msg = sources[i + 1] if i + 1 < len(sources) else ""

user_content = extract_content(user_msg)
assistant_content = extract_content(assistant_msg)
if user_content or assistant_content: # Only add non-empty pairs
pair_index = i // 2
tracker.add_dialogue_pair(item.id, pair_index, user_msg, assistant_msg, memory)
return tracker, original_items


def process_source(
items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None,
concat_strategy: Literal["user", "assistant", "single_turn"] = "user",
) -> str:
"""
Args:
items: List of tuples where each tuple contains (memory, source).
source can be str, Dict, or List.
recent_num: Number of recent items to concatenate.
Returns:
str: Concatenated source.
"""
if items is None:
items = []
concat_data = []
memory = None
for item in items:
memory, source = item
for content in source:
if isinstance(content, str):
if "assistant:" in content:
continue
concat_data.append(content)
if memory is not None:
concat_data = [memory, *concat_data]
return "\n".join(concat_data)


def concat_original_source(
graph_results: list,
merge_field: list[str] | None = None,
concat_strategy: Literal["user", "assistant", "single_turn"] = "user",
) -> list[str]:
"""
Merge memory items with original dialogue.
Args:
graph_results (list[TextualMemoryItem]): List of memory items with embeddings.
merge_field (List[str]): List of fields to merge.
Returns:
list[str]: List of memory and concat orginal memory.
"""
if merge_field is None:
merge_field = ["sources"]
documents = []
for item in graph_results:
memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m
sources = []
for field in merge_field:
source = getattr(item.metadata, field, "")
sources.append((memory, source))
concat_string = process_source(sources)
documents.append(concat_string)
return documents
10 changes: 10 additions & 0 deletions src/memos/reranker/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .cosine_local import CosineLocalReranker
from .http_bge import HTTPBGEReranker
from .noop import NoopReranker
from .http_bge_strategy import HTTPBGEWithStrategyReranker


if TYPE_CHECKING:
Expand All @@ -29,6 +30,7 @@ def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None:
model=c.get("model", "bge-reranker-v2-m3"),
timeout=int(c.get("timeout", 10)),
headers_extra=c.get("headers_extra"),
rerank_source=c.get("rerank_source"),
)

if backend in {"cosine_local", "cosine"}:
Expand All @@ -40,4 +42,12 @@ def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None:
if backend in {"noop", "none", "disabled"}:
return NoopReranker()

if backend in {"http_bge_strategy", "bge_strategy"}:
return HTTPBGEWithStrategyReranker(
reranker_url=c.get("url") or c.get("endpoint") or c.get("reranker_url"),
model=c.get("model", "bge-reranker-v2-m3"),
timeout=int(c.get("timeout", 10)),
headers_extra=c.get("headers_extra")
)

raise ValueError(f"Unknown reranker backend: {cfg.backend}")
27 changes: 21 additions & 6 deletions src/memos/reranker/http_bge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

import requests

from memos.log import get_logger

from .base import BaseReranker
from .concat import concat_original_source


logger = get_logger(__name__)


if TYPE_CHECKING:
Expand All @@ -28,6 +34,7 @@ def __init__(
model: str = "bge-reranker-v2-m3",
timeout: int = 10,
headers_extra: dict | None = None,
rerank_source: list[str] | None = None,
):
if not reranker_url:
raise ValueError("reranker_url must not be empty")
Expand All @@ -36,6 +43,7 @@ def __init__(
self.model = model
self.timeout = timeout
self.headers_extra = headers_extra or {}
self.concat_source = rerank_source

def rerank(
self,
Expand All @@ -47,11 +55,18 @@ def rerank(
if not graph_results:
return []

documents = [
(_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m)
for item in graph_results
]
documents = [d for d in documents if isinstance(d, str) and d]
documents = []
if self.concat_source:
documents = concat_original_source(graph_results, self.concat_source)
else:
documents = [
(_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m)
for item in graph_results
]
documents = [d for d in documents if isinstance(d, str) and d]

logger.info(f"[HTTPBGERerankerSample] query: {query} , documents: {documents[:5]}...")

if not documents:
return []

Expand Down Expand Up @@ -95,5 +110,5 @@ def rerank(
return [(item, 0.0) for item in graph_results[:top_k]]

except Exception as e:
print(f"[HTTPBGEReranker] request failed: {e}")
logger.error(f"[HTTPBGEReranker] request failed: {e}")
return [(item, 0.0) for item in graph_results[:top_k]]
145 changes: 145 additions & 0 deletions src/memos/reranker/http_bge_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import re
import json
import requests
from typing import TYPE_CHECKING, Literal

from memos.log import get_logger

from .base import BaseReranker
from .item import DialogueRankingTracker
from .concat import concat_original_source, concat_single_turn

logger = get_logger(__name__)

from memos.memories.textual.item import TextualMemoryItem

_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*")


class HTTPBGEWithStrategyReranker(BaseReranker):
"""
HTTP-based BGE reranker with enhanced source text processing.
Supports multiple text concatenation strategies including dialogue pairing.
"""

def __init__(
self,
reranker_url: str,
token: str = "",
model: str = "bge-reranker-v2-m3",
timeout: int = 10,
headers_extra: dict | None = None,
rerank_source: list[str] | None = None,
concat_strategy: Literal["user", "assistant", "single_turn"] = "single_turn",
source_weight: float = 0.3,
):
if not reranker_url:
raise ValueError("reranker_url must not be empty")

self.reranker_url = reranker_url
self.token = token or ""
self.model = model
self.timeout = timeout
self.headers_extra = headers_extra or {}
self.concat_strategy = concat_strategy
self.source_weight = source_weight

def _prepare_documents(self, graph_results: list) -> tuple[DialogueRankingTracker, dict[str, any], list[str]]:
"""Prepare documents based on the concatenation strategy.
Args:
graph_results: List of graph results
Returns:
tuple[DialogueRankingTracker, dict[str, any], list[str]]: Tracker, original items, documents
"""
documents = []
tracker = None
original_items = None

if self.concat_strategy == "single_turn":
tracker, original_items = concat_single_turn(graph_results)
documents = tracker.get_documents_for_ranking()

elif self.concat_strategy == "user":
raise NotImplementedError("User strategy is not implemented")

elif self.concat_strategy == "assistant":
raise NotImplementedError("Assistant strategy is not implemented")

else:
raise ValueError(f"Unknown concat_strategy: {self.concat_strategy}")

return tracker, original_items, documents

def rerank(
self,
query: str,
graph_results: list,
top_k: int,
**kwargs,
) -> list[tuple[TextualMemoryItem, float]]:
if not graph_results:
return []

tracker, original_items, documents = self._prepare_documents(graph_results)

logger.info(
f"[HTTPBGEWithSourceReranker] strategy: {self.concat_strategy}, "
f"query: {query}, documents count: {len(documents)}"
)
logger.debug(f"[HTTPBGEWithSourceReranker] sample documents: {documents[:2]}...")

if not documents:
return []

headers = {"Content-Type": "application/json", **self.headers_extra}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"

payload = {"model": self.model, "query": query, "documents": documents}

try:
resp = requests.post(
self.reranker_url, headers=headers, json=payload, timeout=self.timeout
)
resp.raise_for_status()
data = resp.json()
logger.info(f"[HTTPBGEWithStrategyReranker] response: {json.dumps(data, indent=4)}")
# Parse ranking results
ranked_indices = []
scores = []

if "results" in data:
rows = data.get("results", [])
for r in rows:
idx = r.get("index")
if isinstance(idx, int) and 0 <= idx < len(documents):
score = float(r.get("relevance_score", r.get("score", 0.0)))
ranked_indices.append(idx)
scores.append(score)

elif "data" in data:
rows = data.get("data", [])
score_list = [float(r.get("score", 0.0)) for r in rows]

# Create ranked indices based on scores
indexed_scores = [(i, score) for i, score in enumerate(score_list)]
indexed_scores.sort(key=lambda x: x[1], reverse=True)

ranked_indices = [idx for idx, _ in indexed_scores]
scores = [score for _, score in indexed_scores]

else:
# Fallback: return original items with zero scores
return [(item, 0.0) for item in graph_results[:top_k]]

# Reconstruct memory items from ranked dialogue pairs
reconstructed_items = tracker.reconstruct_memory_items(
ranked_indices, scores, original_items, top_k
)

logger.info(f"[HTTPBGEDialogueReranker] reconstructed {len(reconstructed_items)} memory items")
return reconstructed_items

except Exception as e:
logger.error(f"[HTTPBGEWithSourceReranker] request failed: {e}")
return [(item, 0.0) for item in graph_results[:top_k]]
Loading