Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions checkpoint/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import json
import os
import shutil
import sys
import time
from datetime import datetime, timedelta
from pathlib import Path
Expand Down Expand Up @@ -97,7 +98,7 @@ def track_file_edit(session_id: str, file_path: str) -> str | None:
except OSError:
return None
if size > _MAX_FILE_SIZE:
print(f"[checkpoint] skipping large file ({size} bytes): {file_path}")
print(f"[checkpoint] skipping large file ({size} bytes): {file_path}", file=sys.stderr)
return None

# Copy file to backups/
Expand All @@ -107,7 +108,7 @@ def track_file_edit(session_id: str, file_path: str) -> str | None:
try:
shutil.copy2(str(p), str(backup_path))
except Exception as e:
print(f"[checkpoint] backup failed for {file_path}: {e}")
print(f"[checkpoint] backup failed for {file_path}: {e}", file=sys.stderr)
return None

return backup_name
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Shared pytest fixtures for all tests."""

from __future__ import annotations

import pytest


# --------------- quota stub (avoids ImportError on CI for calc_cost) --------

@pytest.fixture(autouse=True)
def _no_quota(monkeypatch):
"""Disable quota.record_usage so tests never hit the real billing path."""
import quota
monkeypatch.setattr(quota, "record_usage", lambda *a, **kw: None)
27 changes: 27 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Reusable test helpers (importable from any test module)."""

from __future__ import annotations

from agent import AssistantTurn


def scripted_stream(captured_schemas: list, turns: list[dict]):
"""Return a fake ``stream()`` callable that yields pre-defined turns.

*captured_schemas* receives the ``tool_schemas`` kwarg from each call,
letting tests assert on schema injection. *turns* is a list of dicts,
each with optional ``text`` and ``tool_calls`` keys.
"""
cursor = iter(turns)

def fake_stream(**kwargs):
captured_schemas.append(kwargs.get("tool_schemas") or [])
spec = next(cursor)
yield AssistantTurn(
text=spec.get("text", ""),
tool_calls=spec.get("tool_calls") or [],
in_tokens=1,
out_tokens=1,
)

return fake_stream
112 changes: 112 additions & 0 deletions tests/test_checkpoint_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""End-to-end: drive a real agent.run() conversation where the LLM calls Write,
and verify the checkpoint hook intercepts the call and files a backup to disk.

Only the LLM provider is mocked (via monkeypatching agent.stream). The Write
tool, checkpoint hooks and checkpoint store all run for real against tmp_path.
"""
from __future__ import annotations

import pytest

import tools as _tools_init # noqa: F401 - force built-in tool registration
from agent import AgentState, run
from providers import AssistantTurn
from checkpoint import hooks as checkpoint_hooks
from checkpoint import store as checkpoint_store


def _scripted_stream(turns):
cursor = iter(turns)

def fake_stream(**_kwargs):
spec = next(cursor)
yield AssistantTurn(
text=spec.get("text", ""),
tool_calls=spec.get("tool_calls") or [],
in_tokens=1, out_tokens=1,
)

return fake_stream


@pytest.fixture
def sandboxed_checkpoints(tmp_path, monkeypatch):
"""Run checkpoint store against tmp_path and install hooks on built-in tools."""
monkeypatch.setattr(
checkpoint_store, "_checkpoints_root", lambda: tmp_path / ".checkpoints"
)
checkpoint_store.reset_file_versions()
checkpoint_hooks.set_session("e2e-session")
checkpoint_hooks.reset_tracked()
checkpoint_hooks.install_hooks()
yield tmp_path
checkpoint_hooks.reset_tracked()


def test_llm_write_triggers_checkpoint_backup(monkeypatch, sandboxed_checkpoints):
"""When the LLM calls Write, the checkpoint hook must back the pre-edit file up.

Pre-populate a small file, then let the LLM overwrite it via the Write
tool. The hook should copy the old content into checkpoints/.../backups/
before the Write executes, so the backup holds the original bytes.
"""
target = sandboxed_checkpoints / "hello.py"
target.write_text("print('before')\n", encoding="utf-8")

turns = [
{"tool_calls": [{
"id": "w1",
"name": "Write",
"input": {"file_path": str(target), "content": "print('after')\n"},
}]},
{"text": "done"},
]
monkeypatch.setattr("agent.stream", _scripted_stream(turns))

state = AgentState()
config = {"model": "test", "permission_mode": "accept-all",
"_session_id": "e2e-session"}
list(run("overwrite the file", state, config, "system prompt"))

# After the turn: Write applied the new content
assert target.read_text(encoding="utf-8") == "print('after')\n"

# And the checkpoint hook filed a backup with the pre-edit content
backups_dir = sandboxed_checkpoints / ".checkpoints" / "e2e-session" / "backups"
backups = list(backups_dir.iterdir())
assert backups, "checkpoint hook did not create a backup file"
assert any(b.read_text(encoding="utf-8") == "print('before')\n" for b in backups)


def test_oversized_write_logs_to_stderr_not_stdout(
monkeypatch, sandboxed_checkpoints, capfd
):
"""Over the _MAX_FILE_SIZE threshold the hook skips + logs — to stderr only.

This is the actual user-visible contract of PR #47: checkpoint skips must
not pollute stdout (which carries the conversation transcript), they must
land on stderr where operators look.
"""
monkeypatch.setattr(checkpoint_store, "_MAX_FILE_SIZE", 20)
big = sandboxed_checkpoints / "big.py"
big.write_text("x" * 100, encoding="utf-8")

turns = [
{"tool_calls": [{
"id": "w1",
"name": "Write",
"input": {"file_path": str(big), "content": "y" * 100},
}]},
{"text": "ok"},
]
monkeypatch.setattr("agent.stream", _scripted_stream(turns))

state = AgentState()
list(run("rewrite", state, {"model": "test", "permission_mode": "accept-all",
"_session_id": "e2e-session",
"disabled_tools": ["Agent"]},
"sys"))

out, errtxt = capfd.readouterr()
assert "[checkpoint] skipping large file" in errtxt
assert "[checkpoint] skipping large file" not in out
61 changes: 61 additions & 0 deletions tests/test_checkpoint_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Integration tests for checkpoint store: stderr capture + large file skip."""
from __future__ import annotations

import pytest

import checkpoint.store as store


@pytest.fixture(autouse=True)
def isolate_store(tmp_path, monkeypatch):
"""Redirect checkpoint root to tmp_path and reset global state."""
monkeypatch.setattr(store, "_checkpoints_root", lambda: tmp_path / "checkpoints")
store.reset_file_versions()


def test_large_file_skipped_and_logged_to_stderr(tmp_path, monkeypatch, capsys):
monkeypatch.setattr(store, "_MAX_FILE_SIZE", 50)
big_file = tmp_path / "big.txt"
big_file.write_bytes(b"x" * 100)

result = store.track_file_edit("test-session", str(big_file))

assert result is None
captured = capsys.readouterr()
assert "[checkpoint] skipping large file" in captured.err
assert "100 bytes" in captured.err
assert captured.out == ""


def test_normal_file_backed_up(tmp_path, capsys):
small_file = tmp_path / "small.txt"
content = b"hello world"
small_file.write_bytes(content)

result = store.track_file_edit("test-session", str(small_file))

assert result is not None
backup_dir = tmp_path / "checkpoints" / "test-session" / "backups"
backup_path = backup_dir / result
assert backup_path.exists()
assert backup_path.read_bytes() == content
captured = capsys.readouterr()
assert captured.err == ""


def test_backup_failure_logged_to_stderr(tmp_path, monkeypatch, capsys):
normal_file = tmp_path / "normal.txt"
normal_file.write_bytes(b"some data")

def failing_copy(*args, **kwargs):
raise PermissionError("access denied")

monkeypatch.setattr(store.shutil, "copy2", failing_copy)

result = store.track_file_edit("test-session", str(normal_file))

assert result is None
captured = capsys.readouterr()
assert "[checkpoint] backup failed" in captured.err
assert "access denied" in captured.err
assert captured.out == ""
Loading