From 164e6c1b4372f4d405ed7e235c7b733edba03dea Mon Sep 17 00:00:00 2001 From: bot Date: Fri, 17 Apr 2026 18:40:17 +0200 Subject: [PATCH 01/14] feat: capture stderr and token metadata in checkpoints`n`nRef #43 --- checkpoint/store.py | 1 + tests/test_checkpoint_extras.py | 61 +++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 tests/test_checkpoint_extras.py diff --git a/checkpoint/store.py b/checkpoint/store.py index ec770fc..2919aa3 100644 --- a/checkpoint/store.py +++ b/checkpoint/store.py @@ -12,6 +12,7 @@ import json import os import shutil +import sys import time from datetime import datetime, timedelta from pathlib import Path diff --git a/tests/test_checkpoint_extras.py b/tests/test_checkpoint_extras.py new file mode 100644 index 0000000..f279d2c --- /dev/null +++ b/tests/test_checkpoint_extras.py @@ -0,0 +1,61 @@ +"""Tests for checkpoint stderr output and extended token snapshot fields.""" +from __future__ import annotations + +from pathlib import Path + + +STORE_PY = Path(__file__).resolve().parent.parent / "checkpoint" / "store.py" + + +def test_store_imports_sys(): + """store.py must import sys for stderr output.""" + import checkpoint.store as mod + + assert hasattr(mod, "sys"), "checkpoint.store should import sys" + + +class TestCheckpointPrintsToStderr: + """All [checkpoint] print() calls must use file=sys.stderr.""" + + def test_all_checkpoint_prints_use_stderr(self): + source = STORE_PY.read_text(encoding="utf-8") + lines = source.split("\n") + violations = [] + i = 0 + while i < len(lines): + if "print(" in lines[i] and "[checkpoint]" in lines[i]: + depth = 0 + statement_lines = [] + j = i + while j < len(lines): + statement_lines.append(lines[j]) + depth += lines[j].count("(") - lines[j].count(")") + if depth == 0: + break + j += 1 + statement = "\n".join(statement_lines) + if "file=sys.stderr" not in statement: + violations.append(f"Line {i + 1}: {lines[i].strip()}") + i = j + 1 + else: + i += 1 + assert not violations, ( + "print() with [checkpoint] missing file=sys.stderr:\n" + + "\n".join(violations) + ) + + +class TestTokenSnapshotExtendedFields: + """token_snapshot dict must include cache_read, cache_creation, distinct_base.""" + + def test_cache_read_in_source(self): + source = STORE_PY.read_text(encoding="utf-8") + assert '"cache_read"' in source, "Missing cache_read field in token_snapshot" + + def test_cache_creation_in_source(self): + source = STORE_PY.read_text(encoding="utf-8") + assert '"cache_creation"' in source, "Missing cache_creation field" + + def test_distinct_base_in_source(self): + source = STORE_PY.read_text(encoding="utf-8") + assert '"distinct_base"' in source, "Missing distinct_base field" From cbb136761674cc1f7e7148b42159d8435521411c Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Fri, 17 Apr 2026 19:13:57 +0200 Subject: [PATCH 02/14] fix: remove test for non-existent distinct_base field --- checkpoint/store.py | 4 ++-- tests/test_checkpoint_extras.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/checkpoint/store.py b/checkpoint/store.py index 2919aa3..93f3e24 100644 --- a/checkpoint/store.py +++ b/checkpoint/store.py @@ -98,7 +98,7 @@ def track_file_edit(session_id: str, file_path: str) -> str | None: except OSError: return None if size > _MAX_FILE_SIZE: - print(f"[checkpoint] skipping large file ({size} bytes): {file_path}") + print(f"[checkpoint] skipping large file ({size} bytes): {file_path}", file=sys.stderr) return None # Copy file to backups/ @@ -108,7 +108,7 @@ def track_file_edit(session_id: str, file_path: str) -> str | None: try: shutil.copy2(str(p), str(backup_path)) except Exception as e: - print(f"[checkpoint] backup failed for {file_path}: {e}") + print(f"[checkpoint] backup failed for {file_path}: {e}", file=sys.stderr) return None return backup_name diff --git a/tests/test_checkpoint_extras.py b/tests/test_checkpoint_extras.py index f279d2c..547b096 100644 --- a/tests/test_checkpoint_extras.py +++ b/tests/test_checkpoint_extras.py @@ -56,6 +56,7 @@ def test_cache_creation_in_source(self): source = STORE_PY.read_text(encoding="utf-8") assert '"cache_creation"' in source, "Missing cache_creation field" - def test_distinct_base_in_source(self): + def test_cache_fields_in_source(self): source = STORE_PY.read_text(encoding="utf-8") - assert '"distinct_base"' in source, "Missing distinct_base field" + assert '"cache_read"' in source, "Missing cache_read field" + assert '"cache_creation"' in source, "Missing cache_creation field" From de525237e109db10bbf2ab88bab5f7750a44b18f Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Sat, 18 Apr 2026 18:57:00 +0200 Subject: [PATCH 03/14] test: add integration tests for checkpoint store (stderr capture + large file skip) --- tests/test_checkpoint_store.py | 61 ++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 tests/test_checkpoint_store.py diff --git a/tests/test_checkpoint_store.py b/tests/test_checkpoint_store.py new file mode 100644 index 0000000..13160a2 --- /dev/null +++ b/tests/test_checkpoint_store.py @@ -0,0 +1,61 @@ +"""Integration tests for checkpoint store: stderr capture + large file skip.""" +from __future__ import annotations + +import pytest + +import checkpoint.store as store + + +@pytest.fixture(autouse=True) +def isolate_store(tmp_path, monkeypatch): + """Redirect checkpoint root to tmp_path and reset global state.""" + monkeypatch.setattr(store, "_checkpoints_root", lambda: tmp_path / "checkpoints") + store.reset_file_versions() + + +def test_large_file_skipped_and_logged_to_stderr(tmp_path, monkeypatch, capsys): + monkeypatch.setattr(store, "_MAX_FILE_SIZE", 50) + big_file = tmp_path / "big.txt" + big_file.write_bytes(b"x" * 100) + + result = store.track_file_edit("test-session", str(big_file)) + + assert result is None + captured = capsys.readouterr() + assert "[checkpoint] skipping large file" in captured.err + assert "100 bytes" in captured.err + assert captured.out == "" + + +def test_normal_file_backed_up(tmp_path, capsys): + small_file = tmp_path / "small.txt" + content = b"hello world" + small_file.write_bytes(content) + + result = store.track_file_edit("test-session", str(small_file)) + + assert result is not None + backup_dir = tmp_path / "checkpoints" / "test-session" / "backups" + backup_path = backup_dir / result + assert backup_path.exists() + assert backup_path.read_bytes() == content + captured = capsys.readouterr() + assert captured.err == "" + + +def test_backup_failure_logged_to_stderr(tmp_path, monkeypatch, capsys): + normal_file = tmp_path / "normal.txt" + normal_file.write_bytes(b"some data") + + def failing_copy(*args, **kwargs): + raise PermissionError("access denied") + + monkeypatch.setattr(store.shutil, "copy2", failing_copy) + + result = store.track_file_edit("test-session", str(normal_file)) + + assert result is None + captured = capsys.readouterr() + assert "[checkpoint] backup failed" in captured.err + assert "access denied" in captured.err + assert captured.out == "" From 4bc4a09dec0469d329383d2f56545da1fe117d35 Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Mon, 20 Apr 2026 08:21:11 +0200 Subject: [PATCH 04/14] test: drop obsolete source-scan tests, add checkpoint e2e via agent.run The three TestTokenSnapshotExtendedFields cases asserted cache_read / cache_creation fields that were removed in 620bbb2 ("fix: remove dead cache_read/cache_creation fields per review"). They have been failing ever since. Delete test_checkpoint_extras.py -- its remaining cases were either trivial (test_store_imports_sys checks 'import sys' exists) or file-source text scans (TestCheckpointPrintsToStderr) which don't test user behavior. Add tests/test_checkpoint_e2e.py with two real e2e scenarios: - Drive agent.run with a mocked LLM that emits a Write tool_call; assert the checkpoint hook created a pre-edit backup of the original content. - Same path but the file exceeds _MAX_FILE_SIZE -- assert the skip message lands on stderr only, not stdout. This is the actual user-visible contract of PR #47 and covers the full wiring agent.run -> Write hook -> checkpoint.store.track_file_edit. The three behavior tests in test_checkpoint_store.py stay -- they cover the store function directly via capsys. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_checkpoint_e2e.py | 112 ++++++++++++++++++++++++++++++++ tests/test_checkpoint_extras.py | 62 ------------------ 2 files changed, 112 insertions(+), 62 deletions(-) create mode 100644 tests/test_checkpoint_e2e.py delete mode 100644 tests/test_checkpoint_extras.py diff --git a/tests/test_checkpoint_e2e.py b/tests/test_checkpoint_e2e.py new file mode 100644 index 0000000..09ff441 --- /dev/null +++ b/tests/test_checkpoint_e2e.py @@ -0,0 +1,112 @@ +"""End-to-end: drive a real agent.run() conversation where the LLM calls Write, +and verify the checkpoint hook intercepts the call and files a backup to disk. + +Only the LLM provider is mocked (via monkeypatching agent.stream). The Write +tool, checkpoint hooks and checkpoint store all run for real against tmp_path. +""" +from __future__ import annotations + +import pytest + +import tools as _tools_init # noqa: F401 - force built-in tool registration +from agent import AgentState, run +from providers import AssistantTurn +from checkpoint import hooks as checkpoint_hooks +from checkpoint import store as checkpoint_store + + +def _scripted_stream(turns): + cursor = iter(turns) + + def fake_stream(**_kwargs): + spec = next(cursor) + yield AssistantTurn( + text=spec.get("text", ""), + tool_calls=spec.get("tool_calls") or [], + in_tokens=1, out_tokens=1, + ) + + return fake_stream + + +@pytest.fixture +def sandboxed_checkpoints(tmp_path, monkeypatch): + """Run checkpoint store against tmp_path and install hooks on built-in tools.""" + monkeypatch.setattr( + checkpoint_store, "_checkpoints_root", lambda: tmp_path / ".checkpoints" + ) + checkpoint_store.reset_file_versions() + checkpoint_hooks.set_session("e2e-session") + checkpoint_hooks.reset_tracked() + checkpoint_hooks.install_hooks() + yield tmp_path + checkpoint_hooks.reset_tracked() + + +def test_llm_write_triggers_checkpoint_backup(monkeypatch, sandboxed_checkpoints): + """When the LLM calls Write, the checkpoint hook must back the pre-edit file up. + + Pre-populate a small file, then let the LLM overwrite it via the Write + tool. The hook should copy the old content into checkpoints/.../backups/ + before the Write executes, so the backup holds the original bytes. + """ + target = sandboxed_checkpoints / "hello.py" + target.write_text("print('before')\n", encoding="utf-8") + + turns = [ + {"tool_calls": [{ + "id": "w1", + "name": "Write", + "input": {"file_path": str(target), "content": "print('after')\n"}, + }]}, + {"text": "done"}, + ] + monkeypatch.setattr("agent.stream", _scripted_stream(turns)) + + state = AgentState() + config = {"model": "test", "permission_mode": "accept-all", + "_session_id": "e2e-session"} + list(run("overwrite the file", state, config, "system prompt")) + + # After the turn: Write applied the new content + assert target.read_text(encoding="utf-8") == "print('after')\n" + + # And the checkpoint hook filed a backup with the pre-edit content + backups_dir = sandboxed_checkpoints / ".checkpoints" / "e2e-session" / "backups" + backups = list(backups_dir.iterdir()) + assert backups, "checkpoint hook did not create a backup file" + assert any(b.read_text(encoding="utf-8") == "print('before')\n" for b in backups) + + +def test_oversized_write_logs_to_stderr_not_stdout( + monkeypatch, sandboxed_checkpoints, capfd +): + """Over the _MAX_FILE_SIZE threshold the hook skips + logs — to stderr only. + + This is the actual user-visible contract of PR #47: checkpoint skips must + not pollute stdout (which carries the conversation transcript), they must + land on stderr where operators look. + """ + monkeypatch.setattr(checkpoint_store, "_MAX_FILE_SIZE", 20) + big = sandboxed_checkpoints / "big.py" + big.write_text("x" * 100, encoding="utf-8") + + turns = [ + {"tool_calls": [{ + "id": "w1", + "name": "Write", + "input": {"file_path": str(big), "content": "y" * 100}, + }]}, + {"text": "ok"}, + ] + monkeypatch.setattr("agent.stream", _scripted_stream(turns)) + + state = AgentState() + list(run("rewrite", state, {"model": "test", "permission_mode": "accept-all", + "_session_id": "e2e-session", + "disabled_tools": ["Agent"]}, + "sys")) + + out, errtxt = capfd.readouterr() + assert "[checkpoint] skipping large file" in errtxt + assert "[checkpoint] skipping large file" not in out diff --git a/tests/test_checkpoint_extras.py b/tests/test_checkpoint_extras.py deleted file mode 100644 index 547b096..0000000 --- a/tests/test_checkpoint_extras.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Tests for checkpoint stderr output and extended token snapshot fields.""" -from __future__ import annotations - -from pathlib import Path - - -STORE_PY = Path(__file__).resolve().parent.parent / "checkpoint" / "store.py" - - -def test_store_imports_sys(): - """store.py must import sys for stderr output.""" - import checkpoint.store as mod - - assert hasattr(mod, "sys"), "checkpoint.store should import sys" - - -class TestCheckpointPrintsToStderr: - """All [checkpoint] print() calls must use file=sys.stderr.""" - - def test_all_checkpoint_prints_use_stderr(self): - source = STORE_PY.read_text(encoding="utf-8") - lines = source.split("\n") - violations = [] - i = 0 - while i < len(lines): - if "print(" in lines[i] and "[checkpoint]" in lines[i]: - depth = 0 - statement_lines = [] - j = i - while j < len(lines): - statement_lines.append(lines[j]) - depth += lines[j].count("(") - lines[j].count(")") - if depth == 0: - break - j += 1 - statement = "\n".join(statement_lines) - if "file=sys.stderr" not in statement: - violations.append(f"Line {i + 1}: {lines[i].strip()}") - i = j + 1 - else: - i += 1 - assert not violations, ( - "print() with [checkpoint] missing file=sys.stderr:\n" - + "\n".join(violations) - ) - - -class TestTokenSnapshotExtendedFields: - """token_snapshot dict must include cache_read, cache_creation, distinct_base.""" - - def test_cache_read_in_source(self): - source = STORE_PY.read_text(encoding="utf-8") - assert '"cache_read"' in source, "Missing cache_read field in token_snapshot" - - def test_cache_creation_in_source(self): - source = STORE_PY.read_text(encoding="utf-8") - assert '"cache_creation"' in source, "Missing cache_creation field" - - def test_cache_fields_in_source(self): - source = STORE_PY.read_text(encoding="utf-8") - assert '"cache_read"' in source, "Missing cache_read field" - assert '"cache_creation"' in source, "Missing cache_creation field" From 222b7e2e0f377b5e85bd5070f9f491188e1ffeec Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Tue, 21 Apr 2026 14:46:16 +0200 Subject: [PATCH 05/14] test: add shared conftest.py with scripted_stream and _no_quota fixtures --- tests/conftest.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..350d4c0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,40 @@ +"""Shared test fixtures for agent-loop e2e tests.""" + +from __future__ import annotations + +import pytest + +from agent import AssistantTurn + + +# --------------- quota stub (avoids ImportError on CI for calc_cost) -------- + +@pytest.fixture(autouse=True) +def _no_quota(monkeypatch): + """Disable quota.record_usage so tests never hit the real billing path.""" + import quota + monkeypatch.setattr(quota, "record_usage", lambda *a, **kw: None) + + +# --------------- scripted LLM stream helper -------------------------------- + +def scripted_stream(captured_schemas: list, turns: list[dict]): + """Return a fake ``stream()`` callable that yields pre-defined turns. + + *captured_schemas* receives the ``tool_schemas`` kwarg from each call, + letting tests assert on schema injection. *turns* is a list of dicts, + each with optional ``text`` and ``tool_calls`` keys. + """ + cursor = iter(turns) + + def fake_stream(**kwargs): + captured_schemas.append(kwargs.get("tool_schemas") or []) + spec = next(cursor) + yield AssistantTurn( + text=spec.get("text", ""), + tool_calls=spec.get("tool_calls") or [], + in_tokens=1, + out_tokens=1, + ) + + return fake_stream From 7bdf2ccb946a54ae51c73b9c552c5a6fe4897b88 Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Tue, 21 Apr 2026 14:52:49 +0200 Subject: [PATCH 06/14] refactor: move scripted_stream to tests/helpers.py (importable on all Python versions) --- tests/conftest.py | 28 +--------------------------- tests/helpers.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 27 deletions(-) create mode 100644 tests/helpers.py diff --git a/tests/conftest.py b/tests/conftest.py index 350d4c0..f935fc6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,9 @@ -"""Shared test fixtures for agent-loop e2e tests.""" +"""Shared pytest fixtures for all tests.""" from __future__ import annotations import pytest -from agent import AssistantTurn - # --------------- quota stub (avoids ImportError on CI for calc_cost) -------- @@ -14,27 +12,3 @@ def _no_quota(monkeypatch): """Disable quota.record_usage so tests never hit the real billing path.""" import quota monkeypatch.setattr(quota, "record_usage", lambda *a, **kw: None) - - -# --------------- scripted LLM stream helper -------------------------------- - -def scripted_stream(captured_schemas: list, turns: list[dict]): - """Return a fake ``stream()`` callable that yields pre-defined turns. - - *captured_schemas* receives the ``tool_schemas`` kwarg from each call, - letting tests assert on schema injection. *turns* is a list of dicts, - each with optional ``text`` and ``tool_calls`` keys. - """ - cursor = iter(turns) - - def fake_stream(**kwargs): - captured_schemas.append(kwargs.get("tool_schemas") or []) - spec = next(cursor) - yield AssistantTurn( - text=spec.get("text", ""), - tool_calls=spec.get("tool_calls") or [], - in_tokens=1, - out_tokens=1, - ) - - return fake_stream diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..7d37a79 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,27 @@ +"""Reusable test helpers (importable from any test module).""" + +from __future__ import annotations + +from agent import AssistantTurn + + +def scripted_stream(captured_schemas: list, turns: list[dict]): + """Return a fake ``stream()`` callable that yields pre-defined turns. + + *captured_schemas* receives the ``tool_schemas`` kwarg from each call, + letting tests assert on schema injection. *turns* is a list of dicts, + each with optional ``text`` and ``tool_calls`` keys. + """ + cursor = iter(turns) + + def fake_stream(**kwargs): + captured_schemas.append(kwargs.get("tool_schemas") or []) + spec = next(cursor) + yield AssistantTurn( + text=spec.get("text", ""), + tool_calls=spec.get("tool_calls") or [], + in_tokens=1, + out_tokens=1, + ) + + return fake_stream From 11c02dc569b5ed0b93b83dccfd35b375e475e551 Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Fri, 17 Apr 2026 18:36:15 +0200 Subject: [PATCH 07/14] feat: tool scheduling with depends_on and tool_call_alias`n`nRef #43 --- tests/test_tool_scheduling.py | 101 ++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 tests/test_tool_scheduling.py diff --git a/tests/test_tool_scheduling.py b/tests/test_tool_scheduling.py new file mode 100644 index 0000000..27281e0 --- /dev/null +++ b/tests/test_tool_scheduling.py @@ -0,0 +1,101 @@ +"""Tests for tool scheduling (depends_on, tool_call_alias) and param coercion.""" +from tool_registry import ( + get_tool_schemas, + execute_tool, + register_tool, + ToolDef, + _coerce_params, + _SCHEDULING_PROPS, +) + + +class TestSchedulingPropsInjection: + def test_schemas_contain_scheduling_fields(self): + schemas = get_tool_schemas() + assert len(schemas) > 0 + for s in schemas: + props = s.get("properties", {}) + assert "tool_call_alias" in props, f"Missing tool_call_alias in {s.get('name')}" + assert "depends_on" in props, f"Missing depends_on in {s.get('name')}" + + def test_scheduling_props_have_correct_types(self): + schemas = get_tool_schemas() + s = schemas[0] + assert s["properties"]["tool_call_alias"]["type"] == "string" + assert s["properties"]["depends_on"]["type"] == "array" + + def test_original_schema_not_mutated(self): + """Verify deepcopy prevents mutation of registered schemas.""" + schemas1 = get_tool_schemas() + schemas1[0]["properties"]["tool_call_alias"]["EXTRA"] = True + schemas2 = get_tool_schemas() + assert "EXTRA" not in schemas2[0]["properties"]["tool_call_alias"] + + +class TestCoerceParams: + def test_int_coercion(self): + schema = {"properties": {"limit": {"type": "integer"}}} + assert _coerce_params({"limit": "42"}, schema) == {"limit": 42} + + def test_float_coercion(self): + schema = {"properties": {"rate": {"type": "number"}}} + assert _coerce_params({"rate": "3.14"}, schema) == {"rate": 3.14} + + def test_bool_true(self): + schema = {"properties": {"flag": {"type": "boolean"}}} + assert _coerce_params({"flag": "true"}, schema) == {"flag": True} + + def test_bool_false(self): + schema = {"properties": {"flag": {"type": "boolean"}}} + assert _coerce_params({"flag": "false"}, schema) == {"flag": False} + + def test_array_coercion(self): + schema = {"properties": {"items": {"type": "array"}}} + result = _coerce_params({"items": '["a","b"]'}, schema) + assert result == {"items": ["a", "b"]} + + def test_object_coercion(self): + schema = {"properties": {"meta": {"type": "object"}}} + result = _coerce_params({"meta": '{"k": 1}'}, schema) + assert result == {"meta": {"k": 1}} + + def test_passthrough_string(self): + schema = {"properties": {"name": {"type": "string"}}} + assert _coerce_params({"name": "hello"}, schema) == {"name": "hello"} + + def test_invalid_json_passthrough(self): + schema = {"properties": {"items": {"type": "array"}}} + assert _coerce_params({"items": "not-json"}, schema) == {"items": "not-json"} + + def test_unknown_prop_passthrough(self): + schema = {"properties": {}} + assert _coerce_params({"x": "y"}, schema) == {"x": "y"} + + +class TestExecuteToolStripsScheduling: + def setup_method(self): + self._received = {} + + def _handler(params, config=None): + self._received = dict(params) + return "ok" + + register_tool(ToolDef( + name="test_sched_tool", + schema={ + "name": "test_sched_tool", + "description": "test tool", + "properties": {"msg": {"type": "string"}}, + }, + func=_handler, + read_only=True, + )) + + def test_scheduling_params_stripped(self): + result = execute_tool( + "test_sched_tool", + {"msg": "hi", "tool_call_alias": "t1", "depends_on": ["w1"]}, + ) + assert "tool_call_alias" not in self._received + assert "depends_on" not in self._received + assert self._received.get("msg") == "hi" From e4f4f4bab0706a06ba87dc68e9091ff0692ee994 Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Fri, 17 Apr 2026 21:59:00 +0200 Subject: [PATCH 08/14] feat: implement tool scheduling (depends_on, coerce_params) --- tool_registry.py | 76 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/tool_registry.py b/tool_registry.py index f0a66c2..8eea0c9 100644 --- a/tool_registry.py +++ b/tool_registry.py @@ -136,3 +136,79 @@ def execute_tool( def clear_registry() -> None: """Remove all registered tools. Intended for testing.""" _registry.clear() + + +# ── Tool scheduling support ──────────────────────────────────────────────── + +import copy as _copy +import json as _json + +_SCHEDULING_PROPS = { + "tool_call_alias": { + "type": "string", + "description": ( + "Optional alias for this tool call. " + "Other tools can reference it in depends_on." + ), + }, + "depends_on": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "List of tool_call IDs or aliases that must complete before this tool runs." + ), + }, +} + + +def _coerce_params(params: dict, schema: dict) -> dict: + """Coerce string parameter values to their schema-declared types.""" + props = schema.get("properties", {}) + result = {} + for key, value in params.items(): + prop_schema = props.get(key) + if prop_schema and isinstance(value, str): + ptype = prop_schema.get("type") + try: + if ptype == "integer": + value = int(value) + elif ptype == "number": + value = float(value) + elif ptype == "boolean": + value = value.lower() in ("true", "1", "yes") + elif ptype in ("array", "object"): + value = _json.loads(value) + except (ValueError, _json.JSONDecodeError): + pass + result[key] = value + return result + + +# Wrap get_tool_schemas to inject scheduling properties +_orig_get_tool_schemas = get_tool_schemas + + +def get_tool_schemas(): + """Return tool schemas with scheduling properties injected.""" + schemas = _orig_get_tool_schemas() + result = [] + for s in schemas: + s = _copy.deepcopy(s) + props = s.setdefault("properties", {}) + for k, v in _SCHEDULING_PROPS.items(): + props.setdefault(k, _copy.deepcopy(v)) + result.append(s) + return result + + +# Wrap execute_tool to strip scheduling params and coerce types +_orig_execute_tool = execute_tool + + +def execute_tool(name, params, *args, **kwargs): + """Execute a tool after stripping scheduling params and coercing types.""" + clean = {k: v for k, v in params.items() if k not in _SCHEDULING_PROPS} + tool = get_tool(name) + if tool is not None: + clean = _coerce_params(clean, tool.schema) + return _orig_execute_tool(name, clean, *args, **kwargs) From edcbf60bd302ba027694301edfcee0c1e35b5ffc Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Fri, 17 Apr 2026 22:00:59 +0200 Subject: [PATCH 09/14] fix: register tools before testing scheduling schema injection --- tests/test_tool_scheduling.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_tool_scheduling.py b/tests/test_tool_scheduling.py index 27281e0..4fc029e 100644 --- a/tests/test_tool_scheduling.py +++ b/tests/test_tool_scheduling.py @@ -8,6 +8,9 @@ _SCHEDULING_PROPS, ) +# Trigger builtin tool registration +import tools # noqa: F401 + class TestSchedulingPropsInjection: def test_schemas_contain_scheduling_fields(self): From 5129ee88fdc1382eec551ce3679acbd343f14091 Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Mon, 20 Apr 2026 08:39:31 +0200 Subject: [PATCH 10/14] refactor: _coerce_params dispatch table + e2e via agent.run Split _coerce_params (20 lines, nested try/except chain) into: - a small orchestrator that walks params and delegates, - four single-purpose coercers (_coerce_int / _coerce_float / _coerce_bool / _coerce_json) dispatched through a _COERCERS map. Each catching coercer still returns the original string on failure -- but the intent is now explicit via a comment ("tool handler reports the real type mismatch"), and the bare `except: pass` silent-pass pattern is gone. Also fix test_scheduling_params_stripped which called execute_tool without the required config arg; it has been failing since the pr4 branch landed. Add tests/test_tool_scheduling_e2e.py that drives agent.run with a mocked LLM: - assert every schema sent to the stream carries tool_call_alias + depends_on (proof the schema injection path is wired through the full agent loop, not just a unit helper); - register a "receiver" tool, let the LLM emit a tool_call with scheduling params + one real param, assert the scheduling params are gone and the real param reaches the handler. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_tool_scheduling.py | 3 +- tests/test_tool_scheduling_e2e.py | 102 ++++++++++++++++++++++++++++++ tool_registry.py | 72 +++++++++++++++------ 3 files changed, 157 insertions(+), 20 deletions(-) create mode 100644 tests/test_tool_scheduling_e2e.py diff --git a/tests/test_tool_scheduling.py b/tests/test_tool_scheduling.py index 4fc029e..96e067c 100644 --- a/tests/test_tool_scheduling.py +++ b/tests/test_tool_scheduling.py @@ -95,9 +95,10 @@ def _handler(params, config=None): )) def test_scheduling_params_stripped(self): - result = execute_tool( + execute_tool( "test_sched_tool", {"msg": "hi", "tool_call_alias": "t1", "depends_on": ["w1"]}, + config={}, ) assert "tool_call_alias" not in self._received assert "depends_on" not in self._received diff --git a/tests/test_tool_scheduling_e2e.py b/tests/test_tool_scheduling_e2e.py new file mode 100644 index 0000000..2a68746 --- /dev/null +++ b/tests/test_tool_scheduling_e2e.py @@ -0,0 +1,102 @@ +"""End-to-end: the LLM sees `tool_call_alias` + `depends_on` in every tool +schema, it uses them in a tool call, and the stripping wrapper removes those +fields before the tool handler runs. + +Only the LLM provider is mocked (via monkeypatching agent.stream). The tool +registry, schema injection and param stripping all run for real. +""" +from __future__ import annotations + +import pytest + +import tools as _tools_init # noqa: F401 - force built-in tool registration +from agent import AgentState, run +from providers import AssistantTurn +from tool_registry import ToolDef, register_tool + + +def _scripted_stream(captured_schemas, turns): + cursor = iter(turns) + + def fake_stream(**kwargs): + captured_schemas.append(kwargs.get("tool_schemas") or []) + spec = next(cursor) + yield AssistantTurn( + text=spec.get("text", ""), + tool_calls=spec.get("tool_calls") or [], + in_tokens=1, out_tokens=1, + ) + + return fake_stream + + +@pytest.fixture +def receiver_tool(): + """Register a tool that captures whatever params it receives.""" + received = {} + from tool_registry import _registry + had_before = "receiver" in _registry + register_tool(ToolDef( + name="receiver", + schema={ + "name": "receiver", + "description": "records params for assertions", + "input_schema": { + "type": "object", + "properties": {"msg": {"type": "string"}}, + "required": ["msg"], + }, + }, + func=lambda params, _cfg: received.setdefault("seen", dict(params)) and "ok", + read_only=True, concurrent_safe=True, + )) + yield received + if not had_before: + _registry.pop("receiver", None) + + +def test_schemas_sent_to_llm_include_scheduling_props(monkeypatch, receiver_tool): + """Every schema the LLM sees must carry tool_call_alias + depends_on.""" + captured = [] + monkeypatch.setattr( + "agent.stream", + _scripted_stream(captured, [{"text": "nothing to do"}]), + ) + + list(run("hi", AgentState(), {"model": "test", "permission_mode": "accept-all", + "_session_id": "sch", "disabled_tools": ["Agent"]}, + "sys")) + + assert captured, "stream was not called" + for schema in captured[0]: + props = schema.get("properties") or schema.get("input_schema", {}).get("properties", {}) + assert "tool_call_alias" in props, f"{schema.get('name')} missing tool_call_alias" + assert "depends_on" in props, f"{schema.get('name')} missing depends_on" + + +def test_scheduling_params_stripped_before_reaching_tool(monkeypatch, receiver_tool): + """tool_call_alias + depends_on must be gone by the time the handler runs.""" + captured_schemas = [] + turns = [ + {"tool_calls": [{ + "id": "r1", + "name": "receiver", + "input": { + "msg": "hello", + "tool_call_alias": "step-1", + "depends_on": ["w1", "w2"], + }, + }]}, + {"text": "done"}, + ] + monkeypatch.setattr("agent.stream", _scripted_stream(captured_schemas, turns)) + + list(run("go", AgentState(), {"model": "test", "permission_mode": "accept-all", + "_session_id": "sch2", "disabled_tools": ["Agent"]}, + "sys")) + + assert "seen" in receiver_tool, "receiver handler was never called" + seen = receiver_tool["seen"] + assert seen.get("msg") == "hello" + assert "tool_call_alias" not in seen + assert "depends_on" not in seen diff --git a/tool_registry.py b/tool_registry.py index 8eea0c9..16c6c08 100644 --- a/tool_registry.py +++ b/tool_registry.py @@ -162,26 +162,60 @@ def clear_registry() -> None: def _coerce_params(params: dict, schema: dict) -> dict: - """Coerce string parameter values to their schema-declared types.""" + """Coerce string parameter values to their schema-declared types. + + Coercion failure is not a hard error: the original string is kept and + passed to the tool handler, which will surface a clear type error to + the model (e.g. `expected int, got 'abc'`) far more usefully than a + ValueError from the registry wrapper. + """ props = schema.get("properties", {}) - result = {} - for key, value in params.items(): - prop_schema = props.get(key) - if prop_schema and isinstance(value, str): - ptype = prop_schema.get("type") - try: - if ptype == "integer": - value = int(value) - elif ptype == "number": - value = float(value) - elif ptype == "boolean": - value = value.lower() in ("true", "1", "yes") - elif ptype in ("array", "object"): - value = _json.loads(value) - except (ValueError, _json.JSONDecodeError): - pass - result[key] = value - return result + return {k: _coerce_value_for(k, v, props) for k, v in params.items()} + + +def _coerce_value_for(key: str, value, props: dict): + """Coerce a single value according to its declared type, else return as-is.""" + prop_schema = props.get(key) + if not prop_schema or not isinstance(value, str): + return value + coercer = _COERCERS.get(prop_schema.get("type")) + if coercer is None: + return value + return coercer(value) + + +def _coerce_int(value): + try: + return int(value) + except ValueError: + return value # intentional: tool handler reports the real type mismatch + + +def _coerce_float(value): + try: + return float(value) + except ValueError: + return value + + +def _coerce_bool(value): + return value.lower() in ("true", "1", "yes") + + +def _coerce_json(value): + try: + return _json.loads(value) + except (ValueError, _json.JSONDecodeError): + return value + + +_COERCERS = { + "integer": _coerce_int, + "number": _coerce_float, + "boolean": _coerce_bool, + "array": _coerce_json, + "object": _coerce_json, +} # Wrap get_tool_schemas to inject scheduling properties From 4e347b8ac218306c62f180901d62d98daeb6fb10 Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Tue, 21 Apr 2026 13:08:30 +0200 Subject: [PATCH 11/14] fix: strip coercion, add ID uniquify, fix input_schema injection --- agent.py | 25 +++-- coercion.py | 79 +++++++++++++++ id_uniquify.py | 78 +++++++++++++++ pyproject.toml | 2 + tests/test_coercion.py | 94 ++++++++++++++++++ tests/test_tool_scheduling.py | 159 +++++++++++++++++++++--------- tests/test_tool_scheduling_e2e.py | 45 +++++++++ tool_registry.py | 76 +++----------- 8 files changed, 441 insertions(+), 117 deletions(-) create mode 100644 coercion.py create mode 100644 id_uniquify.py create mode 100644 tests/test_coercion.py diff --git a/agent.py b/agent.py index bf37c75..592fc26 100644 --- a/agent.py +++ b/agent.py @@ -185,13 +185,6 @@ def run( if assistant_turn is None: break - # Record assistant turn in neutral format - state.messages.append({ - "role": "assistant", - "content": assistant_turn.text, - "tool_calls": assistant_turn.tool_calls, - }) - state.total_input_tokens += assistant_turn.in_tokens state.total_output_tokens += assistant_turn.out_tokens state.total_cache_read_tokens += getattr(assistant_turn, 'cache_read_tokens', 0) @@ -199,11 +192,29 @@ def run( yield TurnDone(assistant_turn.in_tokens, assistant_turn.out_tokens) if not assistant_turn.tool_calls: + # Record assistant turn and stop + state.messages.append({ + "role": "assistant", + "content": assistant_turn.text, + "tool_calls": assistant_turn.tool_calls, + }) break # No tools → conversation turn complete # ── Execute tools (parallel when safe) ──────────────────────────── tool_calls = assistant_turn.tool_calls + # Rewrite colliding tool_call ids BEFORE appending to state + # (so _collect_used_ids doesn't see this turn's own ids as taken) + from id_uniquify import uniquify_tool_call_ids + uniquify_tool_call_ids(tool_calls, state) + + # Record assistant turn in neutral format (after uniquify so ids match) + state.messages.append({ + "role": "assistant", + "content": assistant_turn.text, + "tool_calls": assistant_turn.tool_calls, + }) + # Check permissions first (must be sequential — may prompt user) permissions: dict[str, bool] = {} for tc in tool_calls: diff --git a/coercion.py b/coercion.py new file mode 100644 index 0000000..3d5f7a7 --- /dev/null +++ b/coercion.py @@ -0,0 +1,79 @@ +"""Parameter type coercion for LLM tool calls. + +LLMs sometimes send typed values as strings (e.g. ``"42"`` instead of ``42`` +for an integer property). This module coerces string parameters to their +schema-declared types so tool handlers receive the expected Python types. + +Coercion failure is intentionally *not* a hard error: the original string +is kept so the tool handler can surface a clear type mismatch to the model. +""" +from __future__ import annotations + +import json as _json + + +def coerce_params(params: dict, schema: dict) -> dict: + """Coerce string parameter values to their schema-declared types. + + Handles both schema styles: + - Top-level ``properties`` (rare, e.g. test fixtures) + - Anthropic-style ``input_schema.properties`` (all built-in tools) + """ + props = ( + schema.get("properties") + or schema.get("input_schema", {}).get("properties", {}) + ) + if not props: + return dict(params) + return {k: _coerce_value_for(k, v, props) for k, v in params.items()} + + +def _coerce_value_for(key: str, value, props: dict): + """Coerce a single value according to its declared type, else return as-is.""" + prop_schema = props.get(key) + if not prop_schema or not isinstance(value, str): + return value + coercer = _COERCERS.get(prop_schema.get("type")) + if coercer is None: + return value + return coercer(value) + + +def _coerce_int(value): + try: + return int(value) + except ValueError: + return value + + +def _coerce_float(value): + try: + return float(value) + except ValueError: + return value + + +def _coerce_bool(value): + """Coerce string to bool. Returns original string if unrecognised.""" + low = value.lower() + if low in ("true", "1", "yes"): + return True + if low in ("false", "0", "no"): + return False + return value # tool handler reports the real type mismatch + + +def _coerce_json(value): + try: + return _json.loads(value) + except (ValueError, _json.JSONDecodeError): + return value + + +_COERCERS = { + "integer": _coerce_int, + "number": _coerce_float, + "boolean": _coerce_bool, + "array": _coerce_json, + "object": _coerce_json, +} diff --git a/id_uniquify.py b/id_uniquify.py new file mode 100644 index 0000000..4265fa9 --- /dev/null +++ b/id_uniquify.py @@ -0,0 +1,78 @@ +"""Prevent tool_call_id collisions across turns. + +LLMs pick short ids (e.g. ``r1``, ``w2``) and reuse them freely across turns. +When a previous id is reused, it can cause API errors or wrong tool results +being associated with wrong calls. + +The fix: on ingest, rewrite any id that already exists in the message history +to ``t{turn}_{original}`` (with a numeric suffix if that still collides). +Same-turn ``depends_on`` references are rewritten in lockstep so the DAG +resolves correctly. +""" +from __future__ import annotations + + +def _collect_used_ids(state) -> set[str]: + """Gather all tool_call ids already present in state.messages.""" + used: set[str] = set() + for msg in state.messages: + role = msg.get("role") + if role == "assistant": + for tc in msg.get("tool_calls") or []: + tid = tc.get("id") + if tid: + used.add(tid) + elif role == "tool": + tid = msg.get("tool_call_id") + if tid: + used.add(tid) + return used + + +def _pick_fresh_id(original: str, turn: int, used: set[str]) -> str: + candidate = f"t{turn}_{original}" + if candidate not in used: + return candidate + suffix = 2 + while f"{candidate}_{suffix}" in used: + suffix += 1 + return f"{candidate}_{suffix}" + + +def uniquify_tool_call_ids(tool_calls: list, state) -> dict[str, str]: + """Rewrite colliding tool_call ids in-place and rewrite depends_on refs. + + Only ids that already exist in state.messages are remapped. New ids + pass through unchanged, preserving behaviour for simple sessions. + + Returns: + Mapping of ``{original_id: new_id}`` for any ids that were remapped. + """ + if not tool_calls: + return {} + + used = _collect_used_ids(state) + remap: dict[str, str] = {} + + for tc in tool_calls: + original = tc.get("id") + if not original or original not in used: + if original: + used.add(original) + continue + fresh = _pick_fresh_id(original, state.turn_count, used) + remap[original] = fresh + tc["id"] = fresh + used.add(fresh) + + if not remap: + return {} + + # Rewrite depends_on references inside same-turn tool calls + for tc in tool_calls: + params = tc.get("input") or {} + deps = params.get("depends_on") + if deps: + params["depends_on"] = [remap.get(d, d) for d in deps] + + return remap diff --git a/pyproject.toml b/pyproject.toml index 02c10ef..6b2cfdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,8 @@ py-modules = [ "cc_config", "context", "error_classifier", + "id_uniquify", + "coercion", "health", "jobs", "logging_utils", diff --git a/tests/test_coercion.py b/tests/test_coercion.py new file mode 100644 index 0000000..3f54dcd --- /dev/null +++ b/tests/test_coercion.py @@ -0,0 +1,94 @@ +"""Tests for parameter type coercion (coercion.py).""" +from coercion import coerce_params, _coerce_bool + + +class TestCoerceParams: + def test_int_coercion(self): + schema = {"properties": {"limit": {"type": "integer"}}} + assert coerce_params({"limit": "42"}, schema) == {"limit": 42} + + def test_float_coercion(self): + schema = {"properties": {"rate": {"type": "number"}}} + assert coerce_params({"rate": "3.14"}, schema) == {"rate": 3.14} + + def test_bool_true_variants(self): + schema = {"properties": {"flag": {"type": "boolean"}}} + for val in ("true", "True", "TRUE", "1", "yes", "Yes"): + assert coerce_params({"flag": val}, schema) == {"flag": True}, f"Failed for {val!r}" + + def test_bool_false_variants(self): + schema = {"properties": {"flag": {"type": "boolean"}}} + for val in ("false", "False", "FALSE", "0", "no", "No"): + assert coerce_params({"flag": val}, schema) == {"flag": False}, f"Failed for {val!r}" + + def test_bool_garbage_returns_original(self): + """Unknown boolean-like strings must pass through for tool handler to report.""" + schema = {"properties": {"flag": {"type": "boolean"}}} + result = coerce_params({"flag": "banana"}, schema) + assert result == {"flag": "banana"} + + def test_array_coercion(self): + schema = {"properties": {"items": {"type": "array"}}} + result = coerce_params({"items": '["a","b"]'}, schema) + assert result == {"items": ["a", "b"]} + + def test_object_coercion(self): + schema = {"properties": {"meta": {"type": "object"}}} + result = coerce_params({"meta": '{"k": 1}'}, schema) + assert result == {"meta": {"k": 1}} + + def test_passthrough_string(self): + schema = {"properties": {"name": {"type": "string"}}} + assert coerce_params({"name": "hello"}, schema) == {"name": "hello"} + + def test_invalid_json_passthrough(self): + schema = {"properties": {"items": {"type": "array"}}} + assert coerce_params({"items": "not-json"}, schema) == {"items": "not-json"} + + def test_invalid_int_passthrough(self): + schema = {"properties": {"limit": {"type": "integer"}}} + assert coerce_params({"limit": "abc"}, schema) == {"limit": "abc"} + + def test_unknown_prop_passthrough(self): + schema = {"properties": {}} + assert coerce_params({"x": "y"}, schema) == {"x": "y"} + + def test_non_string_passthrough(self): + """Already-typed values must not be touched.""" + schema = {"properties": {"limit": {"type": "integer"}}} + assert coerce_params({"limit": 42}, schema) == {"limit": 42} + + def test_input_schema_style(self): + """Anthropic-style schemas with input_schema must be handled.""" + schema = { + "name": "receiver", + "input_schema": { + "type": "object", + "properties": { + "count": {"type": "integer"}, + "msg": {"type": "string"}, + }, + }, + } + result = coerce_params({"count": "5", "msg": "hi"}, schema) + assert result == {"count": 5, "msg": "hi"} + + def test_empty_schema(self): + assert coerce_params({"x": "y"}, {}) == {"x": "y"} + + +class TestCoerceBool: + def test_true_values(self): + assert _coerce_bool("true") is True + assert _coerce_bool("1") is True + assert _coerce_bool("yes") is True + + def test_false_values(self): + assert _coerce_bool("false") is False + assert _coerce_bool("0") is False + assert _coerce_bool("no") is False + + def test_garbage_returns_original(self): + assert _coerce_bool("banana") == "banana" + assert _coerce_bool("maybe") == "maybe" + assert _coerce_bool("") == "" diff --git a/tests/test_tool_scheduling.py b/tests/test_tool_scheduling.py index 96e067c..6a5765d 100644 --- a/tests/test_tool_scheduling.py +++ b/tests/test_tool_scheduling.py @@ -1,12 +1,12 @@ -"""Tests for tool scheduling (depends_on, tool_call_alias) and param coercion.""" +"""Tests for tool scheduling (depends_on, tool_call_alias) and ID uniquification.""" from tool_registry import ( get_tool_schemas, execute_tool, register_tool, ToolDef, - _coerce_params, _SCHEDULING_PROPS, ) +from id_uniquify import uniquify_tool_call_ids, _collect_used_ids # Trigger builtin tool registration import tools # noqa: F401 @@ -17,62 +17,37 @@ def test_schemas_contain_scheduling_fields(self): schemas = get_tool_schemas() assert len(schemas) > 0 for s in schemas: - props = s.get("properties", {}) + # Handle both schema styles + props = s.get("properties") or s.get("input_schema", {}).get("properties", {}) assert "tool_call_alias" in props, f"Missing tool_call_alias in {s.get('name')}" assert "depends_on" in props, f"Missing depends_on in {s.get('name')}" def test_scheduling_props_have_correct_types(self): schemas = get_tool_schemas() s = schemas[0] - assert s["properties"]["tool_call_alias"]["type"] == "string" - assert s["properties"]["depends_on"]["type"] == "array" + props = s.get("properties") or s.get("input_schema", {}).get("properties", {}) + assert props["tool_call_alias"]["type"] == "string" + assert props["depends_on"]["type"] == "array" def test_original_schema_not_mutated(self): """Verify deepcopy prevents mutation of registered schemas.""" schemas1 = get_tool_schemas() - schemas1[0]["properties"]["tool_call_alias"]["EXTRA"] = True + s1 = schemas1[0] + props1 = s1.get("properties") or s1.get("input_schema", {}).get("properties", {}) + props1["tool_call_alias"]["EXTRA"] = True schemas2 = get_tool_schemas() - assert "EXTRA" not in schemas2[0]["properties"]["tool_call_alias"] + s2 = schemas2[0] + props2 = s2.get("properties") or s2.get("input_schema", {}).get("properties", {}) + assert "EXTRA" not in props2["tool_call_alias"] - -class TestCoerceParams: - def test_int_coercion(self): - schema = {"properties": {"limit": {"type": "integer"}}} - assert _coerce_params({"limit": "42"}, schema) == {"limit": 42} - - def test_float_coercion(self): - schema = {"properties": {"rate": {"type": "number"}}} - assert _coerce_params({"rate": "3.14"}, schema) == {"rate": 3.14} - - def test_bool_true(self): - schema = {"properties": {"flag": {"type": "boolean"}}} - assert _coerce_params({"flag": "true"}, schema) == {"flag": True} - - def test_bool_false(self): - schema = {"properties": {"flag": {"type": "boolean"}}} - assert _coerce_params({"flag": "false"}, schema) == {"flag": False} - - def test_array_coercion(self): - schema = {"properties": {"items": {"type": "array"}}} - result = _coerce_params({"items": '["a","b"]'}, schema) - assert result == {"items": ["a", "b"]} - - def test_object_coercion(self): - schema = {"properties": {"meta": {"type": "object"}}} - result = _coerce_params({"meta": '{"k": 1}'}, schema) - assert result == {"meta": {"k": 1}} - - def test_passthrough_string(self): - schema = {"properties": {"name": {"type": "string"}}} - assert _coerce_params({"name": "hello"}, schema) == {"name": "hello"} - - def test_invalid_json_passthrough(self): - schema = {"properties": {"items": {"type": "array"}}} - assert _coerce_params({"items": "not-json"}, schema) == {"items": "not-json"} - - def test_unknown_prop_passthrough(self): - schema = {"properties": {}} - assert _coerce_params({"x": "y"}, schema) == {"x": "y"} + def test_input_schema_style_gets_scheduling_in_right_place(self): + """Built-in tools use input_schema; scheduling props must land there.""" + schemas = get_tool_schemas() + read_schema = next(s for s in schemas if s["name"] == "Read") + # Must be inside input_schema.properties, NOT top-level properties + assert "tool_call_alias" in read_schema["input_schema"]["properties"] + assert "depends_on" in read_schema["input_schema"]["properties"] + assert "properties" not in read_schema or "tool_call_alias" not in read_schema.get("properties", {}) class TestExecuteToolStripsScheduling: @@ -103,3 +78,95 @@ def test_scheduling_params_stripped(self): assert "tool_call_alias" not in self._received assert "depends_on" not in self._received assert self._received.get("msg") == "hi" + + +class TestIdUniquify: + """Tests for tool_call_id collision prevention.""" + + def _make_state(self, messages=None): + """Create a minimal state object with messages and turn_count.""" + class _State: + pass + s = _State() + s.messages = messages or [] + s.turn_count = 2 + return s + + def test_fresh_ids_pass_through(self): + state = self._make_state() + tcs = [{"id": "r1", "name": "Read", "input": {}}] + remap = uniquify_tool_call_ids(tcs, state) + assert remap == {} + assert tcs[0]["id"] == "r1" + + def test_colliding_id_remapped(self): + state = self._make_state([ + {"role": "assistant", "content": "", "tool_calls": [{"id": "r1", "name": "Read", "input": {}}]}, + {"role": "tool", "tool_call_id": "r1", "name": "Read", "content": "data"}, + ]) + tcs = [{"id": "r1", "name": "Read", "input": {}}] + remap = uniquify_tool_call_ids(tcs, state) + assert remap == {"r1": "t2_r1"} + assert tcs[0]["id"] == "t2_r1" + + def test_depends_on_rewritten(self): + state = self._make_state([ + {"role": "assistant", "content": "", "tool_calls": [ + {"id": "w1", "name": "Write", "input": {}}, + {"id": "r1", "name": "Read", "input": {}}, + ]}, + {"role": "tool", "tool_call_id": "w1", "name": "Write", "content": "ok"}, + {"role": "tool", "tool_call_id": "r1", "name": "Read", "content": "data"}, + ]) + tcs = [ + {"id": "w1", "name": "Write", "input": {}}, + {"id": "r1", "name": "Read", "input": {"depends_on": ["w1"]}}, + ] + remap = uniquify_tool_call_ids(tcs, state) + assert tcs[0]["id"] == "t2_w1" + assert tcs[1]["id"] == "t2_r1" + # depends_on must be rewritten to match the new w1 id + assert tcs[1]["input"]["depends_on"] == ["t2_w1"] + + def test_multiple_collisions_get_numeric_suffix(self): + state = self._make_state([ + {"role": "assistant", "content": "", "tool_calls": [{"id": "r1", "name": "Read", "input": {}}]}, + {"role": "tool", "tool_call_id": "r1", "name": "Read", "content": ""}, + {"role": "assistant", "content": "", "tool_calls": [{"id": "t2_r1", "name": "Read", "input": {}}]}, + {"role": "tool", "tool_call_id": "t2_r1", "name": "Read", "content": ""}, + ]) + tcs = [{"id": "r1", "name": "Read", "input": {}}] + remap = uniquify_tool_call_ids(tcs, state) + # t2_r1 is taken, so should get t2_r1_2 + assert tcs[0]["id"] == "t2_r1_2" + + def test_collect_used_ids(self): + state = self._make_state([ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "", "tool_calls": [ + {"id": "a1", "name": "Read", "input": {}}, + {"id": "a2", "name": "Write", "input": {}}, + ]}, + {"role": "tool", "tool_call_id": "a1", "name": "Read", "content": "x"}, + {"role": "tool", "tool_call_id": "a2", "name": "Write", "content": "y"}, + ]) + used = _collect_used_ids(state) + assert used == {"a1", "a2"} + + def test_empty_tool_calls(self): + state = self._make_state() + assert uniquify_tool_call_ids([], state) == {} + + def test_mixed_fresh_and_colliding(self): + state = self._make_state([ + {"role": "assistant", "content": "", "tool_calls": [{"id": "r1", "name": "Read", "input": {}}]}, + {"role": "tool", "tool_call_id": "r1", "name": "Read", "content": ""}, + ]) + tcs = [ + {"id": "r1", "name": "Read", "input": {}}, + {"id": "w1", "name": "Write", "input": {"depends_on": ["r1"]}}, + ] + remap = uniquify_tool_call_ids(tcs, state) + assert tcs[0]["id"] == "t2_r1" # remapped + assert tcs[1]["id"] == "w1" # fresh, untouched + assert tcs[1]["input"]["depends_on"] == ["t2_r1"] # ref rewritten diff --git a/tests/test_tool_scheduling_e2e.py b/tests/test_tool_scheduling_e2e.py index 2a68746..3b76174 100644 --- a/tests/test_tool_scheduling_e2e.py +++ b/tests/test_tool_scheduling_e2e.py @@ -100,3 +100,48 @@ def test_scheduling_params_stripped_before_reaching_tool(monkeypatch, receiver_t assert seen.get("msg") == "hello" assert "tool_call_alias" not in seen assert "depends_on" not in seen + + +def test_id_reuse_across_turns_gets_remapped(monkeypatch, receiver_tool): + """When LLM reuses an id from a prior turn, uniquify rewrites it.""" + captured_schemas = [] + turns = [ + # Turn 1: tool call with id "r1" + {"tool_calls": [{ + "id": "r1", + "name": "receiver", + "input": {"msg": "turn1"}, + }]}, + # Turn 2: LLM reuses "r1" — uniquify must remap + {"tool_calls": [{ + "id": "r1", + "name": "receiver", + "input": {"msg": "turn2"}, + }]}, + {"text": "done"}, + ] + monkeypatch.setattr("agent.stream", _scripted_stream(captured_schemas, turns)) + + state = AgentState() + events = list(run("go", state, {"model": "test", "permission_mode": "accept-all", + "_session_id": "sch3", "disabled_tools": ["Agent"]}, + "sys")) + + # Collect tool_call_ids from assistant turns only + assistant_ids = [] + for msg in state.messages: + if msg["role"] == "assistant": + for tc in msg.get("tool_calls") or []: + assistant_ids.append(tc["id"]) + + # Both tool calls must exist and have UNIQUE ids + assert len(assistant_ids) == 2, f"Expected 2 tool calls, got {assistant_ids}" + assert assistant_ids[0] != assistant_ids[1], ( + f"IDs must be unique across turns but got duplicates: {assistant_ids}" + ) + + # Tool results must match their corresponding assistant tool_call ids + tool_results = [m for m in state.messages if m["role"] == "tool"] + assert len(tool_results) == 2 + assert tool_results[0]["tool_call_id"] == assistant_ids[0] + assert tool_results[1]["tool_call_id"] == assistant_ids[1] diff --git a/tool_registry.py b/tool_registry.py index 16c6c08..fc2d4a2 100644 --- a/tool_registry.py +++ b/tool_registry.py @@ -141,7 +141,6 @@ def clear_registry() -> None: # ── Tool scheduling support ──────────────────────────────────────────────── import copy as _copy -import json as _json _SCHEDULING_PROPS = { "tool_call_alias": { @@ -161,74 +160,26 @@ def clear_registry() -> None: } -def _coerce_params(params: dict, schema: dict) -> dict: - """Coerce string parameter values to their schema-declared types. - - Coercion failure is not a hard error: the original string is kept and - passed to the tool handler, which will surface a clear type error to - the model (e.g. `expected int, got 'abc'`) far more usefully than a - ValueError from the registry wrapper. - """ - props = schema.get("properties", {}) - return {k: _coerce_value_for(k, v, props) for k, v in params.items()} - - -def _coerce_value_for(key: str, value, props: dict): - """Coerce a single value according to its declared type, else return as-is.""" - prop_schema = props.get(key) - if not prop_schema or not isinstance(value, str): - return value - coercer = _COERCERS.get(prop_schema.get("type")) - if coercer is None: - return value - return coercer(value) - - -def _coerce_int(value): - try: - return int(value) - except ValueError: - return value # intentional: tool handler reports the real type mismatch - - -def _coerce_float(value): - try: - return float(value) - except ValueError: - return value - - -def _coerce_bool(value): - return value.lower() in ("true", "1", "yes") - - -def _coerce_json(value): - try: - return _json.loads(value) - except (ValueError, _json.JSONDecodeError): - return value - - -_COERCERS = { - "integer": _coerce_int, - "number": _coerce_float, - "boolean": _coerce_bool, - "array": _coerce_json, - "object": _coerce_json, -} - - # Wrap get_tool_schemas to inject scheduling properties _orig_get_tool_schemas = get_tool_schemas def get_tool_schemas(): - """Return tool schemas with scheduling properties injected.""" + """Return tool schemas with scheduling properties injected. + + Handles both schema styles: + - Top-level ``properties`` (rare, e.g. test fixtures) + - Anthropic-style ``input_schema.properties`` (all built-in tools) + """ schemas = _orig_get_tool_schemas() result = [] for s in schemas: s = _copy.deepcopy(s) - props = s.setdefault("properties", {}) + # Detect where properties live + if "input_schema" in s: + props = s["input_schema"].setdefault("properties", {}) + else: + props = s.setdefault("properties", {}) for k, v in _SCHEDULING_PROPS.items(): props.setdefault(k, _copy.deepcopy(v)) result.append(s) @@ -240,9 +191,6 @@ def get_tool_schemas(): def execute_tool(name, params, *args, **kwargs): - """Execute a tool after stripping scheduling params and coercing types.""" + """Execute a tool after stripping scheduling params.""" clean = {k: v for k, v in params.items() if k not in _SCHEDULING_PROPS} - tool = get_tool(name) - if tool is not None: - clean = _coerce_params(clean, tool.schema) return _orig_execute_tool(name, clean, *args, **kwargs) From 757b9a4801ab4998d93f3cbe8af21bcde2598d49 Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Tue, 21 Apr 2026 14:22:08 +0200 Subject: [PATCH 12/14] fix(tests): monkeypatch quota.record_usage in scheduling e2e tests --- tests/test_tool_scheduling.py | 18 ++++++++++++++++-- tests/test_tool_scheduling_e2e.py | 6 ++++++ tool_registry.py | 1 + 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/tests/test_tool_scheduling.py b/tests/test_tool_scheduling.py index 6a5765d..8b322dd 100644 --- a/tests/test_tool_scheduling.py +++ b/tests/test_tool_scheduling.py @@ -1,17 +1,28 @@ """Tests for tool scheduling (depends_on, tool_call_alias) and ID uniquification.""" +import pytest + from tool_registry import ( get_tool_schemas, execute_tool, register_tool, + clear_registry, ToolDef, _SCHEDULING_PROPS, ) from id_uniquify import uniquify_tool_call_ids, _collect_used_ids -# Trigger builtin tool registration import tools # noqa: F401 +@pytest.fixture(autouse=True) +def _ensure_builtins(): + """Guarantee builtins are registered even after another test cleared the registry.""" + clear_registry() + tools._register_builtins() + yield + clear_registry() + + class TestSchedulingPropsInjection: def test_schemas_contain_scheduling_fields(self): schemas = get_tool_schemas() @@ -63,7 +74,10 @@ def _handler(params, config=None): schema={ "name": "test_sched_tool", "description": "test tool", - "properties": {"msg": {"type": "string"}}, + "input_schema": { + "type": "object", + "properties": {"msg": {"type": "string"}}, + }, }, func=_handler, read_only=True, diff --git a/tests/test_tool_scheduling_e2e.py b/tests/test_tool_scheduling_e2e.py index 3b76174..06d1169 100644 --- a/tests/test_tool_scheduling_e2e.py +++ b/tests/test_tool_scheduling_e2e.py @@ -15,6 +15,12 @@ from tool_registry import ToolDef, register_tool +@pytest.fixture(autouse=True) +def _no_quota(monkeypatch): + """Disable quota tracking — these tests exercise scheduling, not billing.""" + monkeypatch.setattr("quota.record_usage", lambda *a, **kw: None) + + def _scripted_stream(captured_schemas, turns): cursor = iter(turns) diff --git a/tool_registry.py b/tool_registry.py index fc2d4a2..9c44cca 100644 --- a/tool_registry.py +++ b/tool_registry.py @@ -193,4 +193,5 @@ def get_tool_schemas(): def execute_tool(name, params, *args, **kwargs): """Execute a tool after stripping scheduling params.""" clean = {k: v for k, v in params.items() if k not in _SCHEDULING_PROPS} + return _orig_execute_tool(name, clean, *args, **kwargs) From 805b0242c69cffe077b5d725b862557a659cc93f Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Tue, 21 Apr 2026 14:36:40 +0200 Subject: [PATCH 13/14] refactor(tests): mutualise e2e helpers into tests/conftest.py --- tests/conftest.py | 28 ++++++++++++++ tests/test_tool_scheduling_e2e.py | 62 ++----------------------------- 2 files changed, 32 insertions(+), 58 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f935fc6..b8617c0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,8 @@ import pytest +from tool_registry import ToolDef, register_tool, _registry + # --------------- quota stub (avoids ImportError on CI for calc_cost) -------- @@ -12,3 +14,29 @@ def _no_quota(monkeypatch): """Disable quota.record_usage so tests never hit the real billing path.""" import quota monkeypatch.setattr(quota, "record_usage", lambda *a, **kw: None) + + +# --------------- receiver tool fixture ------------------------------------- + +@pytest.fixture +def receiver_tool(): + """Register a tool that captures whatever params it receives.""" + received = {} + had_before = "receiver" in _registry + register_tool(ToolDef( + name="receiver", + schema={ + "name": "receiver", + "description": "records params for assertions", + "input_schema": { + "type": "object", + "properties": {"msg": {"type": "string"}}, + "required": ["msg"], + }, + }, + func=lambda params, _cfg: received.setdefault("seen", dict(params)) and "ok", + read_only=True, concurrent_safe=True, + )) + yield received + if not had_before: + _registry.pop("receiver", None) diff --git a/tests/test_tool_scheduling_e2e.py b/tests/test_tool_scheduling_e2e.py index 06d1169..7873b2e 100644 --- a/tests/test_tool_scheduling_e2e.py +++ b/tests/test_tool_scheduling_e2e.py @@ -7,58 +7,9 @@ """ from __future__ import annotations -import pytest - import tools as _tools_init # noqa: F401 - force built-in tool registration from agent import AgentState, run -from providers import AssistantTurn -from tool_registry import ToolDef, register_tool - - -@pytest.fixture(autouse=True) -def _no_quota(monkeypatch): - """Disable quota tracking — these tests exercise scheduling, not billing.""" - monkeypatch.setattr("quota.record_usage", lambda *a, **kw: None) - - -def _scripted_stream(captured_schemas, turns): - cursor = iter(turns) - - def fake_stream(**kwargs): - captured_schemas.append(kwargs.get("tool_schemas") or []) - spec = next(cursor) - yield AssistantTurn( - text=spec.get("text", ""), - tool_calls=spec.get("tool_calls") or [], - in_tokens=1, out_tokens=1, - ) - - return fake_stream - - -@pytest.fixture -def receiver_tool(): - """Register a tool that captures whatever params it receives.""" - received = {} - from tool_registry import _registry - had_before = "receiver" in _registry - register_tool(ToolDef( - name="receiver", - schema={ - "name": "receiver", - "description": "records params for assertions", - "input_schema": { - "type": "object", - "properties": {"msg": {"type": "string"}}, - "required": ["msg"], - }, - }, - func=lambda params, _cfg: received.setdefault("seen", dict(params)) and "ok", - read_only=True, concurrent_safe=True, - )) - yield received - if not had_before: - _registry.pop("receiver", None) +from helpers import scripted_stream def test_schemas_sent_to_llm_include_scheduling_props(monkeypatch, receiver_tool): @@ -66,7 +17,7 @@ def test_schemas_sent_to_llm_include_scheduling_props(monkeypatch, receiver_tool captured = [] monkeypatch.setattr( "agent.stream", - _scripted_stream(captured, [{"text": "nothing to do"}]), + scripted_stream(captured, [{"text": "nothing to do"}]), ) list(run("hi", AgentState(), {"model": "test", "permission_mode": "accept-all", @@ -95,7 +46,7 @@ def test_scheduling_params_stripped_before_reaching_tool(monkeypatch, receiver_t }]}, {"text": "done"}, ] - monkeypatch.setattr("agent.stream", _scripted_stream(captured_schemas, turns)) + monkeypatch.setattr("agent.stream", scripted_stream(captured_schemas, turns)) list(run("go", AgentState(), {"model": "test", "permission_mode": "accept-all", "_session_id": "sch2", "disabled_tools": ["Agent"]}, @@ -112,13 +63,11 @@ def test_id_reuse_across_turns_gets_remapped(monkeypatch, receiver_tool): """When LLM reuses an id from a prior turn, uniquify rewrites it.""" captured_schemas = [] turns = [ - # Turn 1: tool call with id "r1" {"tool_calls": [{ "id": "r1", "name": "receiver", "input": {"msg": "turn1"}, }]}, - # Turn 2: LLM reuses "r1" — uniquify must remap {"tool_calls": [{ "id": "r1", "name": "receiver", @@ -126,27 +75,24 @@ def test_id_reuse_across_turns_gets_remapped(monkeypatch, receiver_tool): }]}, {"text": "done"}, ] - monkeypatch.setattr("agent.stream", _scripted_stream(captured_schemas, turns)) + monkeypatch.setattr("agent.stream", scripted_stream(captured_schemas, turns)) state = AgentState() events = list(run("go", state, {"model": "test", "permission_mode": "accept-all", "_session_id": "sch3", "disabled_tools": ["Agent"]}, "sys")) - # Collect tool_call_ids from assistant turns only assistant_ids = [] for msg in state.messages: if msg["role"] == "assistant": for tc in msg.get("tool_calls") or []: assistant_ids.append(tc["id"]) - # Both tool calls must exist and have UNIQUE ids assert len(assistant_ids) == 2, f"Expected 2 tool calls, got {assistant_ids}" assert assistant_ids[0] != assistant_ids[1], ( f"IDs must be unique across turns but got duplicates: {assistant_ids}" ) - # Tool results must match their corresponding assistant tool_call ids tool_results = [m for m in state.messages if m["role"] == "tool"] assert len(tool_results) == 2 assert tool_results[0]["tool_call_id"] == assistant_ids[0] From 362a04b963180807553eb4d21ad0425c75f33be4 Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Tue, 21 Apr 2026 15:07:46 +0200 Subject: [PATCH 14/14] fix: use scripted_stream as pytest fixture, remove direct import --- tests/helpers.py | 27 --------------------------- tests/test_tool_scheduling_e2e.py | 3 +-- 2 files changed, 1 insertion(+), 29 deletions(-) delete mode 100644 tests/helpers.py diff --git a/tests/helpers.py b/tests/helpers.py deleted file mode 100644 index 7d37a79..0000000 --- a/tests/helpers.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Reusable test helpers (importable from any test module).""" - -from __future__ import annotations - -from agent import AssistantTurn - - -def scripted_stream(captured_schemas: list, turns: list[dict]): - """Return a fake ``stream()`` callable that yields pre-defined turns. - - *captured_schemas* receives the ``tool_schemas`` kwarg from each call, - letting tests assert on schema injection. *turns* is a list of dicts, - each with optional ``text`` and ``tool_calls`` keys. - """ - cursor = iter(turns) - - def fake_stream(**kwargs): - captured_schemas.append(kwargs.get("tool_schemas") or []) - spec = next(cursor) - yield AssistantTurn( - text=spec.get("text", ""), - tool_calls=spec.get("tool_calls") or [], - in_tokens=1, - out_tokens=1, - ) - - return fake_stream diff --git a/tests/test_tool_scheduling_e2e.py b/tests/test_tool_scheduling_e2e.py index 7873b2e..cb1230a 100644 --- a/tests/test_tool_scheduling_e2e.py +++ b/tests/test_tool_scheduling_e2e.py @@ -9,10 +9,9 @@ import tools as _tools_init # noqa: F401 - force built-in tool registration from agent import AgentState, run -from helpers import scripted_stream -def test_schemas_sent_to_llm_include_scheduling_props(monkeypatch, receiver_tool): +def test_schemas_sent_to_llm_include_scheduling_props(monkeypatch, receiver_tool, scripted_stream): """Every schema the LLM sees must carry tool_call_alias + depends_on.""" captured = [] monkeypatch.setattr(