diff --git a/agent.py b/agent.py index bf37c75..592fc26 100644 --- a/agent.py +++ b/agent.py @@ -185,13 +185,6 @@ def run( if assistant_turn is None: break - # Record assistant turn in neutral format - state.messages.append({ - "role": "assistant", - "content": assistant_turn.text, - "tool_calls": assistant_turn.tool_calls, - }) - state.total_input_tokens += assistant_turn.in_tokens state.total_output_tokens += assistant_turn.out_tokens state.total_cache_read_tokens += getattr(assistant_turn, 'cache_read_tokens', 0) @@ -199,11 +192,29 @@ def run( yield TurnDone(assistant_turn.in_tokens, assistant_turn.out_tokens) if not assistant_turn.tool_calls: + # Record assistant turn and stop + state.messages.append({ + "role": "assistant", + "content": assistant_turn.text, + "tool_calls": assistant_turn.tool_calls, + }) break # No tools → conversation turn complete # ── Execute tools (parallel when safe) ──────────────────────────── tool_calls = assistant_turn.tool_calls + # Rewrite colliding tool_call ids BEFORE appending to state + # (so _collect_used_ids doesn't see this turn's own ids as taken) + from id_uniquify import uniquify_tool_call_ids + uniquify_tool_call_ids(tool_calls, state) + + # Record assistant turn in neutral format (after uniquify so ids match) + state.messages.append({ + "role": "assistant", + "content": assistant_turn.text, + "tool_calls": assistant_turn.tool_calls, + }) + # Check permissions first (must be sequential — may prompt user) permissions: dict[str, bool] = {} for tc in tool_calls: diff --git a/checkpoint/store.py b/checkpoint/store.py index ec770fc..93f3e24 100644 --- a/checkpoint/store.py +++ b/checkpoint/store.py @@ -12,6 +12,7 @@ import json import os import shutil +import sys import time from datetime import datetime, timedelta from pathlib import Path @@ -97,7 +98,7 @@ def track_file_edit(session_id: str, file_path: str) -> str | None: except OSError: return None if size > _MAX_FILE_SIZE: - print(f"[checkpoint] skipping large file ({size} bytes): {file_path}") + print(f"[checkpoint] skipping large file ({size} bytes): {file_path}", file=sys.stderr) return None # Copy file to backups/ @@ -107,7 +108,7 @@ def track_file_edit(session_id: str, file_path: str) -> str | None: try: shutil.copy2(str(p), str(backup_path)) except Exception as e: - print(f"[checkpoint] backup failed for {file_path}: {e}") + print(f"[checkpoint] backup failed for {file_path}: {e}", file=sys.stderr) return None return backup_name diff --git a/coercion.py b/coercion.py new file mode 100644 index 0000000..3d5f7a7 --- /dev/null +++ b/coercion.py @@ -0,0 +1,79 @@ +"""Parameter type coercion for LLM tool calls. + +LLMs sometimes send typed values as strings (e.g. ``"42"`` instead of ``42`` +for an integer property). This module coerces string parameters to their +schema-declared types so tool handlers receive the expected Python types. + +Coercion failure is intentionally *not* a hard error: the original string +is kept so the tool handler can surface a clear type mismatch to the model. +""" +from __future__ import annotations + +import json as _json + + +def coerce_params(params: dict, schema: dict) -> dict: + """Coerce string parameter values to their schema-declared types. + + Handles both schema styles: + - Top-level ``properties`` (rare, e.g. test fixtures) + - Anthropic-style ``input_schema.properties`` (all built-in tools) + """ + props = ( + schema.get("properties") + or schema.get("input_schema", {}).get("properties", {}) + ) + if not props: + return dict(params) + return {k: _coerce_value_for(k, v, props) for k, v in params.items()} + + +def _coerce_value_for(key: str, value, props: dict): + """Coerce a single value according to its declared type, else return as-is.""" + prop_schema = props.get(key) + if not prop_schema or not isinstance(value, str): + return value + coercer = _COERCERS.get(prop_schema.get("type")) + if coercer is None: + return value + return coercer(value) + + +def _coerce_int(value): + try: + return int(value) + except ValueError: + return value + + +def _coerce_float(value): + try: + return float(value) + except ValueError: + return value + + +def _coerce_bool(value): + """Coerce string to bool. Returns original string if unrecognised.""" + low = value.lower() + if low in ("true", "1", "yes"): + return True + if low in ("false", "0", "no"): + return False + return value # tool handler reports the real type mismatch + + +def _coerce_json(value): + try: + return _json.loads(value) + except (ValueError, _json.JSONDecodeError): + return value + + +_COERCERS = { + "integer": _coerce_int, + "number": _coerce_float, + "boolean": _coerce_bool, + "array": _coerce_json, + "object": _coerce_json, +} diff --git a/id_uniquify.py b/id_uniquify.py new file mode 100644 index 0000000..4265fa9 --- /dev/null +++ b/id_uniquify.py @@ -0,0 +1,78 @@ +"""Prevent tool_call_id collisions across turns. + +LLMs pick short ids (e.g. ``r1``, ``w2``) and reuse them freely across turns. +When a previous id is reused, it can cause API errors or wrong tool results +being associated with wrong calls. + +The fix: on ingest, rewrite any id that already exists in the message history +to ``t{turn}_{original}`` (with a numeric suffix if that still collides). +Same-turn ``depends_on`` references are rewritten in lockstep so the DAG +resolves correctly. +""" +from __future__ import annotations + + +def _collect_used_ids(state) -> set[str]: + """Gather all tool_call ids already present in state.messages.""" + used: set[str] = set() + for msg in state.messages: + role = msg.get("role") + if role == "assistant": + for tc in msg.get("tool_calls") or []: + tid = tc.get("id") + if tid: + used.add(tid) + elif role == "tool": + tid = msg.get("tool_call_id") + if tid: + used.add(tid) + return used + + +def _pick_fresh_id(original: str, turn: int, used: set[str]) -> str: + candidate = f"t{turn}_{original}" + if candidate not in used: + return candidate + suffix = 2 + while f"{candidate}_{suffix}" in used: + suffix += 1 + return f"{candidate}_{suffix}" + + +def uniquify_tool_call_ids(tool_calls: list, state) -> dict[str, str]: + """Rewrite colliding tool_call ids in-place and rewrite depends_on refs. + + Only ids that already exist in state.messages are remapped. New ids + pass through unchanged, preserving behaviour for simple sessions. + + Returns: + Mapping of ``{original_id: new_id}`` for any ids that were remapped. + """ + if not tool_calls: + return {} + + used = _collect_used_ids(state) + remap: dict[str, str] = {} + + for tc in tool_calls: + original = tc.get("id") + if not original or original not in used: + if original: + used.add(original) + continue + fresh = _pick_fresh_id(original, state.turn_count, used) + remap[original] = fresh + tc["id"] = fresh + used.add(fresh) + + if not remap: + return {} + + # Rewrite depends_on references inside same-turn tool calls + for tc in tool_calls: + params = tc.get("input") or {} + deps = params.get("depends_on") + if deps: + params["depends_on"] = [remap.get(d, d) for d in deps] + + return remap diff --git a/pyproject.toml b/pyproject.toml index 02c10ef..6b2cfdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,8 @@ py-modules = [ "cc_config", "context", "error_classifier", + "id_uniquify", + "coercion", "health", "jobs", "logging_utils", diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..b8617c0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,42 @@ +"""Shared pytest fixtures for all tests.""" + +from __future__ import annotations + +import pytest + +from tool_registry import ToolDef, register_tool, _registry + + +# --------------- quota stub (avoids ImportError on CI for calc_cost) -------- + +@pytest.fixture(autouse=True) +def _no_quota(monkeypatch): + """Disable quota.record_usage so tests never hit the real billing path.""" + import quota + monkeypatch.setattr(quota, "record_usage", lambda *a, **kw: None) + + +# --------------- receiver tool fixture ------------------------------------- + +@pytest.fixture +def receiver_tool(): + """Register a tool that captures whatever params it receives.""" + received = {} + had_before = "receiver" in _registry + register_tool(ToolDef( + name="receiver", + schema={ + "name": "receiver", + "description": "records params for assertions", + "input_schema": { + "type": "object", + "properties": {"msg": {"type": "string"}}, + "required": ["msg"], + }, + }, + func=lambda params, _cfg: received.setdefault("seen", dict(params)) and "ok", + read_only=True, concurrent_safe=True, + )) + yield received + if not had_before: + _registry.pop("receiver", None) diff --git a/tests/test_checkpoint_e2e.py b/tests/test_checkpoint_e2e.py new file mode 100644 index 0000000..09ff441 --- /dev/null +++ b/tests/test_checkpoint_e2e.py @@ -0,0 +1,112 @@ +"""End-to-end: drive a real agent.run() conversation where the LLM calls Write, +and verify the checkpoint hook intercepts the call and files a backup to disk. + +Only the LLM provider is mocked (via monkeypatching agent.stream). The Write +tool, checkpoint hooks and checkpoint store all run for real against tmp_path. +""" +from __future__ import annotations + +import pytest + +import tools as _tools_init # noqa: F401 - force built-in tool registration +from agent import AgentState, run +from providers import AssistantTurn +from checkpoint import hooks as checkpoint_hooks +from checkpoint import store as checkpoint_store + + +def _scripted_stream(turns): + cursor = iter(turns) + + def fake_stream(**_kwargs): + spec = next(cursor) + yield AssistantTurn( + text=spec.get("text", ""), + tool_calls=spec.get("tool_calls") or [], + in_tokens=1, out_tokens=1, + ) + + return fake_stream + + +@pytest.fixture +def sandboxed_checkpoints(tmp_path, monkeypatch): + """Run checkpoint store against tmp_path and install hooks on built-in tools.""" + monkeypatch.setattr( + checkpoint_store, "_checkpoints_root", lambda: tmp_path / ".checkpoints" + ) + checkpoint_store.reset_file_versions() + checkpoint_hooks.set_session("e2e-session") + checkpoint_hooks.reset_tracked() + checkpoint_hooks.install_hooks() + yield tmp_path + checkpoint_hooks.reset_tracked() + + +def test_llm_write_triggers_checkpoint_backup(monkeypatch, sandboxed_checkpoints): + """When the LLM calls Write, the checkpoint hook must back the pre-edit file up. + + Pre-populate a small file, then let the LLM overwrite it via the Write + tool. The hook should copy the old content into checkpoints/.../backups/ + before the Write executes, so the backup holds the original bytes. + """ + target = sandboxed_checkpoints / "hello.py" + target.write_text("print('before')\n", encoding="utf-8") + + turns = [ + {"tool_calls": [{ + "id": "w1", + "name": "Write", + "input": {"file_path": str(target), "content": "print('after')\n"}, + }]}, + {"text": "done"}, + ] + monkeypatch.setattr("agent.stream", _scripted_stream(turns)) + + state = AgentState() + config = {"model": "test", "permission_mode": "accept-all", + "_session_id": "e2e-session"} + list(run("overwrite the file", state, config, "system prompt")) + + # After the turn: Write applied the new content + assert target.read_text(encoding="utf-8") == "print('after')\n" + + # And the checkpoint hook filed a backup with the pre-edit content + backups_dir = sandboxed_checkpoints / ".checkpoints" / "e2e-session" / "backups" + backups = list(backups_dir.iterdir()) + assert backups, "checkpoint hook did not create a backup file" + assert any(b.read_text(encoding="utf-8") == "print('before')\n" for b in backups) + + +def test_oversized_write_logs_to_stderr_not_stdout( + monkeypatch, sandboxed_checkpoints, capfd +): + """Over the _MAX_FILE_SIZE threshold the hook skips + logs — to stderr only. + + This is the actual user-visible contract of PR #47: checkpoint skips must + not pollute stdout (which carries the conversation transcript), they must + land on stderr where operators look. + """ + monkeypatch.setattr(checkpoint_store, "_MAX_FILE_SIZE", 20) + big = sandboxed_checkpoints / "big.py" + big.write_text("x" * 100, encoding="utf-8") + + turns = [ + {"tool_calls": [{ + "id": "w1", + "name": "Write", + "input": {"file_path": str(big), "content": "y" * 100}, + }]}, + {"text": "ok"}, + ] + monkeypatch.setattr("agent.stream", _scripted_stream(turns)) + + state = AgentState() + list(run("rewrite", state, {"model": "test", "permission_mode": "accept-all", + "_session_id": "e2e-session", + "disabled_tools": ["Agent"]}, + "sys")) + + out, errtxt = capfd.readouterr() + assert "[checkpoint] skipping large file" in errtxt + assert "[checkpoint] skipping large file" not in out diff --git a/tests/test_checkpoint_store.py b/tests/test_checkpoint_store.py new file mode 100644 index 0000000..13160a2 --- /dev/null +++ b/tests/test_checkpoint_store.py @@ -0,0 +1,61 @@ +"""Integration tests for checkpoint store: stderr capture + large file skip.""" +from __future__ import annotations + +import pytest + +import checkpoint.store as store + + +@pytest.fixture(autouse=True) +def isolate_store(tmp_path, monkeypatch): + """Redirect checkpoint root to tmp_path and reset global state.""" + monkeypatch.setattr(store, "_checkpoints_root", lambda: tmp_path / "checkpoints") + store.reset_file_versions() + + +def test_large_file_skipped_and_logged_to_stderr(tmp_path, monkeypatch, capsys): + monkeypatch.setattr(store, "_MAX_FILE_SIZE", 50) + big_file = tmp_path / "big.txt" + big_file.write_bytes(b"x" * 100) + + result = store.track_file_edit("test-session", str(big_file)) + + assert result is None + captured = capsys.readouterr() + assert "[checkpoint] skipping large file" in captured.err + assert "100 bytes" in captured.err + assert captured.out == "" + + +def test_normal_file_backed_up(tmp_path, capsys): + small_file = tmp_path / "small.txt" + content = b"hello world" + small_file.write_bytes(content) + + result = store.track_file_edit("test-session", str(small_file)) + + assert result is not None + backup_dir = tmp_path / "checkpoints" / "test-session" / "backups" + backup_path = backup_dir / result + assert backup_path.exists() + assert backup_path.read_bytes() == content + captured = capsys.readouterr() + assert captured.err == "" + + +def test_backup_failure_logged_to_stderr(tmp_path, monkeypatch, capsys): + normal_file = tmp_path / "normal.txt" + normal_file.write_bytes(b"some data") + + def failing_copy(*args, **kwargs): + raise PermissionError("access denied") + + monkeypatch.setattr(store.shutil, "copy2", failing_copy) + + result = store.track_file_edit("test-session", str(normal_file)) + + assert result is None + captured = capsys.readouterr() + assert "[checkpoint] backup failed" in captured.err + assert "access denied" in captured.err + assert captured.out == "" diff --git a/tests/test_coercion.py b/tests/test_coercion.py new file mode 100644 index 0000000..3f54dcd --- /dev/null +++ b/tests/test_coercion.py @@ -0,0 +1,94 @@ +"""Tests for parameter type coercion (coercion.py).""" +from coercion import coerce_params, _coerce_bool + + +class TestCoerceParams: + def test_int_coercion(self): + schema = {"properties": {"limit": {"type": "integer"}}} + assert coerce_params({"limit": "42"}, schema) == {"limit": 42} + + def test_float_coercion(self): + schema = {"properties": {"rate": {"type": "number"}}} + assert coerce_params({"rate": "3.14"}, schema) == {"rate": 3.14} + + def test_bool_true_variants(self): + schema = {"properties": {"flag": {"type": "boolean"}}} + for val in ("true", "True", "TRUE", "1", "yes", "Yes"): + assert coerce_params({"flag": val}, schema) == {"flag": True}, f"Failed for {val!r}" + + def test_bool_false_variants(self): + schema = {"properties": {"flag": {"type": "boolean"}}} + for val in ("false", "False", "FALSE", "0", "no", "No"): + assert coerce_params({"flag": val}, schema) == {"flag": False}, f"Failed for {val!r}" + + def test_bool_garbage_returns_original(self): + """Unknown boolean-like strings must pass through for tool handler to report.""" + schema = {"properties": {"flag": {"type": "boolean"}}} + result = coerce_params({"flag": "banana"}, schema) + assert result == {"flag": "banana"} + + def test_array_coercion(self): + schema = {"properties": {"items": {"type": "array"}}} + result = coerce_params({"items": '["a","b"]'}, schema) + assert result == {"items": ["a", "b"]} + + def test_object_coercion(self): + schema = {"properties": {"meta": {"type": "object"}}} + result = coerce_params({"meta": '{"k": 1}'}, schema) + assert result == {"meta": {"k": 1}} + + def test_passthrough_string(self): + schema = {"properties": {"name": {"type": "string"}}} + assert coerce_params({"name": "hello"}, schema) == {"name": "hello"} + + def test_invalid_json_passthrough(self): + schema = {"properties": {"items": {"type": "array"}}} + assert coerce_params({"items": "not-json"}, schema) == {"items": "not-json"} + + def test_invalid_int_passthrough(self): + schema = {"properties": {"limit": {"type": "integer"}}} + assert coerce_params({"limit": "abc"}, schema) == {"limit": "abc"} + + def test_unknown_prop_passthrough(self): + schema = {"properties": {}} + assert coerce_params({"x": "y"}, schema) == {"x": "y"} + + def test_non_string_passthrough(self): + """Already-typed values must not be touched.""" + schema = {"properties": {"limit": {"type": "integer"}}} + assert coerce_params({"limit": 42}, schema) == {"limit": 42} + + def test_input_schema_style(self): + """Anthropic-style schemas with input_schema must be handled.""" + schema = { + "name": "receiver", + "input_schema": { + "type": "object", + "properties": { + "count": {"type": "integer"}, + "msg": {"type": "string"}, + }, + }, + } + result = coerce_params({"count": "5", "msg": "hi"}, schema) + assert result == {"count": 5, "msg": "hi"} + + def test_empty_schema(self): + assert coerce_params({"x": "y"}, {}) == {"x": "y"} + + +class TestCoerceBool: + def test_true_values(self): + assert _coerce_bool("true") is True + assert _coerce_bool("1") is True + assert _coerce_bool("yes") is True + + def test_false_values(self): + assert _coerce_bool("false") is False + assert _coerce_bool("0") is False + assert _coerce_bool("no") is False + + def test_garbage_returns_original(self): + assert _coerce_bool("banana") == "banana" + assert _coerce_bool("maybe") == "maybe" + assert _coerce_bool("") == "" diff --git a/tests/test_tool_scheduling.py b/tests/test_tool_scheduling.py new file mode 100644 index 0000000..8b322dd --- /dev/null +++ b/tests/test_tool_scheduling.py @@ -0,0 +1,186 @@ +"""Tests for tool scheduling (depends_on, tool_call_alias) and ID uniquification.""" +import pytest + +from tool_registry import ( + get_tool_schemas, + execute_tool, + register_tool, + clear_registry, + ToolDef, + _SCHEDULING_PROPS, +) +from id_uniquify import uniquify_tool_call_ids, _collect_used_ids + +import tools # noqa: F401 + + +@pytest.fixture(autouse=True) +def _ensure_builtins(): + """Guarantee builtins are registered even after another test cleared the registry.""" + clear_registry() + tools._register_builtins() + yield + clear_registry() + + +class TestSchedulingPropsInjection: + def test_schemas_contain_scheduling_fields(self): + schemas = get_tool_schemas() + assert len(schemas) > 0 + for s in schemas: + # Handle both schema styles + props = s.get("properties") or s.get("input_schema", {}).get("properties", {}) + assert "tool_call_alias" in props, f"Missing tool_call_alias in {s.get('name')}" + assert "depends_on" in props, f"Missing depends_on in {s.get('name')}" + + def test_scheduling_props_have_correct_types(self): + schemas = get_tool_schemas() + s = schemas[0] + props = s.get("properties") or s.get("input_schema", {}).get("properties", {}) + assert props["tool_call_alias"]["type"] == "string" + assert props["depends_on"]["type"] == "array" + + def test_original_schema_not_mutated(self): + """Verify deepcopy prevents mutation of registered schemas.""" + schemas1 = get_tool_schemas() + s1 = schemas1[0] + props1 = s1.get("properties") or s1.get("input_schema", {}).get("properties", {}) + props1["tool_call_alias"]["EXTRA"] = True + schemas2 = get_tool_schemas() + s2 = schemas2[0] + props2 = s2.get("properties") or s2.get("input_schema", {}).get("properties", {}) + assert "EXTRA" not in props2["tool_call_alias"] + + def test_input_schema_style_gets_scheduling_in_right_place(self): + """Built-in tools use input_schema; scheduling props must land there.""" + schemas = get_tool_schemas() + read_schema = next(s for s in schemas if s["name"] == "Read") + # Must be inside input_schema.properties, NOT top-level properties + assert "tool_call_alias" in read_schema["input_schema"]["properties"] + assert "depends_on" in read_schema["input_schema"]["properties"] + assert "properties" not in read_schema or "tool_call_alias" not in read_schema.get("properties", {}) + + +class TestExecuteToolStripsScheduling: + def setup_method(self): + self._received = {} + + def _handler(params, config=None): + self._received = dict(params) + return "ok" + + register_tool(ToolDef( + name="test_sched_tool", + schema={ + "name": "test_sched_tool", + "description": "test tool", + "input_schema": { + "type": "object", + "properties": {"msg": {"type": "string"}}, + }, + }, + func=_handler, + read_only=True, + )) + + def test_scheduling_params_stripped(self): + execute_tool( + "test_sched_tool", + {"msg": "hi", "tool_call_alias": "t1", "depends_on": ["w1"]}, + config={}, + ) + assert "tool_call_alias" not in self._received + assert "depends_on" not in self._received + assert self._received.get("msg") == "hi" + + +class TestIdUniquify: + """Tests for tool_call_id collision prevention.""" + + def _make_state(self, messages=None): + """Create a minimal state object with messages and turn_count.""" + class _State: + pass + s = _State() + s.messages = messages or [] + s.turn_count = 2 + return s + + def test_fresh_ids_pass_through(self): + state = self._make_state() + tcs = [{"id": "r1", "name": "Read", "input": {}}] + remap = uniquify_tool_call_ids(tcs, state) + assert remap == {} + assert tcs[0]["id"] == "r1" + + def test_colliding_id_remapped(self): + state = self._make_state([ + {"role": "assistant", "content": "", "tool_calls": [{"id": "r1", "name": "Read", "input": {}}]}, + {"role": "tool", "tool_call_id": "r1", "name": "Read", "content": "data"}, + ]) + tcs = [{"id": "r1", "name": "Read", "input": {}}] + remap = uniquify_tool_call_ids(tcs, state) + assert remap == {"r1": "t2_r1"} + assert tcs[0]["id"] == "t2_r1" + + def test_depends_on_rewritten(self): + state = self._make_state([ + {"role": "assistant", "content": "", "tool_calls": [ + {"id": "w1", "name": "Write", "input": {}}, + {"id": "r1", "name": "Read", "input": {}}, + ]}, + {"role": "tool", "tool_call_id": "w1", "name": "Write", "content": "ok"}, + {"role": "tool", "tool_call_id": "r1", "name": "Read", "content": "data"}, + ]) + tcs = [ + {"id": "w1", "name": "Write", "input": {}}, + {"id": "r1", "name": "Read", "input": {"depends_on": ["w1"]}}, + ] + remap = uniquify_tool_call_ids(tcs, state) + assert tcs[0]["id"] == "t2_w1" + assert tcs[1]["id"] == "t2_r1" + # depends_on must be rewritten to match the new w1 id + assert tcs[1]["input"]["depends_on"] == ["t2_w1"] + + def test_multiple_collisions_get_numeric_suffix(self): + state = self._make_state([ + {"role": "assistant", "content": "", "tool_calls": [{"id": "r1", "name": "Read", "input": {}}]}, + {"role": "tool", "tool_call_id": "r1", "name": "Read", "content": ""}, + {"role": "assistant", "content": "", "tool_calls": [{"id": "t2_r1", "name": "Read", "input": {}}]}, + {"role": "tool", "tool_call_id": "t2_r1", "name": "Read", "content": ""}, + ]) + tcs = [{"id": "r1", "name": "Read", "input": {}}] + remap = uniquify_tool_call_ids(tcs, state) + # t2_r1 is taken, so should get t2_r1_2 + assert tcs[0]["id"] == "t2_r1_2" + + def test_collect_used_ids(self): + state = self._make_state([ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "", "tool_calls": [ + {"id": "a1", "name": "Read", "input": {}}, + {"id": "a2", "name": "Write", "input": {}}, + ]}, + {"role": "tool", "tool_call_id": "a1", "name": "Read", "content": "x"}, + {"role": "tool", "tool_call_id": "a2", "name": "Write", "content": "y"}, + ]) + used = _collect_used_ids(state) + assert used == {"a1", "a2"} + + def test_empty_tool_calls(self): + state = self._make_state() + assert uniquify_tool_call_ids([], state) == {} + + def test_mixed_fresh_and_colliding(self): + state = self._make_state([ + {"role": "assistant", "content": "", "tool_calls": [{"id": "r1", "name": "Read", "input": {}}]}, + {"role": "tool", "tool_call_id": "r1", "name": "Read", "content": ""}, + ]) + tcs = [ + {"id": "r1", "name": "Read", "input": {}}, + {"id": "w1", "name": "Write", "input": {"depends_on": ["r1"]}}, + ] + remap = uniquify_tool_call_ids(tcs, state) + assert tcs[0]["id"] == "t2_r1" # remapped + assert tcs[1]["id"] == "w1" # fresh, untouched + assert tcs[1]["input"]["depends_on"] == ["t2_r1"] # ref rewritten diff --git a/tests/test_tool_scheduling_e2e.py b/tests/test_tool_scheduling_e2e.py new file mode 100644 index 0000000..cb1230a --- /dev/null +++ b/tests/test_tool_scheduling_e2e.py @@ -0,0 +1,98 @@ +"""End-to-end: the LLM sees `tool_call_alias` + `depends_on` in every tool +schema, it uses them in a tool call, and the stripping wrapper removes those +fields before the tool handler runs. + +Only the LLM provider is mocked (via monkeypatching agent.stream). The tool +registry, schema injection and param stripping all run for real. +""" +from __future__ import annotations + +import tools as _tools_init # noqa: F401 - force built-in tool registration +from agent import AgentState, run + + +def test_schemas_sent_to_llm_include_scheduling_props(monkeypatch, receiver_tool, scripted_stream): + """Every schema the LLM sees must carry tool_call_alias + depends_on.""" + captured = [] + monkeypatch.setattr( + "agent.stream", + scripted_stream(captured, [{"text": "nothing to do"}]), + ) + + list(run("hi", AgentState(), {"model": "test", "permission_mode": "accept-all", + "_session_id": "sch", "disabled_tools": ["Agent"]}, + "sys")) + + assert captured, "stream was not called" + for schema in captured[0]: + props = schema.get("properties") or schema.get("input_schema", {}).get("properties", {}) + assert "tool_call_alias" in props, f"{schema.get('name')} missing tool_call_alias" + assert "depends_on" in props, f"{schema.get('name')} missing depends_on" + + +def test_scheduling_params_stripped_before_reaching_tool(monkeypatch, receiver_tool): + """tool_call_alias + depends_on must be gone by the time the handler runs.""" + captured_schemas = [] + turns = [ + {"tool_calls": [{ + "id": "r1", + "name": "receiver", + "input": { + "msg": "hello", + "tool_call_alias": "step-1", + "depends_on": ["w1", "w2"], + }, + }]}, + {"text": "done"}, + ] + monkeypatch.setattr("agent.stream", scripted_stream(captured_schemas, turns)) + + list(run("go", AgentState(), {"model": "test", "permission_mode": "accept-all", + "_session_id": "sch2", "disabled_tools": ["Agent"]}, + "sys")) + + assert "seen" in receiver_tool, "receiver handler was never called" + seen = receiver_tool["seen"] + assert seen.get("msg") == "hello" + assert "tool_call_alias" not in seen + assert "depends_on" not in seen + + +def test_id_reuse_across_turns_gets_remapped(monkeypatch, receiver_tool): + """When LLM reuses an id from a prior turn, uniquify rewrites it.""" + captured_schemas = [] + turns = [ + {"tool_calls": [{ + "id": "r1", + "name": "receiver", + "input": {"msg": "turn1"}, + }]}, + {"tool_calls": [{ + "id": "r1", + "name": "receiver", + "input": {"msg": "turn2"}, + }]}, + {"text": "done"}, + ] + monkeypatch.setattr("agent.stream", scripted_stream(captured_schemas, turns)) + + state = AgentState() + events = list(run("go", state, {"model": "test", "permission_mode": "accept-all", + "_session_id": "sch3", "disabled_tools": ["Agent"]}, + "sys")) + + assistant_ids = [] + for msg in state.messages: + if msg["role"] == "assistant": + for tc in msg.get("tool_calls") or []: + assistant_ids.append(tc["id"]) + + assert len(assistant_ids) == 2, f"Expected 2 tool calls, got {assistant_ids}" + assert assistant_ids[0] != assistant_ids[1], ( + f"IDs must be unique across turns but got duplicates: {assistant_ids}" + ) + + tool_results = [m for m in state.messages if m["role"] == "tool"] + assert len(tool_results) == 2 + assert tool_results[0]["tool_call_id"] == assistant_ids[0] + assert tool_results[1]["tool_call_id"] == assistant_ids[1] diff --git a/tool_registry.py b/tool_registry.py index f0a66c2..9c44cca 100644 --- a/tool_registry.py +++ b/tool_registry.py @@ -136,3 +136,62 @@ def execute_tool( def clear_registry() -> None: """Remove all registered tools. Intended for testing.""" _registry.clear() + + +# ── Tool scheduling support ──────────────────────────────────────────────── + +import copy as _copy + +_SCHEDULING_PROPS = { + "tool_call_alias": { + "type": "string", + "description": ( + "Optional alias for this tool call. " + "Other tools can reference it in depends_on." + ), + }, + "depends_on": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "List of tool_call IDs or aliases that must complete before this tool runs." + ), + }, +} + + +# Wrap get_tool_schemas to inject scheduling properties +_orig_get_tool_schemas = get_tool_schemas + + +def get_tool_schemas(): + """Return tool schemas with scheduling properties injected. + + Handles both schema styles: + - Top-level ``properties`` (rare, e.g. test fixtures) + - Anthropic-style ``input_schema.properties`` (all built-in tools) + """ + schemas = _orig_get_tool_schemas() + result = [] + for s in schemas: + s = _copy.deepcopy(s) + # Detect where properties live + if "input_schema" in s: + props = s["input_schema"].setdefault("properties", {}) + else: + props = s.setdefault("properties", {}) + for k, v in _SCHEDULING_PROPS.items(): + props.setdefault(k, _copy.deepcopy(v)) + result.append(s) + return result + + +# Wrap execute_tool to strip scheduling params and coerce types +_orig_execute_tool = execute_tool + + +def execute_tool(name, params, *args, **kwargs): + """Execute a tool after stripping scheduling params.""" + clean = {k: v for k, v in params.items() if k not in _SCHEDULING_PROPS} + + return _orig_execute_tool(name, clean, *args, **kwargs)