From 2446a3c2a5423720dfaccd706a82649c983a3de0 Mon Sep 17 00:00:00 2001 From: coe0718 Date: Mon, 30 Mar 2026 20:46:24 -0400 Subject: [PATCH] Improve episodic memory ranking with reinforcement and decay --- docs/memory.md | 5 +- src/memory/__tests__/context-builder.test.ts | 55 ++++++++++ src/memory/__tests__/episodic.test.ts | 62 ++++++++++++ src/memory/__tests__/ranking.test.ts | 78 ++++++++++++++ src/memory/context-builder.ts | 12 ++- src/memory/episodic.ts | 43 ++++---- src/memory/ranking.ts | 101 +++++++++++++++++++ 7 files changed, 326 insertions(+), 30 deletions(-) create mode 100644 src/memory/__tests__/ranking.test.ts create mode 100644 src/memory/ranking.ts diff --git a/docs/memory.md b/docs/memory.md index 5ac6a74..70e42e8 100644 --- a/docs/memory.md +++ b/docs/memory.md @@ -17,6 +17,8 @@ Session transcripts stored as embeddings. Each episode contains: Search: "What happened last time I worked on the auth service?" +Episode ranking is not raw vector score alone. Retrieval blends semantic match with importance, reinforcement from repeated access, and decay over time so durable memories stay available while stale one-off memories fade. + ### Tier 2: Semantic Memory Accumulated facts with contradiction detection and temporal validity: @@ -50,7 +52,8 @@ Before each agent invocation, the context builder: 3. Searches semantic memory (top 20 facts) 4. Searches procedural memory (top 5 procedures) 5. Budgets results to fit within the token limit (default: 50,000 tokens) -6. Formats results into the memory section of the system prompt +6. Filters out stale, low-signal episodic memories before prompt injection +7. Formats results into the memory section of the system prompt ## Consolidation diff --git a/src/memory/__tests__/context-builder.test.ts b/src/memory/__tests__/context-builder.test.ts index d2ce1f4..13300c1 100644 --- a/src/memory/__tests__/context-builder.test.ts +++ b/src/memory/__tests__/context-builder.test.ts @@ -124,6 +124,61 @@ describe("MemoryContextBuilder", () => { expect(result).toContain("success"); }); + test("filters stale low-signal episodes from prompt context", async () => { + const memory = createMockMemorySystem({ + episodes: Promise.resolve([ + { + id: "stale-ep", + type: "task" as const, + summary: "One-off stale note", + detail: "No longer important", + parent_id: null, + session_id: "s1", + user_id: "u1", + tools_used: [], + files_touched: [], + outcome: "success" as const, + outcome_detail: "", + lessons: [], + started_at: new Date(Date.now() - 90 * 24 * 3600 * 1000).toISOString(), + ended_at: new Date(Date.now() - 90 * 24 * 3600 * 1000).toISOString(), + duration_seconds: 300, + importance: 0.2, + access_count: 0, + last_accessed_at: new Date(Date.now() - 90 * 24 * 3600 * 1000).toISOString(), + decay_rate: 1.0, + }, + { + id: "durable-ep", + type: "task" as const, + summary: "Repeated deployment pattern", + detail: "Still referenced often", + parent_id: null, + session_id: "s2", + user_id: "u1", + tools_used: ["Bash"], + files_touched: [], + outcome: "success" as const, + outcome_detail: "", + lessons: [], + started_at: new Date(Date.now() - 45 * 24 * 3600 * 1000).toISOString(), + ended_at: new Date(Date.now() - 45 * 24 * 3600 * 1000).toISOString(), + duration_seconds: 300, + importance: 0.8, + access_count: 4, + last_accessed_at: new Date(Date.now() - 24 * 3600 * 1000).toISOString(), + decay_rate: 1.0, + }, + ]), + }); + + const builder = new MemoryContextBuilder(memory, TEST_CONFIG); + const result = await builder.build("deployment"); + + expect(result).toContain("Repeated deployment pattern"); + expect(result).not.toContain("One-off stale note"); + }); + test("formats procedure section correctly", async () => { const memory = createMockMemorySystem({ procedure: Promise.resolve({ diff --git a/src/memory/__tests__/episodic.test.ts b/src/memory/__tests__/episodic.test.ts index c3537e1..1997e21 100644 --- a/src/memory/__tests__/episodic.test.ts +++ b/src/memory/__tests__/episodic.test.ts @@ -234,4 +234,66 @@ describe("EpisodicStore", () => { expect(episodes[0].id).toBe("new-ep"); expect(episodes[1].id).toBe("old-ep"); }); + + test("recall() metadata strategy favors reinforced memories", async () => { + const vec = make768dVector(); + const now = Date.now(); + + globalThis.fetch = mock((url: string | Request) => { + const urlStr = typeof url === "string" ? url : url.url; + + if (urlStr.includes("/api/embed")) { + return Promise.resolve(new Response(JSON.stringify({ embeddings: [vec] }), { status: 200 })); + } + + if (urlStr.includes("/points/query")) { + return Promise.resolve( + new Response( + JSON.stringify({ + result: { + points: [ + { + id: "stale-ep", + score: 0.82, + payload: { + type: "task", + summary: "Stale one-off episode", + importance: 0.3, + access_count: 0, + last_accessed_at: new Date(now - 45 * 24 * 3600 * 1000).toISOString(), + started_at: now - 45 * 24 * 3600 * 1000, + }, + }, + { + id: "durable-ep", + score: 0.7, + payload: { + type: "task", + summary: "Frequently reused deployment memory", + importance: 0.8, + access_count: 6, + last_accessed_at: new Date(now - 2 * 24 * 3600 * 1000).toISOString(), + started_at: now - 45 * 24 * 3600 * 1000, + }, + }, + ], + }, + }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ), + ); + } + + return Promise.resolve(new Response(JSON.stringify({ status: "ok" }), { status: 200 })); + }) as unknown as typeof fetch; + + const qdrant = new QdrantClient(TEST_CONFIG); + const embedder = new EmbeddingClient(TEST_CONFIG); + const store = new EpisodicStore(qdrant, embedder, TEST_CONFIG); + + const episodes = await store.recall("deployment", { strategy: "metadata" }); + + expect(episodes[0].id).toBe("durable-ep"); + expect(episodes[1].id).toBe("stale-ep"); + }); }); diff --git a/src/memory/__tests__/ranking.test.ts b/src/memory/__tests__/ranking.test.ts new file mode 100644 index 0000000..126d586 --- /dev/null +++ b/src/memory/__tests__/ranking.test.ts @@ -0,0 +1,78 @@ +import { describe, expect, test } from "bun:test"; +import { calculateEpisodeRecallScore, shouldIncludeEpisodeInContext } from "../ranking.ts"; +import type { Episode } from "../types.ts"; + +function makeEpisode(overrides?: Partial): Episode { + return { + id: "ep-1", + type: "task", + summary: "Memory summary", + detail: "Memory detail", + parent_id: null, + session_id: "session-1", + user_id: "user-1", + tools_used: [], + files_touched: [], + outcome: "success", + outcome_detail: "Completed successfully", + lessons: [], + started_at: new Date(Date.now() - 24 * 3600 * 1000).toISOString(), + ended_at: new Date().toISOString(), + duration_seconds: 60, + importance: 0.6, + access_count: 0, + last_accessed_at: new Date().toISOString(), + decay_rate: 1, + ...overrides, + }; +} + +describe("memory ranking", () => { + test("metadata strategy rewards reinforced memories", () => { + const staleWeak = calculateEpisodeRecallScore( + 0.82, + { + importance: 0.3, + accessCount: 0, + startedAt: Date.now() - 45 * 24 * 3600 * 1000, + lastAccessedAt: new Date(Date.now() - 45 * 24 * 3600 * 1000).toISOString(), + decayRate: 1, + }, + "metadata", + ); + + const durableRepeat = calculateEpisodeRecallScore( + 0.7, + { + importance: 0.8, + accessCount: 6, + startedAt: Date.now() - 45 * 24 * 3600 * 1000, + lastAccessedAt: new Date(Date.now() - 2 * 24 * 3600 * 1000).toISOString(), + decayRate: 1, + }, + "metadata", + ); + + expect(durableRepeat).toBeGreaterThan(staleWeak); + }); + + test("context filtering drops stale low-signal memories", () => { + const staleWeak = makeEpisode({ + importance: 0.2, + access_count: 0, + started_at: new Date(Date.now() - 60 * 24 * 3600 * 1000).toISOString(), + last_accessed_at: new Date(Date.now() - 60 * 24 * 3600 * 1000).toISOString(), + }); + + const durableRepeat = makeEpisode({ + id: "ep-2", + importance: 0.85, + access_count: 5, + started_at: new Date(Date.now() - 60 * 24 * 3600 * 1000).toISOString(), + last_accessed_at: new Date(Date.now() - 24 * 3600 * 1000).toISOString(), + }); + + expect(shouldIncludeEpisodeInContext(staleWeak)).toBe(false); + expect(shouldIncludeEpisodeInContext(durableRepeat)).toBe(true); + }); +}); diff --git a/src/memory/context-builder.ts b/src/memory/context-builder.ts index c0c9b12..00f5381 100644 --- a/src/memory/context-builder.ts +++ b/src/memory/context-builder.ts @@ -1,4 +1,5 @@ import type { MemoryConfig } from "../config/types.ts"; +import { shouldIncludeEpisodeInContext } from "./ranking.ts"; import type { MemorySystem } from "./system.ts"; import type { Episode, Procedure, SemanticFact } from "./types.ts"; @@ -44,10 +45,13 @@ export class MemoryContextBuilder { // Recent memories provide episode context if (episodes.length > 0 && tokenBudget > 500) { - const episodeSection = this.formatEpisodes(episodes, tokenBudget); + const durableEpisodes = episodes.filter(shouldIncludeEpisodeInContext); + const episodeSection = this.formatEpisodes(durableEpisodes, tokenBudget); const episodeTokens = this.estimateTokens(episodeSection); - sections.push(episodeSection); - tokenBudget -= episodeTokens; + if (episodeSection) { + sections.push(episodeSection); + tokenBudget -= episodeTokens; + } } // Relevant procedures @@ -70,6 +74,8 @@ export class MemoryContextBuilder { } private formatEpisodes(episodes: Episode[], tokenBudget: number): string { + if (episodes.length === 0) return ""; + const header = "## Recent Memories\n"; let content = header; const maxChars = tokenBudget * CHARS_PER_TOKEN; diff --git a/src/memory/episodic.ts b/src/memory/episodic.ts index 0c5b1e5..64c0674 100644 --- a/src/memory/episodic.ts +++ b/src/memory/episodic.ts @@ -1,6 +1,7 @@ import type { MemoryConfig } from "../config/types.ts"; import { type EmbeddingClient, textToSparseVector } from "./embeddings.ts"; import type { QdrantClient } from "./qdrant-client.ts"; +import { calculateEpisodeRecallScore } from "./ranking.ts"; import type { Episode, QdrantSearchResult, RecallOptions } from "./types.ts"; const COLLECTION_SCHEMA = { @@ -128,6 +129,7 @@ export class EpisodicStore { for (const id of ids) { try { await this.qdrant.updatePayload(this.collectionName, id, { + access_count: { $inc: 1 }, last_accessed_at: new Date().toISOString(), }); } catch { @@ -165,34 +167,23 @@ export class EpisodicStore { return { must }; } - private applyStrategy(results: QdrantSearchResult[], strategy: string): QdrantSearchResult[] { - const now = Date.now(); - + private applyStrategy(results: QdrantSearchResult[], strategy: RecallOptions["strategy"]): QdrantSearchResult[] { return results .map((r) => { - const startedAt = (r.payload.started_at as number) ?? 0; - const importance = (r.payload.importance as number) ?? 0.5; - const hoursSince = (now - startedAt) / (1000 * 60 * 60); - const recencyScore = Math.exp(-0.01 * hoursSince); - - let finalScore: number; - switch (strategy) { - case "similarity": - finalScore = r.score * 0.7 + importance * 0.2 + recencyScore * 0.1; - break; - case "temporal": - finalScore = recencyScore * 0.7 + r.score * 0.2 + importance * 0.1; - break; - case "metadata": - finalScore = r.score * 0.5 + recencyScore * 0.3 + importance * 0.2; - break; - default: - // recency-biased (default) - finalScore = r.score * 0.4 + recencyScore * 0.4 + importance * 0.2; - break; - } - - return { ...r, score: finalScore }; + return { + ...r, + score: calculateEpisodeRecallScore( + r.score, + { + importance: (r.payload.importance as number) ?? 0.5, + accessCount: (r.payload.access_count as number) ?? 0, + startedAt: (r.payload.started_at as number) ?? 0, + lastAccessedAt: (r.payload.last_accessed_at as string | undefined) ?? undefined, + decayRate: (r.payload.decay_rate as number) ?? 1, + }, + strategy, + ), + }; }) .sort((a, b) => b.score - a.score); } diff --git a/src/memory/ranking.ts b/src/memory/ranking.ts new file mode 100644 index 0000000..9b02c26 --- /dev/null +++ b/src/memory/ranking.ts @@ -0,0 +1,101 @@ +import type { Episode, RecallOptions } from "./types.ts"; + +type EpisodeRankingMetadata = { + importance?: number; + accessCount?: number; + startedAt?: number | string; + lastAccessedAt?: string; + decayRate?: number; +}; + +const MIN_DECAY_RATE = 0.25; +const MAX_DECAY_RATE = 3; +const RECENCY_HALF_LIFE_HOURS = 24 * 14; +const ACCESS_HALF_LIFE_HOURS = 24 * 21; +const ACCESS_SATURATION = Math.log1p(8); +const CONTEXT_SCORE_THRESHOLD = 0.25; + +export function calculateEpisodeRecallScore( + searchScore: number, + metadata: EpisodeRankingMetadata, + strategy: RecallOptions["strategy"] = "recency", +): number { + const signals = getEpisodeSignals(metadata); + + switch (strategy) { + case "similarity": + return weightedAverage(searchScore, signals.durability, signals.recency, 0.55, 0.3, 0.15); + case "temporal": + return weightedAverage(searchScore, signals.durability, signals.recency, 0.25, 0.2, 0.55); + case "metadata": + return weightedAverage(searchScore, signals.durability, signals.recency, 0.2, 0.6, 0.2); + default: + return weightedAverage(searchScore, signals.durability, signals.recency, 0.3, 0.3, 0.4); + } +} + +export function calculateEpisodeContextScore(episode: Episode): number { + const signals = getEpisodeSignals({ + importance: episode.importance, + accessCount: episode.access_count, + startedAt: episode.started_at, + lastAccessedAt: episode.last_accessed_at, + decayRate: episode.decay_rate, + }); + + return weightedAverage(signals.durability, signals.recency, 0, 0.6, 0.4, 0); +} + +export function shouldIncludeEpisodeInContext(episode: Episode): boolean { + if (episode.importance >= 0.85) return true; + if (episode.access_count >= 3) return true; + + return calculateEpisodeContextScore(episode) >= CONTEXT_SCORE_THRESHOLD; +} + +function getEpisodeSignals(metadata: EpisodeRankingMetadata): { durability: number; recency: number } { + const importance = clamp(metadata.importance ?? 0.5, 0, 1); + const accessCount = Math.max(0, metadata.accessCount ?? 0); + const decayRate = clamp(metadata.decayRate ?? 1, MIN_DECAY_RATE, MAX_DECAY_RATE); + const ageHours = hoursSince(metadata.startedAt); + const lastAccessHours = metadata.lastAccessedAt ? hoursSince(metadata.lastAccessedAt) : Number.POSITIVE_INFINITY; + + const recency = exponentialDecay(ageHours, RECENCY_HALF_LIFE_HOURS, decayRate); + const accessFreshness = + lastAccessHours === Number.POSITIVE_INFINITY + ? 0 + : exponentialDecay(lastAccessHours, ACCESS_HALF_LIFE_HOURS, decayRate); + const accessReinforcement = clamp(Math.log1p(accessCount) / ACCESS_SATURATION, 0, 1); + const durability = weightedAverage(importance, accessReinforcement, accessFreshness, 0.55, 0.3, 0.15); + + return { durability, recency }; +} + +function weightedAverage(a: number, b: number, c: number, aWeight: number, bWeight: number, cWeight: number): number { + return a * aWeight + b * bWeight + c * cWeight; +} + +function exponentialDecay(ageHours: number, halfLifeHours: number, decayRate: number): number { + if (!Number.isFinite(ageHours) || ageHours < 0) return 1; + return Math.exp(-((ageHours / halfLifeHours) * decayRate)); +} + +function hoursSince(value?: number | string): number { + if (value == null) return Number.POSITIVE_INFINITY; + + const timestamp = + typeof value === "number" + ? value + : (() => { + const parsed = Date.parse(value); + return Number.isNaN(parsed) ? Number.NaN : parsed; + })(); + + if (!Number.isFinite(timestamp)) return Number.POSITIVE_INFINITY; + + return Math.max(0, (Date.now() - timestamp) / (1000 * 60 * 60)); +} + +function clamp(value: number, min: number, max: number): number { + return Math.min(Math.max(value, min), max); +}