diff --git a/CHANGELOG.md b/CHANGELOG.md index 09e6c6d..0e81235 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,8 +23,15 @@ elfmem uses [Semantic Versioning](https://semver.org/). ### Added +- **`MemorySystem.learn_document(text, chunk_size, chunker, skip_llm)`:** Ingest a document in one call — chunks text, learns each chunk, auto-consolidates via `dream()` at `inbox_threshold` intervals. Accepts an optional `chunker` callback (e.g. `nltk.sent_tokenize`); default splits at sentence boundaries. Returns `LearnDocumentResult` with chunk and consolidation counts. +- **`LearnDocumentResult` type:** New result type with `chunks_total`, `chunks_created`, `chunks_duplicate`, `consolidations`, `blocks_promoted`. Exported from `elfmem`. +- **BM25 keyword search in retrieval pipeline (stage 2b):** `hybrid_retrieve()` now runs BM25 in parallel with vector search, discovering blocks with strong keyword overlap that embedding similarity misses. Soft dependency on `rank_bm25` — when not installed, the stage is silently skipped (zero regression). Install via `pip install elfmem[bm25]`. +- **`dream(skip_llm, skip_contradictions)` parameters:** `dream()` now forwards `skip_llm` and `skip_contradictions` to `consolidate()`, enabling fast-path consolidation without bypassing policy tracking or threshold persistence. - **`MABenchConfig.context_window_tokens`:** New config field (default 4096) representing the LM Studio model's context window. All answer-context truncation derives from this value; set to 2048 for smaller models. +### Fixed +- **Config wiring: `contradiction_threshold`, `near_dup_exact_threshold`, `near_dup_near_threshold`:** These three `MemoryConfig` fields existed but were not passed from `MemorySystem.consolidate()` to the consolidation operation. Custom config values were silently ignored (defaults matched, so no observable bug at default settings). Now wired through. + ### Added - **LoCoMo benchmark harness:** Complete benchmark suite for evaluating elfmem against LoCoMo (ACL 2024) — 10 conversations, 1,986 QA pairs, 5 categories. Includes metrics (Porter-stemmed F1), typed data loading, BM25 hybrid retrieval, observation transform, and CLI runner with `--test`, `--baselines`, `--resume`, `--top-k`, `--category` flags. Results conform to `benchmark_report_spec.md`. - **`consolidate(skip_llm=True)`:** Bypass all LLM calls during consolidation (embed + promote only). Reduces ingestion from hours to seconds for bulk import and benchmarks. diff --git a/benchmarks/memoryagentbench/adapter.py b/benchmarks/memoryagentbench/adapter.py index 2adccb3..a535c68 100644 --- a/benchmarks/memoryagentbench/adapter.py +++ b/benchmarks/memoryagentbench/adapter.py @@ -1,9 +1,12 @@ -"""elfmem adapter for MemoryAgentBench — chunk ingestion + retrieval. +"""elfmem adapter for MemoryAgentBench — document ingestion + retrieval. Key difference from LoCoMo: MemoryAgentBench's Conflict Resolution competency tests contradiction detection — elfmem's primary moat. For CR, we use FULL consolidation (contradiction detection ON). For other competencies, we use -skip_contradictions for speed. +skip_llm for speed. + +BM25 hybrid retrieval is handled natively by elfmem's retrieval pipeline +(stage 2b in hybrid_retrieve). No adapter-level BM25 or RRF needed. """ from __future__ import annotations @@ -11,17 +14,14 @@ import logging import tempfile import time -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path from typing import Any import nltk -from rank_bm25 import BM25Okapi - -from elfmem import ElfmemConfig, MemorySystem - from benchmarks.memoryagentbench.config import MABenchConfig +from elfmem import ElfmemConfig, MemorySystem nltk.download("punkt_tab", quiet=True) @@ -50,65 +50,6 @@ class ExampleResult: ingestion_seconds: float -class _BM25Index: - """BM25 index over block content for keyword-boosted retrieval. - - Built after consolidation from elfmem's active block content (summaries - when available, raw content otherwise). Stores block IDs so RRF merge - can match by ID rather than approximate content-prefix heuristics. - """ - - def __init__(self) -> None: - self._ids: list[str] = [] - self._contents: list[str] = [] - self._bm25: BM25Okapi | None = None - - def add(self, block_id: str, content: str) -> None: - self._ids.append(block_id) - self._contents.append(content) - - def build(self) -> None: - if self._contents: - tokenized = [c.lower().split() for c in self._contents] - self._bm25 = BM25Okapi(tokenized) - - def search(self, query: str, top_k: int = 10) -> list[tuple[str, str, float]]: - """Return (block_id, content, bm25_score) triples ranked by score.""" - if self._bm25 is None: - return [] - scores = self._bm25.get_scores(query.lower().split()) - ranked = sorted( - zip(self._ids, self._contents, scores), - key=lambda x: x[2], - reverse=True, - ) - return ranked[:top_k] - - -def chunk_text(text: str, chunk_size: int = 1024) -> list[str]: - """Split text into sentence-aligned chunks of ~chunk_size words. - - Uses NLTK sentence tokenization for clean boundaries. - """ - sentences = nltk.sent_tokenize(text) - chunks: list[str] = [] - current: list[str] = [] - current_len = 0 - - for sentence in sentences: - sent_len = len(sentence.split()) - if current_len + sent_len > chunk_size and current: - chunks.append(" ".join(current)) - current = [] - current_len = 0 - current.append(sentence) - current_len += sent_len - - if current: - chunks.append(" ".join(current)) - return chunks - - def _context_budget_words(config: MABenchConfig) -> int: """Answer-context word budget derived from the model's context window. @@ -153,34 +94,6 @@ def build_elfmem_config(config: MABenchConfig) -> ElfmemConfig: }) -def _rrf_merge( - vector_blocks: list, - bm25_results: list[tuple[str, str, float]], - top_k: int, - k: int = 60, -) -> tuple[list, str]: - """Merge vector search and BM25 results via Reciprocal Rank Fusion. - - BM25 results carry block IDs (from _BM25Index.search), so matching is - exact — no content-prefix heuristics, no supplementary raw-chunk fallback. - Both retrieval paths use the same block content (summaries or raw text), - so the merged context is consistent. - """ - block_scores: dict[str, float] = {} - block_map: dict[str, object] = {} - for rank, block in enumerate(vector_blocks): - block_scores[block.id] = 1.0 / (k + rank) - block_map[block.id] = block - - for rank, (block_id, _content, _score) in enumerate(bm25_results): - if block_id in block_map: - block_scores[block_id] += 1.0 / (k + rank) - - ranked = sorted(block_map.values(), key=lambda b: block_scores[b.id], reverse=True) - trimmed = ranked[:top_k] - return trimmed, "\n\n".join(b.content for b in trimmed) - - async def process_example( example: dict[str, Any], competency: str, @@ -199,6 +112,7 @@ async def process_example( # Conflict Resolution needs contradiction detection — elfmem's moat is_conflict_resolution = competency == "Conflict_Resolution" + skip_llm = not is_conflict_resolution elfmem_cfg = build_elfmem_config(config) @@ -207,80 +121,32 @@ async def process_example( system = await MemorySystem.from_config(db_path, config=elfmem_cfg) try: - # Phase 1: Chunk and ingest context + # Phase 1: Ingest document via learn_document() start_ingest = time.monotonic() - chunks = chunk_text(context, config.chunk_size) - total_promoted = 0 - - # CR needs contradiction detection — elfmem's core capability. - # Other competencies use skip_llm for speed (embedding-only retrieval). - skip_llm = not is_conflict_resolution - - await system.begin_session(task_type="ingestion") - for i, chunk in enumerate(chunks): - await system.learn( - content=chunk, - tags=[f"chunk:{i}"], - category="knowledge", - source="memoryagentbench", - ) - # Consolidate periodically - if (i + 1) % config.consolidate_every_n_chunks == 0: - await system.end_session() - await system.begin_session(task_type="consolidation") - result = await system.consolidate(skip_llm=skip_llm) - total_promoted += result.promoted - await system.end_session() - await system.begin_session(task_type="ingestion") - await system.end_session() - - # Final consolidation for remaining chunks - await system.begin_session(task_type="consolidation") - result = await system.consolidate(skip_llm=skip_llm) - total_promoted += result.promoted - await system.end_session() - - # Build BM25 from active block content after consolidation. - # With skip_llm=False (CR), blocks carry LLM-generated summaries — - # the same content used by elfmem's vector retrieval. This ensures - # BM25 and vector search are consistent, enabling exact ID-based RRF - # merging without heuristic content-prefix matching. - # With skip_llm=True (other competencies), content falls back to the - # raw chunk text — same as the previous raw-chunk approach. - await system.begin_session(task_type="retrieval") - index_frame = await system.frame("attention", query=None, top_k=10000) - await system.end_session() - - bm25_index = _BM25Index() - for block in index_frame.blocks: - bm25_index.add(block.id, block.content) - bm25_index.build() - + doc_result = await system.learn_document( + context, + chunk_size=config.chunk_size, + chunker=nltk.sent_tokenize, + source="memoryagentbench", + skip_llm=skip_llm, + ) ingestion_time = time.monotonic() - start_ingest - log.info(f" Ingested {len(chunks)} chunks, {total_promoted} promoted in {ingestion_time:.1f}s") + log.info( + f" Ingested {doc_result.chunks_total} chunks, " + f"{doc_result.blocks_promoted} promoted in {ingestion_time:.1f}s" + ) # Phase 2: Answer each question qa_results: list[QAResult] = [] - for q_idx, (question, answer_list) in enumerate(zip(questions, answers)): + for _q_idx, (question, answer_list) in enumerate(zip(questions, answers, strict=False)): q_start = time.monotonic() - await system.begin_session(task_type="retrieval") - frame_result = await system.frame("attention", query=question, top_k=config.top_k) - await system.end_session() - - # Build context from blocks directly, bypassing the frame's hardcoded - # token_budget (2000 tokens in the attention frame definition). Both - # the BM25 and no-BM25 paths are now bounded only by _context_budget_words, - # which is derived from config.context_window_tokens. - blocks = frame_result.blocks + # elfmem's recall() includes BM25 (stage 2b) + graph expansion + # + contradiction suppression natively. + blocks = await system.recall(query=question, top_k=config.top_k) context_text = "\n\n".join(b.content for b in blocks) - bm25_hits = bm25_index.search(question, top_k=config.top_k) - if bm25_hits: - blocks, context_text = _rrf_merge(blocks, bm25_hits, config.top_k) # Truncate context to fit the model's context window. - # Budget is derived from config.context_window_tokens minus - # system prompt, template, question, and answer overhead. max_context_words = _context_budget_words(config) words = context_text.split() if len(words) > max_context_words: @@ -304,8 +170,8 @@ async def process_example( return ExampleResult( source=source, competency=competency, - chunks_ingested=len(chunks), - blocks_promoted=total_promoted, + chunks_ingested=doc_result.chunks_total, + blocks_promoted=doc_result.blocks_promoted, qa_results=qa_results, ingestion_seconds=ingestion_time, ) diff --git a/benchmarks/memoryagentbench/runner.py b/benchmarks/memoryagentbench/runner.py index 5fdc2c6..c0c6dc5 100644 --- a/benchmarks/memoryagentbench/runner.py +++ b/benchmarks/memoryagentbench/runner.py @@ -11,12 +11,12 @@ import logging import os import time -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from datasets import load_dataset -from benchmarks.memoryagentbench.adapter import ExampleResult, process_example +from benchmarks.memoryagentbench.adapter import process_example from benchmarks.memoryagentbench.config import MABenchConfig from benchmarks.memoryagentbench.metrics import score_question from benchmarks.shared.answerer import generate_answer @@ -34,6 +34,39 @@ ] +# ── Resume helpers ─────────────────────────────────────────────────────────── + + +def _example_id(source: str, index: int) -> str: + """Composite ID for an example: 'source/index'.""" + return f"{source}/{index}" + + +def _get_completed_example_ids(output_path: Path) -> set[str]: + """Load set of example IDs already written to output JSON.""" + if not output_path.exists(): + return set() + data = json.loads(output_path.read_text()) + return {q["example_id"] for q in data.get("questions", []) if "example_id" in q} + + +def _append_example_results( + output_path: Path, new_questions: list[dict], +) -> None: + """Atomically append question results to the output JSON. + + Uses tmpfile + rename for crash safety on POSIX filesystems. + """ + existing = json.loads(output_path.read_text()) if output_path.exists() else {"questions": []} + existing["questions"].extend(new_questions) + tmp = output_path.with_suffix(".tmp") + tmp.write_text(json.dumps(existing, indent=2, default=str)) + tmp.rename(output_path) + + +# ── Scoring ────────────────────────────────────────────────────────────────── + + async def _answer_and_score( qa, config: MABenchConfig, ) -> dict: @@ -56,6 +89,9 @@ async def _answer_and_score( } +# ── Main runner ────────────────────────────────────────────────────────────── + + async def run(args: argparse.Namespace) -> None: """Run the benchmark.""" config = MABenchConfig() @@ -73,10 +109,17 @@ async def run(args: argparse.Namespace) -> None: config.resume = True config.output_dir.mkdir(parents=True, exist_ok=True) - timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") output_path = args.output or config.output_dir / f"{timestamp}_mabench_elfmem.json" output_path = Path(output_path) + # Resume: load previously completed examples + completed_ids: set[str] = set() + if config.resume: + completed_ids = _get_completed_example_ids(output_path) + if completed_ids: + log.info(f"Resuming: {len(completed_ids)} example(s) already completed") + log.info("Loading MemoryAgentBench dataset from HuggingFace...") ds = load_dataset("ai-hyz/MemoryAgentBench") @@ -104,19 +147,33 @@ async def run(args: argparse.Namespace) -> None: for ex_idx, example in enumerate(examples): metadata = example.get("metadata", {}) source = metadata.get("source", "unknown") if isinstance(metadata, dict) else "unknown" + ex_id = _example_id(source, ex_idx) n_questions = len(example["questions"]) + + # Resume: skip completed examples + if ex_id in completed_ids: + log.info(f" [{ex_idx+1}/{len(examples)}] {source}: skipped (resumed)") + continue + log.info(f"\n [{ex_idx+1}/{len(examples)}] {source}: {n_questions} questions") try: ex_result = await process_example(example, competency, config) + example_questions: list[dict] = [] for qa in ex_result.qa_results: scored = await _answer_and_score(qa, config) scored["competency"] = competency scored["source"] = source - all_results.append(scored) + scored["example_id"] = ex_id + example_questions.append(scored) comp_f1s.append(scored["f1"]) + all_results.extend(example_questions) + + # Atomic write after each example — crash-safe resume + _append_example_results(output_path, example_questions) + avg_f1 = sum(comp_f1s[-n_questions:]) / n_questions if n_questions else 0 log.info(f" {n_questions} Qs answered | avg F1={avg_f1:.3f}") @@ -125,7 +182,7 @@ async def run(args: argparse.Namespace) -> None: competency_scores[competency] = comp_f1s - # Build report + # Build final report (overwrites with complete metadata) duration = time.monotonic() - start_time scores_by_comp: dict[str, dict] = {} all_f1s: list[float] = [] @@ -142,11 +199,18 @@ async def run(args: argparse.Namespace) -> None: except ImportError: elfmem_version = "unknown" + # Merge any previously completed results with this run's results + if output_path.exists(): + existing = json.loads(output_path.read_text()) + all_questions = existing.get("questions", []) + else: + all_questions = all_results + report = { "meta": { "benchmark": "memoryagentbench", "version": "1.0", - "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), + "timestamp": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"), "duration_seconds": round(duration, 1), "elfmem_version": elfmem_version, "models": { @@ -170,7 +234,7 @@ async def run(args: argparse.Namespace) -> None: "overall": round(overall * 100, 1), "by_competency": scores_by_comp, }, - "questions": all_results, + "questions": all_questions, } output_path.write_text(json.dumps(report, indent=2, default=str)) diff --git a/pyproject.toml b/pyproject.toml index cb83cc8..bf1682d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ Changelog = "https://github.com/emson/elfmem/blob/main/CHANGELOG.md" Security = "https://github.com/emson/elfmem/blob/main/SECURITY.md" [project.optional-dependencies] +bm25 = ["rank_bm25>=0.2.2"] mcp = ["fastmcp>=2.0"] cli = ["typer>=0.12"] tools = ["fastmcp>=2.0", "typer>=0.12"] diff --git a/src/elfmem/__init__.py b/src/elfmem/__init__.py index 14a95a6..c6b483c 100644 --- a/src/elfmem/__init__.py +++ b/src/elfmem/__init__.py @@ -41,6 +41,7 @@ DisconnectResult, DisplacedEdge, FrameResult, + LearnDocumentResult, LearnResult, OperationRecord, OutcomeResult, @@ -58,6 +59,7 @@ "ConsolidationPolicy", # Result types "LearnResult", + "LearnDocumentResult", "ConsolidateResult", "FrameResult", "CurateResult", diff --git a/src/elfmem/api.py b/src/elfmem/api.py index 45ab281..0f989bd 100644 --- a/src/elfmem/api.py +++ b/src/elfmem/api.py @@ -4,9 +4,10 @@ import json import os +import re import time from collections import deque -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager, suppress from datetime import UTC, datetime from typing import Any, Literal @@ -56,6 +57,7 @@ CurateResult, DisconnectResult, FrameResult, + LearnDocumentResult, LearnResult, OperationRecord, OutcomeResult, @@ -65,6 +67,33 @@ TokenUsage, ) +# ── Document chunking helpers ──────────────────────────────────────────────── + +_SENTENCE_SPLIT_RE = re.compile(r"(?<=[.!?])\s+") + + +def _default_chunker(text: str) -> list[str]: + """Split text into sentences at [.!?] followed by whitespace.""" + return [s.strip() for s in _SENTENCE_SPLIT_RE.split(text) if s.strip()] + + +def _assemble_chunks(sentences: list[str], chunk_size: int) -> list[str]: + """Combine sentences into chunks of ~chunk_size words.""" + chunks: list[str] = [] + current: list[str] = [] + current_len = 0 + for sentence in sentences: + sent_len = len(sentence.split()) + if current_len + sent_len > chunk_size and current: + chunks.append(" ".join(current)) + current = [] + current_len = 0 + current.append(sentence) + current_len += sent_len + if current: + chunks.append(" ".join(current)) + return chunks + class MemorySystem: """Adaptive memory system for LLM agents. @@ -696,6 +725,9 @@ async def consolidate( embedding_svc=self._embedding, current_active_hours=current_hours, self_alignment_threshold=mem.self_alignment_threshold, + contradiction_threshold=mem.contradiction_threshold, + near_dup_exact_threshold=mem.near_dup_exact_threshold, + near_dup_near_threshold=mem.near_dup_near_threshold, edge_score_threshold=mem.edge_score_threshold, edge_degree_cap=mem.edge_degree_cap, contradiction_similarity_prefilter=mem.contradiction_similarity_prefilter, @@ -774,7 +806,93 @@ async def remember( await self.begin_session() return await self.learn(content, tags=tags, category=category, source=source) - async def dream(self) -> ConsolidateResult | None: + async def learn_document( + self, + text: str, + *, + chunk_size: int = 256, + chunker: Callable[[str], list[str]] | None = None, + tags: list[str] | None = None, + category: str = "knowledge", + source: str = "document", + skip_llm: bool = False, + ) -> LearnDocumentResult: + """Ingest a document: chunk, learn each chunk, auto-consolidate. + + USE WHEN: Ingesting a document, article, or long-form text. Handles + chunking, learning, and consolidation in one call. + + DON'T USE WHEN: Ingesting a single fact or short observation — use + learn() instead. + + COST: O(chunks) learn() calls + dream() at inbox_threshold intervals. + With skip_llm=True, dream() runs the fast embedding-only path. + + RETURNS: LearnDocumentResult with chunk and consolidation counts. + + NEXT: recall() or frame() to query the ingested knowledge. + + Args: + text: The full document text to ingest. + chunk_size: Target words per chunk (default 256). + chunker: Sentence splitter function. Receives the full text, + returns a list of sentences. Default splits at [.!?] + followed by whitespace. Pass ``nltk.sent_tokenize`` for + better accuracy with abbreviations and edge cases. + tags: Tags applied to every chunk (in addition to ``chunk:N``). + category: Block category (default "knowledge"). + source: Block source identifier (default "document"). + skip_llm: Forward to dream() — bypass LLM calls during + consolidation (embed + promote only). + + Example:: + + result = await system.learn_document(article_text, chunk_size=200) + print(result) # Ingested 12 chunks: 12 created, 2 consolidations, 10 promoted. + """ + split_fn = chunker or _default_chunker + sentences = split_fn(text) + chunks = _assemble_chunks(sentences, chunk_size) + + created = 0 + duplicates = 0 + consolidations = 0 + promoted = 0 + + for i, chunk in enumerate(chunks): + result = await self.learn( + content=chunk, + tags=[f"chunk:{i}", *(tags or [])], + category=category, + source=source, + ) + if result.status == "created": + created += 1 + elif result.status == "duplicate_rejected": + duplicates += 1 + + if self.should_dream: + dr = await self.dream(skip_llm=skip_llm) + if dr is not None: + consolidations += 1 + promoted += dr.promoted + + doc_result = LearnDocumentResult( + chunks_total=len(chunks), + chunks_created=created, + chunks_duplicate=duplicates, + consolidations=consolidations, + blocks_promoted=promoted, + ) + self._record_op("learn_document", doc_result.summary) + return doc_result + + async def dream( + self, + *, + skip_llm: bool = False, + skip_contradictions: bool = False, + ) -> ConsolidateResult | None: """Consolidate pending blocks at a natural pause point. The breathing rhythm: learn fast (remember), process deliberately (dream). @@ -785,8 +903,8 @@ async def dream(self) -> ConsolidateResult | None: DON'T USE WHEN: In a tight loop. One call processes all pending blocks. - COST: LLM call per pending block. Returns None immediately (zero cost) - if inbox is empty. + COST: LLM call per pending block (unless skip_llm=True). Returns None + immediately (zero cost) if inbox is empty. RETURNS: ConsolidateResult if blocks were processed; None if inbox was empty. None is not an error — it means "nothing needed doing." @@ -798,6 +916,12 @@ async def dream(self) -> ConsolidateResult | None: (e.g., another process), dream() may return None despite inbox having items. Call status() for DB-accurate inbox_count. + Args: + skip_llm: Bypass all LLM calls (embed + promote only). Fastest path + for bulk ingestion where alignment scoring isn't needed. + skip_contradictions: Keep LLM summaries but skip O(n²) contradiction + detection. Good balance of quality and speed. + Example:: # Always-on agent loop @@ -809,7 +933,9 @@ async def dream(self) -> ConsolidateResult | None: """ if self._pending == 0: return None - result = await self.consolidate() + result = await self.consolidate( + skip_llm=skip_llm, skip_contradictions=skip_contradictions, + ) # Feed policy so it can adapt the threshold for the next cycle, # then persist the new threshold so it survives process restarts. # Persistence is best-effort: a DB failure here is non-fatal — diff --git a/src/elfmem/guide.py b/src/elfmem/guide.py index fa05ce5..703812e 100644 --- a/src/elfmem/guide.py +++ b/src/elfmem/guide.py @@ -146,6 +146,34 @@ def __str__(self) -> str: "print(result) # Stored block a1b2c3d4. Status: created." ), ), + "learn_document": AgentGuide( + name="learn_document", + what="Ingest a document: chunk, learn each chunk, auto-consolidate.", + when=( + "The agent needs to ingest a document, article, or long-form text. " + "Handles chunking, learning, and consolidation in one call." + ), + when_not=( + "Single facts or short observations — use learn() instead. " + "Already-chunked data — use learn() in a loop." + ), + cost=( + "O(chunks) learn() calls + dream() at inbox_threshold intervals. " + "With skip_llm=True, dream() uses the fast embedding-only path." + ), + returns=( + "LearnDocumentResult with chunks_total, chunks_created, " + "chunks_duplicate, consolidations, blocks_promoted." + ), + next=( + "recall() or frame() to query the ingested knowledge. " + "Consolidation happened automatically during ingestion." + ), + example=( + "result = await system.learn_document(article_text, chunk_size=200)\n" + "print(result) # Ingested 12 chunks: 12 created, 2 consolidations." + ), + ), "consolidate": AgentGuide( name="consolidate", what="Process inbox blocks: score, embed, deduplicate, and promote to active memory.", diff --git a/src/elfmem/memory/retrieval.py b/src/elfmem/memory/retrieval.py index 04e1721..f266e81 100644 --- a/src/elfmem/memory/retrieval.py +++ b/src/elfmem/memory/retrieval.py @@ -1,7 +1,13 @@ -"""4-stage hybrid retrieval pipeline — pure (no side effects).""" +"""6-stage hybrid retrieval pipeline — pure (no side effects). + +Stages: pre-filter → vector → BM25 → graph expand → composite score → MMR. +BM25 (stage 2b) requires the optional ``rank_bm25`` package. When not installed, +the stage is silently skipped and retrieval works as a 5-stage vector-only pipeline. +""" from __future__ import annotations +import logging import math from typing import Any @@ -19,6 +25,16 @@ ) from elfmem.types import ScoredBlock +# Soft dependency — retrieval works without it. +try: + from rank_bm25 import BM25Okapi + + _HAS_BM25 = True +except ImportError: # pragma: no cover + _HAS_BM25 = False + +log = logging.getLogger(__name__) + N_SEEDS_MULTIPLIER = 4 CONTRADICTION_OVERSAMPLE = 2 DEFAULT_SEARCH_WINDOW_HOURS = 200.0 @@ -37,13 +53,14 @@ async def hybrid_retrieve( tag_filter: str | None = None, search_window_hours: float = DEFAULT_SEARCH_WINDOW_HOURS, ) -> list[ScoredBlock]: - """Execute the 5-stage hybrid retrieval pipeline. + """Execute the 6-stage hybrid retrieval pipeline. - Stage 1 — Pre-filter: active blocks within search window. - Stage 2 — Vector search: cosine similarity → top N_seeds. (Skipped if no query.) - Stage 3 — Graph expand: 1-hop neighbours of seeds. (Skipped if no query.) - Stage 4 — Composite score: rank all candidates. - Stage 5 — MMR diversity: reorder for relevance + diversity. (Query-aware only.) + Stage 1 — Pre-filter: active blocks within search window. + Stage 2 — Vector search: cosine similarity → top N_seeds. (Skipped if no query.) + Stage 2b — BM25 keyword search: term overlap → top N_bm25. (Requires rank_bm25.) + Stage 3 — Graph expand: 1-hop neighbours of seeds. (Skipped if no query.) + Stage 4 — Composite score: rank all candidates. + Stage 5 — MMR diversity: reorder for relevance + diversity. (Query-aware only.) Returns top_k * CONTRADICTION_OVERSAMPLE ScoredBlocks for contradiction headroom. """ @@ -64,11 +81,23 @@ async def hybrid_retrieve( candidates, n_seeds=top_k * N_SEEDS_MULTIPLIER, ) + + # Stage 2b: BM25 keyword candidates (additive — discovers blocks + # that vector search missed due to vocabulary mismatch). + seed_ids_set = {b["id"] for b, _ in seed_pairs} + bm25_pairs = _stage_2b_bm25_search( + candidates, query, n_seeds=top_k * N_SEEDS_MULTIPLIER, + ) + for block, _bm25_score in bm25_pairs: + if block["id"] not in seed_ids_set: + seed_pairs.append((block, 0.0)) + seed_ids_set.add(block["id"]) + seed_ids = [b["id"] for b, _ in seed_pairs] expanded = await _stage_3_graph_expand( conn, seed_ids=seed_ids, - existing_candidate_ids={b["id"] for b, _ in seed_pairs}, + existing_candidate_ids=seed_ids_set, ) # candidates for scoring: (block, similarity, was_expanded) @@ -154,6 +183,27 @@ async def _stage_2_vector_search( return scored[:n_seeds] +def _stage_2b_bm25_search( + candidates: list[dict[str, Any]], + query: str, + n_seeds: int, +) -> list[tuple[dict[str, Any], float]]: + """Stage 2b: BM25 keyword search over pre-filtered candidates. + + Discovers blocks with strong term overlap that vector search may miss + (vocabulary mismatch, exact entity names, etc.). Requires the optional + ``rank_bm25`` package — returns ``[]`` when not installed. + """ + if not _HAS_BM25 or not candidates: + return [] + contents = [b.get("summary") or b.get("content", "") for b in candidates] + tokenized = [c.lower().split() for c in contents] + bm25 = BM25Okapi(tokenized) + scores = bm25.get_scores(query.lower().split()) + ranked = sorted(zip(candidates, scores, strict=False), key=lambda x: x[1], reverse=True) + return ranked[:n_seeds] + + async def _stage_3_graph_expand( conn: AsyncConnection, seed_ids: list[str], diff --git a/src/elfmem/types.py b/src/elfmem/types.py index a22d453..8a45c3f 100644 --- a/src/elfmem/types.py +++ b/src/elfmem/types.py @@ -130,6 +130,48 @@ def to_dict(self) -> dict[str, Any]: return {"block_id": self.block_id, "status": self.status} +@dataclass +class LearnDocumentResult: + """Result of ingesting a document via learn_document(). + + USE WHEN: Inspecting how a document was ingested. + DON'T USE WHEN: You need per-block detail — check individual LearnResults. + COST: Zero — this is a pure summary. + RETURNS: Chunk and consolidation counts. + NEXT: recall() or frame() to query the ingested knowledge. + """ + + chunks_total: int + chunks_created: int + chunks_duplicate: int + consolidations: int + blocks_promoted: int + + @property + def summary(self) -> str: + if self.chunks_total == 0: + return "No chunks to ingest (empty document)." + parts = [f"{self.chunks_created} created"] + if self.chunks_duplicate: + parts.append(f"{self.chunks_duplicate} duplicate") + if self.consolidations: + parts.append(f"{self.consolidations} consolidations") + parts.append(f"{self.blocks_promoted} promoted") + return f"Ingested {self.chunks_total} chunks: {', '.join(parts)}." + + def __str__(self) -> str: + return self.summary + + def to_dict(self) -> dict[str, int]: + return { + "chunks_total": self.chunks_total, + "chunks_created": self.chunks_created, + "chunks_duplicate": self.chunks_duplicate, + "consolidations": self.consolidations, + "blocks_promoted": self.blocks_promoted, + } + + @dataclass class ConsolidateResult: processed: int diff --git a/tests/benchmarks/test_mabench_adapter.py b/tests/benchmarks/test_mabench_adapter.py index e1bf922..e115789 100644 --- a/tests/benchmarks/test_mabench_adapter.py +++ b/tests/benchmarks/test_mabench_adapter.py @@ -1,86 +1,38 @@ -"""Tests for MemoryAgentBench adapter utilities — BM25 index and RRF merge.""" +"""Tests for MemoryAgentBench adapter utilities.""" import pytest -pytest.importorskip("rank_bm25", reason="rank-bm25 not installed") +pytest.importorskip("nltk", reason="nltk not installed") -from benchmarks.memoryagentbench.adapter import _BM25Index, _rrf_merge +from benchmarks.memoryagentbench.adapter import _context_budget_words, build_elfmem_config +from benchmarks.memoryagentbench.config import MABenchConfig -class _FakeBlock: - """Minimal stand-in for ScoredBlock in RRF merge tests.""" +class TestContextBudgetWords: + def test_default_config_budget_positive(self) -> None: + config = MABenchConfig() + budget = _context_budget_words(config) + assert budget > 100 - def __init__(self, block_id: str, content: str) -> None: - self.id = block_id - self.content = content + def test_larger_context_window_increases_budget(self) -> None: + small = MABenchConfig(context_window_tokens=2048) + large = MABenchConfig(context_window_tokens=8192) + assert _context_budget_words(large) > _context_budget_words(small) + def test_minimum_budget_is_100(self) -> None: + tiny = MABenchConfig(context_window_tokens=100) + assert _context_budget_words(tiny) == 100 -class TestBM25Index: - def test_search_returns_id_content_score_triples(self) -> None: - idx = _BM25Index() - idx.add("id1", "python async programming") - idx.add("id2", "java synchronous code") - idx.add("id3", "ruby sequential scripts") # 3 docs: IDF for 'async' is positive - idx.build() - results = idx.search("async", top_k=3) - block_id, content, score = results[0] - assert block_id == "id1" - assert content == "python async programming" - assert score > 0 - def test_search_ranks_by_relevance(self) -> None: - idx = _BM25Index() - idx.add("a", "python async") - idx.add("b", "java synchronous") - idx.build() - results = idx.search("python", top_k=2) - assert results[0][0] == "a" +class TestBuildElfmemConfig: + def test_returns_valid_config(self) -> None: + config = MABenchConfig() + elfmem_cfg = build_elfmem_config(config) + assert elfmem_cfg.llm.model == config.elfmem_llm_model + assert elfmem_cfg.embeddings.model == config.elfmem_embedding_model + assert elfmem_cfg.memory.top_k == config.top_k - def test_search_empty_index_returns_empty(self) -> None: - idx = _BM25Index() - idx.build() - assert idx.search("anything") == [] - - def test_search_respects_top_k(self) -> None: - idx = _BM25Index() - for i in range(10): - idx.add(f"id{i}", f"content about topic {i}") - idx.build() - assert len(idx.search("content topic", top_k=3)) == 3 - - -class TestRRFMerge: - def test_matched_bm25_block_scores_higher(self) -> None: - """A block present in both vector and BM25 results scores higher than one in vector only.""" - b1 = _FakeBlock("b1", "async python programming") - b2 = _FakeBlock("b2", "java synchronous code") - vector_blocks = [b1, b2] # b1 ranked first by vector - - bm25_results = [("b1", "async python programming", 5.0)] # b1 also in BM25 - - _, context = _rrf_merge(vector_blocks, bm25_results, top_k=2) - # b1 should appear first — boosted by both retrieval paths - assert context.startswith("async python programming") - - def test_unmatched_bm25_id_ignored(self) -> None: - """BM25 result for unknown block ID has no effect — no supplementary raw chunks.""" - b1 = _FakeBlock("b1", "block one") - vector_blocks = [b1] - bm25_results = [("unknown_id", "some raw text", 3.0)] - - trimmed, context = _rrf_merge(vector_blocks, bm25_results, top_k=1) - assert len(trimmed) == 1 - assert "some raw text" not in context - - def test_context_is_newline_joined_block_content(self) -> None: - b1 = _FakeBlock("b1", "first block") - b2 = _FakeBlock("b2", "second block") - _, context = _rrf_merge([b1, b2], [], top_k=2) - assert "first block" in context - assert "second block" in context - assert "\n\n" in context - - def test_top_k_limits_output(self) -> None: - blocks = [_FakeBlock(f"b{i}", f"block {i}") for i in range(5)] - trimmed, _ = _rrf_merge(blocks, [], top_k=2) - assert len(trimmed) == 2 + def test_contradiction_prefilter_forwarded(self) -> None: + config = MABenchConfig(contradiction_similarity_prefilter=0.65) + elfmem_cfg = build_elfmem_config(config) + assert elfmem_cfg.memory.contradiction_similarity_prefilter == 0.65 diff --git a/tests/test_bm25_retrieval.py b/tests/test_bm25_retrieval.py new file mode 100644 index 0000000..82325f6 --- /dev/null +++ b/tests/test_bm25_retrieval.py @@ -0,0 +1,46 @@ +"""Tests for BM25 hybrid retrieval — tested through the public recall() API.""" + +import pytest + +from elfmem import ElfmemConfig, MemorySystem +from elfmem.config import MemoryConfig + + +@pytest.fixture +async def system(test_engine, mock_llm, mock_embedding) -> MemorySystem: + """MemorySystem with inbox_threshold=3 for fast consolidation.""" + cfg = ElfmemConfig(memory=MemoryConfig(inbox_threshold=3)) + return MemorySystem( + engine=test_engine, + llm_service=mock_llm, + embedding_service=mock_embedding, + config=cfg, + ) + + +class TestBM25Retrieval: + """BM25 integration tested through recall() — the public API.""" + + async def test_recall_returns_results(self, system): + """recall() with query returns relevant blocks.""" + for i in range(4): + await system.learn(f"Knowledge item number {i} about topic alpha") + await system.consolidate() + + blocks = await system.recall(query="alpha", top_k=5) + assert len(blocks) > 0 + + async def test_recall_no_duplicate_blocks(self, system): + """Blocks appear at most once in recall results.""" + for i in range(4): + await system.learn(f"Unique fact {i} about gamma topic") + await system.consolidate() + + blocks = await system.recall(query="gamma", top_k=10) + block_ids = [b.id for b in blocks] + assert len(block_ids) == len(set(block_ids)) + + async def test_recall_empty_database_returns_empty(self, system): + """recall() on empty database returns empty list, not error.""" + blocks = await system.recall(query="anything", top_k=5) + assert blocks == [] diff --git a/tests/test_learn_document.py b/tests/test_learn_document.py new file mode 100644 index 0000000..a69bc79 --- /dev/null +++ b/tests/test_learn_document.py @@ -0,0 +1,166 @@ +"""Tests for learn_document(), dream(skip_llm=), and config wiring fixes.""" + +import pytest + +from elfmem import ElfmemConfig, LearnDocumentResult, MemorySystem +from elfmem.config import MemoryConfig + +# ── Fixtures ───────────────────────────────────────────────────────────────── + + +@pytest.fixture +async def system(test_engine, mock_llm, mock_embedding) -> MemorySystem: + """MemorySystem with inbox_threshold=3 for fast consolidation cycles.""" + cfg = ElfmemConfig(memory=MemoryConfig(inbox_threshold=3)) + return MemorySystem( + engine=test_engine, + llm_service=mock_llm, + embedding_service=mock_embedding, + config=cfg, + ) + + +# ── learn_document() ──────────────────────────────────────────────────────── + + +class TestLearnDocument: + async def test_basic_ingestion(self, system): + text = "Fact one. Fact two. Fact three. Fact four. Fact five." + result = await system.learn_document(text, chunk_size=2) + assert isinstance(result, LearnDocumentResult) + assert result.chunks_total > 0 + assert result.chunks_created > 0 + assert result.chunks_duplicate == 0 + + async def test_auto_consolidation_triggers(self, system): + """With inbox_threshold=3, learning 5+ chunks should trigger dream().""" + sentences = [f"Sentence number {i} with unique content." for i in range(10)] + text = " ".join(sentences) + result = await system.learn_document(text, chunk_size=3) + assert result.consolidations > 0 + assert result.blocks_promoted > 0 + + async def test_empty_text(self, system): + result = await system.learn_document("") + assert result.chunks_total == 0 + assert result.chunks_created == 0 + assert result.consolidations == 0 + + async def test_custom_chunker(self, system): + text = "line one\nline two\nline three" + result = await system.learn_document( + text, + chunker=lambda t: t.split("\n"), + chunk_size=2, # small enough that each line becomes its own chunk + ) + assert result.chunks_total == 3 + assert result.chunks_created == 3 + + async def test_duplicate_ingestion(self, system): + text = "A unique fact about penguins." + first = await system.learn_document(text) + assert first.chunks_created == 1 + second = await system.learn_document(text) + assert second.chunks_duplicate == 1 + assert second.chunks_created == 0 + + async def test_skip_llm_forwarded(self, system): + """skip_llm=True should produce consolidation without LLM calls.""" + sentences = [f"Fact {i} about topic X." for i in range(5)] + text = " ".join(sentences) + result = await system.learn_document(text, chunk_size=2, skip_llm=True) + if result.consolidations > 0: + assert result.blocks_promoted > 0 + + async def test_tags_applied_to_all_chunks(self, system): + text = "Alpha info. Beta info." + await system.learn_document(text, chunk_size=100, tags=["source:test"]) + blocks = await system.recall(query="info", top_k=5) + for block in blocks: + assert "source:test" in block.tags + + async def test_result_to_dict(self, system): + text = "Some content here." + result = await system.learn_document(text) + d = result.to_dict() + assert "chunks_total" in d + assert "chunks_created" in d + assert "consolidations" in d + + async def test_ingested_content_is_recallable(self, system): + """After learn_document, content is queryable via recall().""" + text = "The capital of France is Paris. Berlin is the capital of Germany." + await system.learn_document(text, chunk_size=100) + # Force consolidation for any remaining inbox blocks + await system.consolidate() + blocks = await system.recall(query="capital", top_k=5) + assert len(blocks) > 0 + + +# ── dream(skip_llm=) ──────────────────────────────────────────────────────── + + +class TestDreamSkipLLM: + async def test_dream_skip_llm_promotes_blocks(self, system): + """dream(skip_llm=True) should promote inbox blocks without LLM.""" + for i in range(3): + await system.learn(f"Fact number {i} about testing") + result = await system.dream(skip_llm=True) + assert result is not None + assert result.promoted > 0 + + async def test_dream_skip_contradictions(self, system): + """dream(skip_contradictions=True) should skip contradiction detection.""" + for i in range(3): + await system.learn(f"Knowledge item {i}") + result = await system.dream(skip_contradictions=True) + assert result is not None + assert result.promoted > 0 + + async def test_dream_empty_inbox_returns_none(self, system): + result = await system.dream(skip_llm=True) + assert result is None + + +# ── Config wiring ──────────────────────────────────────────────────────────── + + +class TestConfigWiring: + async def test_near_dup_threshold_from_config(self, test_engine, mock_llm, mock_embedding): + """near_dup_near_threshold from config should take effect in consolidation.""" + cfg = ElfmemConfig( + memory=MemoryConfig( + inbox_threshold=3, + near_dup_near_threshold=0.30, + ), + ) + system = MemorySystem( + engine=test_engine, + llm_service=mock_llm, + embedding_service=mock_embedding, + config=cfg, + ) + await system.learn("The cat sat on the mat") + await system.learn("The cat sat on the mat today") + await system.learn("A different topic entirely") + result = await system.consolidate() + assert result.processed == 3 + + async def test_contradiction_threshold_from_config(self, test_engine, mock_llm, mock_embedding): + """contradiction_threshold from config should be wired through.""" + cfg = ElfmemConfig( + memory=MemoryConfig( + inbox_threshold=3, + contradiction_threshold=0.50, + ), + ) + system = MemorySystem( + engine=test_engine, + llm_service=mock_llm, + embedding_service=mock_embedding, + config=cfg, + ) + for i in range(3): + await system.learn(f"Statement {i}") + result = await system.consolidate() + assert result.processed == 3 diff --git a/uv.lock b/uv.lock index 482ed4e..0025b4d 100644 --- a/uv.lock +++ b/uv.lock @@ -585,6 +585,9 @@ dependencies = [ ] [package.optional-dependencies] +bm25 = [ + { name = "rank-bm25" }, +] cli = [ { name = "typer" }, ] @@ -635,6 +638,7 @@ requires-dist = [ { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=5.0" }, { name = "pyyaml", specifier = ">=6.0" }, + { name = "rank-bm25", marker = "extra == 'bm25'", specifier = ">=0.2.2" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.3" }, { name = "sqlalchemy", specifier = ">=2.0" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12" }, @@ -642,7 +646,7 @@ requires-dist = [ { name = "typer", marker = "extra == 'tools'", specifier = ">=0.12" }, { name = "types-pyyaml", marker = "extra == 'dev'" }, ] -provides-extras = ["mcp", "cli", "tools", "viz", "dev"] +provides-extras = ["bm25", "mcp", "cli", "tools", "viz", "dev"] [[package]] name = "email-validator" @@ -2159,6 +2163,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl", hash = "sha256:17109e1a528561e32f026364712fee1264bc2ea6715120891174ed1b980d2e04", size = 4722, upload-time = "2025-05-13T15:23:59.629Z" }, ] +[[package]] +name = "rank-bm25" +version = "0.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/0a/f9579384aa017d8b4c15613f86954b92a95a93d641cc849182467cf0bb3b/rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d", size = 8347, upload-time = "2022-02-16T12:10:52.196Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/21/f691fb2613100a62b3fa91e9988c991e9ca5b89ea31c0d3152a3210344f9/rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae", size = 8584, upload-time = "2022-02-16T12:10:50.626Z" }, +] + [[package]] name = "redis" version = "7.4.0"