diff --git a/agent.py b/agent.py index dba7fe2..36cc1ab 100644 --- a/agent.py +++ b/agent.py @@ -31,6 +31,8 @@ class AgentState: messages: list = field(default_factory=list) total_input_tokens: int = 0 total_output_tokens: int = 0 + total_cache_read_tokens: int = 0 + total_cache_write_tokens: int = 0 turn_count: int = 0 @@ -181,6 +183,8 @@ def run( state.total_input_tokens += assistant_turn.in_tokens state.total_output_tokens += assistant_turn.out_tokens + state.total_cache_read_tokens += getattr(assistant_turn, 'cache_read_tokens', 0) + state.total_cache_write_tokens += getattr(assistant_turn, 'cache_write_tokens', 0) yield TurnDone(assistant_turn.in_tokens, assistant_turn.out_tokens) if not assistant_turn.tool_calls: diff --git a/checkpoint/store.py b/checkpoint/store.py index 9474720..ec770fc 100644 --- a/checkpoint/store.py +++ b/checkpoint/store.py @@ -182,6 +182,8 @@ def make_snapshot( token_snapshot={ "input": getattr(state, "total_input_tokens", 0), "output": getattr(state, "total_output_tokens", 0), + "cache_read": getattr(state, "total_cache_read_tokens", 0), + "cache_write": getattr(state, "total_cache_write_tokens", 0), }, file_backups=new_backups, ) diff --git a/providers.py b/providers.py index fbac5bc..0781ceb 100644 --- a/providers.py +++ b/providers.py @@ -467,11 +467,14 @@ def __init__(self, text): self.text = text class AssistantTurn: """Completed assistant turn with text + tool_calls.""" - def __init__(self, text, tool_calls, in_tokens, out_tokens): - self.text = text - self.tool_calls = tool_calls # list of {id, name, input} - self.in_tokens = in_tokens - self.out_tokens = out_tokens + def __init__(self, text, tool_calls, in_tokens, out_tokens, + cache_read_tokens=0, cache_write_tokens=0): + self.text = text + self.tool_calls = tool_calls # list of {id, name, input} + self.in_tokens = in_tokens + self.out_tokens = out_tokens + self.cache_read_tokens = cache_read_tokens + self.cache_write_tokens = cache_write_tokens def stream_anthropic( @@ -524,13 +527,42 @@ def stream_anthropic( "input": block.input, }) + cache_r, cache_w = _anthropic_cache_tokens(final.usage) yield AssistantTurn( text, tool_calls, final.usage.input_tokens, final.usage.output_tokens, + cache_read_tokens=cache_r, + cache_write_tokens=cache_w, ) +def _anthropic_cache_tokens(usage) -> tuple[int, int]: + """Extract (cache_read, cache_write) token counts from an Anthropic usage object. + + Returns (0, 0) if the fields are missing -- older Anthropic SDKs, non-cached + calls and most downstream wrappers (e.g. Bedrock over litellm) all fall + through to this default rather than raising AttributeError. + """ + read = getattr(usage, "cache_read_input_tokens", 0) or 0 + write = getattr(usage, "cache_creation_input_tokens", 0) or 0 + return int(read), int(write) + + +def _openai_cached_read_tokens(usage) -> int: + """Extract the OpenAI-compatible cached read-token count. + + OpenAI-compatible providers surface cache hits as + `usage.prompt_tokens_details.cached_tokens`; there is no separate + "cache creation" counter in the OpenAI schema (caching is implicit on + their side), so the write-side is always 0 for this family of providers. + """ + details = getattr(usage, "prompt_tokens_details", None) + if details is None: + return 0 + return int(getattr(details, "cached_tokens", 0) or 0) + + def stream_openai_compat( api_key: str, base_url: str, @@ -584,6 +616,7 @@ def stream_openai_compat( text = "" tool_buf: dict = {} # index → {id, name, args_str} in_tok = out_tok = 0 + cache_read_tok = cache_write_tok = 0 stream = client.chat.completions.create(**kwargs) for chunk in stream: @@ -592,6 +625,7 @@ def stream_openai_compat( if hasattr(chunk, "usage") and chunk.usage: in_tok = chunk.usage.prompt_tokens out_tok = chunk.usage.completion_tokens + cache_read_tok = _openai_cached_read_tokens(chunk.usage) or cache_read_tok continue choice = chunk.choices[0] @@ -622,6 +656,7 @@ def stream_openai_compat( if hasattr(chunk, "usage") and chunk.usage: in_tok = chunk.usage.prompt_tokens or in_tok out_tok = chunk.usage.completion_tokens or out_tok + cache_read_tok = _openai_cached_read_tokens(chunk.usage) or cache_read_tok tool_calls = [] for idx in sorted(tool_buf): @@ -635,7 +670,7 @@ def stream_openai_compat( tc_entry["extra_content"] = v["extra_content"] tool_calls.append(tc_entry) - yield AssistantTurn(text, tool_calls, in_tok, out_tok) + yield AssistantTurn(text, tool_calls, in_tok, out_tok, cache_read_tok, cache_write_tok) def stream_ollama( @@ -750,7 +785,7 @@ def _make_request(p): # Ollama doesn't return exact token counts via livestream easily until "done", # but we can do a rough estimate or 0, cheetahclaws handles zero gracefully - yield AssistantTurn(text, tool_calls, 0, 0) + yield AssistantTurn(text, tool_calls, 0, 0, 0, 0) def stream( @@ -823,7 +858,9 @@ def stream( breaker.record_success() _log.info("api_call_done", session_id=session_id, provider=provider_name, model=model_name, - in_tokens=event.in_tokens, out_tokens=event.out_tokens) + in_tokens=event.in_tokens, out_tokens=event.out_tokens, + cache_read_tokens=getattr(event, 'cache_read_tokens', 0), + cache_write_tokens=getattr(event, 'cache_write_tokens', 0)) yield event except Exception as exc: breaker.record_failure() diff --git a/tests/test_cache_tokens.py b/tests/test_cache_tokens.py new file mode 100644 index 0000000..125ae76 --- /dev/null +++ b/tests/test_cache_tokens.py @@ -0,0 +1,200 @@ +"""End-to-end coverage for cache-token tracking. + +Layers covered: +1. The AssistantTurn carries cache_read / cache_write fields (unit). +2. AgentState accumulates them across turns (unit). +3. Checkpoint snapshots persist them (unit, real make_snapshot on tmp_path). +4. Provider extraction helpers work against synthetic usage objects for each + supported family (Anthropic, OpenAI-compatible, Ollama). +5. E2E: agent.run drains a mocked provider stream that emits an AssistantTurn + with cache tokens, and state + checkpoint see the totals. +""" +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + + +# ---------- 1 & 2: AssistantTurn + AgentState ---------- + +def test_assistant_turn_has_cache_fields(): + from providers import AssistantTurn + turn = AssistantTurn( + text="hello", tool_calls=[], in_tokens=100, out_tokens=50, + cache_read_tokens=80, cache_write_tokens=20, + ) + assert turn.cache_read_tokens == 80 + assert turn.cache_write_tokens == 20 + + +def test_assistant_turn_cache_defaults_zero(): + """Older providers and ad-hoc callers construct AssistantTurn without cache fields.""" + from providers import AssistantTurn + turn = AssistantTurn(text="hi", tool_calls=[], in_tokens=10, out_tokens=5) + assert turn.cache_read_tokens == 0 + assert turn.cache_write_tokens == 0 + + +def test_agent_state_accumulates_cache_tokens(): + from agent import AgentState + state = AgentState() + assert (state.total_cache_read_tokens, state.total_cache_write_tokens) == (0, 0) + + state.total_cache_read_tokens += 80 + state.total_cache_write_tokens += 20 + state.total_cache_read_tokens += 60 + state.total_cache_write_tokens += 10 + + assert state.total_cache_read_tokens == 140 + assert state.total_cache_write_tokens == 30 + + +# ---------- 3: Checkpoint persistence ---------- + +def test_checkpoint_snapshot_includes_cache(tmp_path, monkeypatch): + from checkpoint import store + from agent import AgentState + + monkeypatch.setattr(store, "_checkpoints_root", lambda: tmp_path / ".checkpoints") + store.reset_file_versions() + + state = AgentState() + state.total_input_tokens = 500 + state.total_output_tokens = 200 + state.total_cache_read_tokens = 300 + state.total_cache_write_tokens = 50 + state.turn_count = 3 + state.messages = [{"role": "user", "content": "test"}] + + snap = store.make_snapshot("test-session", state, {}, "hello user") + assert snap.token_snapshot == { + "input": 500, "output": 200, "cache_read": 300, "cache_write": 50, + } + + +# ---------- 4: Provider extraction helpers ---------- + +class TestAnthropicCacheExtraction: + """_anthropic_cache_tokens must read cache_read_input_tokens / cache_creation_input_tokens.""" + + def test_returns_both_when_populated(self): + from providers import _anthropic_cache_tokens + usage = SimpleNamespace( + input_tokens=120, output_tokens=40, + cache_read_input_tokens=77, cache_creation_input_tokens=33, + ) + assert _anthropic_cache_tokens(usage) == (77, 33) + + def test_missing_fields_default_to_zero(self): + """Older Anthropic SDKs and Bedrock-over-litellm wrappers omit the cache fields.""" + from providers import _anthropic_cache_tokens + usage = SimpleNamespace(input_tokens=10, output_tokens=5) + assert _anthropic_cache_tokens(usage) == (0, 0) + + def test_none_fields_coerced_to_zero(self): + """Anthropic occasionally returns None (JSON null) rather than omitting the field.""" + from providers import _anthropic_cache_tokens + usage = SimpleNamespace( + input_tokens=10, output_tokens=5, + cache_read_input_tokens=None, cache_creation_input_tokens=None, + ) + assert _anthropic_cache_tokens(usage) == (0, 0) + + +class TestOpenAICacheExtraction: + """_openai_cached_read_tokens must walk prompt_tokens_details.cached_tokens.""" + + def test_reads_cached_tokens_from_details(self): + from providers import _openai_cached_read_tokens + usage = SimpleNamespace( + prompt_tokens=100, completion_tokens=50, + prompt_tokens_details=SimpleNamespace(cached_tokens=42), + ) + assert _openai_cached_read_tokens(usage) == 42 + + def test_missing_details_returns_zero(self): + from providers import _openai_cached_read_tokens + usage = SimpleNamespace(prompt_tokens=100, completion_tokens=50) + assert _openai_cached_read_tokens(usage) == 0 + + def test_none_cached_tokens_returns_zero(self): + from providers import _openai_cached_read_tokens + usage = SimpleNamespace( + prompt_tokens=100, completion_tokens=50, + prompt_tokens_details=SimpleNamespace(cached_tokens=None), + ) + assert _openai_cached_read_tokens(usage) == 0 + + +def test_ollama_stream_never_reports_cache_tokens(): + """Ollama has no prompt-caching; the path must yield 0/0 without raising.""" + from providers import AssistantTurn + # stream_ollama yields AssistantTurn(text, tool_calls, 0, 0, 0, 0) -- we can't + # reach the full HTTP call in a unit test, but we can assert the shape of the + # yielded object the callers rely on. + turn = AssistantTurn("hi", [], 0, 0, 0, 0) + assert turn.cache_read_tokens == 0 + assert turn.cache_write_tokens == 0 + + +# ---------- 5: End-to-end through agent.run ---------- + +def test_agent_run_propagates_cache_tokens_from_mocked_stream(monkeypatch, tmp_path): + """Drive agent.run once with a scripted stream and assert totals + snapshot.""" + import tools as _tools_init # noqa: F401 - register tools + from agent import AgentState, run + from providers import AssistantTurn + from checkpoint import store as ck_store + + monkeypatch.setattr(ck_store, "_checkpoints_root", lambda: tmp_path / ".checkpoints") + ck_store.reset_file_versions() + + def fake_stream(**_kwargs): + yield AssistantTurn( + text="all good", tool_calls=[], + in_tokens=1000, out_tokens=200, + cache_read_tokens=700, cache_write_tokens=50, + ) + + monkeypatch.setattr("agent.stream", fake_stream) + + state = AgentState() + list(run("hello", state, { + "model": "test", "permission_mode": "accept-all", + "_session_id": "cache_e2e", "disabled_tools": ["Agent"], + }, "sys")) + + assert state.total_cache_read_tokens == 700 + assert state.total_cache_write_tokens == 50 + + snap = ck_store.make_snapshot("cache_e2e", state, {}, "hello") + assert snap.token_snapshot["cache_read"] == 700 + assert snap.token_snapshot["cache_write"] == 50 + + +def test_agent_run_accumulates_cache_across_multi_turn(monkeypatch): + """Two consecutive agent.run calls must sum their cache counters in state.""" + import tools as _tools_init # noqa: F401 + from agent import AgentState, run + from providers import AssistantTurn + + emitted = iter([ + AssistantTurn("one", [], 100, 50, cache_read_tokens=40, cache_write_tokens=10), + AssistantTurn("two", [], 120, 60, cache_read_tokens=90, cache_write_tokens=0), + ]) + + def fake_stream(**_kwargs): + yield next(emitted) + + monkeypatch.setattr("agent.stream", fake_stream) + + state = AgentState() + cfg = {"model": "test", "permission_mode": "accept-all", + "_session_id": "multi", "disabled_tools": ["Agent"]} + + list(run("first", state, cfg, "sys")) + list(run("second", state, cfg, "sys")) + + assert state.total_cache_read_tokens == 130 + assert state.total_cache_write_tokens == 10 diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 16751c3..e0c1f82 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -456,3 +456,7 @@ def test_throttle_conversation_rewind_works(self, tmp_home): state.turn_count = snap1.turn_count assert len(state.messages) == 0 assert state.turn_count == 0 + + +# Cache-token coverage lives in tests/test_cache_tokens.py so this module +# stays focused on snapshot / restore / file-backup behaviour.