From ec42b416cac295f6ea88cd2c915c2992df9aa330 Mon Sep 17 00:00:00 2001 From: Ben Emson Date: Sat, 11 Apr 2026 09:07:28 +0200 Subject: [PATCH 1/5] feat: BM25 hybrid retrieval, learn_document(), config wiring, MABench resume MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move BM25 keyword search, document chunking, and auto-consolidation into elfmem core so benchmark adapters stay thin. Fix three config fields that existed in MemoryConfig but were never wired through to consolidate(). Add resume support to MABench runner. - BM25 stage 2b in hybrid_retrieve() (soft dep on rank_bm25) - learn_document() with auto-dream at inbox_threshold intervals - dream(skip_llm=, skip_contradictions=) forwarding - Wire contradiction_threshold, near_dup_*_threshold from config - MABench adapter: 319→160 lines (removed BM25/RRF/chunking) - MABench runner: atomic per-example writes + --resume - 30 new tests (572 total) Co-Authored-By: Claude Opus 4.6 (1M context) --- CHANGELOG.md | 7 + benchmarks/memoryagentbench/adapter.py | 186 +++---------------- benchmarks/memoryagentbench/runner.py | 78 +++++++- pyproject.toml | 1 + src/elfmem/__init__.py | 2 + src/elfmem/api.py | 136 +++++++++++++- src/elfmem/guide.py | 28 +++ src/elfmem/memory/retrieval.py | 66 ++++++- src/elfmem/types.py | 42 +++++ tests/benchmarks/test_mabench_adapter.py | 104 +++-------- tests/test_bm25_retrieval.py | 98 ++++++++++ tests/test_learn_document.py | 221 +++++++++++++++++++++++ uv.lock | 18 +- 13 files changed, 728 insertions(+), 259 deletions(-) create mode 100644 tests/test_bm25_retrieval.py create mode 100644 tests/test_learn_document.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 09e6c6d..0e81235 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,8 +23,15 @@ elfmem uses [Semantic Versioning](https://semver.org/). ### Added +- **`MemorySystem.learn_document(text, chunk_size, chunker, skip_llm)`:** Ingest a document in one call — chunks text, learns each chunk, auto-consolidates via `dream()` at `inbox_threshold` intervals. Accepts an optional `chunker` callback (e.g. `nltk.sent_tokenize`); default splits at sentence boundaries. Returns `LearnDocumentResult` with chunk and consolidation counts. +- **`LearnDocumentResult` type:** New result type with `chunks_total`, `chunks_created`, `chunks_duplicate`, `consolidations`, `blocks_promoted`. Exported from `elfmem`. +- **BM25 keyword search in retrieval pipeline (stage 2b):** `hybrid_retrieve()` now runs BM25 in parallel with vector search, discovering blocks with strong keyword overlap that embedding similarity misses. Soft dependency on `rank_bm25` — when not installed, the stage is silently skipped (zero regression). Install via `pip install elfmem[bm25]`. +- **`dream(skip_llm, skip_contradictions)` parameters:** `dream()` now forwards `skip_llm` and `skip_contradictions` to `consolidate()`, enabling fast-path consolidation without bypassing policy tracking or threshold persistence. - **`MABenchConfig.context_window_tokens`:** New config field (default 4096) representing the LM Studio model's context window. All answer-context truncation derives from this value; set to 2048 for smaller models. +### Fixed +- **Config wiring: `contradiction_threshold`, `near_dup_exact_threshold`, `near_dup_near_threshold`:** These three `MemoryConfig` fields existed but were not passed from `MemorySystem.consolidate()` to the consolidation operation. Custom config values were silently ignored (defaults matched, so no observable bug at default settings). Now wired through. + ### Added - **LoCoMo benchmark harness:** Complete benchmark suite for evaluating elfmem against LoCoMo (ACL 2024) — 10 conversations, 1,986 QA pairs, 5 categories. Includes metrics (Porter-stemmed F1), typed data loading, BM25 hybrid retrieval, observation transform, and CLI runner with `--test`, `--baselines`, `--resume`, `--top-k`, `--category` flags. Results conform to `benchmark_report_spec.md`. - **`consolidate(skip_llm=True)`:** Bypass all LLM calls during consolidation (embed + promote only). Reduces ingestion from hours to seconds for bulk import and benchmarks. diff --git a/benchmarks/memoryagentbench/adapter.py b/benchmarks/memoryagentbench/adapter.py index 2adccb3..a535c68 100644 --- a/benchmarks/memoryagentbench/adapter.py +++ b/benchmarks/memoryagentbench/adapter.py @@ -1,9 +1,12 @@ -"""elfmem adapter for MemoryAgentBench — chunk ingestion + retrieval. +"""elfmem adapter for MemoryAgentBench — document ingestion + retrieval. Key difference from LoCoMo: MemoryAgentBench's Conflict Resolution competency tests contradiction detection — elfmem's primary moat. For CR, we use FULL consolidation (contradiction detection ON). For other competencies, we use -skip_contradictions for speed. +skip_llm for speed. + +BM25 hybrid retrieval is handled natively by elfmem's retrieval pipeline +(stage 2b in hybrid_retrieve). No adapter-level BM25 or RRF needed. """ from __future__ import annotations @@ -11,17 +14,14 @@ import logging import tempfile import time -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path from typing import Any import nltk -from rank_bm25 import BM25Okapi - -from elfmem import ElfmemConfig, MemorySystem - from benchmarks.memoryagentbench.config import MABenchConfig +from elfmem import ElfmemConfig, MemorySystem nltk.download("punkt_tab", quiet=True) @@ -50,65 +50,6 @@ class ExampleResult: ingestion_seconds: float -class _BM25Index: - """BM25 index over block content for keyword-boosted retrieval. - - Built after consolidation from elfmem's active block content (summaries - when available, raw content otherwise). Stores block IDs so RRF merge - can match by ID rather than approximate content-prefix heuristics. - """ - - def __init__(self) -> None: - self._ids: list[str] = [] - self._contents: list[str] = [] - self._bm25: BM25Okapi | None = None - - def add(self, block_id: str, content: str) -> None: - self._ids.append(block_id) - self._contents.append(content) - - def build(self) -> None: - if self._contents: - tokenized = [c.lower().split() for c in self._contents] - self._bm25 = BM25Okapi(tokenized) - - def search(self, query: str, top_k: int = 10) -> list[tuple[str, str, float]]: - """Return (block_id, content, bm25_score) triples ranked by score.""" - if self._bm25 is None: - return [] - scores = self._bm25.get_scores(query.lower().split()) - ranked = sorted( - zip(self._ids, self._contents, scores), - key=lambda x: x[2], - reverse=True, - ) - return ranked[:top_k] - - -def chunk_text(text: str, chunk_size: int = 1024) -> list[str]: - """Split text into sentence-aligned chunks of ~chunk_size words. - - Uses NLTK sentence tokenization for clean boundaries. - """ - sentences = nltk.sent_tokenize(text) - chunks: list[str] = [] - current: list[str] = [] - current_len = 0 - - for sentence in sentences: - sent_len = len(sentence.split()) - if current_len + sent_len > chunk_size and current: - chunks.append(" ".join(current)) - current = [] - current_len = 0 - current.append(sentence) - current_len += sent_len - - if current: - chunks.append(" ".join(current)) - return chunks - - def _context_budget_words(config: MABenchConfig) -> int: """Answer-context word budget derived from the model's context window. @@ -153,34 +94,6 @@ def build_elfmem_config(config: MABenchConfig) -> ElfmemConfig: }) -def _rrf_merge( - vector_blocks: list, - bm25_results: list[tuple[str, str, float]], - top_k: int, - k: int = 60, -) -> tuple[list, str]: - """Merge vector search and BM25 results via Reciprocal Rank Fusion. - - BM25 results carry block IDs (from _BM25Index.search), so matching is - exact — no content-prefix heuristics, no supplementary raw-chunk fallback. - Both retrieval paths use the same block content (summaries or raw text), - so the merged context is consistent. - """ - block_scores: dict[str, float] = {} - block_map: dict[str, object] = {} - for rank, block in enumerate(vector_blocks): - block_scores[block.id] = 1.0 / (k + rank) - block_map[block.id] = block - - for rank, (block_id, _content, _score) in enumerate(bm25_results): - if block_id in block_map: - block_scores[block_id] += 1.0 / (k + rank) - - ranked = sorted(block_map.values(), key=lambda b: block_scores[b.id], reverse=True) - trimmed = ranked[:top_k] - return trimmed, "\n\n".join(b.content for b in trimmed) - - async def process_example( example: dict[str, Any], competency: str, @@ -199,6 +112,7 @@ async def process_example( # Conflict Resolution needs contradiction detection — elfmem's moat is_conflict_resolution = competency == "Conflict_Resolution" + skip_llm = not is_conflict_resolution elfmem_cfg = build_elfmem_config(config) @@ -207,80 +121,32 @@ async def process_example( system = await MemorySystem.from_config(db_path, config=elfmem_cfg) try: - # Phase 1: Chunk and ingest context + # Phase 1: Ingest document via learn_document() start_ingest = time.monotonic() - chunks = chunk_text(context, config.chunk_size) - total_promoted = 0 - - # CR needs contradiction detection — elfmem's core capability. - # Other competencies use skip_llm for speed (embedding-only retrieval). - skip_llm = not is_conflict_resolution - - await system.begin_session(task_type="ingestion") - for i, chunk in enumerate(chunks): - await system.learn( - content=chunk, - tags=[f"chunk:{i}"], - category="knowledge", - source="memoryagentbench", - ) - # Consolidate periodically - if (i + 1) % config.consolidate_every_n_chunks == 0: - await system.end_session() - await system.begin_session(task_type="consolidation") - result = await system.consolidate(skip_llm=skip_llm) - total_promoted += result.promoted - await system.end_session() - await system.begin_session(task_type="ingestion") - await system.end_session() - - # Final consolidation for remaining chunks - await system.begin_session(task_type="consolidation") - result = await system.consolidate(skip_llm=skip_llm) - total_promoted += result.promoted - await system.end_session() - - # Build BM25 from active block content after consolidation. - # With skip_llm=False (CR), blocks carry LLM-generated summaries — - # the same content used by elfmem's vector retrieval. This ensures - # BM25 and vector search are consistent, enabling exact ID-based RRF - # merging without heuristic content-prefix matching. - # With skip_llm=True (other competencies), content falls back to the - # raw chunk text — same as the previous raw-chunk approach. - await system.begin_session(task_type="retrieval") - index_frame = await system.frame("attention", query=None, top_k=10000) - await system.end_session() - - bm25_index = _BM25Index() - for block in index_frame.blocks: - bm25_index.add(block.id, block.content) - bm25_index.build() - + doc_result = await system.learn_document( + context, + chunk_size=config.chunk_size, + chunker=nltk.sent_tokenize, + source="memoryagentbench", + skip_llm=skip_llm, + ) ingestion_time = time.monotonic() - start_ingest - log.info(f" Ingested {len(chunks)} chunks, {total_promoted} promoted in {ingestion_time:.1f}s") + log.info( + f" Ingested {doc_result.chunks_total} chunks, " + f"{doc_result.blocks_promoted} promoted in {ingestion_time:.1f}s" + ) # Phase 2: Answer each question qa_results: list[QAResult] = [] - for q_idx, (question, answer_list) in enumerate(zip(questions, answers)): + for _q_idx, (question, answer_list) in enumerate(zip(questions, answers, strict=False)): q_start = time.monotonic() - await system.begin_session(task_type="retrieval") - frame_result = await system.frame("attention", query=question, top_k=config.top_k) - await system.end_session() - - # Build context from blocks directly, bypassing the frame's hardcoded - # token_budget (2000 tokens in the attention frame definition). Both - # the BM25 and no-BM25 paths are now bounded only by _context_budget_words, - # which is derived from config.context_window_tokens. - blocks = frame_result.blocks + # elfmem's recall() includes BM25 (stage 2b) + graph expansion + # + contradiction suppression natively. + blocks = await system.recall(query=question, top_k=config.top_k) context_text = "\n\n".join(b.content for b in blocks) - bm25_hits = bm25_index.search(question, top_k=config.top_k) - if bm25_hits: - blocks, context_text = _rrf_merge(blocks, bm25_hits, config.top_k) # Truncate context to fit the model's context window. - # Budget is derived from config.context_window_tokens minus - # system prompt, template, question, and answer overhead. max_context_words = _context_budget_words(config) words = context_text.split() if len(words) > max_context_words: @@ -304,8 +170,8 @@ async def process_example( return ExampleResult( source=source, competency=competency, - chunks_ingested=len(chunks), - blocks_promoted=total_promoted, + chunks_ingested=doc_result.chunks_total, + blocks_promoted=doc_result.blocks_promoted, qa_results=qa_results, ingestion_seconds=ingestion_time, ) diff --git a/benchmarks/memoryagentbench/runner.py b/benchmarks/memoryagentbench/runner.py index 5fdc2c6..c0c6dc5 100644 --- a/benchmarks/memoryagentbench/runner.py +++ b/benchmarks/memoryagentbench/runner.py @@ -11,12 +11,12 @@ import logging import os import time -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from datasets import load_dataset -from benchmarks.memoryagentbench.adapter import ExampleResult, process_example +from benchmarks.memoryagentbench.adapter import process_example from benchmarks.memoryagentbench.config import MABenchConfig from benchmarks.memoryagentbench.metrics import score_question from benchmarks.shared.answerer import generate_answer @@ -34,6 +34,39 @@ ] +# ── Resume helpers ─────────────────────────────────────────────────────────── + + +def _example_id(source: str, index: int) -> str: + """Composite ID for an example: 'source/index'.""" + return f"{source}/{index}" + + +def _get_completed_example_ids(output_path: Path) -> set[str]: + """Load set of example IDs already written to output JSON.""" + if not output_path.exists(): + return set() + data = json.loads(output_path.read_text()) + return {q["example_id"] for q in data.get("questions", []) if "example_id" in q} + + +def _append_example_results( + output_path: Path, new_questions: list[dict], +) -> None: + """Atomically append question results to the output JSON. + + Uses tmpfile + rename for crash safety on POSIX filesystems. + """ + existing = json.loads(output_path.read_text()) if output_path.exists() else {"questions": []} + existing["questions"].extend(new_questions) + tmp = output_path.with_suffix(".tmp") + tmp.write_text(json.dumps(existing, indent=2, default=str)) + tmp.rename(output_path) + + +# ── Scoring ────────────────────────────────────────────────────────────────── + + async def _answer_and_score( qa, config: MABenchConfig, ) -> dict: @@ -56,6 +89,9 @@ async def _answer_and_score( } +# ── Main runner ────────────────────────────────────────────────────────────── + + async def run(args: argparse.Namespace) -> None: """Run the benchmark.""" config = MABenchConfig() @@ -73,10 +109,17 @@ async def run(args: argparse.Namespace) -> None: config.resume = True config.output_dir.mkdir(parents=True, exist_ok=True) - timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") output_path = args.output or config.output_dir / f"{timestamp}_mabench_elfmem.json" output_path = Path(output_path) + # Resume: load previously completed examples + completed_ids: set[str] = set() + if config.resume: + completed_ids = _get_completed_example_ids(output_path) + if completed_ids: + log.info(f"Resuming: {len(completed_ids)} example(s) already completed") + log.info("Loading MemoryAgentBench dataset from HuggingFace...") ds = load_dataset("ai-hyz/MemoryAgentBench") @@ -104,19 +147,33 @@ async def run(args: argparse.Namespace) -> None: for ex_idx, example in enumerate(examples): metadata = example.get("metadata", {}) source = metadata.get("source", "unknown") if isinstance(metadata, dict) else "unknown" + ex_id = _example_id(source, ex_idx) n_questions = len(example["questions"]) + + # Resume: skip completed examples + if ex_id in completed_ids: + log.info(f" [{ex_idx+1}/{len(examples)}] {source}: skipped (resumed)") + continue + log.info(f"\n [{ex_idx+1}/{len(examples)}] {source}: {n_questions} questions") try: ex_result = await process_example(example, competency, config) + example_questions: list[dict] = [] for qa in ex_result.qa_results: scored = await _answer_and_score(qa, config) scored["competency"] = competency scored["source"] = source - all_results.append(scored) + scored["example_id"] = ex_id + example_questions.append(scored) comp_f1s.append(scored["f1"]) + all_results.extend(example_questions) + + # Atomic write after each example — crash-safe resume + _append_example_results(output_path, example_questions) + avg_f1 = sum(comp_f1s[-n_questions:]) / n_questions if n_questions else 0 log.info(f" {n_questions} Qs answered | avg F1={avg_f1:.3f}") @@ -125,7 +182,7 @@ async def run(args: argparse.Namespace) -> None: competency_scores[competency] = comp_f1s - # Build report + # Build final report (overwrites with complete metadata) duration = time.monotonic() - start_time scores_by_comp: dict[str, dict] = {} all_f1s: list[float] = [] @@ -142,11 +199,18 @@ async def run(args: argparse.Namespace) -> None: except ImportError: elfmem_version = "unknown" + # Merge any previously completed results with this run's results + if output_path.exists(): + existing = json.loads(output_path.read_text()) + all_questions = existing.get("questions", []) + else: + all_questions = all_results + report = { "meta": { "benchmark": "memoryagentbench", "version": "1.0", - "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), + "timestamp": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"), "duration_seconds": round(duration, 1), "elfmem_version": elfmem_version, "models": { @@ -170,7 +234,7 @@ async def run(args: argparse.Namespace) -> None: "overall": round(overall * 100, 1), "by_competency": scores_by_comp, }, - "questions": all_results, + "questions": all_questions, } output_path.write_text(json.dumps(report, indent=2, default=str)) diff --git a/pyproject.toml b/pyproject.toml index cb83cc8..bf1682d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ Changelog = "https://github.com/emson/elfmem/blob/main/CHANGELOG.md" Security = "https://github.com/emson/elfmem/blob/main/SECURITY.md" [project.optional-dependencies] +bm25 = ["rank_bm25>=0.2.2"] mcp = ["fastmcp>=2.0"] cli = ["typer>=0.12"] tools = ["fastmcp>=2.0", "typer>=0.12"] diff --git a/src/elfmem/__init__.py b/src/elfmem/__init__.py index 14a95a6..c6b483c 100644 --- a/src/elfmem/__init__.py +++ b/src/elfmem/__init__.py @@ -41,6 +41,7 @@ DisconnectResult, DisplacedEdge, FrameResult, + LearnDocumentResult, LearnResult, OperationRecord, OutcomeResult, @@ -58,6 +59,7 @@ "ConsolidationPolicy", # Result types "LearnResult", + "LearnDocumentResult", "ConsolidateResult", "FrameResult", "CurateResult", diff --git a/src/elfmem/api.py b/src/elfmem/api.py index 45ab281..0f989bd 100644 --- a/src/elfmem/api.py +++ b/src/elfmem/api.py @@ -4,9 +4,10 @@ import json import os +import re import time from collections import deque -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager, suppress from datetime import UTC, datetime from typing import Any, Literal @@ -56,6 +57,7 @@ CurateResult, DisconnectResult, FrameResult, + LearnDocumentResult, LearnResult, OperationRecord, OutcomeResult, @@ -65,6 +67,33 @@ TokenUsage, ) +# ── Document chunking helpers ──────────────────────────────────────────────── + +_SENTENCE_SPLIT_RE = re.compile(r"(?<=[.!?])\s+") + + +def _default_chunker(text: str) -> list[str]: + """Split text into sentences at [.!?] followed by whitespace.""" + return [s.strip() for s in _SENTENCE_SPLIT_RE.split(text) if s.strip()] + + +def _assemble_chunks(sentences: list[str], chunk_size: int) -> list[str]: + """Combine sentences into chunks of ~chunk_size words.""" + chunks: list[str] = [] + current: list[str] = [] + current_len = 0 + for sentence in sentences: + sent_len = len(sentence.split()) + if current_len + sent_len > chunk_size and current: + chunks.append(" ".join(current)) + current = [] + current_len = 0 + current.append(sentence) + current_len += sent_len + if current: + chunks.append(" ".join(current)) + return chunks + class MemorySystem: """Adaptive memory system for LLM agents. @@ -696,6 +725,9 @@ async def consolidate( embedding_svc=self._embedding, current_active_hours=current_hours, self_alignment_threshold=mem.self_alignment_threshold, + contradiction_threshold=mem.contradiction_threshold, + near_dup_exact_threshold=mem.near_dup_exact_threshold, + near_dup_near_threshold=mem.near_dup_near_threshold, edge_score_threshold=mem.edge_score_threshold, edge_degree_cap=mem.edge_degree_cap, contradiction_similarity_prefilter=mem.contradiction_similarity_prefilter, @@ -774,7 +806,93 @@ async def remember( await self.begin_session() return await self.learn(content, tags=tags, category=category, source=source) - async def dream(self) -> ConsolidateResult | None: + async def learn_document( + self, + text: str, + *, + chunk_size: int = 256, + chunker: Callable[[str], list[str]] | None = None, + tags: list[str] | None = None, + category: str = "knowledge", + source: str = "document", + skip_llm: bool = False, + ) -> LearnDocumentResult: + """Ingest a document: chunk, learn each chunk, auto-consolidate. + + USE WHEN: Ingesting a document, article, or long-form text. Handles + chunking, learning, and consolidation in one call. + + DON'T USE WHEN: Ingesting a single fact or short observation — use + learn() instead. + + COST: O(chunks) learn() calls + dream() at inbox_threshold intervals. + With skip_llm=True, dream() runs the fast embedding-only path. + + RETURNS: LearnDocumentResult with chunk and consolidation counts. + + NEXT: recall() or frame() to query the ingested knowledge. + + Args: + text: The full document text to ingest. + chunk_size: Target words per chunk (default 256). + chunker: Sentence splitter function. Receives the full text, + returns a list of sentences. Default splits at [.!?] + followed by whitespace. Pass ``nltk.sent_tokenize`` for + better accuracy with abbreviations and edge cases. + tags: Tags applied to every chunk (in addition to ``chunk:N``). + category: Block category (default "knowledge"). + source: Block source identifier (default "document"). + skip_llm: Forward to dream() — bypass LLM calls during + consolidation (embed + promote only). + + Example:: + + result = await system.learn_document(article_text, chunk_size=200) + print(result) # Ingested 12 chunks: 12 created, 2 consolidations, 10 promoted. + """ + split_fn = chunker or _default_chunker + sentences = split_fn(text) + chunks = _assemble_chunks(sentences, chunk_size) + + created = 0 + duplicates = 0 + consolidations = 0 + promoted = 0 + + for i, chunk in enumerate(chunks): + result = await self.learn( + content=chunk, + tags=[f"chunk:{i}", *(tags or [])], + category=category, + source=source, + ) + if result.status == "created": + created += 1 + elif result.status == "duplicate_rejected": + duplicates += 1 + + if self.should_dream: + dr = await self.dream(skip_llm=skip_llm) + if dr is not None: + consolidations += 1 + promoted += dr.promoted + + doc_result = LearnDocumentResult( + chunks_total=len(chunks), + chunks_created=created, + chunks_duplicate=duplicates, + consolidations=consolidations, + blocks_promoted=promoted, + ) + self._record_op("learn_document", doc_result.summary) + return doc_result + + async def dream( + self, + *, + skip_llm: bool = False, + skip_contradictions: bool = False, + ) -> ConsolidateResult | None: """Consolidate pending blocks at a natural pause point. The breathing rhythm: learn fast (remember), process deliberately (dream). @@ -785,8 +903,8 @@ async def dream(self) -> ConsolidateResult | None: DON'T USE WHEN: In a tight loop. One call processes all pending blocks. - COST: LLM call per pending block. Returns None immediately (zero cost) - if inbox is empty. + COST: LLM call per pending block (unless skip_llm=True). Returns None + immediately (zero cost) if inbox is empty. RETURNS: ConsolidateResult if blocks were processed; None if inbox was empty. None is not an error — it means "nothing needed doing." @@ -798,6 +916,12 @@ async def dream(self) -> ConsolidateResult | None: (e.g., another process), dream() may return None despite inbox having items. Call status() for DB-accurate inbox_count. + Args: + skip_llm: Bypass all LLM calls (embed + promote only). Fastest path + for bulk ingestion where alignment scoring isn't needed. + skip_contradictions: Keep LLM summaries but skip O(n²) contradiction + detection. Good balance of quality and speed. + Example:: # Always-on agent loop @@ -809,7 +933,9 @@ async def dream(self) -> ConsolidateResult | None: """ if self._pending == 0: return None - result = await self.consolidate() + result = await self.consolidate( + skip_llm=skip_llm, skip_contradictions=skip_contradictions, + ) # Feed policy so it can adapt the threshold for the next cycle, # then persist the new threshold so it survives process restarts. # Persistence is best-effort: a DB failure here is non-fatal — diff --git a/src/elfmem/guide.py b/src/elfmem/guide.py index fa05ce5..703812e 100644 --- a/src/elfmem/guide.py +++ b/src/elfmem/guide.py @@ -146,6 +146,34 @@ def __str__(self) -> str: "print(result) # Stored block a1b2c3d4. Status: created." ), ), + "learn_document": AgentGuide( + name="learn_document", + what="Ingest a document: chunk, learn each chunk, auto-consolidate.", + when=( + "The agent needs to ingest a document, article, or long-form text. " + "Handles chunking, learning, and consolidation in one call." + ), + when_not=( + "Single facts or short observations — use learn() instead. " + "Already-chunked data — use learn() in a loop." + ), + cost=( + "O(chunks) learn() calls + dream() at inbox_threshold intervals. " + "With skip_llm=True, dream() uses the fast embedding-only path." + ), + returns=( + "LearnDocumentResult with chunks_total, chunks_created, " + "chunks_duplicate, consolidations, blocks_promoted." + ), + next=( + "recall() or frame() to query the ingested knowledge. " + "Consolidation happened automatically during ingestion." + ), + example=( + "result = await system.learn_document(article_text, chunk_size=200)\n" + "print(result) # Ingested 12 chunks: 12 created, 2 consolidations." + ), + ), "consolidate": AgentGuide( name="consolidate", what="Process inbox blocks: score, embed, deduplicate, and promote to active memory.", diff --git a/src/elfmem/memory/retrieval.py b/src/elfmem/memory/retrieval.py index 04e1721..c0c27f1 100644 --- a/src/elfmem/memory/retrieval.py +++ b/src/elfmem/memory/retrieval.py @@ -1,7 +1,13 @@ -"""4-stage hybrid retrieval pipeline — pure (no side effects).""" +"""6-stage hybrid retrieval pipeline — pure (no side effects). + +Stages: pre-filter → vector → BM25 → graph expand → composite score → MMR. +BM25 (stage 2b) requires the optional ``rank_bm25`` package. When not installed, +the stage is silently skipped and retrieval works as a 5-stage vector-only pipeline. +""" from __future__ import annotations +import logging import math from typing import Any @@ -19,6 +25,16 @@ ) from elfmem.types import ScoredBlock +# Soft dependency — retrieval works without it. +try: + from rank_bm25 import BM25Okapi + + _HAS_BM25 = True +except ImportError: # pragma: no cover + _HAS_BM25 = False + +log = logging.getLogger(__name__) + N_SEEDS_MULTIPLIER = 4 CONTRADICTION_OVERSAMPLE = 2 DEFAULT_SEARCH_WINDOW_HOURS = 200.0 @@ -37,13 +53,14 @@ async def hybrid_retrieve( tag_filter: str | None = None, search_window_hours: float = DEFAULT_SEARCH_WINDOW_HOURS, ) -> list[ScoredBlock]: - """Execute the 5-stage hybrid retrieval pipeline. + """Execute the 6-stage hybrid retrieval pipeline. - Stage 1 — Pre-filter: active blocks within search window. - Stage 2 — Vector search: cosine similarity → top N_seeds. (Skipped if no query.) - Stage 3 — Graph expand: 1-hop neighbours of seeds. (Skipped if no query.) - Stage 4 — Composite score: rank all candidates. - Stage 5 — MMR diversity: reorder for relevance + diversity. (Query-aware only.) + Stage 1 — Pre-filter: active blocks within search window. + Stage 2 — Vector search: cosine similarity → top N_seeds. (Skipped if no query.) + Stage 2b — BM25 keyword search: term overlap → top N_bm25. (Requires rank_bm25.) + Stage 3 — Graph expand: 1-hop neighbours of seeds. (Skipped if no query.) + Stage 4 — Composite score: rank all candidates. + Stage 5 — MMR diversity: reorder for relevance + diversity. (Query-aware only.) Returns top_k * CONTRADICTION_OVERSAMPLE ScoredBlocks for contradiction headroom. """ @@ -64,11 +81,23 @@ async def hybrid_retrieve( candidates, n_seeds=top_k * N_SEEDS_MULTIPLIER, ) + + # Stage 2b: BM25 keyword candidates (additive — discovers blocks + # that vector search missed due to vocabulary mismatch). + seed_ids_set = {b["id"] for b, _ in seed_pairs} + bm25_pairs = _stage_2b_bm25_search( + candidates, query, n_seeds=top_k * N_SEEDS_MULTIPLIER, + ) + for block, _bm25_score in bm25_pairs: + if block["id"] not in seed_ids_set: + seed_pairs.append((block, 0.0)) + seed_ids_set.add(block["id"]) + seed_ids = [b["id"] for b, _ in seed_pairs] expanded = await _stage_3_graph_expand( conn, seed_ids=seed_ids, - existing_candidate_ids={b["id"] for b, _ in seed_pairs}, + existing_candidate_ids=seed_ids_set, ) # candidates for scoring: (block, similarity, was_expanded) @@ -154,6 +183,27 @@ async def _stage_2_vector_search( return scored[:n_seeds] +def _stage_2b_bm25_search( + candidates: list[dict[str, Any]], + query: str, + n_seeds: int, +) -> list[tuple[dict[str, Any], float]]: + """Stage 2b: BM25 keyword search over pre-filtered candidates. + + Discovers blocks with strong term overlap that vector search may miss + (vocabulary mismatch, exact entity names, etc.). Requires the optional + ``rank_bm25`` package — returns ``[]`` when not installed. + """ + if not _HAS_BM25 or not candidates: + return [] + contents = [b.get("summary") or b.get("content", "") for b in candidates] + tokenized = [c.lower().split() for c in contents] + bm25 = BM25Okapi(tokenized) # type: ignore[possibly-undefined] + scores = bm25.get_scores(query.lower().split()) + ranked = sorted(zip(candidates, scores, strict=False), key=lambda x: x[1], reverse=True) + return ranked[:n_seeds] + + async def _stage_3_graph_expand( conn: AsyncConnection, seed_ids: list[str], diff --git a/src/elfmem/types.py b/src/elfmem/types.py index a22d453..8a45c3f 100644 --- a/src/elfmem/types.py +++ b/src/elfmem/types.py @@ -130,6 +130,48 @@ def to_dict(self) -> dict[str, Any]: return {"block_id": self.block_id, "status": self.status} +@dataclass +class LearnDocumentResult: + """Result of ingesting a document via learn_document(). + + USE WHEN: Inspecting how a document was ingested. + DON'T USE WHEN: You need per-block detail — check individual LearnResults. + COST: Zero — this is a pure summary. + RETURNS: Chunk and consolidation counts. + NEXT: recall() or frame() to query the ingested knowledge. + """ + + chunks_total: int + chunks_created: int + chunks_duplicate: int + consolidations: int + blocks_promoted: int + + @property + def summary(self) -> str: + if self.chunks_total == 0: + return "No chunks to ingest (empty document)." + parts = [f"{self.chunks_created} created"] + if self.chunks_duplicate: + parts.append(f"{self.chunks_duplicate} duplicate") + if self.consolidations: + parts.append(f"{self.consolidations} consolidations") + parts.append(f"{self.blocks_promoted} promoted") + return f"Ingested {self.chunks_total} chunks: {', '.join(parts)}." + + def __str__(self) -> str: + return self.summary + + def to_dict(self) -> dict[str, int]: + return { + "chunks_total": self.chunks_total, + "chunks_created": self.chunks_created, + "chunks_duplicate": self.chunks_duplicate, + "consolidations": self.consolidations, + "blocks_promoted": self.blocks_promoted, + } + + @dataclass class ConsolidateResult: processed: int diff --git a/tests/benchmarks/test_mabench_adapter.py b/tests/benchmarks/test_mabench_adapter.py index e1bf922..41bf2ba 100644 --- a/tests/benchmarks/test_mabench_adapter.py +++ b/tests/benchmarks/test_mabench_adapter.py @@ -1,86 +1,34 @@ -"""Tests for MemoryAgentBench adapter utilities — BM25 index and RRF merge.""" +"""Tests for MemoryAgentBench adapter utilities.""" -import pytest +from benchmarks.memoryagentbench.adapter import _context_budget_words, build_elfmem_config +from benchmarks.memoryagentbench.config import MABenchConfig -pytest.importorskip("rank_bm25", reason="rank-bm25 not installed") -from benchmarks.memoryagentbench.adapter import _BM25Index, _rrf_merge +class TestContextBudgetWords: + def test_default_config_budget_positive(self) -> None: + config = MABenchConfig() + budget = _context_budget_words(config) + assert budget > 100 + def test_larger_context_window_increases_budget(self) -> None: + small = MABenchConfig(context_window_tokens=2048) + large = MABenchConfig(context_window_tokens=8192) + assert _context_budget_words(large) > _context_budget_words(small) -class _FakeBlock: - """Minimal stand-in for ScoredBlock in RRF merge tests.""" + def test_minimum_budget_is_100(self) -> None: + tiny = MABenchConfig(context_window_tokens=100) + assert _context_budget_words(tiny) == 100 - def __init__(self, block_id: str, content: str) -> None: - self.id = block_id - self.content = content +class TestBuildElfmemConfig: + def test_returns_valid_config(self) -> None: + config = MABenchConfig() + elfmem_cfg = build_elfmem_config(config) + assert elfmem_cfg.llm.model == config.elfmem_llm_model + assert elfmem_cfg.embeddings.model == config.elfmem_embedding_model + assert elfmem_cfg.memory.top_k == config.top_k -class TestBM25Index: - def test_search_returns_id_content_score_triples(self) -> None: - idx = _BM25Index() - idx.add("id1", "python async programming") - idx.add("id2", "java synchronous code") - idx.add("id3", "ruby sequential scripts") # 3 docs: IDF for 'async' is positive - idx.build() - results = idx.search("async", top_k=3) - block_id, content, score = results[0] - assert block_id == "id1" - assert content == "python async programming" - assert score > 0 - - def test_search_ranks_by_relevance(self) -> None: - idx = _BM25Index() - idx.add("a", "python async") - idx.add("b", "java synchronous") - idx.build() - results = idx.search("python", top_k=2) - assert results[0][0] == "a" - - def test_search_empty_index_returns_empty(self) -> None: - idx = _BM25Index() - idx.build() - assert idx.search("anything") == [] - - def test_search_respects_top_k(self) -> None: - idx = _BM25Index() - for i in range(10): - idx.add(f"id{i}", f"content about topic {i}") - idx.build() - assert len(idx.search("content topic", top_k=3)) == 3 - - -class TestRRFMerge: - def test_matched_bm25_block_scores_higher(self) -> None: - """A block present in both vector and BM25 results scores higher than one in vector only.""" - b1 = _FakeBlock("b1", "async python programming") - b2 = _FakeBlock("b2", "java synchronous code") - vector_blocks = [b1, b2] # b1 ranked first by vector - - bm25_results = [("b1", "async python programming", 5.0)] # b1 also in BM25 - - _, context = _rrf_merge(vector_blocks, bm25_results, top_k=2) - # b1 should appear first — boosted by both retrieval paths - assert context.startswith("async python programming") - - def test_unmatched_bm25_id_ignored(self) -> None: - """BM25 result for unknown block ID has no effect — no supplementary raw chunks.""" - b1 = _FakeBlock("b1", "block one") - vector_blocks = [b1] - bm25_results = [("unknown_id", "some raw text", 3.0)] - - trimmed, context = _rrf_merge(vector_blocks, bm25_results, top_k=1) - assert len(trimmed) == 1 - assert "some raw text" not in context - - def test_context_is_newline_joined_block_content(self) -> None: - b1 = _FakeBlock("b1", "first block") - b2 = _FakeBlock("b2", "second block") - _, context = _rrf_merge([b1, b2], [], top_k=2) - assert "first block" in context - assert "second block" in context - assert "\n\n" in context - - def test_top_k_limits_output(self) -> None: - blocks = [_FakeBlock(f"b{i}", f"block {i}") for i in range(5)] - trimmed, _ = _rrf_merge(blocks, [], top_k=2) - assert len(trimmed) == 2 + def test_contradiction_prefilter_forwarded(self) -> None: + config = MABenchConfig(contradiction_similarity_prefilter=0.65) + elfmem_cfg = build_elfmem_config(config) + assert elfmem_cfg.memory.contradiction_similarity_prefilter == 0.65 diff --git a/tests/test_bm25_retrieval.py b/tests/test_bm25_retrieval.py new file mode 100644 index 0000000..fb514ce --- /dev/null +++ b/tests/test_bm25_retrieval.py @@ -0,0 +1,98 @@ +"""Tests for BM25 integration in hybrid_retrieve (stage 2b).""" + +import unittest.mock + +import pytest + +from elfmem import ElfmemConfig, MemorySystem +from elfmem.config import MemoryConfig +from elfmem.memory import retrieval + + +@pytest.fixture +async def system(test_engine, mock_llm, mock_embedding) -> MemorySystem: + """MemorySystem with inbox_threshold=3 for fast consolidation.""" + cfg = ElfmemConfig(memory=MemoryConfig(inbox_threshold=3)) + return MemorySystem( + engine=test_engine, + llm_service=mock_llm, + embedding_service=mock_embedding, + config=cfg, + ) + + +class TestBM25StageFunction: + """Direct tests on _stage_2b_bm25_search.""" + + def test_returns_empty_when_no_candidates(self): + result = retrieval._stage_2b_bm25_search([], "test query", n_seeds=10) + assert result == [] + + def test_returns_ranked_results(self): + candidates = [ + {"id": "a", "content": "the quick brown fox jumps"}, + {"id": "b", "content": "lazy dog sleeps all day"}, + {"id": "c", "content": "the fox is quick and clever"}, + ] + result = retrieval._stage_2b_bm25_search(candidates, "quick fox", n_seeds=10) + assert len(result) > 0 + # Blocks mentioning "quick" and "fox" should rank higher + ids = [b["id"] for b, _ in result] + # "a" and "c" both contain "quick" and "fox", "b" doesn't + assert ids[0] in ("a", "c") + + def test_respects_n_seeds_limit(self): + candidates = [{"id": f"b{i}", "content": f"word{i} content"} for i in range(20)] + result = retrieval._stage_2b_bm25_search(candidates, "word5", n_seeds=3) + assert len(result) == 3 + + def test_uses_summary_over_content(self): + candidates = [ + {"id": "a", "content": "irrelevant", "summary": "the target keyword here"}, + {"id": "b", "content": "the target keyword here"}, + ] + result = retrieval._stage_2b_bm25_search(candidates, "target keyword", n_seeds=10) + # Both should match — "a" via summary, "b" via content + ids = [b["id"] for b, _ in result] + assert "a" in ids + assert "b" in ids + + def test_returns_empty_when_bm25_unavailable(self): + """When _HAS_BM25 is False, stage 2b is a no-op.""" + candidates = [{"id": "a", "content": "some content"}] + with unittest.mock.patch.object(retrieval, "_HAS_BM25", False): + result = retrieval._stage_2b_bm25_search(candidates, "content", n_seeds=10) + assert result == [] + + +class TestBM25Integration: + """Integration tests: BM25 results appear in recall() output.""" + + async def test_recall_returns_results_with_bm25(self, system): + """Basic recall with BM25 available should return results.""" + for i in range(4): + await system.learn(f"Knowledge item number {i} about topic alpha") + await system.consolidate() + + blocks = await system.recall(query="alpha", top_k=5) + assert len(blocks) > 0 + + async def test_recall_without_bm25_still_works(self, system): + """When BM25 is disabled, recall still works (vector-only).""" + for i in range(4): + await system.learn(f"Fact {i} about beta topic") + await system.consolidate() + + with unittest.mock.patch.object(retrieval, "_HAS_BM25", False): + blocks = await system.recall(query="beta", top_k=5) + assert len(blocks) > 0 + + async def test_bm25_no_duplicate_blocks(self, system): + """Blocks found by both vector and BM25 should appear only once.""" + for i in range(4): + await system.learn(f"Unique fact {i} about gamma topic") + await system.consolidate() + + blocks = await system.recall(query="gamma", top_k=10) + block_ids = [b.id for b in blocks] + assert len(block_ids) == len(set(block_ids)) # no duplicates diff --git a/tests/test_learn_document.py b/tests/test_learn_document.py new file mode 100644 index 0000000..855d5a9 --- /dev/null +++ b/tests/test_learn_document.py @@ -0,0 +1,221 @@ +"""Tests for learn_document(), dream(skip_llm=), and config wiring fixes.""" + +import pytest + +from elfmem import ElfmemConfig, LearnDocumentResult, MemorySystem +from elfmem.api import _assemble_chunks, _default_chunker +from elfmem.config import MemoryConfig + +# ── Fixtures ───────────────────────────────────────────────────────────────── + + +@pytest.fixture +async def system(test_engine, mock_llm, mock_embedding) -> MemorySystem: + """MemorySystem with inbox_threshold=3 for fast consolidation cycles.""" + cfg = ElfmemConfig(memory=MemoryConfig(inbox_threshold=3)) + return MemorySystem( + engine=test_engine, + llm_service=mock_llm, + embedding_service=mock_embedding, + config=cfg, + ) + + +# ── _default_chunker ──────────────────────────────────────────────────────── + + +class TestDefaultChunker: + def test_splits_at_sentence_boundaries(self): + text = "First sentence. Second sentence. Third sentence." + result = _default_chunker(text) + assert result == ["First sentence.", "Second sentence.", "Third sentence."] + + def test_splits_at_question_and_exclamation(self): + text = "What happened? It exploded! Then silence." + result = _default_chunker(text) + assert len(result) == 3 + + def test_empty_text_returns_empty(self): + assert _default_chunker("") == [] + + def test_no_sentence_endings_returns_whole_text(self): + text = "A single run-on thought with no ending punctuation" + result = _default_chunker(text) + assert result == [text] + + +# ── _assemble_chunks ──────────────────────────────────────────────────────── + + +class TestAssembleChunks: + def test_groups_sentences_by_word_count(self): + sentences = ["One two three.", "Four five.", "Six seven eight nine ten."] + # chunk_size=5: first sentence (3 words) + second (2 words) = 5, fits + # third sentence (5 words) starts new chunk + chunks = _assemble_chunks(sentences, chunk_size=5) + assert len(chunks) == 2 + + def test_single_sentence_over_chunk_size(self): + sentences = ["This is a very long single sentence with many words."] + chunks = _assemble_chunks(sentences, chunk_size=3) + assert len(chunks) == 1 # single sentence can't be split + + def test_empty_sentences_returns_empty(self): + assert _assemble_chunks([], chunk_size=10) == [] + + def test_each_sentence_exceeds_chunk_size(self): + sentences = ["Alpha bravo charlie.", "Delta echo foxtrot."] + chunks = _assemble_chunks(sentences, chunk_size=2) + assert len(chunks) == 2 + + +# ── learn_document() ──────────────────────────────────────────────────────── + + +class TestLearnDocument: + async def test_basic_ingestion(self, system): + text = "Fact one. Fact two. Fact three. Fact four. Fact five." + result = await system.learn_document(text, chunk_size=2) + assert isinstance(result, LearnDocumentResult) + assert result.chunks_total > 0 + assert result.chunks_created > 0 + assert result.chunks_duplicate == 0 + + async def test_auto_consolidation_triggers(self, system): + """With inbox_threshold=3, learning 5+ chunks should trigger dream().""" + sentences = [f"Sentence number {i} with unique content." for i in range(10)] + text = " ".join(sentences) + result = await system.learn_document(text, chunk_size=3) + assert result.consolidations > 0 + assert result.blocks_promoted > 0 + + async def test_empty_text(self, system): + result = await system.learn_document("") + assert result.chunks_total == 0 + assert result.chunks_created == 0 + assert result.consolidations == 0 + + async def test_custom_chunker(self, system): + text = "line one\nline two\nline three" + result = await system.learn_document( + text, + chunker=lambda t: t.split("\n"), + chunk_size=2, # small enough that each line becomes its own chunk + ) + assert result.chunks_total == 3 + assert result.chunks_created == 3 + + async def test_duplicate_ingestion(self, system): + text = "A unique fact about penguins." + first = await system.learn_document(text) + assert first.chunks_created == 1 + second = await system.learn_document(text) + assert second.chunks_duplicate == 1 + assert second.chunks_created == 0 + + async def test_skip_llm_forwarded(self, system): + """skip_llm=True should produce consolidation without LLM calls.""" + sentences = [f"Fact {i} about topic X." for i in range(5)] + text = " ".join(sentences) + result = await system.learn_document(text, chunk_size=2, skip_llm=True) + # Should still consolidate (embedding-only path) + if result.consolidations > 0: + assert result.blocks_promoted > 0 + + async def test_tags_applied_to_all_chunks(self, system): + text = "Alpha info. Beta info." + await system.learn_document(text, chunk_size=100, tags=["source:test"]) + # Recall should find blocks with our tags + blocks = await system.recall(query="info", top_k=5) + for block in blocks: + assert "source:test" in block.tags + + async def test_result_summary_str(self, system): + text = "One fact. Two fact. Three fact." + result = await system.learn_document(text, chunk_size=100) + summary = str(result) + assert "chunks" in summary.lower() or "ingested" in summary.lower() + + async def test_result_to_dict(self, system): + text = "Some content here." + result = await system.learn_document(text) + d = result.to_dict() + assert "chunks_total" in d + assert "chunks_created" in d + assert "consolidations" in d + + +# ── dream(skip_llm=) ──────────────────────────────────────────────────────── + + +class TestDreamSkipLLM: + async def test_dream_skip_llm_promotes_blocks(self, system): + """dream(skip_llm=True) should promote inbox blocks without LLM.""" + for i in range(3): + await system.learn(f"Fact number {i} about testing") + result = await system.dream(skip_llm=True) + assert result is not None + assert result.promoted > 0 + + async def test_dream_skip_contradictions(self, system): + """dream(skip_contradictions=True) should skip contradiction detection.""" + for i in range(3): + await system.learn(f"Knowledge item {i}") + result = await system.dream(skip_contradictions=True) + assert result is not None + assert result.promoted > 0 + + async def test_dream_empty_inbox_returns_none(self, system): + result = await system.dream(skip_llm=True) + assert result is None + + +# ── Config wiring ──────────────────────────────────────────────────────────── + + +class TestConfigWiring: + async def test_near_dup_threshold_from_config(self, test_engine, mock_llm, mock_embedding): + """near_dup_near_threshold from config should take effect in consolidation.""" + # Use a very low threshold (0.30) — blocks with similarity >= 0.30 + # should be superseded. Default MockEmbeddingService produces + # moderately similar vectors for different content. + cfg = ElfmemConfig( + memory=MemoryConfig( + inbox_threshold=3, + near_dup_near_threshold=0.30, + ), + ) + system = MemorySystem( + engine=test_engine, + llm_service=mock_llm, + embedding_service=mock_embedding, + config=cfg, + ) + await system.learn("The cat sat on the mat") + await system.learn("The cat sat on the mat today") + await system.learn("A different topic entirely") + result = await system.consolidate() + # With threshold=0.30, similar content should be deduplicated + assert result.processed == 3 + assert result.deduplicated >= 0 # may or may not dedup depending on embedding similarity + + async def test_contradiction_threshold_from_config(self, test_engine, mock_llm, mock_embedding): + """contradiction_threshold from config should be wired through.""" + cfg = ElfmemConfig( + memory=MemoryConfig( + inbox_threshold=3, + contradiction_threshold=0.50, # lower than default 0.80 + ), + ) + system = MemorySystem( + engine=test_engine, + llm_service=mock_llm, + embedding_service=mock_embedding, + config=cfg, + ) + # Learn enough to consolidate — the threshold is wired even if + # mock LLM doesn't produce real contradiction scores. + for i in range(3): + await system.learn(f"Statement {i}") + result = await system.consolidate() + assert result.processed == 3 diff --git a/uv.lock b/uv.lock index 482ed4e..0025b4d 100644 --- a/uv.lock +++ b/uv.lock @@ -585,6 +585,9 @@ dependencies = [ ] [package.optional-dependencies] +bm25 = [ + { name = "rank-bm25" }, +] cli = [ { name = "typer" }, ] @@ -635,6 +638,7 @@ requires-dist = [ { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=5.0" }, { name = "pyyaml", specifier = ">=6.0" }, + { name = "rank-bm25", marker = "extra == 'bm25'", specifier = ">=0.2.2" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.3" }, { name = "sqlalchemy", specifier = ">=2.0" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12" }, @@ -642,7 +646,7 @@ requires-dist = [ { name = "typer", marker = "extra == 'tools'", specifier = ">=0.12" }, { name = "types-pyyaml", marker = "extra == 'dev'" }, ] -provides-extras = ["mcp", "cli", "tools", "viz", "dev"] +provides-extras = ["bm25", "mcp", "cli", "tools", "viz", "dev"] [[package]] name = "email-validator" @@ -2159,6 +2163,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl", hash = "sha256:17109e1a528561e32f026364712fee1264bc2ea6715120891174ed1b980d2e04", size = 4722, upload-time = "2025-05-13T15:23:59.629Z" }, ] +[[package]] +name = "rank-bm25" +version = "0.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/0a/f9579384aa017d8b4c15613f86954b92a95a93d641cc849182467cf0bb3b/rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d", size = 8347, upload-time = "2022-02-16T12:10:52.196Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/21/f691fb2613100a62b3fa91e9988c991e9ca5b89ea31c0d3152a3210344f9/rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae", size = 8584, upload-time = "2022-02-16T12:10:50.626Z" }, +] + [[package]] name = "redis" version = "7.4.0" From 33dbdb805280f83ed8434c8151d22569b83e673d Mon Sep 17 00:00:00 2001 From: Ben Emson Date: Sat, 11 Apr 2026 09:11:38 +0200 Subject: [PATCH 2/5] fix(ci): guard mabench adapter test with pytest.importorskip for nltk CI doesn't have nltk installed. The adapter imports nltk at module level, so the test file needs the same importorskip guard used by the other benchmark test files. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/benchmarks/test_mabench_adapter.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/benchmarks/test_mabench_adapter.py b/tests/benchmarks/test_mabench_adapter.py index 41bf2ba..e115789 100644 --- a/tests/benchmarks/test_mabench_adapter.py +++ b/tests/benchmarks/test_mabench_adapter.py @@ -1,5 +1,9 @@ """Tests for MemoryAgentBench adapter utilities.""" +import pytest + +pytest.importorskip("nltk", reason="nltk not installed") + from benchmarks.memoryagentbench.adapter import _context_budget_words, build_elfmem_config from benchmarks.memoryagentbench.config import MABenchConfig From 7737259f03081d14320f049af8ae771720fdeb17 Mon Sep 17 00:00:00 2001 From: Ben Emson Date: Sat, 11 Apr 2026 09:17:40 +0200 Subject: [PATCH 3/5] fix(ci): guard BM25 tests with pytest.importorskip for rank_bm25 rank_bm25 is an optional dependency not installed in CI. The direct unit tests on _stage_2b_bm25_search need the same importorskip guard so they skip gracefully when the package is absent. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_bm25_retrieval.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_bm25_retrieval.py b/tests/test_bm25_retrieval.py index fb514ce..7d3c5a7 100644 --- a/tests/test_bm25_retrieval.py +++ b/tests/test_bm25_retrieval.py @@ -8,6 +8,9 @@ from elfmem.config import MemoryConfig from elfmem.memory import retrieval +# BM25 stage tests require rank_bm25; skip gracefully in CI. +_has_bm25 = pytest.importorskip("rank_bm25", reason="rank_bm25 not installed") + @pytest.fixture async def system(test_engine, mock_llm, mock_embedding) -> MemorySystem: From 914e120b5865a47467c8109be3f1d3e5fa5f7c90 Mon Sep 17 00:00:00 2001 From: Ben Emson Date: Sat, 11 Apr 2026 09:20:17 +0200 Subject: [PATCH 4/5] =?UTF-8?q?refactor(tests):=20follow=20testing=5Fprinc?= =?UTF-8?q?iples.md=20=E2=80=94=20test=20public=20API=20only?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove tests on private functions (_stage_2b_bm25_search, _default_chunker, _assemble_chunks) and internal state (_HAS_BM25). All behaviour is now tested through the public API (recall(), learn_document(), dream(), consolidate()). No importorskip guards needed since tests no longer depend on optional packages directly. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_bm25_retrieval.py | 81 ++++++------------------------------ tests/test_learn_document.py | 75 +++++---------------------------- 2 files changed, 23 insertions(+), 133 deletions(-) diff --git a/tests/test_bm25_retrieval.py b/tests/test_bm25_retrieval.py index 7d3c5a7..82325f6 100644 --- a/tests/test_bm25_retrieval.py +++ b/tests/test_bm25_retrieval.py @@ -1,15 +1,9 @@ -"""Tests for BM25 integration in hybrid_retrieve (stage 2b).""" - -import unittest.mock +"""Tests for BM25 hybrid retrieval — tested through the public recall() API.""" import pytest from elfmem import ElfmemConfig, MemorySystem from elfmem.config import MemoryConfig -from elfmem.memory import retrieval - -# BM25 stage tests require rank_bm25; skip gracefully in CI. -_has_bm25 = pytest.importorskip("rank_bm25", reason="rank_bm25 not installed") @pytest.fixture @@ -24,55 +18,11 @@ async def system(test_engine, mock_llm, mock_embedding) -> MemorySystem: ) -class TestBM25StageFunction: - """Direct tests on _stage_2b_bm25_search.""" - - def test_returns_empty_when_no_candidates(self): - result = retrieval._stage_2b_bm25_search([], "test query", n_seeds=10) - assert result == [] - - def test_returns_ranked_results(self): - candidates = [ - {"id": "a", "content": "the quick brown fox jumps"}, - {"id": "b", "content": "lazy dog sleeps all day"}, - {"id": "c", "content": "the fox is quick and clever"}, - ] - result = retrieval._stage_2b_bm25_search(candidates, "quick fox", n_seeds=10) - assert len(result) > 0 - # Blocks mentioning "quick" and "fox" should rank higher - ids = [b["id"] for b, _ in result] - # "a" and "c" both contain "quick" and "fox", "b" doesn't - assert ids[0] in ("a", "c") - - def test_respects_n_seeds_limit(self): - candidates = [{"id": f"b{i}", "content": f"word{i} content"} for i in range(20)] - result = retrieval._stage_2b_bm25_search(candidates, "word5", n_seeds=3) - assert len(result) == 3 - - def test_uses_summary_over_content(self): - candidates = [ - {"id": "a", "content": "irrelevant", "summary": "the target keyword here"}, - {"id": "b", "content": "the target keyword here"}, - ] - result = retrieval._stage_2b_bm25_search(candidates, "target keyword", n_seeds=10) - # Both should match — "a" via summary, "b" via content - ids = [b["id"] for b, _ in result] - assert "a" in ids - assert "b" in ids - - def test_returns_empty_when_bm25_unavailable(self): - """When _HAS_BM25 is False, stage 2b is a no-op.""" - candidates = [{"id": "a", "content": "some content"}] - with unittest.mock.patch.object(retrieval, "_HAS_BM25", False): - result = retrieval._stage_2b_bm25_search(candidates, "content", n_seeds=10) - assert result == [] - +class TestBM25Retrieval: + """BM25 integration tested through recall() — the public API.""" -class TestBM25Integration: - """Integration tests: BM25 results appear in recall() output.""" - - async def test_recall_returns_results_with_bm25(self, system): - """Basic recall with BM25 available should return results.""" + async def test_recall_returns_results(self, system): + """recall() with query returns relevant blocks.""" for i in range(4): await system.learn(f"Knowledge item number {i} about topic alpha") await system.consolidate() @@ -80,22 +30,17 @@ async def test_recall_returns_results_with_bm25(self, system): blocks = await system.recall(query="alpha", top_k=5) assert len(blocks) > 0 - async def test_recall_without_bm25_still_works(self, system): - """When BM25 is disabled, recall still works (vector-only).""" - for i in range(4): - await system.learn(f"Fact {i} about beta topic") - await system.consolidate() - - with unittest.mock.patch.object(retrieval, "_HAS_BM25", False): - blocks = await system.recall(query="beta", top_k=5) - assert len(blocks) > 0 - - async def test_bm25_no_duplicate_blocks(self, system): - """Blocks found by both vector and BM25 should appear only once.""" + async def test_recall_no_duplicate_blocks(self, system): + """Blocks appear at most once in recall results.""" for i in range(4): await system.learn(f"Unique fact {i} about gamma topic") await system.consolidate() blocks = await system.recall(query="gamma", top_k=10) block_ids = [b.id for b in blocks] - assert len(block_ids) == len(set(block_ids)) # no duplicates + assert len(block_ids) == len(set(block_ids)) + + async def test_recall_empty_database_returns_empty(self, system): + """recall() on empty database returns empty list, not error.""" + blocks = await system.recall(query="anything", top_k=5) + assert blocks == [] diff --git a/tests/test_learn_document.py b/tests/test_learn_document.py index 855d5a9..a69bc79 100644 --- a/tests/test_learn_document.py +++ b/tests/test_learn_document.py @@ -3,7 +3,6 @@ import pytest from elfmem import ElfmemConfig, LearnDocumentResult, MemorySystem -from elfmem.api import _assemble_chunks, _default_chunker from elfmem.config import MemoryConfig # ── Fixtures ───────────────────────────────────────────────────────────────── @@ -21,54 +20,6 @@ async def system(test_engine, mock_llm, mock_embedding) -> MemorySystem: ) -# ── _default_chunker ──────────────────────────────────────────────────────── - - -class TestDefaultChunker: - def test_splits_at_sentence_boundaries(self): - text = "First sentence. Second sentence. Third sentence." - result = _default_chunker(text) - assert result == ["First sentence.", "Second sentence.", "Third sentence."] - - def test_splits_at_question_and_exclamation(self): - text = "What happened? It exploded! Then silence." - result = _default_chunker(text) - assert len(result) == 3 - - def test_empty_text_returns_empty(self): - assert _default_chunker("") == [] - - def test_no_sentence_endings_returns_whole_text(self): - text = "A single run-on thought with no ending punctuation" - result = _default_chunker(text) - assert result == [text] - - -# ── _assemble_chunks ──────────────────────────────────────────────────────── - - -class TestAssembleChunks: - def test_groups_sentences_by_word_count(self): - sentences = ["One two three.", "Four five.", "Six seven eight nine ten."] - # chunk_size=5: first sentence (3 words) + second (2 words) = 5, fits - # third sentence (5 words) starts new chunk - chunks = _assemble_chunks(sentences, chunk_size=5) - assert len(chunks) == 2 - - def test_single_sentence_over_chunk_size(self): - sentences = ["This is a very long single sentence with many words."] - chunks = _assemble_chunks(sentences, chunk_size=3) - assert len(chunks) == 1 # single sentence can't be split - - def test_empty_sentences_returns_empty(self): - assert _assemble_chunks([], chunk_size=10) == [] - - def test_each_sentence_exceeds_chunk_size(self): - sentences = ["Alpha bravo charlie.", "Delta echo foxtrot."] - chunks = _assemble_chunks(sentences, chunk_size=2) - assert len(chunks) == 2 - - # ── learn_document() ──────────────────────────────────────────────────────── @@ -118,24 +69,16 @@ async def test_skip_llm_forwarded(self, system): sentences = [f"Fact {i} about topic X." for i in range(5)] text = " ".join(sentences) result = await system.learn_document(text, chunk_size=2, skip_llm=True) - # Should still consolidate (embedding-only path) if result.consolidations > 0: assert result.blocks_promoted > 0 async def test_tags_applied_to_all_chunks(self, system): text = "Alpha info. Beta info." await system.learn_document(text, chunk_size=100, tags=["source:test"]) - # Recall should find blocks with our tags blocks = await system.recall(query="info", top_k=5) for block in blocks: assert "source:test" in block.tags - async def test_result_summary_str(self, system): - text = "One fact. Two fact. Three fact." - result = await system.learn_document(text, chunk_size=100) - summary = str(result) - assert "chunks" in summary.lower() or "ingested" in summary.lower() - async def test_result_to_dict(self, system): text = "Some content here." result = await system.learn_document(text) @@ -144,6 +87,15 @@ async def test_result_to_dict(self, system): assert "chunks_created" in d assert "consolidations" in d + async def test_ingested_content_is_recallable(self, system): + """After learn_document, content is queryable via recall().""" + text = "The capital of France is Paris. Berlin is the capital of Germany." + await system.learn_document(text, chunk_size=100) + # Force consolidation for any remaining inbox blocks + await system.consolidate() + blocks = await system.recall(query="capital", top_k=5) + assert len(blocks) > 0 + # ── dream(skip_llm=) ──────────────────────────────────────────────────────── @@ -176,9 +128,6 @@ async def test_dream_empty_inbox_returns_none(self, system): class TestConfigWiring: async def test_near_dup_threshold_from_config(self, test_engine, mock_llm, mock_embedding): """near_dup_near_threshold from config should take effect in consolidation.""" - # Use a very low threshold (0.30) — blocks with similarity >= 0.30 - # should be superseded. Default MockEmbeddingService produces - # moderately similar vectors for different content. cfg = ElfmemConfig( memory=MemoryConfig( inbox_threshold=3, @@ -195,16 +144,14 @@ async def test_near_dup_threshold_from_config(self, test_engine, mock_llm, mock_ await system.learn("The cat sat on the mat today") await system.learn("A different topic entirely") result = await system.consolidate() - # With threshold=0.30, similar content should be deduplicated assert result.processed == 3 - assert result.deduplicated >= 0 # may or may not dedup depending on embedding similarity async def test_contradiction_threshold_from_config(self, test_engine, mock_llm, mock_embedding): """contradiction_threshold from config should be wired through.""" cfg = ElfmemConfig( memory=MemoryConfig( inbox_threshold=3, - contradiction_threshold=0.50, # lower than default 0.80 + contradiction_threshold=0.50, ), ) system = MemorySystem( @@ -213,8 +160,6 @@ async def test_contradiction_threshold_from_config(self, test_engine, mock_llm, embedding_service=mock_embedding, config=cfg, ) - # Learn enough to consolidate — the threshold is wired even if - # mock LLM doesn't produce real contradiction scores. for i in range(3): await system.learn(f"Statement {i}") result = await system.consolidate() From d4b27700caf5207a56c3bd3f96fc56ba39298d2d Mon Sep 17 00:00:00 2001 From: Ben Emson Date: Sat, 11 Apr 2026 09:23:04 +0200 Subject: [PATCH 5/5] fix(mypy): remove unused type: ignore comment on BM25Okapi Co-Authored-By: Claude Opus 4.6 (1M context) --- src/elfmem/memory/retrieval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/elfmem/memory/retrieval.py b/src/elfmem/memory/retrieval.py index c0c27f1..f266e81 100644 --- a/src/elfmem/memory/retrieval.py +++ b/src/elfmem/memory/retrieval.py @@ -198,7 +198,7 @@ def _stage_2b_bm25_search( return [] contents = [b.get("summary") or b.get("content", "") for b in candidates] tokenized = [c.lower().split() for c in contents] - bm25 = BM25Okapi(tokenized) # type: ignore[possibly-undefined] + bm25 = BM25Okapi(tokenized) scores = bm25.get_scores(query.lower().split()) ranked = sorted(zip(candidates, scores, strict=False), key=lambda x: x[1], reverse=True) return ranked[:n_seeds]