diff --git a/agent.py b/agent.py index dba7fe2..fd31162 100644 --- a/agent.py +++ b/agent.py @@ -12,6 +12,7 @@ import tools as _tools_init # ensure built-in tools are registered on import from providers import stream, AssistantTurn, TextChunk, ThinkingChunk, detect_provider from compaction import maybe_compact, estimate_tokens, get_context_limit, compact_messages +from context_gc import GCState import logging_utils as _log import quota as _quota from circuit_breaker import CircuitOpenError as _CircuitOpenError @@ -32,6 +33,11 @@ class AgentState: total_input_tokens: int = 0 total_output_tokens: int = 0 turn_count: int = 0 + # Persisted so trashed_ids, snippets and notes survive /save and /load. + # Without this, restoring a session leaks back tool_results the model had trashed. + gc_state: GCState = field(default_factory=GCState) + # Timeline of note changes (for debugging / replay) + notes_timeline: list = field(default_factory=list) @dataclass @@ -85,8 +91,11 @@ def run( user_msg["images"] = [pending_img] state.messages.append(user_msg) - # Inject runtime metadata into config so tools (e.g. Agent) can access it - config = {**config, "_depth": depth, "_system_prompt": system_prompt} + # Inject runtime metadata into config so tools (e.g. Agent, ContextGC) can access it. + # ContextGC reads and mutates config["_gc_state"]; without this binding, every call + # returns "Error: no GC state available" and no trashed_id is ever recorded. + config = {**config, "_depth": depth, "_system_prompt": system_prompt, + "_gc_state": state.gc_state, "_state": state} session_id = config.get("_session_id", "default") # Wire up structured logging from config (idempotent, cheap) @@ -120,7 +129,7 @@ def run( model=config["model"], system=system_prompt, messages=state.messages, - tool_schemas=get_tool_schemas(), + tool_schemas=get_tool_schemas(disabled=config.get("disabled_tools") or ()), config=config, ): if isinstance(event, (TextChunk, ThinkingChunk)): diff --git a/commands/session.py b/commands/session.py index 78b17c5..38c353e 100644 --- a/commands/session.py +++ b/commands/session.py @@ -55,9 +55,45 @@ def _build_session_data(state, session_id: str | None = None) -> dict: "turn_count": state.turn_count, "total_input_tokens": state.total_input_tokens, "total_output_tokens": state.total_output_tokens, + "gc_state": _serialize_gc_state(getattr(state, "gc_state", None)), } +def _serialize_gc_state(gc_state) -> dict: + """JSON-safe view of ContextGC state (trashed_ids, snippets, notes). + + Must be stable across saves so trashed_ids surviving a /load cannot leak + back into the model's context window (the "gc_state leak" class). + """ + if gc_state is None: + return {"trashed_ids": [], "snippets": {}, "notes": {}} + return { + "trashed_ids": sorted(gc_state.trashed_ids), + "snippets": dict(gc_state.snippets), + "notes": dict(gc_state.notes), + } + + +def _restore_state_from_data(state, data: dict) -> None: + """Apply a loaded session dict onto a state in-place. + + Single point of truth for /load, /resume and /cloudsave load. Covers the + full AgentState surface including gc_state — forgetting any of these is + how the session-save/restore roundtrip drifts from in-memory state. + """ + from context_gc import GCState + state.messages = data.get("messages", []) + state.turn_count = data.get("turn_count", 0) + state.total_input_tokens = data.get("total_input_tokens", 0) + state.total_output_tokens = data.get("total_output_tokens", 0) + gc = data.get("gc_state") or {} + state.gc_state = GCState( + trashed_ids=set(gc.get("trashed_ids") or []), + snippets=dict(gc.get("snippets") or {}), + notes=dict(gc.get("notes") or {}), + ) + + # ── /save ────────────────────────────────────────────────────────────────── def cmd_save(args: str, state, config) -> bool: @@ -312,10 +348,7 @@ def cmd_load(args: str, state, config) -> bool: except Exception as e: err(f"Cannot read session file: {e}") return True - state.messages = data.get("messages", []) - state.turn_count = data.get("turn_count", 0) - state.total_input_tokens = data.get("total_input_tokens", 0) - state.total_output_tokens = data.get("total_output_tokens", 0) + _restore_state_from_data(state, data) ok(f"Session loaded from {path} ({len(state.messages)} messages)") return True @@ -353,10 +386,7 @@ def cmd_resume(args: str, state, config) -> bool: except Exception as e: err(f"Cannot read session file: {e}") return True - state.messages = data.get("messages", []) - state.turn_count = data.get("turn_count", 0) - state.total_input_tokens = data.get("total_input_tokens", 0) - state.total_output_tokens = data.get("total_output_tokens", 0) + _restore_state_from_data(state, data) ok(f"Session loaded from {path} ({len(state.messages)} messages)") return True @@ -522,10 +552,7 @@ def cmd_cloudsave(args: str, state, config) -> bool: if err_msg: err(err_msg) return True - state.messages = data.get("messages", []) - state.turn_count = data.get("turn_count", 0) - state.total_input_tokens = data.get("total_input_tokens", 0) - state.total_output_tokens = data.get("total_output_tokens", 0) + _restore_state_from_data(state, data) ok(f"Session loaded from Gist ({len(state.messages)} messages).") return True diff --git a/context_gc.py b/context_gc.py new file mode 100644 index 0000000..69d3b02 --- /dev/null +++ b/context_gc.py @@ -0,0 +1,447 @@ +"""Model-driven context garbage collection for conversation history. + +Lets the LLM trash consumed tool results, keep relevant snippets, +and persist notes across turns to manage its context window. + +Flat-file port of bouzecode's context_gc/ package — combines: + state (GCState, process_gc_call, note_save, note_read) + apply (apply_gc, snippet handling) + notes (inject_notes) + audit (build_verbatim_audit_note, prepend_verbatim_audit) + stubs (strip_trashed_stubs, _is_stub, _is_auto_trashed_stub) +""" +from __future__ import annotations + +import re +import time +from dataclasses import dataclass, field + + +# ── Constants ────────────────────────────────────────────────────────────── + +METHODOLOGY_NOTE = "methodology" + + +# ── Stub detection ───────────────────────────────────────────────────────── + +_ELIDED_RE = re.compile(r'\s*') + +# Only matches stubs produced by apply_gc (model-driven trash). +# Does NOT match breadcrumbs from compact_tool_history — +# those must survive so the model retains a trace of prior tool calls. +_TRASHED_STUB_RE = re.compile(r'^\[.{1,60} -- (?:trashed by model|auto-trashed)\]$') +_AUTO_TRASHED_RE = re.compile(r'^\[.{1,60} -- auto-trashed\]$') + + +def _is_stub(content: str) -> bool: + """Return True for any GC stub (model-trashed OR auto-trashed). + + Used by audit to skip all stubs in the verbatim audit note. + """ + if not content or len(content) > 200: + return False + stripped = content.strip() + if not stripped: + return False + return bool(_TRASHED_STUB_RE.match(stripped)) + + +def _is_auto_trashed_stub(content: str) -> bool: + """Return True only for auto-trashed stubs (ContextGC's own results).""" + if not content or len(content) > 200: + return False + stripped = content.strip() + if not stripped: + return False + return bool(_AUTO_TRASHED_RE.match(stripped)) + + +# ── GCState ──────────────────────────────────────────────────────────────── + +@dataclass +class GCState: + trashed_ids: set = field(default_factory=set) + snippets: dict = field(default_factory=dict) + notes: dict = field(default_factory=dict) + compact_xml: bool = False + + +# ── Process ContextGC tool call ──────────────────────────────────────────── + +def process_gc_call(params: dict, config: dict) -> str: + gc_state: GCState = config.get("_gc_state") + if gc_state is None: + return "Error: no GC state available" + + trashed = params.get("trash") or [] + snippets = params.get("keep_snippets") or [] + notes = params.get("notes") or [] + trash_notes = params.get("trash_notes") or [] + + notes_before = dict(gc_state.notes) + + for tid in trashed: + gc_state.trashed_ids.add(tid) + gc_state.snippets.pop(tid, None) + + for snippet in snippets: + sid = snippet.get("id") + if sid and sid not in gc_state.trashed_ids: + gc_state.snippets[sid] = snippet + + for note in notes: + name = note.get("name") + content = note.get("content", "") + if name: + gc_state.notes[name] = content + + methodology_protected = False + for name in trash_notes: + if name == METHODOLOGY_NOTE: + methodology_protected = True + continue + gc_state.notes.pop(name, None) + + if params.get("compact_xml"): + gc_state.compact_xml = True + + # Track notes timeline + added = [k for k in gc_state.notes if k not in notes_before] + updated = [k for k in gc_state.notes if k in notes_before and gc_state.notes[k] != notes_before[k]] + removed = [k for k in notes_before if k not in gc_state.notes] + if added or updated or removed: + state = config.get("_state") + if state is not None and hasattr(state, "notes_timeline"): + state.notes_timeline.append({ + "turn": getattr(state, "turn_count", 0), + "timestamp": time.time(), + "notes": dict(gc_state.notes), + "delta": {"added": added, "updated": updated, "removed": removed}, + }) + + parts = [] + if trashed: + parts.append(f"trashed {len(trashed)} results") + if snippets: + parts.append(f"kept snippets for {len(snippets)} results") + if notes: + parts.append(f"{len(notes)} notes saved") + if trash_notes: + trashed_count = len(trash_notes) - (1 if methodology_protected else 0) + if trashed_count: + parts.append(f"{trashed_count} notes removed") + if methodology_protected: + parts.append(f"note '{METHODOLOGY_NOTE}' protected from trash") + if params.get("compact_xml"): + parts.append("XML compaction enabled") + parts.append(f"{len(gc_state.notes)} active notes, {len(gc_state.trashed_ids)} total trashed") + return "GC applied: " + ", ".join(parts) + + +# ── NoteSave / NoteRead ─────────────────────────────────────────────────── + +def note_save(params: dict, config: dict) -> str: + gc_state: GCState = config.get("_gc_state") + if gc_state is None: + return "Error: no GC state available" + + name = params.get("name", "") + content = params.get("content", "") + if not name: + return "Error: 'name' is required" + + notes_before = dict(gc_state.notes) + gc_state.notes[name] = content + + is_new = name not in notes_before + changed = not is_new and notes_before[name] != content + if is_new or changed: + state = config.get("_state") + if state is not None and hasattr(state, "notes_timeline"): + state.notes_timeline.append({ + "turn": getattr(state, "turn_count", 0), + "timestamp": time.time(), + "notes": dict(gc_state.notes), + "delta": { + "added": [name] if is_new else [], + "updated": [name] if changed else [], + "removed": [], + }, + }) + + action = "created" if is_new else ("updated" if changed else "unchanged") + return f"Note '{name}' {action}. {len(gc_state.notes)} active notes." + + +def note_read(params: dict, config: dict) -> str: + gc_state: GCState = config.get("_gc_state") + if gc_state is None: + return "Error: no GC state available" + + name = params.get("name") + if name: + content = gc_state.notes.get(name) + if content is None: + available = ", ".join(sorted(gc_state.notes)) or "(none)" + return f"Note '{name}' not found. Active notes: {available}" + return f"## {name}\n{content}" + + if not gc_state.notes: + return "No active notes." + parts = [] + for n, c in gc_state.notes.items(): + parts.append(f"## {n}\n{c}") + return "\n\n".join(parts) + + +# ── Apply GC (transform messages before API call) ───────────────────────── + +def apply_gc(messages: list, gc_state: GCState) -> list: + if not gc_state.trashed_ids and not gc_state.snippets and not gc_state.compact_xml: + return messages + + _compact_all = None + _compact_selective = None + last_asst_idx = -1 + + if gc_state.compact_xml: + from followup_compaction import compact_assistant_xml + _compact_all = compact_assistant_xml + for i in range(len(messages) - 1, -1, -1): + if messages[i].get("role") == "assistant": + last_asst_idx = i + break + + if gc_state.trashed_ids: + from followup_compaction import compact_assistant_xml_selective + _compact_selective = compact_assistant_xml_selective + + result = [] + for idx, msg in enumerate(messages): + role = msg.get("role") + if role == "assistant" and msg.get("tool_calls"): + if _compact_all and idx != last_asst_idx: + stubbed = dict(msg) + stubbed["content"] = _compact_all(msg["content"], msg["tool_calls"]) + result.append(stubbed) + continue + if _compact_selective: + tc_ids = {tc.get("id") for tc in msg["tool_calls"]} + targeted = tc_ids & gc_state.trashed_ids + if targeted: + stubbed = dict(msg) + stubbed["content"] = _compact_selective( + msg["content"], msg["tool_calls"], targeted, + ) + result.append(stubbed) + continue + result.append(msg) + continue + if role != "tool": + result.append(msg) + continue + tc_id = msg.get("tool_call_id", "") + if tc_id in gc_state.trashed_ids: + stubbed = dict(msg) + name = msg.get("name", "tool") + stubbed["content"] = f"[{name} result -- trashed by model]" + result.append(stubbed) + elif tc_id in gc_state.snippets: + transformed = dict(msg) + transformed["content"] = _apply_snippet(msg["content"], gc_state.snippets[tc_id]) + result.append(transformed) + else: + result.append(msg) + return result + + +def _apply_snippet(content: str, snippet: dict) -> str: + if not content: + return content + lines = content.split("\n") + + if "keep_after" in snippet: + anchor = snippet["keep_after"] + idx = _find_anchor_line(lines, anchor) + if idx is None: + return content + f"\n[GC warning: anchor {anchor!r} not found, kept full result]" + kept = lines[idx:] + trimmed = len(lines) - len(kept) + return f"[{trimmed} lines trimmed, kept after {anchor!r}]\n" + "\n".join(kept) + + if "keep_before" in snippet: + anchor = snippet["keep_before"] + idx = _find_anchor_line(lines, anchor) + if idx is None: + return content + f"\n[GC warning: anchor {anchor!r} not found, kept full result]" + kept = lines[:idx] + trimmed = len(lines) - len(kept) + return "\n".join(kept) + f"\n[{trimmed} lines trimmed at {anchor!r}]" + + if "keep_between" in snippet: + anchors = snippet["keep_between"] + if len(anchors) != 2: + return content + "\n[GC warning: keep_between needs exactly 2 anchors]" + start_anchor, end_anchor = anchors + start_idx = _find_anchor_line(lines, start_anchor) + if start_idx is None: + return content + f"\n[GC warning: start anchor {start_anchor!r} not found]" + end_idx = _find_anchor_line(lines, end_anchor, start_from=start_idx) + if end_idx is None: + return content + f"\n[GC warning: end anchor {end_anchor!r} not found]" + kept = lines[start_idx:end_idx + 1] + before = start_idx + after = len(lines) - end_idx - 1 + header = f"[{before} lines trimmed before {start_anchor!r}]" + footer = f"[{after} lines trimmed after {end_anchor!r}]" + return header + "\n" + "\n".join(kept) + "\n" + footer + + return content + + +def _find_anchor_line(lines: list, text: str, start_from: int = 0) -> int | None: + for i in range(start_from, len(lines)): + if text in lines[i]: + return i + return None + + +# ── Notes injection ─────────────────────────────────────────────────────── + +def inject_notes(messages: list, notes: dict) -> list: + if not notes: + return messages + parts = [] + for name, content in notes.items(): + parts.append(f"## {name}\n{content}") + notes_block = "[Your working memory notes]\n" + "\n\n".join(parts) + "\n[/Notes]" + result = list(messages) + for i in range(len(result) - 1, -1, -1): + if result[i].get("role") == "user": + result[i] = dict(result[i]) + result[i]["content"] = notes_block + "\n\n" + result[i]["content"] + break + return result + + +# ── Verbatim audit ──────────────────────────────────────────────────────── + +_ARGS_PREFERRED_KEY = { + "Read": "file_path", "Edit": "file_path", "Write": "file_path", + "NotebookEdit": "notebook_path", + "Glob": "pattern", "Grep": "pattern", + "Bash": "command", + "WebFetch": "url", "WebSearch": "query", +} + + +def _summarize_args(tool_name: str, input_dict: dict, max_len: int = 60) -> str: + if not input_dict: + return "" + val = input_dict.get(_ARGS_PREFERRED_KEY.get(tool_name, "")) + if val is None: + for v in input_dict.values(): + if isinstance(v, str) and v: + val = v + break + if val is None: + return "" + val = str(val).replace("\n", " ") + if len(val) > max_len: + val = val[: max_len - 3] + "..." + return val + + +def build_verbatim_audit_note(messages: list) -> str: + """List every tool_result still kept verbatim with its token size. + + Each entry includes the tool's key arg (file_path, pattern, command...) + so the model can correlate notes with results already in context. + """ + from compaction import estimate_tokens + args_by_id: dict[str, dict] = {} + for message in messages: + if message.get("role") != "assistant": + continue + for tc in message.get("tool_calls") or []: + tc_id = tc.get("id") + if tc_id: + args_by_id[tc_id] = tc.get("input") or {} + lines = [] + for message in messages: + if message.get("role") != "tool": + continue + content = message.get("content", "") + if isinstance(content, list): + content = "".join( + block.get("text", "") if isinstance(block, dict) else str(block) + for block in content + ) + if _is_stub(content): + continue + tool_call_id = message.get("tool_call_id", "?") + tool_name = message.get("name", "?") + size = estimate_tokens([{"content": content}]) + args = _summarize_args(tool_name, args_by_id.get(tool_call_id, {})) + suffix = f" {args}" if args else "" + lines.append(f"- {tool_call_id} ({tool_name}{suffix}): {size} tk") + if not lines: + return "" + return ( + "[Verbatim tool_results still in your context -- trash any you've already consumed]\n" + + "\n".join(lines) + + "\n[/Verbatim audit]" + ) + + +def prepend_verbatim_audit(messages: list) -> list: + """Prepend the verbatim audit note to the last user message.""" + note = build_verbatim_audit_note(messages) + if not note: + return messages + result = list(messages) + for i in range(len(result) - 1, -1, -1): + if result[i].get("role") == "user": + result[i] = dict(result[i]) + result[i]["content"] = note + "\n\n" + result[i]["content"] + break + return result + + +# ── Strip auto-trashed stubs ───────────────────────────────────────────── + +def strip_trashed_stubs(messages: list) -> list: + """Remove auto-trashed tool messages and their tool_call entries entirely.""" + stubbed_ids = set() + for msg in messages: + if msg.get("role") == "tool": + content = msg.get("content", "") + if _is_auto_trashed_stub(content): + tc_id = msg.get("tool_call_id", "") + if tc_id: + stubbed_ids.add(tc_id) + if not stubbed_ids: + return messages + result = [] + for msg in messages: + role = msg.get("role") + if role == "tool" and msg.get("tool_call_id", "") in stubbed_ids: + continue + if role == "assistant" and msg.get("tool_calls"): + original_tcs = msg["tool_calls"] + remaining = [tc for tc in original_tcs if tc.get("id") not in stubbed_ids] + if len(remaining) == len(original_tcs): + result.append(msg) + continue + cleaned = dict(msg) + content = cleaned.get("content", "") or "" + if not remaining: + content = _ELIDED_RE.sub("", content).strip() + cleaned.pop("tool_calls", None) + else: + cleaned["tool_calls"] = remaining + cleaned["content"] = content + result.append(cleaned) + continue + result.append(msg) + return result diff --git a/followup_compaction.py b/followup_compaction.py new file mode 100644 index 0000000..a5eb61d --- /dev/null +++ b/followup_compaction.py @@ -0,0 +1,162 @@ +"""Follow-up compaction: stub past-turn tool_results before each API call. + +Non-destructive: produces a new message list, leaves `state.messages` intact +so persistence and resume keep the full history. +""" +from __future__ import annotations + +import html +import json +import time +from typing import Iterable + +DEFAULT_EXEMPT_TOOLS = frozenset({"Edit", "Write", "TodoWrite"}) + + +def compact_tool_history( + messages: list, + keep_last_n_turns: int = 0, + exempt_tools: Iterable[str] = DEFAULT_EXEMPT_TOOLS, +) -> list: + """Return a NEW list where past-turn tool_result contents are replaced by stubs. + + A "turn" begins at a role='user' message. The current turn (from the last + user message onward) is always kept intact. + """ + exempt = frozenset(exempt_tools) + user_indices = [i for i, m in enumerate(messages) if m.get("role") == "user"] + if len(user_indices) <= keep_last_n_turns + 1: + return list(messages) + + cutoff = user_indices[-(keep_last_n_turns + 1)] + tool_call_lookup = _build_tool_call_lookup(messages) + + compacted = [] + for index, message in enumerate(messages): + if index >= cutoff: + compacted.append(message) + continue + role = message.get("role") + if role != "tool" or message.get("name") in exempt: + compacted.append(message) + continue + tool_call_id = message.get("tool_call_id", "") + name, inp = tool_call_lookup.get( + tool_call_id, (message.get("name", "tool"), {}) + ) + stubbed = dict(message) + stubbed["content"] = _build_stub(name, inp) + compacted.append(stubbed) + return compacted + + +def _build_tool_call_lookup(messages: list) -> dict: + lookup: dict = {} + for message in messages: + if message.get("role") != "assistant": + continue + for tool_call in message.get("tool_calls") or []: + lookup[tool_call.get("id", "")] = ( + tool_call.get("name", ""), + tool_call.get("input") or {}, + ) + return lookup + + +def _escape_xml_attr(value: str) -> str: + return html.escape(value, quote=False).replace('"', '"') + + +def _build_stub(name: str, input_dict: dict) -> str: + brief = _input_brief(name, input_dict) + return f'' + + +def _input_brief(name: str, inp: dict) -> str: + if name == "Read": + path = inp.get("file_path", "?") + parts = [f"file_path={path}"] + if "offset" in inp: + parts.append(f"offset={inp['offset']}") + if "limit" in inp: + parts.append(f"limit={inp['limit']}") + return ", ".join(parts) + if name == "Bash": + cmd = (inp.get("command") or "").replace("\n", " ") + if len(cmd) > 100: + cmd = cmd[:97] + "..." + return f"command={cmd!r}" + if name == "Grep": + parts = [f"pattern={inp.get('pattern', '?')!r}"] + if "path" in inp: + parts.append(f"path={inp['path']}") + return ", ".join(parts) + if name == "Glob": + return f"pattern={inp.get('pattern', '?')!r}" + try: + rendered = json.dumps(inp, ensure_ascii=False) + except (TypeError, ValueError): + rendered = str(inp) + if len(rendered) > 120: + rendered = rendered[:117] + "..." + return rendered + + +def build_messages_for_api(state, config: dict) -> list: + """Apply follow-up compaction + model-driven GC, then inject working memory notes.""" + if not config.get("followup_compaction_enabled", True): + compacted = list(state.messages) + else: + keep = config.get("followup_keep_last_n_turns", 0) + exempt = config.get("followup_exempt_tools", DEFAULT_EXEMPT_TOOLS) + compacted = compact_tool_history(state.messages, keep_last_n_turns=keep, exempt_tools=exempt) + + from compaction import estimate_tokens + tokens_before = estimate_tokens(state.messages) + tokens_after = estimate_tokens(compacted) + if tokens_before != tokens_after: + state.compaction_log.append({ + "event": "followup_compact", + "timestamp": time.time(), + "turn": getattr(state, "turn_count", 0), + "tokens_est_before": tokens_before, + "tokens_est_after": tokens_after, + "tokens_est_saved": tokens_before - tokens_after, + }) + + return _apply_context_gc(compacted, state) + + +def _apply_context_gc(messages: list, state) -> list: + """Apply model-driven GC decisions and inject working memory notes. + + Falls back to returning messages unchanged when the context_gc module is + absent (this PR can ship independently of PR #55). The import is narrow: + only ImportError is swallowed; any other error propagates. + """ + try: + from context_gc import apply_gc, inject_notes, prepend_verbatim_audit + except ImportError: + return messages + gc_state = getattr(state, 'gc_state', None) + if not gc_state: + return prepend_verbatim_audit(messages) + if not gc_state.trashed_ids and not gc_state.snippets and not gc_state.notes: + return prepend_verbatim_audit(messages) + + from compaction import estimate_tokens + tokens_before = estimate_tokens(messages) + result = apply_gc(messages, gc_state) + result = inject_notes(result, gc_state.notes) + tokens_after = estimate_tokens(result) + if tokens_before != tokens_after: + state.compaction_log.append({ + "event": "context_gc", + "timestamp": time.time(), + "turn": getattr(state, "turn_count", 0), + "trashed_count": len(gc_state.trashed_ids), + "snippet_count": len(gc_state.snippets), + "notes_count": len(gc_state.notes), + "tokens_est_saved": tokens_before - tokens_after, + }) + return prepend_verbatim_audit(result) diff --git a/tests/test_context_gc.py b/tests/test_context_gc.py new file mode 100644 index 0000000..9ea2d2d --- /dev/null +++ b/tests/test_context_gc.py @@ -0,0 +1,346 @@ +"""Tests for context_gc module.""" +import pytest + +from context_gc import ( + GCState, process_gc_call, note_save, note_read, + apply_gc, _apply_snippet, _find_anchor_line, + inject_notes, build_verbatim_audit_note, prepend_verbatim_audit, + _is_stub, _is_auto_trashed_stub, strip_trashed_stubs, + METHODOLOGY_NOTE, _summarize_args, +) + + +class TestGCState: + def test_defaults(self): + gs = GCState() + assert gs.trashed_ids == set() + assert gs.snippets == {} + assert gs.notes == {} + assert gs.compact_xml is False + + +class TestProcessGCCall: + def _make_config(self): + return {"_gc_state": GCState()} + + def test_no_gc_state(self): + result = process_gc_call({}, {}) + assert "Error" in result + + def test_trash(self): + cfg = self._make_config() + result = process_gc_call({"trash": ["r1", "r2"]}, cfg) + assert "trashed 2 results" in result + assert "r1" in cfg["_gc_state"].trashed_ids + assert "r2" in cfg["_gc_state"].trashed_ids + + def test_notes(self): + cfg = self._make_config() + result = process_gc_call( + {"notes": [{"name": "key", "content": "value"}]}, cfg + ) + assert "1 notes saved" in result + assert cfg["_gc_state"].notes["key"] == "value" + + def test_trash_notes(self): + cfg = self._make_config() + cfg["_gc_state"].notes["old"] = "data" + result = process_gc_call({"trash_notes": ["old"]}, cfg) + assert "1 notes removed" in result + assert "old" not in cfg["_gc_state"].notes + + def test_methodology_protected(self): + cfg = self._make_config() + cfg["_gc_state"].notes[METHODOLOGY_NOTE] = "important" + result = process_gc_call({"trash_notes": [METHODOLOGY_NOTE]}, cfg) + assert "protected from trash" in result + assert METHODOLOGY_NOTE in cfg["_gc_state"].notes + + def test_keep_snippets(self): + cfg = self._make_config() + result = process_gc_call( + {"keep_snippets": [{"id": "r1", "keep_after": "def main"}]}, cfg + ) + assert "kept snippets for 1 results" in result + assert "r1" in cfg["_gc_state"].snippets + + def test_snippet_ignored_if_trashed(self): + cfg = self._make_config() + cfg["_gc_state"].trashed_ids.add("r1") + process_gc_call( + {"keep_snippets": [{"id": "r1", "keep_after": "x"}]}, cfg + ) + assert "r1" not in cfg["_gc_state"].snippets + + def test_compact_xml(self): + cfg = self._make_config() + result = process_gc_call({"compact_xml": True}, cfg) + assert "XML compaction enabled" in result + assert cfg["_gc_state"].compact_xml is True + + +class TestNoteSave: + def _make_config(self): + return {"_gc_state": GCState()} + + def test_no_gc_state(self): + assert "Error" in note_save({}, {}) + + def test_missing_name(self): + cfg = self._make_config() + assert "Error" in note_save({"content": "x"}, cfg) + + def test_create(self): + cfg = self._make_config() + result = note_save({"name": "k", "content": "v"}, cfg) + assert "created" in result + assert cfg["_gc_state"].notes["k"] == "v" + + def test_update(self): + cfg = self._make_config() + cfg["_gc_state"].notes["k"] = "old" + result = note_save({"name": "k", "content": "new"}, cfg) + assert "updated" in result + assert cfg["_gc_state"].notes["k"] == "new" + + def test_unchanged(self): + cfg = self._make_config() + cfg["_gc_state"].notes["k"] = "same" + result = note_save({"name": "k", "content": "same"}, cfg) + assert "unchanged" in result + + +class TestNoteRead: + def _make_config(self): + return {"_gc_state": GCState()} + + def test_no_gc_state(self): + assert "Error" in note_read({}, {}) + + def test_read_specific(self): + cfg = self._make_config() + cfg["_gc_state"].notes["k"] = "v" + result = note_read({"name": "k"}, cfg) + assert "## k\nv" in result + + def test_read_not_found(self): + cfg = self._make_config() + result = note_read({"name": "missing"}, cfg) + assert "not found" in result + + def test_read_all(self): + cfg = self._make_config() + cfg["_gc_state"].notes["a"] = "1" + cfg["_gc_state"].notes["b"] = "2" + result = note_read({}, cfg) + assert "## a\n1" in result + assert "## b\n2" in result + + def test_read_all_empty(self): + cfg = self._make_config() + result = note_read({}, cfg) + assert "No active notes" in result + + +class TestApplyGC: + def test_no_changes(self): + gs = GCState() + msgs = [{"role": "user", "content": "hi"}] + assert apply_gc(msgs, gs) is msgs + + def test_trash_tool_result(self): + gs = GCState() + gs.trashed_ids.add("tc1") + msgs = [ + {"role": "tool", "tool_call_id": "tc1", "name": "Read", "content": "big data..."}, + {"role": "tool", "tool_call_id": "tc2", "name": "Grep", "content": "kept"}, + ] + result = apply_gc(msgs, gs) + assert "trashed by model" in result[0]["content"] + assert result[1]["content"] == "kept" + + def test_snippet_applied(self): + gs = GCState() + gs.snippets["tc1"] = {"id": "tc1", "keep_after": "def main"} + content = "import os\n\ndef main():\n pass\n" + msgs = [{"role": "tool", "tool_call_id": "tc1", "name": "Read", "content": content}] + result = apply_gc(msgs, gs) + assert "def main" in result[0]["content"] + assert "import os" not in result[0]["content"] + + def test_non_tool_messages_pass_through(self): + gs = GCState() + gs.trashed_ids.add("x") + msgs = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "world"}, + ] + result = apply_gc(msgs, gs) + assert len(result) == 2 + assert result[0]["content"] == "hello" + + +class TestApplySnippet: + def test_keep_after(self): + content = "line1\nline2\ndef main():\n pass" + result = _apply_snippet(content, {"keep_after": "def main"}) + assert "def main" in result + assert "line1" not in result + assert "2 lines trimmed" in result + + def test_keep_before(self): + content = "line1\nline2\nclass Foo:\n pass" + result = _apply_snippet(content, {"keep_before": "class Foo"}) + assert "line1" in result + assert "class Foo:\n pass" not in result + assert "2 lines trimmed" in result + + def test_keep_between(self): + content = "a\nb\nSTART\nc\nd\nEND\ne\nf" + result = _apply_snippet(content, {"keep_between": ["START", "END"]}) + assert "START" in result + assert "END" in result + assert "a\n" not in result + + def test_anchor_not_found(self): + content = "some text" + result = _apply_snippet(content, {"keep_after": "MISSING"}) + assert "GC warning" in result + assert "some text" in result + + def test_empty_content(self): + assert _apply_snippet("", {"keep_after": "x"}) == "" + + def test_keep_between_bad_anchors(self): + result = _apply_snippet("text", {"keep_between": ["a"]}) + assert "needs exactly 2 anchors" in result + + +class TestFindAnchorLine: + def test_found(self): + assert _find_anchor_line(["a", "b", "c"], "b") == 1 + + def test_not_found(self): + assert _find_anchor_line(["a", "b"], "z") is None + + def test_start_from(self): + assert _find_anchor_line(["a", "b", "a"], "a", start_from=1) == 2 + + +class TestInjectNotes: + def test_empty_notes(self): + msgs = [{"role": "user", "content": "hi"}] + assert inject_notes(msgs, {}) is msgs + + def test_inject(self): + msgs = [{"role": "user", "content": "hello"}] + result = inject_notes(msgs, {"key": "value"}) + assert "[Your working memory notes]" in result[0]["content"] + assert "## key\nvalue" in result[0]["content"] + assert "hello" in result[0]["content"] + + def test_injects_in_last_user_msg(self): + msgs = [ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "reply"}, + {"role": "user", "content": "second"}, + ] + result = inject_notes(msgs, {"n": "v"}) + assert "[Your working memory notes]" in result[2]["content"] + assert result[0]["content"] == "first" + + +class TestStubs: + def test_is_stub_trashed(self): + assert _is_stub("[Read result -- trashed by model]") is True + + def test_is_stub_auto_trashed(self): + assert _is_stub("[ContextGC result -- auto-trashed]") is True + + def test_is_stub_normal_content(self): + assert _is_stub("some normal content") is False + + def test_is_stub_too_long(self): + assert _is_stub("x" * 201) is False + + def test_is_auto_trashed(self): + assert _is_auto_trashed_stub("[ContextGC result -- auto-trashed]") is True + assert _is_auto_trashed_stub("[Read result -- trashed by model]") is False + + def test_strip_trashed_stubs(self): + msgs = [ + {"role": "assistant", "content": "text", "tool_calls": [ + {"id": "gc1", "name": "ContextGC", "input": {}}, + {"id": "r1", "name": "Read", "input": {}}, + ]}, + {"role": "tool", "tool_call_id": "gc1", "name": "ContextGC", + "content": "[ContextGC result -- auto-trashed]"}, + {"role": "tool", "tool_call_id": "r1", "name": "Read", + "content": "file content"}, + ] + result = strip_trashed_stubs(msgs) + assert len(result) == 2 # assistant + Read result + # gc1 tool_call removed from assistant + assert len(result[0]["tool_calls"]) == 1 + assert result[0]["tool_calls"][0]["id"] == "r1" + + def test_strip_no_stubs(self): + msgs = [{"role": "user", "content": "hi"}] + assert strip_trashed_stubs(msgs) is msgs + + +class TestSummarizeArgs: + def test_read(self): + assert _summarize_args("Read", {"file_path": "/a/b.py"}) == "/a/b.py" + + def test_bash(self): + assert "echo" in _summarize_args("Bash", {"command": "echo hi"}) + + def test_truncate(self): + result = _summarize_args("Read", {"file_path": "x" * 100}, max_len=20) + assert len(result) == 20 + assert result.endswith("...") + + def test_empty(self): + assert _summarize_args("Read", {}) == "" + + def test_fallback(self): + result = _summarize_args("Custom", {"arg": "val"}) + assert result == "val" + + +class TestVerbatimAudit: + def test_empty(self): + assert build_verbatim_audit_note([]) == "" + + def test_skips_trashed(self): + msgs = [{"role": "tool", "tool_call_id": "t1", "name": "Read", + "content": "[Read result -- trashed by model]"}] + assert build_verbatim_audit_note(msgs) == "" + + def test_skips_auto_trashed(self): + msgs = [{"role": "tool", "tool_call_id": "t1", "name": "ContextGC", + "content": "[ContextGC result -- auto-trashed]"}] + assert build_verbatim_audit_note(msgs) == "" + + def test_includes_verbatim(self): + msgs = [ + {"role": "assistant", "content": "", "tool_calls": [ + {"id": "r1", "name": "Read", "input": {"file_path": "test.py"}} + ]}, + {"role": "tool", "tool_call_id": "r1", "name": "Read", + "content": "file content here"}, + ] + result = build_verbatim_audit_note(msgs) + assert "r1" in result + assert "Read" in result + assert "test.py" in result + assert "tk" in result + + def test_prepend(self): + msgs = [ + {"role": "user", "content": "hi"}, + {"role": "tool", "tool_call_id": "r1", "name": "Read", "content": "data"}, + ] + result = prepend_verbatim_audit(msgs) + assert "[Verbatim" in result[0]["content"] diff --git a/tests/test_context_gc_e2e.py b/tests/test_context_gc_e2e.py new file mode 100644 index 0000000..b74d2b8 --- /dev/null +++ b/tests/test_context_gc_e2e.py @@ -0,0 +1,153 @@ +"""End-to-end: drive a real agent.run() conversation where the LLM calls +ContextGC, and verify gc_state ends up correctly populated + survives a +session save/load roundtrip. + +Only the LLM provider is mocked (via monkeypatching agent.stream). The tool +registry, session serializer and ContextGC dispatch all run for real. +""" +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +import tools as _tools_init # noqa: F401 - force built-in tool registration +from agent import AgentState, run +from providers import AssistantTurn, TextChunk +from tool_registry import ToolDef, register_tool +from commands.session import _build_session_data, _restore_state_from_data + + +def _scripted_stream(turns): + """Yield pre-scripted AssistantTurn objects one per call to stream(...). + + Signature matches providers.stream(**kwargs). We ignore all kwargs. + """ + cursor = iter(turns) + + def fake_stream(**_kwargs): + spec = next(cursor) + if spec.get("text"): + yield TextChunk(spec["text"]) + yield AssistantTurn( + text=spec.get("text", ""), + tool_calls=spec.get("tool_calls") or [], + in_tokens=1, + out_tokens=1, + ) + + return fake_stream + + +@pytest.fixture +def echo_tool(): + """Register a simple echo tool that returns its input verbatim. + + Non-invasive: leaves the rest of the registry intact (built-ins + plugins + loaded at module import) so unrelated tests sharing the process still see + their tools. Only the echo entry is removed on teardown. + """ + from tool_registry import _registry # private, but fine for test isolation + had_echo_before = "echo" in _registry + register_tool(ToolDef( + name="echo", + schema={ + "name": "echo", + "description": "echo", + "input_schema": { + "type": "object", + "properties": {"text": {"type": "string"}}, + }, + }, + func=lambda params, _cfg: f"echoed: {params.get('text', '')}", + read_only=True, concurrent_safe=True, + )) + yield + if not had_echo_before: + _registry.pop("echo", None) + + +def test_llm_trashes_tool_result_via_contextgc_end_to_end(monkeypatch, echo_tool): + """LLM calls echo, then ContextGC(trash=[echo_id]); gc_state is mutated.""" + turns = [ + # Turn 1 (first stream call): LLM issues the echo tool call. + {"tool_calls": [ + {"id": "echo_42", "name": "echo", "input": {"text": "hi"}}, + ]}, + # Turn 2: LLM follows up with ContextGC to trash echo_42. + {"tool_calls": [ + {"id": "gc_1", "name": "ContextGC", "input": {"trash": ["echo_42"]}}, + ]}, + # Turn 3: LLM emits plain text; no tool_calls → loop exits. + {"text": "all set"}, + ] + monkeypatch.setattr("agent.stream", _scripted_stream(turns)) + + state = AgentState() + config = {"model": "test", "permission_mode": "accept-all", "_session_id": "gc_e2e"} + + list(run("please echo and clean up", state, config, "system prompt")) + + assert state.gc_state.trashed_ids == {"echo_42"} + assert "[ContextGC result]" not in state.gc_state.trashed_ids + # Neither the ContextGC tool result nor the echo result are deleted from + # state.messages -- only the OUTGOING messages on the next turn are reshaped. + tool_results = [m for m in state.messages if m.get("role") == "tool"] + assert len(tool_results) == 2 + + +def test_gc_state_survives_save_and_reload_via_session_helpers(monkeypatch, echo_tool, tmp_path): + """Roundtrip through _build_session_data / _restore_state_from_data.""" + turns = [ + {"tool_calls": [ + {"id": "echo_1", "name": "echo", "input": {"text": "x"}}, + ]}, + {"tool_calls": [ + {"id": "gc_1", "name": "ContextGC", "input": {"trash": ["echo_1"]}}, + ]}, + {"text": "done"}, + ] + monkeypatch.setattr("agent.stream", _scripted_stream(turns)) + + state = AgentState() + list(run("go", state, {"model": "test", "permission_mode": "accept-all", + "_session_id": "rt"}, "sys")) + assert state.gc_state.trashed_ids == {"echo_1"} + + # Serialize to disk (through JSON to exercise the real save path). + session_path: Path = tmp_path / "session.json" + session_path.write_text( + json.dumps(_build_session_data(state), default=str), encoding="utf-8" + ) + + # Restore into a brand-new state — trashed_ids must come back intact. + reloaded = AgentState() + _restore_state_from_data( + reloaded, json.loads(session_path.read_text(encoding="utf-8")) + ) + assert reloaded.gc_state.trashed_ids == {"echo_1"} + + +def test_disabled_tools_hides_contextgc_schema_from_llm(monkeypatch, echo_tool): + """With config['disabled_tools']=['ContextGC'] the LLM never sees the schema.""" + captured_schemas = [] + + def spy_stream(**kwargs): + captured_schemas.append([s["name"] for s in kwargs.get("tool_schemas") or []]) + yield AssistantTurn(text="hello", tool_calls=[], in_tokens=1, out_tokens=1) + + monkeypatch.setattr("agent.stream", spy_stream) + + state = AgentState() + list(run("hi", state, { + "model": "test", + "permission_mode": "accept-all", + "_session_id": "gated", + "disabled_tools": ["ContextGC"], + }, "sys")) + + assert captured_schemas, "stream() must have been called at least once" + for schemas in captured_schemas: + assert "ContextGC" not in schemas + assert "echo" in schemas # non-disabled tool still present diff --git a/tests/test_followup_compaction.py b/tests/test_followup_compaction.py new file mode 100644 index 0000000..024a803 --- /dev/null +++ b/tests/test_followup_compaction.py @@ -0,0 +1,138 @@ +"""Tests for followup_compaction module.""" +import pytest + +from followup_compaction import ( + compact_tool_history, _build_tool_call_lookup, _build_stub, + _input_brief, _escape_xml_attr, + DEFAULT_EXEMPT_TOOLS, +) + + +class TestCompactToolHistory: + def _make_messages(self): + return [ + {"role": "user", "content": "turn 1"}, + {"role": "assistant", "content": "ok", "tool_calls": [ + {"id": "tc1", "name": "Read", "input": {"file_path": "/a.py"}}, + ]}, + {"role": "tool", "tool_call_id": "tc1", "name": "Read", "content": "file contents..."}, + {"role": "user", "content": "turn 2"}, + {"role": "assistant", "content": "done"}, + ] + + def test_stubs_old_tool_results(self): + msgs = self._make_messages() + result = compact_tool_history(msgs) + assert "") + + def test_quote(self): + assert """ in _escape_xml_attr('"hello"') + + +class TestBuildToolCallLookup: + def test_builds_lookup(self): + msgs = [ + {"role": "assistant", "tool_calls": [ + {"id": "tc1", "name": "Read", "input": {"file_path": "/x"}}, + {"id": "tc2", "name": "Bash", "input": {"command": "ls"}}, + ]}, + ] + lookup = _build_tool_call_lookup(msgs) + assert lookup["tc1"] == ("Read", {"file_path": "/x"}) + assert lookup["tc2"] == ("Bash", {"command": "ls"}) + + def test_skips_non_assistant(self): + msgs = [{"role": "user", "content": "hi"}] + assert _build_tool_call_lookup(msgs) == {} diff --git a/tests/test_gc_state_persistence.py b/tests/test_gc_state_persistence.py new file mode 100644 index 0000000..958bf09 --- /dev/null +++ b/tests/test_gc_state_persistence.py @@ -0,0 +1,80 @@ +"""gc_state must be a real field on AgentState and survive save/reload. + +Guard against the leak class where ContextGC-trashed tool_call_ids silently +re-materialize after /save + /load because they were only held in a per-turn +config dict, not on AgentState itself. +""" +from __future__ import annotations + +import json + +from agent import AgentState +from context_gc import GCState +from commands.session import _build_session_data, _restore_state_from_data + + +def test_agent_state_has_gc_state_by_default(): + state = AgentState() + assert isinstance(state.gc_state, GCState) + assert state.gc_state.trashed_ids == set() + assert state.gc_state.notes == {} + + +def test_two_agent_states_have_independent_gc_state(): + a = AgentState() + b = AgentState() + a.gc_state.trashed_ids.add("toolcall_1") + assert "toolcall_1" not in b.gc_state.trashed_ids + + +def test_session_save_includes_gc_state_as_sortable_json(): + state = AgentState() + state.gc_state.trashed_ids = {"id_b", "id_a", "id_c"} + state.gc_state.notes = {"task": "do the thing"} + + data = _build_session_data(state) + serialized = json.dumps(data) + assert "gc_state" in data + assert data["gc_state"]["trashed_ids"] == ["id_a", "id_b", "id_c"] + assert data["gc_state"]["notes"] == {"task": "do the thing"} + assert '"trashed_ids":' in serialized + + +def test_session_load_restores_gc_state(): + fresh = AgentState() + _restore_state_from_data(fresh, { + "messages": [], + "gc_state": { + "trashed_ids": ["a", "b"], + "notes": {"k": "v"}, + "snippets": {}, + }, + }) + assert fresh.gc_state.trashed_ids == {"a", "b"} + assert fresh.gc_state.notes == {"k": "v"} + + +def test_session_load_missing_gc_state_returns_fresh_empty(): + fresh = AgentState() + _restore_state_from_data(fresh, {"messages": []}) + assert fresh.gc_state.trashed_ids == set() + assert fresh.gc_state.notes == {} + + +def test_save_then_load_roundtrip_preserves_trashed_ids(): + """End-to-end: trash ids, serialize, rehydrate — ids must still be trashed. + + This is the exact leak the bug class introduces: if roundtrip drops + trashed_ids, previously-elided tool_results come back into context and + inflate the prompt by whatever they were trimmed from. + """ + before = AgentState() + before.gc_state.trashed_ids = {"tool_a", "tool_b"} + before.gc_state.snippets = {"tool_c": {"keep_after": "### Result"}} + + data = json.loads(json.dumps(_build_session_data(before), default=str)) + after = AgentState() + _restore_state_from_data(after, data) + + assert after.gc_state.trashed_ids == {"tool_a", "tool_b"} + assert after.gc_state.snippets == {"tool_c": {"keep_after": "### Result"}} diff --git a/tests/test_tool_registry.py b/tests/test_tool_registry.py index 2a7a8d4..e798e9e 100644 --- a/tests/test_tool_registry.py +++ b/tests/test_tool_registry.py @@ -89,6 +89,19 @@ def test_get_tool_schemas(): assert schemas[0]["name"] == "echo" +def test_get_tool_schemas_honours_disabled_list(): + register_tool(_make_echo_tool("kept")) + register_tool(_make_echo_tool("hidden")) + names = [s["name"] for s in get_tool_schemas(disabled=["hidden"])] + assert names == ["kept"] + + +def test_execute_tool_refuses_disabled_tool(): + register_tool(_make_echo_tool("gated")) + result = execute_tool("gated", {"text": "x"}, config={"disabled_tools": ["gated"]}) + assert "disabled" in result.lower() + + # ------------------------------------------------------------------ # execute_tool # ------------------------------------------------------------------ diff --git a/tool_registry.py b/tool_registry.py index f0a66c2..2d0132f 100644 --- a/tool_registry.py +++ b/tool_registry.py @@ -8,7 +8,7 @@ import hashlib import json from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional @dataclass @@ -64,14 +64,26 @@ def get_tool(name: str) -> Optional[ToolDef]: return _registry.get(name) +def is_concurrent_safe(name: str) -> bool: + """Return True if the named tool is safe to run in parallel.""" + tool = _registry.get(name) + return tool.concurrent_safe if tool else False + + def get_all_tools() -> List[ToolDef]: """Return all registered tools (insertion order).""" return list(_registry.values()) -def get_tool_schemas() -> List[Dict[str, Any]]: - """Return the schemas of all registered tools (for API tool parameter).""" - return [t.schema for t in _registry.values()] +def get_tool_schemas(disabled: Iterable[str] = ()) -> List[Dict[str, Any]]: + """Return the schemas of all registered tools (for API tool parameter). + + Tools whose name appears in ``disabled`` are omitted: the LLM never sees + them, so it cannot call them. Use this to opt a new tool (e.g. ContextGC) + out of a session for backwards-compatibility without touching the registry. + """ + skip = frozenset(disabled or ()) + return [t.schema for t in _registry.values() if t.name not in skip] def execute_tool( @@ -91,6 +103,10 @@ def execute_tool( Returns: Tool result string, possibly truncated. """ + disabled = frozenset(config.get("disabled_tools") or ()) + if name in disabled: + return f"Error: tool '{name}' is disabled in this session (see config['disabled_tools'])." + tool = get_tool(name) if tool is None: return f"Error: tool '{name}' not found." diff --git a/tools/__init__.py b/tools/__init__.py index 8731a8c..cee1e55 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -333,6 +333,103 @@ "required": ["seconds"], }, }, + { + "name": "NoteSave", + "description": ( + "Save or update a working-memory note. Notes persist across turns and are " + "injected into your context automatically. Use for plans, key findings, " + "extracted facts, and methodology tracking." + ), + "input_schema": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Unique note name (overwrites if exists)"}, + "content": {"type": "string", "description": "Note content (markdown supported)"}, + }, + "required": ["name", "content"], + }, + }, + { + "name": "NoteRead", + "description": ( + "Read one or all working-memory notes. " + "Omit 'name' to list all active notes with their content." + ), + "input_schema": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Note name to read. Omit to read all."}, + }, + "required": [], + }, + }, + { + "name": "ContextGC", + "description": ( + "Garbage-collect your context to free space. MANDATORY: call this at the end of " + "every turn with tool calls. Trash tool results you no longer need, keep only " + "relevant snippets from large results, and save key information in notes that " + "persist across turns. Use compact_xml=true to strip verbose XML from your own " + "old assistant outputs." + ), + "input_schema": { + "type": "object", + "properties": { + "trash": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "tool_call_ids to fully discard. Works on ANY tool result: " + "Read, Grep, Bash, Skill, GetFolderDescription, WebFetch, etc." + ), + }, + "keep_snippets": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string", "description": "tool_call_id of the result to trim"}, + "keep_after": {"type": "string", "description": "Keep from line containing this text to end"}, + "keep_before": {"type": "string", "description": "Keep from start to line before this text"}, + "keep_between": { + "type": "array", + "items": {"type": "string"}, + "description": "Keep between two text anchors [start_text, end_text]", + }, + }, + "required": ["id"], + }, + "description": "Partial keeps: trim results using text anchors", + }, + "notes": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Unique note name"}, + "content": {"type": "string", "description": "Note content"}, + }, + "required": ["name", "content"], + }, + "description": "Named scratchpad entries that persist across turns", + }, + "trash_notes": { + "type": "array", + "items": {"type": "string"}, + "description": "Note names to discard", + }, + "compact_xml": { + "type": "boolean", + "description": ( + "Strip verbose XML tool_use blocks from old assistant messages, " + "replacing each with a one-line summary. Keeps prose intact. " + "Once enabled, stays on for the rest of the session." + ), + }, + }, + "required": [], + }, + }, ] @@ -479,6 +576,31 @@ def _register_builtins() -> None: read_only=False, concurrent_safe=True, ), ] + + # NoteSave / NoteRead tools + from context_gc import note_save, note_read + _tool_defs.append(ToolDef( + name="NoteSave", + schema=_schemas["NoteSave"], + func=lambda p, c: note_save(p, c), + read_only=False, concurrent_safe=True, + )) + _tool_defs.append(ToolDef( + name="NoteRead", + schema=_schemas["NoteRead"], + func=lambda p, c: note_read(p, c), + read_only=True, concurrent_safe=True, + )) + + # ContextGC tool + from context_gc import process_gc_call + _tool_defs.append(ToolDef( + name="ContextGC", + schema=_schemas["ContextGC"], + func=lambda p, c: process_gc_call(p, c), + read_only=True, concurrent_safe=True, + )) + for td in _tool_defs: register_tool(td)