From 5a11c541b44ba54513dd14c4694e65208f65918f Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Sat, 18 Apr 2026 09:52:02 +0200 Subject: [PATCH 1/5] wip: add cache token fields to AssistantTurn --- providers.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/providers.py b/providers.py index fbac5bc..7e5340d 100644 --- a/providers.py +++ b/providers.py @@ -467,11 +467,14 @@ def __init__(self, text): self.text = text class AssistantTurn: """Completed assistant turn with text + tool_calls.""" - def __init__(self, text, tool_calls, in_tokens, out_tokens): - self.text = text - self.tool_calls = tool_calls # list of {id, name, input} - self.in_tokens = in_tokens - self.out_tokens = out_tokens + def __init__(self, text, tool_calls, in_tokens, out_tokens, + cache_read_tokens=0, cache_creation_tokens=0): + self.text = text + self.tool_calls = tool_calls # list of {id, name, input} + self.in_tokens = in_tokens + self.out_tokens = out_tokens + self.cache_read_tokens = cache_read_tokens + self.cache_creation_tokens = cache_creation_tokens def stream_anthropic( @@ -528,6 +531,8 @@ def stream_anthropic( text, tool_calls, final.usage.input_tokens, final.usage.output_tokens, + cache_read_tokens=getattr(final.usage, "cache_read_input_tokens", 0) or 0, + cache_creation_tokens=getattr(final.usage, "cache_creation_input_tokens", 0) or 0, ) @@ -584,6 +589,7 @@ def stream_openai_compat( text = "" tool_buf: dict = {} # index → {id, name, args_str} in_tok = out_tok = 0 + cache_read_tok = cache_creation_tok = 0 stream = client.chat.completions.create(**kwargs) for chunk in stream: @@ -592,6 +598,9 @@ def stream_openai_compat( if hasattr(chunk, "usage") and chunk.usage: in_tok = chunk.usage.prompt_tokens out_tok = chunk.usage.completion_tokens + _details = getattr(chunk.usage, "prompt_tokens_details", None) + if _details: + cache_read_tok = getattr(_details, "cached_tokens", 0) or 0 continue choice = chunk.choices[0] @@ -622,6 +631,9 @@ def stream_openai_compat( if hasattr(chunk, "usage") and chunk.usage: in_tok = chunk.usage.prompt_tokens or in_tok out_tok = chunk.usage.completion_tokens or out_tok + _details = getattr(chunk.usage, "prompt_tokens_details", None) + if _details: + cache_read_tok = getattr(_details, "cached_tokens", 0) or cache_read_tok tool_calls = [] for idx in sorted(tool_buf): @@ -635,7 +647,7 @@ def stream_openai_compat( tc_entry["extra_content"] = v["extra_content"] tool_calls.append(tc_entry) - yield AssistantTurn(text, tool_calls, in_tok, out_tok) + yield AssistantTurn(text, tool_calls, in_tok, out_tok, cache_read_tok, cache_creation_tok) def stream_ollama( From d0035af306178fb5875d526322672982f7664baf Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Sat, 18 Apr 2026 09:53:56 +0200 Subject: [PATCH 2/5] feat: end-to-end cache token tracking (providers -> state -> checkpoint) --- agent.py | 2 ++ checkpoint/store.py | 2 ++ providers.py | 16 ++++++----- tests/test_checkpoint.py | 61 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 74 insertions(+), 7 deletions(-) diff --git a/agent.py b/agent.py index dba7fe2..b08f457 100644 --- a/agent.py +++ b/agent.py @@ -181,6 +181,8 @@ def run( state.total_input_tokens += assistant_turn.in_tokens state.total_output_tokens += assistant_turn.out_tokens + state.total_cache_read_tokens += getattr(assistant_turn, 'cache_read_tokens', 0) + state.total_cache_write_tokens += getattr(assistant_turn, 'cache_write_tokens', 0) yield TurnDone(assistant_turn.in_tokens, assistant_turn.out_tokens) if not assistant_turn.tool_calls: diff --git a/checkpoint/store.py b/checkpoint/store.py index 9474720..ec770fc 100644 --- a/checkpoint/store.py +++ b/checkpoint/store.py @@ -182,6 +182,8 @@ def make_snapshot( token_snapshot={ "input": getattr(state, "total_input_tokens", 0), "output": getattr(state, "total_output_tokens", 0), + "cache_read": getattr(state, "total_cache_read_tokens", 0), + "cache_write": getattr(state, "total_cache_write_tokens", 0), }, file_backups=new_backups, ) diff --git a/providers.py b/providers.py index 7e5340d..d02a486 100644 --- a/providers.py +++ b/providers.py @@ -468,13 +468,13 @@ def __init__(self, text): self.text = text class AssistantTurn: """Completed assistant turn with text + tool_calls.""" def __init__(self, text, tool_calls, in_tokens, out_tokens, - cache_read_tokens=0, cache_creation_tokens=0): + cache_read_tokens=0, cache_write_tokens=0): self.text = text self.tool_calls = tool_calls # list of {id, name, input} self.in_tokens = in_tokens self.out_tokens = out_tokens self.cache_read_tokens = cache_read_tokens - self.cache_creation_tokens = cache_creation_tokens + self.cache_write_tokens = cache_write_tokens def stream_anthropic( @@ -532,7 +532,7 @@ def stream_anthropic( final.usage.input_tokens, final.usage.output_tokens, cache_read_tokens=getattr(final.usage, "cache_read_input_tokens", 0) or 0, - cache_creation_tokens=getattr(final.usage, "cache_creation_input_tokens", 0) or 0, + cache_write_tokens=getattr(final.usage, "cache_creation_input_tokens", 0) or 0, ) @@ -589,7 +589,7 @@ def stream_openai_compat( text = "" tool_buf: dict = {} # index → {id, name, args_str} in_tok = out_tok = 0 - cache_read_tok = cache_creation_tok = 0 + cache_read_tok = cache_write_tok = 0 stream = client.chat.completions.create(**kwargs) for chunk in stream: @@ -647,7 +647,7 @@ def stream_openai_compat( tc_entry["extra_content"] = v["extra_content"] tool_calls.append(tc_entry) - yield AssistantTurn(text, tool_calls, in_tok, out_tok, cache_read_tok, cache_creation_tok) + yield AssistantTurn(text, tool_calls, in_tok, out_tok, cache_read_tok, cache_write_tok) def stream_ollama( @@ -762,7 +762,7 @@ def _make_request(p): # Ollama doesn't return exact token counts via livestream easily until "done", # but we can do a rough estimate or 0, cheetahclaws handles zero gracefully - yield AssistantTurn(text, tool_calls, 0, 0) + yield AssistantTurn(text, tool_calls, 0, 0, 0, 0) def stream( @@ -835,7 +835,9 @@ def stream( breaker.record_success() _log.info("api_call_done", session_id=session_id, provider=provider_name, model=model_name, - in_tokens=event.in_tokens, out_tokens=event.out_tokens) + in_tokens=event.in_tokens, out_tokens=event.out_tokens, + cache_read_tokens=getattr(event, 'cache_read_tokens', 0), + cache_write_tokens=getattr(event, 'cache_write_tokens', 0)) yield event except Exception as exc: breaker.record_failure() diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 16751c3..c41c85e 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -456,3 +456,64 @@ def test_throttle_conversation_rewind_works(self, tmp_home): state.turn_count = snap1.turn_count assert len(state.messages) == 0 assert state.turn_count == 0 + + +def test_cache_tokens_in_snapshot(): + """Cache tokens flow from AgentState to checkpoint snapshot.""" + from agent import AgentState + from checkpoint.store import make_snapshot + import inspect + + state = AgentState() + state.total_input_tokens = 100 + state.total_output_tokens = 50 + state.total_cache_read_tokens = 30 + state.total_cache_write_tokens = 20 + state.messages = [] + state.turn_count = 1 + + # Call with correct signature (inspect to be safe) + sig = inspect.signature(make_snapshot) + params = list(sig.parameters.keys()) + if len(params) == 1: + snapshot = make_snapshot(state) + else: + snapshot = make_snapshot(state, 1) + + tokens = snapshot["token_snapshot"] + + assert tokens["cache_read"] == 30, f"Expected 30, got {tokens.get('cache_read')}" + assert tokens["cache_write"] == 20, f"Expected 20, got {tokens.get('cache_write')}" + assert tokens["input"] == 100 + assert tokens["output"] == 50 + + +def test_assistant_turn_cache_tokens(): + """AssistantTurn carries cache token counts from the provider.""" + from providers import AssistantTurn + + turn = AssistantTurn("hello", [], 100, 50, cache_read_tokens=30, cache_write_tokens=20) + assert turn.cache_read_tokens == 30 + assert turn.cache_write_tokens == 20 + + # Defaults to 0 + turn2 = AssistantTurn("hello", [], 100, 50) + assert turn2.cache_read_tokens == 0 + assert turn2.cache_write_tokens == 0 + + +def test_agent_state_accumulates_cache_tokens(): + """AgentState accumulates cache tokens across turns.""" + from agent import AgentState + + state = AgentState() + assert state.total_cache_read_tokens == 0 + assert state.total_cache_write_tokens == 0 + + state.total_cache_read_tokens += 10 + state.total_cache_write_tokens += 5 + state.total_cache_read_tokens += 20 + state.total_cache_write_tokens += 15 + + assert state.total_cache_read_tokens == 30 + assert state.total_cache_write_tokens == 20 From 5b1e1c8e4a6b086a7b9c2745a4621f433d1bc835 Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Sat, 18 Apr 2026 09:57:04 +0200 Subject: [PATCH 3/5] feat: add cache token fields to AgentState + behavioral tests --- agent.py | 2 ++ tests/test_cache_tokens.py | 59 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 tests/test_cache_tokens.py diff --git a/agent.py b/agent.py index b08f457..36cc1ab 100644 --- a/agent.py +++ b/agent.py @@ -31,6 +31,8 @@ class AgentState: messages: list = field(default_factory=list) total_input_tokens: int = 0 total_output_tokens: int = 0 + total_cache_read_tokens: int = 0 + total_cache_write_tokens: int = 0 turn_count: int = 0 diff --git a/tests/test_cache_tokens.py b/tests/test_cache_tokens.py new file mode 100644 index 0000000..afce5cf --- /dev/null +++ b/tests/test_cache_tokens.py @@ -0,0 +1,59 @@ +"""Tests for cache token tracking end-to-end.""" +import dataclasses +import pytest + + +def test_assistant_turn_has_cache_fields(): + """AssistantTurn carries cache_read_tokens and cache_write_tokens.""" + from providers import AssistantTurn + turn = AssistantTurn( + text="hello", tool_calls=[], in_tokens=100, out_tokens=50, + cache_read_tokens=80, cache_write_tokens=20, + ) + assert turn.cache_read_tokens == 80 + assert turn.cache_write_tokens == 20 + + +def test_assistant_turn_cache_defaults_zero(): + """Cache fields default to 0 for backward compat.""" + from providers import AssistantTurn + turn = AssistantTurn(text="hi", tool_calls=[], in_tokens=10, out_tokens=5) + assert turn.cache_read_tokens == 0 + assert turn.cache_write_tokens == 0 + + +def test_agent_state_accumulates_cache_tokens(): + """AgentState accumulates cache tokens from AssistantTurn.""" + from agent import AgentState + state = AgentState() + assert state.total_cache_read_tokens == 0 + assert state.total_cache_write_tokens == 0 + + # Simulate what the agent loop does + state.total_cache_read_tokens += 80 + state.total_cache_write_tokens += 20 + state.total_cache_read_tokens += 60 + state.total_cache_write_tokens += 10 + + assert state.total_cache_read_tokens == 140 + assert state.total_cache_write_tokens == 30 + + +def test_checkpoint_snapshot_includes_cache(): + """make_snapshot persists cache tokens in token_snapshot.""" + from checkpoint.store import make_snapshot + from agent import AgentState + + state = AgentState() + state.total_input_tokens = 500 + state.total_output_tokens = 200 + state.total_cache_read_tokens = 300 + state.total_cache_write_tokens = 50 + state.turn_count = 3 + state.messages = [{"role": "user", "content": "test"}] + + snapshot = make_snapshot(state, "test-session", "hello user") + assert snapshot.token_snapshot["cache_read"] == 300 + assert snapshot.token_snapshot["cache_write"] == 50 + assert snapshot.token_snapshot["input"] == 500 + assert snapshot.token_snapshot["output"] == 200 From 1fbd183fc902ad7bf475b02f26a3c0956798cd9b Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Sat, 18 Apr 2026 09:57:58 +0200 Subject: [PATCH 4/5] fix: test_cache_tokens uses correct make_snapshot signature --- tests/test_checkpoint.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index c41c85e..a44a22c 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -458,29 +458,21 @@ def test_throttle_conversation_rewind_works(self, tmp_home): assert state.turn_count == 0 -def test_cache_tokens_in_snapshot(): +def test_cache_tokens_in_snapshot(tmp_home): """Cache tokens flow from AgentState to checkpoint snapshot.""" from agent import AgentState from checkpoint.store import make_snapshot - import inspect state = AgentState() state.total_input_tokens = 100 state.total_output_tokens = 50 state.total_cache_read_tokens = 30 state.total_cache_write_tokens = 20 - state.messages = [] + state.messages = [{"role": "user", "content": "hi"}] state.turn_count = 1 - # Call with correct signature (inspect to be safe) - sig = inspect.signature(make_snapshot) - params = list(sig.parameters.keys()) - if len(params) == 1: - snapshot = make_snapshot(state) - else: - snapshot = make_snapshot(state, 1) - - tokens = snapshot["token_snapshot"] + snap = make_snapshot("test_cache", state, {}, "hi") + tokens = snap.token_snapshot assert tokens["cache_read"] == 30, f"Expected 30, got {tokens.get('cache_read')}" assert tokens["cache_write"] == 20, f"Expected 20, got {tokens.get('cache_write')}" From 90ce4d3e32dc3ba8185ce537be4945157803935a Mon Sep 17 00:00:00 2001 From: Simon FREYBURGER Date: Mon, 20 Apr 2026 09:14:49 +0200 Subject: [PATCH 5/5] refactor: extract provider cache-token helpers + add e2e + multi-provider tests Two new small helpers in providers.py give each provider family one obvious extraction point instead of three sprinkled getattr chains: - _anthropic_cache_tokens(usage) -> (read, write) Reads cache_read_input_tokens / cache_creation_input_tokens. Returns (0, 0) if the fields are missing (older SDKs, Bedrock-via-litellm, non-cached calls) or None (Anthropic occasionally emits JSON null). - _openai_cached_read_tokens(usage) -> int Walks usage.prompt_tokens_details.cached_tokens. OpenAI's schema has no separate cache-creation counter (caching is implicit), so the write-side stays 0 for this entire provider family. stream_anthropic, stream_openai_compat now call these helpers instead of inlining the getattr dance. stream_ollama was already 0/0; behaviour unchanged. Any new provider that builds an AssistantTurn without passing cache_read_tokens / cache_write_tokens inherits the dataclass defaults and agent.run's getattr(... , 0) fallbacks, so downstream totals and snapshots stay consistent. Tests (tests/test_cache_tokens.py, rewritten): - AssistantTurn + AgentState defaults and accumulation. - Checkpoint snapshot persists cache_read + cache_write via real make_snapshot against a tmp_path. - TestAnthropicCacheExtraction (3 cases) + TestOpenAICacheExtraction (3 cases) covering populated / missing / None usage objects. - Ollama shape check (no-cache path). - test_agent_run_propagates_cache_tokens_from_mocked_stream: one turn through agent.run with a scripted stream; asserts state totals AND the produced snapshot. - test_agent_run_accumulates_cache_across_multi_turn: two consecutive runs with distinct cache values; asserts running totals. Cleanup: - The three duplicate cache-token cases previously appended to tests/test_checkpoint.py are removed; test_cache_tokens.py is the single home for this feature now. - Fix the stale make_snapshot(state, session_id, prompt) call in test_cache_tokens that survived from the earlier signature mismatch. Co-Authored-By: Claude Opus 4.7 (1M context) --- providers.py | 39 ++++++-- tests/test_cache_tokens.py | 183 ++++++++++++++++++++++++++++++++----- tests/test_checkpoint.py | 53 +---------- 3 files changed, 195 insertions(+), 80 deletions(-) diff --git a/providers.py b/providers.py index d02a486..0781ceb 100644 --- a/providers.py +++ b/providers.py @@ -527,15 +527,42 @@ def stream_anthropic( "input": block.input, }) + cache_r, cache_w = _anthropic_cache_tokens(final.usage) yield AssistantTurn( text, tool_calls, final.usage.input_tokens, final.usage.output_tokens, - cache_read_tokens=getattr(final.usage, "cache_read_input_tokens", 0) or 0, - cache_write_tokens=getattr(final.usage, "cache_creation_input_tokens", 0) or 0, + cache_read_tokens=cache_r, + cache_write_tokens=cache_w, ) +def _anthropic_cache_tokens(usage) -> tuple[int, int]: + """Extract (cache_read, cache_write) token counts from an Anthropic usage object. + + Returns (0, 0) if the fields are missing -- older Anthropic SDKs, non-cached + calls and most downstream wrappers (e.g. Bedrock over litellm) all fall + through to this default rather than raising AttributeError. + """ + read = getattr(usage, "cache_read_input_tokens", 0) or 0 + write = getattr(usage, "cache_creation_input_tokens", 0) or 0 + return int(read), int(write) + + +def _openai_cached_read_tokens(usage) -> int: + """Extract the OpenAI-compatible cached read-token count. + + OpenAI-compatible providers surface cache hits as + `usage.prompt_tokens_details.cached_tokens`; there is no separate + "cache creation" counter in the OpenAI schema (caching is implicit on + their side), so the write-side is always 0 for this family of providers. + """ + details = getattr(usage, "prompt_tokens_details", None) + if details is None: + return 0 + return int(getattr(details, "cached_tokens", 0) or 0) + + def stream_openai_compat( api_key: str, base_url: str, @@ -598,9 +625,7 @@ def stream_openai_compat( if hasattr(chunk, "usage") and chunk.usage: in_tok = chunk.usage.prompt_tokens out_tok = chunk.usage.completion_tokens - _details = getattr(chunk.usage, "prompt_tokens_details", None) - if _details: - cache_read_tok = getattr(_details, "cached_tokens", 0) or 0 + cache_read_tok = _openai_cached_read_tokens(chunk.usage) or cache_read_tok continue choice = chunk.choices[0] @@ -631,9 +656,7 @@ def stream_openai_compat( if hasattr(chunk, "usage") and chunk.usage: in_tok = chunk.usage.prompt_tokens or in_tok out_tok = chunk.usage.completion_tokens or out_tok - _details = getattr(chunk.usage, "prompt_tokens_details", None) - if _details: - cache_read_tok = getattr(_details, "cached_tokens", 0) or cache_read_tok + cache_read_tok = _openai_cached_read_tokens(chunk.usage) or cache_read_tok tool_calls = [] for idx in sorted(tool_buf): diff --git a/tests/test_cache_tokens.py b/tests/test_cache_tokens.py index afce5cf..125ae76 100644 --- a/tests/test_cache_tokens.py +++ b/tests/test_cache_tokens.py @@ -1,10 +1,24 @@ -"""Tests for cache token tracking end-to-end.""" -import dataclasses +"""End-to-end coverage for cache-token tracking. + +Layers covered: +1. The AssistantTurn carries cache_read / cache_write fields (unit). +2. AgentState accumulates them across turns (unit). +3. Checkpoint snapshots persist them (unit, real make_snapshot on tmp_path). +4. Provider extraction helpers work against synthetic usage objects for each + supported family (Anthropic, OpenAI-compatible, Ollama). +5. E2E: agent.run drains a mocked provider stream that emits an AssistantTurn + with cache tokens, and state + checkpoint see the totals. +""" +from __future__ import annotations + +from types import SimpleNamespace + import pytest +# ---------- 1 & 2: AssistantTurn + AgentState ---------- + def test_assistant_turn_has_cache_fields(): - """AssistantTurn carries cache_read_tokens and cache_write_tokens.""" from providers import AssistantTurn turn = AssistantTurn( text="hello", tool_calls=[], in_tokens=100, out_tokens=50, @@ -15,7 +29,7 @@ def test_assistant_turn_has_cache_fields(): def test_assistant_turn_cache_defaults_zero(): - """Cache fields default to 0 for backward compat.""" + """Older providers and ad-hoc callers construct AssistantTurn without cache fields.""" from providers import AssistantTurn turn = AssistantTurn(text="hi", tool_calls=[], in_tokens=10, out_tokens=5) assert turn.cache_read_tokens == 0 @@ -23,37 +37,164 @@ def test_assistant_turn_cache_defaults_zero(): def test_agent_state_accumulates_cache_tokens(): - """AgentState accumulates cache tokens from AssistantTurn.""" from agent import AgentState state = AgentState() - assert state.total_cache_read_tokens == 0 - assert state.total_cache_write_tokens == 0 + assert (state.total_cache_read_tokens, state.total_cache_write_tokens) == (0, 0) - # Simulate what the agent loop does - state.total_cache_read_tokens += 80 + state.total_cache_read_tokens += 80 state.total_cache_write_tokens += 20 - state.total_cache_read_tokens += 60 + state.total_cache_read_tokens += 60 state.total_cache_write_tokens += 10 assert state.total_cache_read_tokens == 140 assert state.total_cache_write_tokens == 30 -def test_checkpoint_snapshot_includes_cache(): - """make_snapshot persists cache tokens in token_snapshot.""" - from checkpoint.store import make_snapshot +# ---------- 3: Checkpoint persistence ---------- + +def test_checkpoint_snapshot_includes_cache(tmp_path, monkeypatch): + from checkpoint import store from agent import AgentState + monkeypatch.setattr(store, "_checkpoints_root", lambda: tmp_path / ".checkpoints") + store.reset_file_versions() + state = AgentState() - state.total_input_tokens = 500 - state.total_output_tokens = 200 - state.total_cache_read_tokens = 300 + state.total_input_tokens = 500 + state.total_output_tokens = 200 + state.total_cache_read_tokens = 300 state.total_cache_write_tokens = 50 state.turn_count = 3 state.messages = [{"role": "user", "content": "test"}] - snapshot = make_snapshot(state, "test-session", "hello user") - assert snapshot.token_snapshot["cache_read"] == 300 - assert snapshot.token_snapshot["cache_write"] == 50 - assert snapshot.token_snapshot["input"] == 500 - assert snapshot.token_snapshot["output"] == 200 + snap = store.make_snapshot("test-session", state, {}, "hello user") + assert snap.token_snapshot == { + "input": 500, "output": 200, "cache_read": 300, "cache_write": 50, + } + + +# ---------- 4: Provider extraction helpers ---------- + +class TestAnthropicCacheExtraction: + """_anthropic_cache_tokens must read cache_read_input_tokens / cache_creation_input_tokens.""" + + def test_returns_both_when_populated(self): + from providers import _anthropic_cache_tokens + usage = SimpleNamespace( + input_tokens=120, output_tokens=40, + cache_read_input_tokens=77, cache_creation_input_tokens=33, + ) + assert _anthropic_cache_tokens(usage) == (77, 33) + + def test_missing_fields_default_to_zero(self): + """Older Anthropic SDKs and Bedrock-over-litellm wrappers omit the cache fields.""" + from providers import _anthropic_cache_tokens + usage = SimpleNamespace(input_tokens=10, output_tokens=5) + assert _anthropic_cache_tokens(usage) == (0, 0) + + def test_none_fields_coerced_to_zero(self): + """Anthropic occasionally returns None (JSON null) rather than omitting the field.""" + from providers import _anthropic_cache_tokens + usage = SimpleNamespace( + input_tokens=10, output_tokens=5, + cache_read_input_tokens=None, cache_creation_input_tokens=None, + ) + assert _anthropic_cache_tokens(usage) == (0, 0) + + +class TestOpenAICacheExtraction: + """_openai_cached_read_tokens must walk prompt_tokens_details.cached_tokens.""" + + def test_reads_cached_tokens_from_details(self): + from providers import _openai_cached_read_tokens + usage = SimpleNamespace( + prompt_tokens=100, completion_tokens=50, + prompt_tokens_details=SimpleNamespace(cached_tokens=42), + ) + assert _openai_cached_read_tokens(usage) == 42 + + def test_missing_details_returns_zero(self): + from providers import _openai_cached_read_tokens + usage = SimpleNamespace(prompt_tokens=100, completion_tokens=50) + assert _openai_cached_read_tokens(usage) == 0 + + def test_none_cached_tokens_returns_zero(self): + from providers import _openai_cached_read_tokens + usage = SimpleNamespace( + prompt_tokens=100, completion_tokens=50, + prompt_tokens_details=SimpleNamespace(cached_tokens=None), + ) + assert _openai_cached_read_tokens(usage) == 0 + + +def test_ollama_stream_never_reports_cache_tokens(): + """Ollama has no prompt-caching; the path must yield 0/0 without raising.""" + from providers import AssistantTurn + # stream_ollama yields AssistantTurn(text, tool_calls, 0, 0, 0, 0) -- we can't + # reach the full HTTP call in a unit test, but we can assert the shape of the + # yielded object the callers rely on. + turn = AssistantTurn("hi", [], 0, 0, 0, 0) + assert turn.cache_read_tokens == 0 + assert turn.cache_write_tokens == 0 + + +# ---------- 5: End-to-end through agent.run ---------- + +def test_agent_run_propagates_cache_tokens_from_mocked_stream(monkeypatch, tmp_path): + """Drive agent.run once with a scripted stream and assert totals + snapshot.""" + import tools as _tools_init # noqa: F401 - register tools + from agent import AgentState, run + from providers import AssistantTurn + from checkpoint import store as ck_store + + monkeypatch.setattr(ck_store, "_checkpoints_root", lambda: tmp_path / ".checkpoints") + ck_store.reset_file_versions() + + def fake_stream(**_kwargs): + yield AssistantTurn( + text="all good", tool_calls=[], + in_tokens=1000, out_tokens=200, + cache_read_tokens=700, cache_write_tokens=50, + ) + + monkeypatch.setattr("agent.stream", fake_stream) + + state = AgentState() + list(run("hello", state, { + "model": "test", "permission_mode": "accept-all", + "_session_id": "cache_e2e", "disabled_tools": ["Agent"], + }, "sys")) + + assert state.total_cache_read_tokens == 700 + assert state.total_cache_write_tokens == 50 + + snap = ck_store.make_snapshot("cache_e2e", state, {}, "hello") + assert snap.token_snapshot["cache_read"] == 700 + assert snap.token_snapshot["cache_write"] == 50 + + +def test_agent_run_accumulates_cache_across_multi_turn(monkeypatch): + """Two consecutive agent.run calls must sum their cache counters in state.""" + import tools as _tools_init # noqa: F401 + from agent import AgentState, run + from providers import AssistantTurn + + emitted = iter([ + AssistantTurn("one", [], 100, 50, cache_read_tokens=40, cache_write_tokens=10), + AssistantTurn("two", [], 120, 60, cache_read_tokens=90, cache_write_tokens=0), + ]) + + def fake_stream(**_kwargs): + yield next(emitted) + + monkeypatch.setattr("agent.stream", fake_stream) + + state = AgentState() + cfg = {"model": "test", "permission_mode": "accept-all", + "_session_id": "multi", "disabled_tools": ["Agent"]} + + list(run("first", state, cfg, "sys")) + list(run("second", state, cfg, "sys")) + + assert state.total_cache_read_tokens == 130 + assert state.total_cache_write_tokens == 10 diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index a44a22c..e0c1f82 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -458,54 +458,5 @@ def test_throttle_conversation_rewind_works(self, tmp_home): assert state.turn_count == 0 -def test_cache_tokens_in_snapshot(tmp_home): - """Cache tokens flow from AgentState to checkpoint snapshot.""" - from agent import AgentState - from checkpoint.store import make_snapshot - - state = AgentState() - state.total_input_tokens = 100 - state.total_output_tokens = 50 - state.total_cache_read_tokens = 30 - state.total_cache_write_tokens = 20 - state.messages = [{"role": "user", "content": "hi"}] - state.turn_count = 1 - - snap = make_snapshot("test_cache", state, {}, "hi") - tokens = snap.token_snapshot - - assert tokens["cache_read"] == 30, f"Expected 30, got {tokens.get('cache_read')}" - assert tokens["cache_write"] == 20, f"Expected 20, got {tokens.get('cache_write')}" - assert tokens["input"] == 100 - assert tokens["output"] == 50 - - -def test_assistant_turn_cache_tokens(): - """AssistantTurn carries cache token counts from the provider.""" - from providers import AssistantTurn - - turn = AssistantTurn("hello", [], 100, 50, cache_read_tokens=30, cache_write_tokens=20) - assert turn.cache_read_tokens == 30 - assert turn.cache_write_tokens == 20 - - # Defaults to 0 - turn2 = AssistantTurn("hello", [], 100, 50) - assert turn2.cache_read_tokens == 0 - assert turn2.cache_write_tokens == 0 - - -def test_agent_state_accumulates_cache_tokens(): - """AgentState accumulates cache tokens across turns.""" - from agent import AgentState - - state = AgentState() - assert state.total_cache_read_tokens == 0 - assert state.total_cache_write_tokens == 0 - - state.total_cache_read_tokens += 10 - state.total_cache_write_tokens += 5 - state.total_cache_read_tokens += 20 - state.total_cache_write_tokens += 15 - - assert state.total_cache_read_tokens == 30 - assert state.total_cache_write_tokens == 20 +# Cache-token coverage lives in tests/test_cache_tokens.py so this module +# stays focused on snapshot / restore / file-backup behaviour.