diff --git a/billing.go b/billing.go index b53ddea..194617e 100644 --- a/billing.go +++ b/billing.go @@ -56,14 +56,25 @@ func CalculateCost(provider, model string, costInfo CostInfo, promptTokens, comp cachedTokens = cacheUsage.CachedTokens + cacheUsage.CacheReadInputTokens } - // Clamp to prompt tokens as a safety net — cached tokens can never - // exceed the total prompt tokens the provider reported. + // Providers report prompt tokens differently: + // - OpenAI/Fireworks/Bedrock: promptTokens INCLUDES cached tokens + // → non-cached = promptTokens - cachedTokens + // - Anthropic: input_tokens EXCLUDES cached tokens (only new tokens) + // → non-cached = promptTokens (as reported), cached is additional + // + // We detect the style by comparing: if cached > prompt, the provider + // must be reporting non-cached only (Anthropic style). + var nonCachedTokens int if cachedTokens > promptTokens { - cachedTokens = promptTokens + // Anthropic style: promptTokens = non-cached only, cached is separate + nonCachedTokens = promptTokens + // Adjust promptTokens to reflect the true total for the BillingResult + promptTokens = promptTokens + cachedTokens + } else { + // OpenAI style: promptTokens includes cached + nonCachedTokens = promptTokens - cachedTokens } - nonCachedTokens := promptTokens - cachedTokens - // Non-cached prompt tokens at full input rate inputCost := costInfo.Input * float64(nonCachedTokens) / 1_000_000 diff --git a/billing_test.go b/billing_test.go index de448d2..72ce02a 100644 --- a/billing_test.go +++ b/billing_test.go @@ -67,19 +67,27 @@ func TestCalculateCost_WithOpenAICacheHit(t *testing.T) { } func TestCalculateCost_WithAnthropicCacheHit(t *testing.T) { + // Anthropic reports input_tokens as non-cached only (excludes cached). + // Real example: input_tokens=3, cache_read_input_tokens=2022 costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 0.3} - cacheUsage := &CacheUsage{CacheReadInputTokens: 2000} + cacheUsage := &CacheUsage{CacheReadInputTokens: 2022} - result := CalculateCost("anthropic", "claude-sonnet-4", costInfo, 2500, 100, cacheUsage) + // promptTokens=3 (Anthropic's input_tokens, non-cached only) + result := CalculateCost("anthropic", "claude-sonnet-4", costInfo, 3, 10, cacheUsage) - if result.CachedTokens != 2000 { - t.Errorf("CachedTokens = %d, want 2000", result.CachedTokens) + if result.CachedTokens != 2022 { + t.Errorf("CachedTokens = %d, want 2022", result.CachedTokens) } - // 500 non-cached at full rate, 2000 cached at cache rate - expectedInput := 3.0 * 500 / 1_000_000 - expectedCached := 0.3 * 2000 / 1_000_000 - expectedOutput := 15.0 * 100 / 1_000_000 + // PromptTokens should be adjusted to include cached: 3 + 2022 = 2025 + if result.PromptTokens != 2025 { + t.Errorf("PromptTokens = %d, want 2025 (non-cached + cached)", result.PromptTokens) + } + + // 3 non-cached at full rate, 2022 cached at cache rate + expectedInput := 3.0 * 3 / 1_000_000 + expectedCached := 0.3 * 2022 / 1_000_000 + expectedOutput := 15.0 * 10 / 1_000_000 assertFloat(t, "InputCost", result.InputCost, expectedInput) assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) @@ -105,21 +113,26 @@ func TestCalculateCost_CacheUsageWithZeroTokens(t *testing.T) { assertFloat(t, "TotalCost", result.TotalCost, expectedInput+expectedOutput) } -func TestCalculateCost_CacheUsageExceedsPromptTokens(t *testing.T) { +func TestCalculateCost_CacheExceedsPromptTokens_AnthropicStyle(t *testing.T) { + // Anthropic style: prompt_tokens=10 (non-cached), cache_read=5000 (cached). + // cached > prompt means the provider reports non-cached only. costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} - // More cached tokens than prompt tokens (shouldn't happen, but defensive) - cacheUsage := &CacheUsage{CachedTokens: 5000} + cacheUsage := &CacheUsage{CacheReadInputTokens: 5000} - result := CalculateCost("openai", "gpt-4o", costInfo, 1000, 500, cacheUsage) + result := CalculateCost("anthropic", "claude-sonnet-4", costInfo, 10, 500, cacheUsage) - // Cached tokens should be clamped to prompt tokens - if result.CachedTokens != 1000 { - t.Errorf("CachedTokens = %d, want 1000 (clamped)", result.CachedTokens) + if result.CachedTokens != 5000 { + t.Errorf("CachedTokens = %d, want 5000", result.CachedTokens) + } + // PromptTokens adjusted to true total: 10 + 5000 + if result.PromptTokens != 5010 { + t.Errorf("PromptTokens = %d, want 5010", result.PromptTokens) } - // All prompt tokens at cache rate, none at full rate - assertFloat(t, "InputCost", result.InputCost, 0) - expectedCached := 1.5 * 1000 / 1_000_000 + // 10 non-cached at full rate, 5000 cached at cache rate + expectedInput := 3.0 * 10 / 1_000_000 + expectedCached := 1.5 * 5000 / 1_000_000 + assertFloat(t, "InputCost", result.InputCost, expectedInput) assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) }