diff --git a/followup_compaction.py b/followup_compaction.py new file mode 100644 index 0000000..5c225fc --- /dev/null +++ b/followup_compaction.py @@ -0,0 +1,278 @@ +"""Follow-up compaction: destroy past-turn tool content before each API call. + +At each user turn boundary, ALL tool messages and assistant tool_calls from +prior turns are completely removed (no stubs). The current turn is always +kept intact. + +Non-destructive to state.messages -- produces a new list so persistence and +resume keep the full history. +""" +from __future__ import annotations + +import html +import re +import time + + +_THINKING_BLOCK_RE = re.compile(r'.*?\s*', re.DOTALL) + + +_ARGS_PREFERRED_KEY = { + "Read": "file_path", "Edit": "file_path", "Write": "file_path", + "NotebookEdit": "notebook_path", + "Glob": "pattern", "Grep": "pattern", + "Bash": "command", + "WebFetch": "url", "WebSearch": "query", +} + + +def _escape_xml_attr(s: str) -> str: + return html.escape(str(s), quote=True) + + +def _input_brief(tool_name: str, input_dict: dict, max_len: int = 60) -> str: + if not input_dict: + return "" + val = input_dict.get(_ARGS_PREFERRED_KEY.get(tool_name, "")) + if val is None: + for v in input_dict.values(): + if isinstance(v, str) and v: + val = v + break + if val is None: + return "" + val = str(val).replace("\n", " ") + if len(val) > max_len: + val = val[: max_len - 3] + "..." + return val + + +def _build_tc_lookup(tool_calls: list | None) -> dict: + lookup: dict = {} + for tc in tool_calls or []: + tid = tc.get("id", "") + if tid: + lookup[tid] = (tc.get("name", "tool"), tc.get("input") or {}) + return lookup + + +def _xml_replacer(tc_lookup: dict, target_ids: set | None = None): + def _replacer(match): + name, tid = match.group(1), match.group(2) + if target_ids is not None and tid not in target_ids: + return match.group(0) + tc_name, tc_input = tc_lookup.get(tid, (name, {})) + brief = _input_brief(tc_name, tc_input) + return f'' + return _replacer + + +_TOOL_USE_RE = re.compile( + r']*>.*?', + re.DOTALL, +) + + +def compact_assistant_xml(content: str, tool_calls: list | None = None) -> str: + """Replace ALL inline XML tool_use blocks with one-line summaries.""" + if not content or " str: + """Replace only XML blocks whose id is in target_ids, leaving others intact.""" + if not content or " bool: + if user_idx == 0: + return True + prev = messages[user_idx - 1] + return prev.get("role") == "assistant" and not prev.get("tool_calls") + + +def compact_tool_history(messages: list, keep_last_n_turns: int = 0) -> list: + """Completely remove prior-turn tool content. + + At user turn boundaries, ALL tool messages and assistant tool_calls from + prior turns are destroyed (no stubs). Assistant messages that become empty + after stripping are also removed. + + The current turn (last ``keep_last_n_turns + 1`` user messages onward) is + kept intact. + """ + user_indices = [i for i, m in enumerate(messages) if m.get("role") == "user"] + if not user_indices: + return list(messages) + + valid_boundaries = [i for i in user_indices + if _is_completed_boundary(messages, i)] + + total_keep = keep_last_n_turns + 1 + if total_keep >= len(valid_boundaries): + return list(messages) + + current_turn_start = valid_boundaries[-total_keep] + + result = [] + for i, msg in enumerate(messages): + if i >= current_turn_start: + result.append(msg) + continue + + role = msg.get("role") + + if role == "tool": + continue + + if role == "user": + result.append(msg) + continue + + if role == "assistant": + tool_calls = msg.get("tool_calls") + content = msg.get("content", "") or "" + + if tool_calls: + content = compact_assistant_xml(content, tool_calls) + cleaned = dict(msg) + cleaned.pop("tool_calls", None) + cleaned["content"] = content + if not content.strip(): + continue + result.append(cleaned) + else: + if content.strip(): + result.append(msg) + continue + + result.append(msg) + + return result + + +def _mark_compaction_boundary(messages: list) -> None: + """Mark the last message before the current user turn with _cache_breakpoint. + + This tells messages_to_anthropic where to place cache_control so the + compacted prefix is cached and current-loop messages stay fresh. + """ + user_indices = [i for i, m in enumerate(messages) if m.get("role") == "user"] + if len(user_indices) < 2: + return + valid_boundaries = [] + for idx in user_indices: + if idx == 0: + valid_boundaries.append(idx) + else: + prev = messages[idx - 1] + role = prev.get("role") + if role == "assistant" and not prev.get("tool_calls"): + valid_boundaries.append(idx) + elif role == "user": + valid_boundaries.append(idx) + if len(valid_boundaries) < 2: + return + current_start = valid_boundaries[-1] + if current_start > 0: + messages[current_start - 1]["_cache_breakpoint"] = True + + +def _strip_thinking_from_messages(messages: list) -> list: + """Remove ... blocks from assistant message content. + + Non-destructive: returns a new list with new dicts where needed. + Handles both string and list-of-blocks content formats. + """ + result = [] + for msg in messages: + if msg.get("role") != "assistant": + result.append(msg) + continue + content = msg.get("content", "") + if isinstance(content, str) and "" in content: + cleaned = _THINKING_BLOCK_RE.sub("", content) + result.append({**msg, "content": cleaned or "."}) + elif isinstance(content, list): + new_blocks = [] + changed = False + for block in content: + if isinstance(block, dict) and block.get("type") == "text" and "" in block.get("text", ""): + cleaned = _THINKING_BLOCK_RE.sub("", block["text"]) + new_blocks.append({**block, "text": cleaned or "."}) + changed = True + else: + new_blocks.append(block) + result.append({**msg, "content": new_blocks} if changed else msg) + else: + result.append(msg) + return result + + +def build_messages_for_api(state, config: dict) -> list: + """Compact prior-turn tool content at user boundaries, then apply ContextGC. + + compact_tool_history runs on every build so the post-compaction prefix is + byte-stable across every call in a turn, not just the one that immediately + follows a user message. The function is idempotent: it always leaves the + last user turn intact and only touches prior-turn tool content. + """ + compacted = compact_tool_history(list(state.messages)) + result = _apply_context_gc(compacted, state) + try: + from context_gc import strip_trashed_stubs + result = strip_trashed_stubs(result) + except ImportError: + pass + result = _strip_thinking_from_messages(result) + _mark_compaction_boundary(result) + return result + + +def _apply_context_gc(messages: list, state) -> list: + """Apply model-driven GC decisions. Notes and audit info are injected + into the last user message in dispatch.py, keeping them out of system + blocks for Anthropic cache stability.""" + try: + from context_gc import apply_gc + except ImportError: + return messages + gc_state = getattr(state, 'gc_state', None) + if not gc_state: + return messages + if not gc_state.trashed_ids and not gc_state.snippets: + return messages + + try: + from compaction import estimate_tokens + tokens_before = estimate_tokens(messages) + except ImportError: + tokens_before = None + + result = apply_gc(messages, gc_state) + + if tokens_before is not None: + try: + tokens_after = estimate_tokens(result) + if tokens_before != tokens_after and hasattr(state, 'compaction_log'): + state.compaction_log.append({ + "event": "context_gc", + "timestamp": time.time(), + "turn": getattr(state, "turn_count", 0), + "trashed_count": len(gc_state.trashed_ids), + "snippet_count": len(gc_state.snippets), + "notes_count": len(gc_state.notes), + "tokens_est_saved": tokens_before - tokens_after, + }) + except ImportError: + pass + return result diff --git a/tests/test_followup_compaction.py b/tests/test_followup_compaction.py new file mode 100644 index 0000000..7c9cc4c --- /dev/null +++ b/tests/test_followup_compaction.py @@ -0,0 +1,435 @@ +"""Tests for followup_compaction.py.""" +from __future__ import annotations + +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from followup_compaction import ( + compact_tool_history, + compact_assistant_xml, + compact_assistant_xml_selective, + build_messages_for_api, +) + + +def _turn(user_text: str, tool_calls: list, tool_results: list) -> list: + """Build [user, assistant+tool_calls, tool_result, tool_result, ...].""" + msgs = [{"role": "user", "content": user_text}] + msgs.append({"role": "assistant", "content": "ok", "tool_calls": tool_calls}) + msgs.extend(tool_results) + return msgs + + +HEAVY = "X" * 5000 +_TERM_ASST = {"role": "assistant", "content": "done"} + + +class TestCompactToolHistory: + def test_removes_old_tool_messages(self): + history = ( + _turn( + "first question", + [{"id": "t1", "name": "Read", "input": {"file_path": "a.py"}}], + [{"role": "tool", "tool_call_id": "t1", "name": "Read", "content": HEAVY}], + ) + + [{"role": "assistant", "content": "done", "tool_calls": []}] + + _turn( + "follow up", + [{"id": "t2", "name": "Read", "input": {"file_path": "b.py"}}], + [{"role": "tool", "tool_call_id": "t2", "name": "Read", "content": HEAVY}], + ) + ) + out = compact_tool_history(history, keep_last_n_turns=0) + prior_tools = [m for m in out if m.get("role") == "tool" and m.get("tool_call_id") == "t1"] + assert len(prior_tools) == 0 + current_tools = [m for m in out if m.get("role") == "tool" and m.get("tool_call_id") == "t2"] + assert len(current_tools) == 1 + assert current_tools[0]["content"] == HEAVY + + def test_strips_tool_calls_from_prior_assistant(self): + history = ( + _turn( + "first", + [{"id": "t1", "name": "Read", "input": {"file_path": "a.py"}}], + [{"role": "tool", "tool_call_id": "t1", "name": "Read", "content": HEAVY}], + ) + + [_TERM_ASST] + + [{"role": "user", "content": "follow"}] + ) + out = compact_tool_history(history, keep_last_n_turns=0) + prior_assistants = [m for m in out if m.get("role") == "assistant"] + for a in prior_assistants: + assert "tool_calls" not in a + + def test_current_turn_intact(self): + history = ( + _turn( + "first", + [{"id": "t1", "name": "Read", "input": {"file_path": "a.py"}}], + [{"role": "tool", "tool_call_id": "t1", "name": "Read", "content": HEAVY}], + ) + + [_TERM_ASST] + + _turn( + "current", + [{"id": "t2", "name": "Bash", "input": {"command": "ls"}}], + [{"role": "tool", "tool_call_id": "t2", "name": "Bash", "content": HEAVY}], + ) + ) + out = compact_tool_history(history, keep_last_n_turns=0) + assert out[-1]["content"] == HEAVY + + def test_keep_last_n_turns_1(self): + history = ( + _turn( + "first", + [{"id": "t1", "name": "Read", "input": {"file_path": "a.py"}}], + [{"role": "tool", "tool_call_id": "t1", "name": "Read", "content": HEAVY}], + ) + + [_TERM_ASST] + + _turn( + "second", + [{"id": "t2", "name": "Read", "input": {"file_path": "b.py"}}], + [{"role": "tool", "tool_call_id": "t2", "name": "Read", "content": HEAVY}], + ) + + [_TERM_ASST] + + _turn( + "third (current)", + [{"id": "t3", "name": "Read", "input": {"file_path": "c.py"}}], + [{"role": "tool", "tool_call_id": "t3", "name": "Read", "content": HEAVY}], + ) + ) + out = compact_tool_history(history, keep_last_n_turns=1) + by_id = {m["tool_call_id"]: m for m in out if m.get("role") == "tool"} + assert "t1" not in by_id + assert by_id["t2"]["content"] == HEAVY + assert by_id["t3"]["content"] == HEAVY + + def test_removes_empty_assistant_after_stripping(self): + history = [ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "", "tool_calls": [ + {"id": "t1", "name": "Read", "input": {"file_path": "a.py"}} + ]}, + {"role": "tool", "tool_call_id": "t1", "name": "Read", "content": HEAVY}, + _TERM_ASST, + {"role": "user", "content": "follow"}, + ] + out = compact_tool_history(history, keep_last_n_turns=0) + assert len(out) == 3 + assert out[0]["content"] == "first" + assert out[1]["content"] == "done" + assert out[2]["content"] == "follow" + + def test_non_destructive(self): + original = _turn( + "q", + [{"id": "t1", "name": "Read", "input": {"file_path": "a.py"}}], + [{"role": "tool", "tool_call_id": "t1", "name": "Read", "content": HEAVY}], + ) + [_TERM_ASST, {"role": "user", "content": "follow"}] + snapshot_content = original[2]["content"] + compact_tool_history(original, keep_last_n_turns=0) + assert original[2]["content"] == snapshot_content + assert original[2]["content"] == HEAVY + + def test_no_compaction_when_only_one_turn(self): + history = _turn( + "only", + [{"id": "t1", "name": "Read", "input": {"file_path": "a.py"}}], + [{"role": "tool", "tool_call_id": "t1", "name": "Read", "content": HEAVY}], + ) + out = compact_tool_history(history, keep_last_n_turns=0) + assert out[2]["content"] == HEAVY + + +class _FakeState: + def __init__(self, messages): + self.messages = messages + self.compaction_log = [] + self.turn_count = 1 + + +class TestBuildMessagesForApi: + def test_mid_loop_compacts_prior_turns_keeps_current(self): + """Mid-loop (last msg is tool): prior turn gets compacted, current turn intact.""" + history = [ + {"role": "user", "content": "q"}, + {"role": "assistant", "content": "ok", "tool_calls": [{"id": "t1", "name": "Read", "input": {"file_path": "a.py"}}]}, + {"role": "tool", "tool_call_id": "t1", "name": "Read", "content": HEAVY}, + {"role": "assistant", "content": "done"}, + {"role": "user", "content": "follow"}, + {"role": "assistant", "content": "ok", "tool_calls": [{"id": "t2", "name": "Read", "input": {"file_path": "b.py"}}]}, + {"role": "tool", "tool_call_id": "t2", "name": "Read", "content": HEAVY}, + ] + state = _FakeState(history) + result = build_messages_for_api(state, {}) + tool_msgs = [m for m in result if m.get("role") == "tool"] + assert len(tool_msgs) == 1, "t1 (prior turn) must be compacted, t2 (current turn) kept" + assert HEAVY in tool_msgs[0]["content"] + assert tool_msgs[0]["tool_call_id"] == "t2" + + def test_user_turn_compacts_prior_tools(self): + """When last message is user (new turn), prior tool content is removed.""" + history = [ + {"role": "user", "content": "q"}, + {"role": "assistant", "content": "ok", "tool_calls": [{"id": "t1", "name": "Read", "input": {"file_path": "a.py"}}]}, + {"role": "tool", "tool_call_id": "t1", "name": "Read", "content": HEAVY}, + {"role": "assistant", "content": "done"}, + {"role": "user", "content": "follow"}, + ] + state = _FakeState(history) + result = build_messages_for_api(state, {}) + tool_msgs = [m for m in result if m.get("role") == "tool"] + assert len(tool_msgs) == 0 + + def test_audit_does_not_mutate_last_user_message(self): + """The audit must live in the system volatile block, NOT be prepended to any user msg.""" + history = [ + {"role": "user", "content": "do stuff"}, + {"role": "assistant", "content": "ok", "tool_calls": [ + {"id": "b99", "name": "Bash", "input": {"command": "ls"}} + ]}, + {"role": "tool", "tool_call_id": "b99", "name": "Bash", "content": "a" * 700}, + ] + state = _FakeState(history) + result = build_messages_for_api(state, {}) + user_msg = next(m for m in result if m.get("role") == "user") + assert user_msg["content"] == "do stuff" + assert "Verbatim tool_results" not in user_msg["content"] + + +class TestCompactAssistantXml: + def test_strips_xml_keeps_prose(self): + content = ( + 'Analysis here.\n\n' + 'foo.py' + '\n\nMore text.' + ) + tool_calls = [{"id": "r1", "name": "Read", "input": {"file_path": "foo.py"}}] + result = compact_assistant_xml(content, tool_calls) + assert "Analysis here." in result + assert "More text." in result + assert 'a.py' + '\ntext\n' + 'ls' + ) + tool_calls = [ + {"id": "r1", "name": "Read", "input": {"file_path": "a.py"}}, + {"id": "b1", "name": "Bash", "input": {"command": "ls"}}, + ] + result = compact_assistant_xml(content, tool_calls) + assert '' + '' + '' + '' + ) + tool_calls = [{"id": "w1", "name": "Write", "input": {"file_path": "test.py"}}] + result = compact_assistant_xml(content, tool_calls) + assert 'a.py' + '\ntext\n' + 'b.py' + 'big code' + ) + tool_calls = [ + {"id": "r1", "name": "Read", "input": {"file_path": "a.py"}}, + {"id": "w1", "name": "Write", "input": {"file_path": "b.py"}}, + ] + result = compact_assistant_xml_selective(content, tool_calls, {"w1"}) + assert '' in result + assert 'a.py' + ) + history = [ + {"role": "user", "content": "first question"}, + {"role": "assistant", "content": xml_content, "tool_calls": [ + {"id": "r1", "name": "Read", "input": {"file_path": "a.py"}} + ]}, + {"role": "tool", "tool_call_id": "r1", "name": "Read", "content": HEAVY}, + _TERM_ASST, + {"role": "user", "content": "follow up"}, + {"role": "assistant", "content": "ok", "tool_calls": [ + {"id": "r2", "name": "Read", "input": {"file_path": "b.py"}} + ]}, + {"role": "tool", "tool_call_id": "r2", "name": "Read", "content": HEAVY}, + ] + out = compact_tool_history(history, keep_last_n_turns=0) + prior_asst = out[1] + assert ' not a valid boundary -> no compaction.""" + history = [ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "ok", "tool_calls": [ + {"id": "t1", "name": "Read", "input": {"file_path": "a.py"}} + ]}, + {"role": "tool", "tool_call_id": "t1", "name": "Read", "content": HEAVY}, + {"role": "user", "content": "after interrupt"}, + ] + out = compact_tool_history(history, keep_last_n_turns=0) + tool_msgs = [m for m in out if m.get("role") == "tool"] + assert len(tool_msgs) == 1 + assert tool_msgs[0]["content"] == HEAVY + + def test_completed_then_interrupted_preserves_interrupted(self): + """Completed turn compacted, interrupted turn preserved.""" + history = [ + {"role": "user", "content": "turn1"}, + {"role": "assistant", "content": "ok", "tool_calls": [ + {"id": "t1", "name": "Read", "input": {"file_path": "a.py"}} + ]}, + {"role": "tool", "tool_call_id": "t1", "name": "Read", "content": HEAVY}, + {"role": "assistant", "content": "done"}, + {"role": "user", "content": "turn2"}, + {"role": "assistant", "content": "ok", "tool_calls": [ + {"id": "t2", "name": "Bash", "input": {"command": "ls"}} + ]}, + {"role": "tool", "tool_call_id": "t2", "name": "Bash", "content": HEAVY}, + {"role": "user", "content": "after interrupt"}, + ] + out = compact_tool_history(history, keep_last_n_turns=0) + by_id = {m["tool_call_id"]: m for m in out if m.get("role") == "tool"} + assert "t1" not in by_id, "completed turn's tools should be compacted" + assert "t2" in by_id, "interrupted turn's tools must be preserved" + assert by_id["t2"]["content"] == HEAVY + + def test_multiple_interrupted_turns_all_preserved(self): + """Two consecutive interrupted turns: neither is compacted.""" + history = [ + {"role": "user", "content": "turn1"}, + {"role": "assistant", "content": "ok", "tool_calls": [ + {"id": "t1", "name": "Read", "input": {"file_path": "a.py"}} + ]}, + {"role": "tool", "tool_call_id": "t1", "name": "Read", "content": HEAVY}, + {"role": "user", "content": "turn2 after interrupt"}, + {"role": "assistant", "content": "ok", "tool_calls": [ + {"id": "t2", "name": "Read", "input": {"file_path": "b.py"}} + ]}, + {"role": "tool", "tool_call_id": "t2", "name": "Read", "content": HEAVY}, + {"role": "user", "content": "turn3 after interrupt"}, + ] + out = compact_tool_history(history, keep_last_n_turns=0) + tool_msgs = [m for m in out if m.get("role") == "tool"] + assert len(tool_msgs) == 2, "both interrupted turns' tools must be preserved" + + def test_build_messages_preserves_interrupted(self): + """build_messages_for_api preserves interrupted turn tool results.""" + history = [ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "ok", "tool_calls": [ + {"id": "t1", "name": "Read", "input": {"file_path": "a.py"}} + ]}, + {"role": "tool", "tool_call_id": "t1", "name": "Read", "content": HEAVY}, + {"role": "user", "content": "after ctrl+c"}, + ] + state = _FakeState(history) + result = build_messages_for_api(state, {}) + tool_msgs = [m for m in result if m.get("role") == "tool"] + assert len(tool_msgs) == 1 + assert tool_msgs[0]["content"] == HEAVY