Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,25 +185,36 @@ def run(
if assistant_turn is None:
break

# Record assistant turn in neutral format
state.messages.append({
"role": "assistant",
"content": assistant_turn.text,
"tool_calls": assistant_turn.tool_calls,
})

state.total_input_tokens += assistant_turn.in_tokens
state.total_output_tokens += assistant_turn.out_tokens
state.total_cache_read_tokens += getattr(assistant_turn, 'cache_read_tokens', 0)
state.total_cache_write_tokens += getattr(assistant_turn, 'cache_write_tokens', 0)
yield TurnDone(assistant_turn.in_tokens, assistant_turn.out_tokens)

if not assistant_turn.tool_calls:
# Record assistant turn and stop
state.messages.append({
"role": "assistant",
"content": assistant_turn.text,
"tool_calls": assistant_turn.tool_calls,
})
break # No tools → conversation turn complete

# ── Execute tools (parallel when safe) ────────────────────────────
tool_calls = assistant_turn.tool_calls

# Rewrite colliding tool_call ids BEFORE appending to state
# (so _collect_used_ids doesn't see this turn's own ids as taken)
from id_uniquify import uniquify_tool_call_ids
uniquify_tool_call_ids(tool_calls, state)

# Record assistant turn in neutral format (after uniquify so ids match)
state.messages.append({
"role": "assistant",
"content": assistant_turn.text,
"tool_calls": assistant_turn.tool_calls,
})

# Check permissions first (must be sequential — may prompt user)
permissions: dict[str, bool] = {}
for tc in tool_calls:
Expand Down
5 changes: 3 additions & 2 deletions checkpoint/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import json
import os
import shutil
import sys
import time
from datetime import datetime, timedelta
from pathlib import Path
Expand Down Expand Up @@ -97,7 +98,7 @@ def track_file_edit(session_id: str, file_path: str) -> str | None:
except OSError:
return None
if size > _MAX_FILE_SIZE:
print(f"[checkpoint] skipping large file ({size} bytes): {file_path}")
print(f"[checkpoint] skipping large file ({size} bytes): {file_path}", file=sys.stderr)
return None

# Copy file to backups/
Expand All @@ -107,7 +108,7 @@ def track_file_edit(session_id: str, file_path: str) -> str | None:
try:
shutil.copy2(str(p), str(backup_path))
except Exception as e:
print(f"[checkpoint] backup failed for {file_path}: {e}")
print(f"[checkpoint] backup failed for {file_path}: {e}", file=sys.stderr)
return None

return backup_name
Expand Down
79 changes: 79 additions & 0 deletions coercion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Parameter type coercion for LLM tool calls.

LLMs sometimes send typed values as strings (e.g. ``"42"`` instead of ``42``
for an integer property). This module coerces string parameters to their
schema-declared types so tool handlers receive the expected Python types.

Coercion failure is intentionally *not* a hard error: the original string
is kept so the tool handler can surface a clear type mismatch to the model.
"""
from __future__ import annotations

import json as _json


def coerce_params(params: dict, schema: dict) -> dict:
"""Coerce string parameter values to their schema-declared types.

Handles both schema styles:
- Top-level ``properties`` (rare, e.g. test fixtures)
- Anthropic-style ``input_schema.properties`` (all built-in tools)
"""
props = (
schema.get("properties")
or schema.get("input_schema", {}).get("properties", {})
)
if not props:
return dict(params)
return {k: _coerce_value_for(k, v, props) for k, v in params.items()}


def _coerce_value_for(key: str, value, props: dict):
"""Coerce a single value according to its declared type, else return as-is."""
prop_schema = props.get(key)
if not prop_schema or not isinstance(value, str):
return value
coercer = _COERCERS.get(prop_schema.get("type"))
if coercer is None:
return value
return coercer(value)


def _coerce_int(value):
try:
return int(value)
except ValueError:
return value


def _coerce_float(value):
try:
return float(value)
except ValueError:
return value


def _coerce_bool(value):
"""Coerce string to bool. Returns original string if unrecognised."""
low = value.lower()
if low in ("true", "1", "yes"):
return True
if low in ("false", "0", "no"):
return False
return value # tool handler reports the real type mismatch


def _coerce_json(value):
try:
return _json.loads(value)
except (ValueError, _json.JSONDecodeError):
return value


_COERCERS = {
"integer": _coerce_int,
"number": _coerce_float,
"boolean": _coerce_bool,
"array": _coerce_json,
"object": _coerce_json,
}
78 changes: 78 additions & 0 deletions id_uniquify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Prevent tool_call_id collisions across turns.

LLMs pick short ids (e.g. ``r1``, ``w2``) and reuse them freely across turns.
When a previous id is reused, it can cause API errors or wrong tool results
being associated with wrong calls.

The fix: on ingest, rewrite any id that already exists in the message history
to ``t{turn}_{original}`` (with a numeric suffix if that still collides).
Same-turn ``depends_on`` references are rewritten in lockstep so the DAG
resolves correctly.
"""
from __future__ import annotations


def _collect_used_ids(state) -> set[str]:
"""Gather all tool_call ids already present in state.messages."""
used: set[str] = set()
for msg in state.messages:
role = msg.get("role")
if role == "assistant":
for tc in msg.get("tool_calls") or []:
tid = tc.get("id")
if tid:
used.add(tid)
elif role == "tool":
tid = msg.get("tool_call_id")
if tid:
used.add(tid)
return used


def _pick_fresh_id(original: str, turn: int, used: set[str]) -> str:
candidate = f"t{turn}_{original}"
if candidate not in used:
return candidate
suffix = 2
while f"{candidate}_{suffix}" in used:
suffix += 1
return f"{candidate}_{suffix}"


def uniquify_tool_call_ids(tool_calls: list, state) -> dict[str, str]:
"""Rewrite colliding tool_call ids in-place and rewrite depends_on refs.

Only ids that already exist in state.messages are remapped. New ids
pass through unchanged, preserving behaviour for simple sessions.

Returns:
Mapping of ``{original_id: new_id}`` for any ids that were remapped.
"""
if not tool_calls:
return {}

used = _collect_used_ids(state)
remap: dict[str, str] = {}

for tc in tool_calls:
original = tc.get("id")
if not original or original not in used:
if original:
used.add(original)
continue
fresh = _pick_fresh_id(original, state.turn_count, used)
remap[original] = fresh
tc["id"] = fresh
used.add(fresh)

if not remap:
return {}

# Rewrite depends_on references inside same-turn tool calls
for tc in tool_calls:
params = tc.get("input") or {}
deps = params.get("depends_on")
if deps:
params["depends_on"] = [remap.get(d, d) for d in deps]

return remap
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ py-modules = [
"cc_config",
"context",
"error_classifier",
"id_uniquify",
"coercion",
"health",
"jobs",
"logging_utils",
Expand Down
42 changes: 42 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Shared pytest fixtures for all tests."""

from __future__ import annotations

import pytest

from tool_registry import ToolDef, register_tool, _registry


# --------------- quota stub (avoids ImportError on CI for calc_cost) --------

@pytest.fixture(autouse=True)
def _no_quota(monkeypatch):
"""Disable quota.record_usage so tests never hit the real billing path."""
import quota
monkeypatch.setattr(quota, "record_usage", lambda *a, **kw: None)


# --------------- receiver tool fixture -------------------------------------

@pytest.fixture
def receiver_tool():
"""Register a tool that captures whatever params it receives."""
received = {}
had_before = "receiver" in _registry
register_tool(ToolDef(
name="receiver",
schema={
"name": "receiver",
"description": "records params for assertions",
"input_schema": {
"type": "object",
"properties": {"msg": {"type": "string"}},
"required": ["msg"],
},
},
func=lambda params, _cfg: received.setdefault("seen", dict(params)) and "ok",
read_only=True, concurrent_safe=True,
))
yield received
if not had_before:
_registry.pop("receiver", None)
112 changes: 112 additions & 0 deletions tests/test_checkpoint_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""End-to-end: drive a real agent.run() conversation where the LLM calls Write,
and verify the checkpoint hook intercepts the call and files a backup to disk.

Only the LLM provider is mocked (via monkeypatching agent.stream). The Write
tool, checkpoint hooks and checkpoint store all run for real against tmp_path.
"""
from __future__ import annotations

import pytest

import tools as _tools_init # noqa: F401 - force built-in tool registration
from agent import AgentState, run
from providers import AssistantTurn
from checkpoint import hooks as checkpoint_hooks
from checkpoint import store as checkpoint_store


def _scripted_stream(turns):
cursor = iter(turns)

def fake_stream(**_kwargs):
spec = next(cursor)
yield AssistantTurn(
text=spec.get("text", ""),
tool_calls=spec.get("tool_calls") or [],
in_tokens=1, out_tokens=1,
)

return fake_stream


@pytest.fixture
def sandboxed_checkpoints(tmp_path, monkeypatch):
"""Run checkpoint store against tmp_path and install hooks on built-in tools."""
monkeypatch.setattr(
checkpoint_store, "_checkpoints_root", lambda: tmp_path / ".checkpoints"
)
checkpoint_store.reset_file_versions()
checkpoint_hooks.set_session("e2e-session")
checkpoint_hooks.reset_tracked()
checkpoint_hooks.install_hooks()
yield tmp_path
checkpoint_hooks.reset_tracked()


def test_llm_write_triggers_checkpoint_backup(monkeypatch, sandboxed_checkpoints):
"""When the LLM calls Write, the checkpoint hook must back the pre-edit file up.

Pre-populate a small file, then let the LLM overwrite it via the Write
tool. The hook should copy the old content into checkpoints/.../backups/
before the Write executes, so the backup holds the original bytes.
"""
target = sandboxed_checkpoints / "hello.py"
target.write_text("print('before')\n", encoding="utf-8")

turns = [
{"tool_calls": [{
"id": "w1",
"name": "Write",
"input": {"file_path": str(target), "content": "print('after')\n"},
}]},
{"text": "done"},
]
monkeypatch.setattr("agent.stream", _scripted_stream(turns))

state = AgentState()
config = {"model": "test", "permission_mode": "accept-all",
"_session_id": "e2e-session"}
list(run("overwrite the file", state, config, "system prompt"))

# After the turn: Write applied the new content
assert target.read_text(encoding="utf-8") == "print('after')\n"

# And the checkpoint hook filed a backup with the pre-edit content
backups_dir = sandboxed_checkpoints / ".checkpoints" / "e2e-session" / "backups"
backups = list(backups_dir.iterdir())
assert backups, "checkpoint hook did not create a backup file"
assert any(b.read_text(encoding="utf-8") == "print('before')\n" for b in backups)


def test_oversized_write_logs_to_stderr_not_stdout(
monkeypatch, sandboxed_checkpoints, capfd
):
"""Over the _MAX_FILE_SIZE threshold the hook skips + logs — to stderr only.

This is the actual user-visible contract of PR #47: checkpoint skips must
not pollute stdout (which carries the conversation transcript), they must
land on stderr where operators look.
"""
monkeypatch.setattr(checkpoint_store, "_MAX_FILE_SIZE", 20)
big = sandboxed_checkpoints / "big.py"
big.write_text("x" * 100, encoding="utf-8")

turns = [
{"tool_calls": [{
"id": "w1",
"name": "Write",
"input": {"file_path": str(big), "content": "y" * 100},
}]},
{"text": "ok"},
]
monkeypatch.setattr("agent.stream", _scripted_stream(turns))

state = AgentState()
list(run("rewrite", state, {"model": "test", "permission_mode": "accept-all",
"_session_id": "e2e-session",
"disabled_tools": ["Agent"]},
"sys"))

out, errtxt = capfd.readouterr()
assert "[checkpoint] skipping large file" in errtxt
assert "[checkpoint] skipping large file" not in out
Loading
Loading