Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import tools as _tools_init # ensure built-in tools are registered on import
from providers import stream, AssistantTurn, TextChunk, ThinkingChunk, detect_provider
from compaction import maybe_compact, estimate_tokens, get_context_limit, compact_messages
from context_gc import GCState
import logging_utils as _log
import quota as _quota
from circuit_breaker import CircuitOpenError as _CircuitOpenError
Expand All @@ -32,6 +33,11 @@ class AgentState:
total_input_tokens: int = 0
total_output_tokens: int = 0
turn_count: int = 0
# Persisted so trashed_ids, snippets and notes survive /save and /load.
# Without this, restoring a session leaks back tool_results the model had trashed.
gc_state: GCState = field(default_factory=GCState)
# Timeline of note changes (for debugging / replay)
notes_timeline: list = field(default_factory=list)


@dataclass
Expand Down Expand Up @@ -85,8 +91,11 @@ def run(
user_msg["images"] = [pending_img]
state.messages.append(user_msg)

# Inject runtime metadata into config so tools (e.g. Agent) can access it
config = {**config, "_depth": depth, "_system_prompt": system_prompt}
# Inject runtime metadata into config so tools (e.g. Agent, ContextGC) can access it.
# ContextGC reads and mutates config["_gc_state"]; without this binding, every call
# returns "Error: no GC state available" and no trashed_id is ever recorded.
config = {**config, "_depth": depth, "_system_prompt": system_prompt,
"_gc_state": state.gc_state, "_state": state}
session_id = config.get("_session_id", "default")

# Wire up structured logging from config (idempotent, cheap)
Expand Down Expand Up @@ -120,7 +129,7 @@ def run(
model=config["model"],
system=system_prompt,
messages=state.messages,
tool_schemas=get_tool_schemas(),
tool_schemas=get_tool_schemas(disabled=config.get("disabled_tools") or ()),
config=config,
):
if isinstance(event, (TextChunk, ThinkingChunk)):
Expand Down
51 changes: 39 additions & 12 deletions commands/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,45 @@ def _build_session_data(state, session_id: str | None = None) -> dict:
"turn_count": state.turn_count,
"total_input_tokens": state.total_input_tokens,
"total_output_tokens": state.total_output_tokens,
"gc_state": _serialize_gc_state(getattr(state, "gc_state", None)),
}


def _serialize_gc_state(gc_state) -> dict:
"""JSON-safe view of ContextGC state (trashed_ids, snippets, notes).

Must be stable across saves so trashed_ids surviving a /load cannot leak
back into the model's context window (the "gc_state leak" class).
"""
if gc_state is None:
return {"trashed_ids": [], "snippets": {}, "notes": {}}
return {
"trashed_ids": sorted(gc_state.trashed_ids),
"snippets": dict(gc_state.snippets),
"notes": dict(gc_state.notes),
}


def _restore_state_from_data(state, data: dict) -> None:
"""Apply a loaded session dict onto a state in-place.

Single point of truth for /load, /resume and /cloudsave load. Covers the
full AgentState surface including gc_state — forgetting any of these is
how the session-save/restore roundtrip drifts from in-memory state.
"""
from context_gc import GCState
state.messages = data.get("messages", [])
state.turn_count = data.get("turn_count", 0)
state.total_input_tokens = data.get("total_input_tokens", 0)
state.total_output_tokens = data.get("total_output_tokens", 0)
gc = data.get("gc_state") or {}
state.gc_state = GCState(
trashed_ids=set(gc.get("trashed_ids") or []),
snippets=dict(gc.get("snippets") or {}),
notes=dict(gc.get("notes") or {}),
)


# ── /save ──────────────────────────────────────────────────────────────────

def cmd_save(args: str, state, config) -> bool:
Expand Down Expand Up @@ -312,10 +348,7 @@ def cmd_load(args: str, state, config) -> bool:
except Exception as e:
err(f"Cannot read session file: {e}")
return True
state.messages = data.get("messages", [])
state.turn_count = data.get("turn_count", 0)
state.total_input_tokens = data.get("total_input_tokens", 0)
state.total_output_tokens = data.get("total_output_tokens", 0)
_restore_state_from_data(state, data)
ok(f"Session loaded from {path} ({len(state.messages)} messages)")
return True

Expand Down Expand Up @@ -353,10 +386,7 @@ def cmd_resume(args: str, state, config) -> bool:
except Exception as e:
err(f"Cannot read session file: {e}")
return True
state.messages = data.get("messages", [])
state.turn_count = data.get("turn_count", 0)
state.total_input_tokens = data.get("total_input_tokens", 0)
state.total_output_tokens = data.get("total_output_tokens", 0)
_restore_state_from_data(state, data)
ok(f"Session loaded from {path} ({len(state.messages)} messages)")
return True

Expand Down Expand Up @@ -522,10 +552,7 @@ def cmd_cloudsave(args: str, state, config) -> bool:
if err_msg:
err(err_msg)
return True
state.messages = data.get("messages", [])
state.turn_count = data.get("turn_count", 0)
state.total_input_tokens = data.get("total_input_tokens", 0)
state.total_output_tokens = data.get("total_output_tokens", 0)
_restore_state_from_data(state, data)
ok(f"Session loaded from Gist ({len(state.messages)} messages).")
return True

Expand Down
Loading