Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class AgentState:
messages: list = field(default_factory=list)
total_input_tokens: int = 0
total_output_tokens: int = 0
total_cache_read_tokens: int = 0
total_cache_write_tokens: int = 0
turn_count: int = 0


Expand Down Expand Up @@ -181,6 +183,8 @@ def run(

state.total_input_tokens += assistant_turn.in_tokens
state.total_output_tokens += assistant_turn.out_tokens
state.total_cache_read_tokens += getattr(assistant_turn, 'cache_read_tokens', 0)
state.total_cache_write_tokens += getattr(assistant_turn, 'cache_write_tokens', 0)
yield TurnDone(assistant_turn.in_tokens, assistant_turn.out_tokens)

if not assistant_turn.tool_calls:
Expand Down
2 changes: 2 additions & 0 deletions checkpoint/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def make_snapshot(
token_snapshot={
"input": getattr(state, "total_input_tokens", 0),
"output": getattr(state, "total_output_tokens", 0),
"cache_read": getattr(state, "total_cache_read_tokens", 0),
"cache_write": getattr(state, "total_cache_write_tokens", 0),
},
file_backups=new_backups,
)
Expand Down
53 changes: 45 additions & 8 deletions providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,14 @@ def __init__(self, text): self.text = text

class AssistantTurn:
"""Completed assistant turn with text + tool_calls."""
def __init__(self, text, tool_calls, in_tokens, out_tokens):
self.text = text
self.tool_calls = tool_calls # list of {id, name, input}
self.in_tokens = in_tokens
self.out_tokens = out_tokens
def __init__(self, text, tool_calls, in_tokens, out_tokens,
cache_read_tokens=0, cache_write_tokens=0):
self.text = text
self.tool_calls = tool_calls # list of {id, name, input}
self.in_tokens = in_tokens
self.out_tokens = out_tokens
self.cache_read_tokens = cache_read_tokens
self.cache_write_tokens = cache_write_tokens


def stream_anthropic(
Expand Down Expand Up @@ -524,13 +527,42 @@ def stream_anthropic(
"input": block.input,
})

cache_r, cache_w = _anthropic_cache_tokens(final.usage)
yield AssistantTurn(
text, tool_calls,
final.usage.input_tokens,
final.usage.output_tokens,
cache_read_tokens=cache_r,
cache_write_tokens=cache_w,
)


def _anthropic_cache_tokens(usage) -> tuple[int, int]:
"""Extract (cache_read, cache_write) token counts from an Anthropic usage object.

Returns (0, 0) if the fields are missing -- older Anthropic SDKs, non-cached
calls and most downstream wrappers (e.g. Bedrock over litellm) all fall
through to this default rather than raising AttributeError.
"""
read = getattr(usage, "cache_read_input_tokens", 0) or 0
write = getattr(usage, "cache_creation_input_tokens", 0) or 0
return int(read), int(write)


def _openai_cached_read_tokens(usage) -> int:
"""Extract the OpenAI-compatible cached read-token count.

OpenAI-compatible providers surface cache hits as
`usage.prompt_tokens_details.cached_tokens`; there is no separate
"cache creation" counter in the OpenAI schema (caching is implicit on
their side), so the write-side is always 0 for this family of providers.
"""
details = getattr(usage, "prompt_tokens_details", None)
if details is None:
return 0
return int(getattr(details, "cached_tokens", 0) or 0)


def stream_openai_compat(
api_key: str,
base_url: str,
Expand Down Expand Up @@ -584,6 +616,7 @@ def stream_openai_compat(
text = ""
tool_buf: dict = {} # index → {id, name, args_str}
in_tok = out_tok = 0
cache_read_tok = cache_write_tok = 0

stream = client.chat.completions.create(**kwargs)
for chunk in stream:
Expand All @@ -592,6 +625,7 @@ def stream_openai_compat(
if hasattr(chunk, "usage") and chunk.usage:
in_tok = chunk.usage.prompt_tokens
out_tok = chunk.usage.completion_tokens
cache_read_tok = _openai_cached_read_tokens(chunk.usage) or cache_read_tok
continue

choice = chunk.choices[0]
Expand Down Expand Up @@ -622,6 +656,7 @@ def stream_openai_compat(
if hasattr(chunk, "usage") and chunk.usage:
in_tok = chunk.usage.prompt_tokens or in_tok
out_tok = chunk.usage.completion_tokens or out_tok
cache_read_tok = _openai_cached_read_tokens(chunk.usage) or cache_read_tok

tool_calls = []
for idx in sorted(tool_buf):
Expand All @@ -635,7 +670,7 @@ def stream_openai_compat(
tc_entry["extra_content"] = v["extra_content"]
tool_calls.append(tc_entry)

yield AssistantTurn(text, tool_calls, in_tok, out_tok)
yield AssistantTurn(text, tool_calls, in_tok, out_tok, cache_read_tok, cache_write_tok)


def stream_ollama(
Expand Down Expand Up @@ -750,7 +785,7 @@ def _make_request(p):

# Ollama doesn't return exact token counts via livestream easily until "done",
# but we can do a rough estimate or 0, cheetahclaws handles zero gracefully
yield AssistantTurn(text, tool_calls, 0, 0)
yield AssistantTurn(text, tool_calls, 0, 0, 0, 0)


def stream(
Expand Down Expand Up @@ -823,7 +858,9 @@ def stream(
breaker.record_success()
_log.info("api_call_done", session_id=session_id,
provider=provider_name, model=model_name,
in_tokens=event.in_tokens, out_tokens=event.out_tokens)
in_tokens=event.in_tokens, out_tokens=event.out_tokens,
cache_read_tokens=getattr(event, 'cache_read_tokens', 0),
cache_write_tokens=getattr(event, 'cache_write_tokens', 0))
yield event
except Exception as exc:
breaker.record_failure()
Expand Down
200 changes: 200 additions & 0 deletions tests/test_cache_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""End-to-end coverage for cache-token tracking.

Layers covered:
1. The AssistantTurn carries cache_read / cache_write fields (unit).
2. AgentState accumulates them across turns (unit).
3. Checkpoint snapshots persist them (unit, real make_snapshot on tmp_path).
4. Provider extraction helpers work against synthetic usage objects for each
supported family (Anthropic, OpenAI-compatible, Ollama).
5. E2E: agent.run drains a mocked provider stream that emits an AssistantTurn
with cache tokens, and state + checkpoint see the totals.
"""
from __future__ import annotations

from types import SimpleNamespace

import pytest


# ---------- 1 & 2: AssistantTurn + AgentState ----------

def test_assistant_turn_has_cache_fields():
from providers import AssistantTurn
turn = AssistantTurn(
text="hello", tool_calls=[], in_tokens=100, out_tokens=50,
cache_read_tokens=80, cache_write_tokens=20,
)
assert turn.cache_read_tokens == 80
assert turn.cache_write_tokens == 20


def test_assistant_turn_cache_defaults_zero():
"""Older providers and ad-hoc callers construct AssistantTurn without cache fields."""
from providers import AssistantTurn
turn = AssistantTurn(text="hi", tool_calls=[], in_tokens=10, out_tokens=5)
assert turn.cache_read_tokens == 0
assert turn.cache_write_tokens == 0


def test_agent_state_accumulates_cache_tokens():
from agent import AgentState
state = AgentState()
assert (state.total_cache_read_tokens, state.total_cache_write_tokens) == (0, 0)

state.total_cache_read_tokens += 80
state.total_cache_write_tokens += 20
state.total_cache_read_tokens += 60
state.total_cache_write_tokens += 10

assert state.total_cache_read_tokens == 140
assert state.total_cache_write_tokens == 30


# ---------- 3: Checkpoint persistence ----------

def test_checkpoint_snapshot_includes_cache(tmp_path, monkeypatch):
from checkpoint import store
from agent import AgentState

monkeypatch.setattr(store, "_checkpoints_root", lambda: tmp_path / ".checkpoints")
store.reset_file_versions()

state = AgentState()
state.total_input_tokens = 500
state.total_output_tokens = 200
state.total_cache_read_tokens = 300
state.total_cache_write_tokens = 50
state.turn_count = 3
state.messages = [{"role": "user", "content": "test"}]

snap = store.make_snapshot("test-session", state, {}, "hello user")
assert snap.token_snapshot == {
"input": 500, "output": 200, "cache_read": 300, "cache_write": 50,
}


# ---------- 4: Provider extraction helpers ----------

class TestAnthropicCacheExtraction:
"""_anthropic_cache_tokens must read cache_read_input_tokens / cache_creation_input_tokens."""

def test_returns_both_when_populated(self):
from providers import _anthropic_cache_tokens
usage = SimpleNamespace(
input_tokens=120, output_tokens=40,
cache_read_input_tokens=77, cache_creation_input_tokens=33,
)
assert _anthropic_cache_tokens(usage) == (77, 33)

def test_missing_fields_default_to_zero(self):
"""Older Anthropic SDKs and Bedrock-over-litellm wrappers omit the cache fields."""
from providers import _anthropic_cache_tokens
usage = SimpleNamespace(input_tokens=10, output_tokens=5)
assert _anthropic_cache_tokens(usage) == (0, 0)

def test_none_fields_coerced_to_zero(self):
"""Anthropic occasionally returns None (JSON null) rather than omitting the field."""
from providers import _anthropic_cache_tokens
usage = SimpleNamespace(
input_tokens=10, output_tokens=5,
cache_read_input_tokens=None, cache_creation_input_tokens=None,
)
assert _anthropic_cache_tokens(usage) == (0, 0)


class TestOpenAICacheExtraction:
"""_openai_cached_read_tokens must walk prompt_tokens_details.cached_tokens."""

def test_reads_cached_tokens_from_details(self):
from providers import _openai_cached_read_tokens
usage = SimpleNamespace(
prompt_tokens=100, completion_tokens=50,
prompt_tokens_details=SimpleNamespace(cached_tokens=42),
)
assert _openai_cached_read_tokens(usage) == 42

def test_missing_details_returns_zero(self):
from providers import _openai_cached_read_tokens
usage = SimpleNamespace(prompt_tokens=100, completion_tokens=50)
assert _openai_cached_read_tokens(usage) == 0

def test_none_cached_tokens_returns_zero(self):
from providers import _openai_cached_read_tokens
usage = SimpleNamespace(
prompt_tokens=100, completion_tokens=50,
prompt_tokens_details=SimpleNamespace(cached_tokens=None),
)
assert _openai_cached_read_tokens(usage) == 0


def test_ollama_stream_never_reports_cache_tokens():
"""Ollama has no prompt-caching; the path must yield 0/0 without raising."""
from providers import AssistantTurn
# stream_ollama yields AssistantTurn(text, tool_calls, 0, 0, 0, 0) -- we can't
# reach the full HTTP call in a unit test, but we can assert the shape of the
# yielded object the callers rely on.
turn = AssistantTurn("hi", [], 0, 0, 0, 0)
assert turn.cache_read_tokens == 0
assert turn.cache_write_tokens == 0


# ---------- 5: End-to-end through agent.run ----------

def test_agent_run_propagates_cache_tokens_from_mocked_stream(monkeypatch, tmp_path):
"""Drive agent.run once with a scripted stream and assert totals + snapshot."""
import tools as _tools_init # noqa: F401 - register tools
from agent import AgentState, run
from providers import AssistantTurn
from checkpoint import store as ck_store

monkeypatch.setattr(ck_store, "_checkpoints_root", lambda: tmp_path / ".checkpoints")
ck_store.reset_file_versions()

def fake_stream(**_kwargs):
yield AssistantTurn(
text="all good", tool_calls=[],
in_tokens=1000, out_tokens=200,
cache_read_tokens=700, cache_write_tokens=50,
)

monkeypatch.setattr("agent.stream", fake_stream)

state = AgentState()
list(run("hello", state, {
"model": "test", "permission_mode": "accept-all",
"_session_id": "cache_e2e", "disabled_tools": ["Agent"],
}, "sys"))

assert state.total_cache_read_tokens == 700
assert state.total_cache_write_tokens == 50

snap = ck_store.make_snapshot("cache_e2e", state, {}, "hello")
assert snap.token_snapshot["cache_read"] == 700
assert snap.token_snapshot["cache_write"] == 50


def test_agent_run_accumulates_cache_across_multi_turn(monkeypatch):
"""Two consecutive agent.run calls must sum their cache counters in state."""
import tools as _tools_init # noqa: F401
from agent import AgentState, run
from providers import AssistantTurn

emitted = iter([
AssistantTurn("one", [], 100, 50, cache_read_tokens=40, cache_write_tokens=10),
AssistantTurn("two", [], 120, 60, cache_read_tokens=90, cache_write_tokens=0),
])

def fake_stream(**_kwargs):
yield next(emitted)

monkeypatch.setattr("agent.stream", fake_stream)

state = AgentState()
cfg = {"model": "test", "permission_mode": "accept-all",
"_session_id": "multi", "disabled_tools": ["Agent"]}

list(run("first", state, cfg, "sys"))
list(run("second", state, cfg, "sys"))

assert state.total_cache_read_tokens == 130
assert state.total_cache_write_tokens == 10
4 changes: 4 additions & 0 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,7 @@ def test_throttle_conversation_rewind_works(self, tmp_home):
state.turn_count = snap1.turn_count
assert len(state.messages) == 0
assert state.turn_count == 0


# Cache-token coverage lives in tests/test_cache_tokens.py so this module
# stays focused on snapshot / restore / file-backup behaviour.