From ed83477046dbbd877aff489bc1bb76691fd6cbde Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Sun, 12 Apr 2026 23:11:40 -0500 Subject: [PATCH 1/3] feat: split billing for cached vs non-cached prompt tokens CalculateCost now accepts CacheUsage and bills cached tokens at the CacheRead rate instead of the full Input rate. This correctly reflects the cost savings from prompt caching (typically 50-90% cheaper). Changes: - BillingResult: added CachedTokens and CachedInputCost fields - CalculateCost: accepts *CacheUsage, splits cached/non-cached pricing - BillingInterceptor: extracts cache_usage from ResponseMetadata.Custom and passes it to CalculateCost --- billing.go | 47 +++++++++++++++++++++++++++++++++++------ interceptors/billing.go | 9 +++++++- 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/billing.go b/billing.go index d20888d..4ec52b4 100644 --- a/billing.go +++ b/billing.go @@ -29,19 +29,52 @@ type BillingResult struct { PromptTokens int // CompletionTokens is the number of output tokens. CompletionTokens int + // CachedTokens is the number of prompt tokens served from cache. + CachedTokens int // TotalTokens is the sum of prompt and completion tokens. TotalTokens int - // InputCost is the calculated input cost in USD. + // InputCost is the calculated input cost in USD (non-cached prompt tokens). InputCost float64 + // CachedInputCost is the cost for cached prompt tokens in USD. + CachedInputCost float64 // OutputCost is the calculated output cost in USD. OutputCost float64 - // TotalCost is the sum of input and output cost in USD. + // TotalCost is the sum of all costs in USD. TotalCost float64 } -// CalculateCost computes the billing result from cost info and token usage. -func CalculateCost(provider, model string, costInfo CostInfo, promptTokens, completionTokens int) BillingResult { - inputCost := costInfo.Input * float64(promptTokens) / 1_000_000 +// CalculateCost computes the billing result from cost info, token usage, and cache usage. +// Cached tokens are billed at the CacheRead rate (if available), and non-cached prompt +// tokens are billed at the full Input rate. +func CalculateCost(provider, model string, costInfo CostInfo, promptTokens, completionTokens int, cacheUsage *CacheUsage) BillingResult { + cachedTokens := 0 + if cacheUsage != nil { + // Normalize cached token count across providers: + // - OpenAI/Fireworks/Bedrock: CachedTokens + // - Anthropic: CacheReadInputTokens + cachedTokens = cacheUsage.CachedTokens + cacheUsage.CacheReadInputTokens + } + + // Ensure cached tokens don't exceed prompt tokens + if cachedTokens > promptTokens { + cachedTokens = promptTokens + } + + nonCachedTokens := promptTokens - cachedTokens + + // Non-cached prompt tokens at full input rate + inputCost := costInfo.Input * float64(nonCachedTokens) / 1_000_000 + + // Cached tokens at cache read rate (falls back to full input rate if no cache pricing) + var cachedInputCost float64 + if cachedTokens > 0 { + cacheRate := costInfo.CacheRead + if cacheRate <= 0 { + cacheRate = costInfo.Input // fallback to full rate + } + cachedInputCost = cacheRate * float64(cachedTokens) / 1_000_000 + } + outputCost := costInfo.Output * float64(completionTokens) / 1_000_000 return BillingResult{ @@ -49,9 +82,11 @@ func CalculateCost(provider, model string, costInfo CostInfo, promptTokens, comp Model: model, PromptTokens: promptTokens, CompletionTokens: completionTokens, + CachedTokens: cachedTokens, TotalTokens: promptTokens + completionTokens, InputCost: inputCost, + CachedInputCost: cachedInputCost, OutputCost: outputCost, - TotalCost: inputCost + outputCost, + TotalCost: inputCost + cachedInputCost + outputCost, } } diff --git a/interceptors/billing.go b/interceptors/billing.go index ab93472..56d8419 100644 --- a/interceptors/billing.go +++ b/interceptors/billing.go @@ -36,7 +36,14 @@ func (i *BillingInterceptor) Intercept(req *http.Request, meta llmproxy.BodyMeta } if found && i.OnResult != nil { - result := llmproxy.CalculateCost(provider, meta.Model, costInfo, respMeta.Usage.PromptTokens, respMeta.Usage.CompletionTokens) + // Extract cache usage from response metadata if available + var cacheUsage *llmproxy.CacheUsage + if cu, ok := respMeta.Custom["cache_usage"]; ok { + if usage, ok := cu.(llmproxy.CacheUsage); ok { + cacheUsage = &usage + } + } + result := llmproxy.CalculateCost(provider, meta.Model, costInfo, respMeta.Usage.PromptTokens, respMeta.Usage.CompletionTokens, cacheUsage) i.OnResult(result) } From 0db9e515bb88349930d32ca54cece3501511a6fa Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Sun, 12 Apr 2026 23:14:26 -0500 Subject: [PATCH 2/3] test: add comprehensive unit tests for cached token billing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tests cover: - No cache usage (nil) — all tokens at full rate - OpenAI cache hit (CachedTokens field) - Anthropic cache hit (CacheReadInputTokens field) - Cache usage present but zero tokens — treated as no caching - Cached tokens exceeding prompt tokens — clamped - No CacheRead price — falls back to full input rate - All tokens cached — zero non-cached cost - Zero tokens — zero cost - Mixed provider cache fields (summed) - Interceptor extracting cache_usage from ResponseMetadata.Custom - Interceptor with nil/empty Custom map --- billing_test.go | 192 +++++++++++++++++++++++++++++++++++ interceptors/billing_test.go | 150 +++++++++++++++++++++++++++ 2 files changed, 342 insertions(+) create mode 100644 billing_test.go diff --git a/billing_test.go b/billing_test.go new file mode 100644 index 0000000..de448d2 --- /dev/null +++ b/billing_test.go @@ -0,0 +1,192 @@ +package llmproxy + +import ( + "math" + "testing" +) + +const epsilon = 1e-9 + +func assertFloat(t *testing.T, name string, got, want float64) { + t.Helper() + if math.Abs(got-want) > epsilon { + t.Errorf("%s = %f, want %f (diff: %e)", name, got, want, math.Abs(got-want)) + } +} + +func TestCalculateCost_NoCacheUsage(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + result := CalculateCost("openai", "gpt-4o", costInfo, 1000, 500, nil) + + if result.Provider != "openai" { + t.Errorf("Provider = %q, want %q", result.Provider, "openai") + } + if result.Model != "gpt-4o" { + t.Errorf("Model = %q, want %q", result.Model, "gpt-4o") + } + if result.PromptTokens != 1000 { + t.Errorf("PromptTokens = %d, want 1000", result.PromptTokens) + } + if result.CompletionTokens != 500 { + t.Errorf("CompletionTokens = %d, want 500", result.CompletionTokens) + } + if result.CachedTokens != 0 { + t.Errorf("CachedTokens = %d, want 0", result.CachedTokens) + } + if result.TotalTokens != 1500 { + t.Errorf("TotalTokens = %d, want 1500", result.TotalTokens) + } + + expectedInput := 3.0 * 1000 / 1_000_000 + expectedOutput := 15.0 * 500 / 1_000_000 + assertFloat(t, "InputCost", result.InputCost, expectedInput) + assertFloat(t, "CachedInputCost", result.CachedInputCost, 0) + assertFloat(t, "OutputCost", result.OutputCost, expectedOutput) + assertFloat(t, "TotalCost", result.TotalCost, expectedInput+expectedOutput) +} + +func TestCalculateCost_WithOpenAICacheHit(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + cacheUsage := &CacheUsage{CachedTokens: 800} + + result := CalculateCost("openai", "gpt-4o", costInfo, 1000, 500, cacheUsage) + + if result.CachedTokens != 800 { + t.Errorf("CachedTokens = %d, want 800", result.CachedTokens) + } + + // 200 non-cached at full rate, 800 cached at cache rate + expectedInput := 3.0 * 200 / 1_000_000 + expectedCached := 1.5 * 800 / 1_000_000 + expectedOutput := 15.0 * 500 / 1_000_000 + + assertFloat(t, "InputCost", result.InputCost, expectedInput) + assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) + assertFloat(t, "OutputCost", result.OutputCost, expectedOutput) + assertFloat(t, "TotalCost", result.TotalCost, expectedInput+expectedCached+expectedOutput) +} + +func TestCalculateCost_WithAnthropicCacheHit(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 0.3} + cacheUsage := &CacheUsage{CacheReadInputTokens: 2000} + + result := CalculateCost("anthropic", "claude-sonnet-4", costInfo, 2500, 100, cacheUsage) + + if result.CachedTokens != 2000 { + t.Errorf("CachedTokens = %d, want 2000", result.CachedTokens) + } + + // 500 non-cached at full rate, 2000 cached at cache rate + expectedInput := 3.0 * 500 / 1_000_000 + expectedCached := 0.3 * 2000 / 1_000_000 + expectedOutput := 15.0 * 100 / 1_000_000 + + assertFloat(t, "InputCost", result.InputCost, expectedInput) + assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) + assertFloat(t, "TotalCost", result.TotalCost, expectedInput+expectedCached+expectedOutput) +} + +func TestCalculateCost_CacheUsageWithZeroTokens(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + // CacheUsage present but all fields are zero + cacheUsage := &CacheUsage{} + + result := CalculateCost("openai", "gpt-4o", costInfo, 1000, 500, cacheUsage) + + // Should behave exactly like no cache usage + if result.CachedTokens != 0 { + t.Errorf("CachedTokens = %d, want 0", result.CachedTokens) + } + assertFloat(t, "CachedInputCost", result.CachedInputCost, 0) + + expectedInput := 3.0 * 1000 / 1_000_000 + expectedOutput := 15.0 * 500 / 1_000_000 + assertFloat(t, "InputCost", result.InputCost, expectedInput) + assertFloat(t, "TotalCost", result.TotalCost, expectedInput+expectedOutput) +} + +func TestCalculateCost_CacheUsageExceedsPromptTokens(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + // More cached tokens than prompt tokens (shouldn't happen, but defensive) + cacheUsage := &CacheUsage{CachedTokens: 5000} + + result := CalculateCost("openai", "gpt-4o", costInfo, 1000, 500, cacheUsage) + + // Cached tokens should be clamped to prompt tokens + if result.CachedTokens != 1000 { + t.Errorf("CachedTokens = %d, want 1000 (clamped)", result.CachedTokens) + } + + // All prompt tokens at cache rate, none at full rate + assertFloat(t, "InputCost", result.InputCost, 0) + expectedCached := 1.5 * 1000 / 1_000_000 + assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) +} + +func TestCalculateCost_NoCacheReadPrice(t *testing.T) { + // Provider doesn't have cache pricing — should fall back to full input rate + costInfo := CostInfo{Input: 3.0, Output: 15.0} + cacheUsage := &CacheUsage{CachedTokens: 800} + + result := CalculateCost("groq", "llama-3.3-70b", costInfo, 1000, 500, cacheUsage) + + if result.CachedTokens != 800 { + t.Errorf("CachedTokens = %d, want 800", result.CachedTokens) + } + + // Cached tokens should fall back to full input rate + expectedInput := 3.0 * 200 / 1_000_000 + expectedCached := 3.0 * 800 / 1_000_000 // same as input rate + expectedOutput := 15.0 * 500 / 1_000_000 + + assertFloat(t, "InputCost", result.InputCost, expectedInput) + assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) + assertFloat(t, "TotalCost", result.TotalCost, expectedInput+expectedCached+expectedOutput) +} + +func TestCalculateCost_AllTokensCached(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + cacheUsage := &CacheUsage{CachedTokens: 1000} + + result := CalculateCost("openai", "gpt-4o", costInfo, 1000, 500, cacheUsage) + + if result.CachedTokens != 1000 { + t.Errorf("CachedTokens = %d, want 1000", result.CachedTokens) + } + + // All prompt tokens cached — zero non-cached input cost + assertFloat(t, "InputCost", result.InputCost, 0) + expectedCached := 1.5 * 1000 / 1_000_000 + assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) +} + +func TestCalculateCost_ZeroTokens(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + result := CalculateCost("openai", "gpt-4o", costInfo, 0, 0, nil) + + assertFloat(t, "InputCost", result.InputCost, 0) + assertFloat(t, "CachedInputCost", result.CachedInputCost, 0) + assertFloat(t, "OutputCost", result.OutputCost, 0) + assertFloat(t, "TotalCost", result.TotalCost, 0) +} + +func TestCalculateCost_MixedProviderCacheFields(t *testing.T) { + // Both CachedTokens and CacheReadInputTokens set (shouldn't happen, but test summing) + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + cacheUsage := &CacheUsage{ + CachedTokens: 300, + CacheReadInputTokens: 200, + } + + result := CalculateCost("test", "model", costInfo, 1000, 100, cacheUsage) + + // Should sum both fields: 300 + 200 = 500 + if result.CachedTokens != 500 { + t.Errorf("CachedTokens = %d, want 500", result.CachedTokens) + } + + expectedInput := 3.0 * 500 / 1_000_000 + expectedCached := 1.5 * 500 / 1_000_000 + assertFloat(t, "InputCost", result.InputCost, expectedInput) + assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) +} diff --git a/interceptors/billing_test.go b/interceptors/billing_test.go index 2f3e174..c8f6470 100644 --- a/interceptors/billing_test.go +++ b/interceptors/billing_test.go @@ -119,6 +119,156 @@ func TestBillingInterceptor_ErrorPassthrough(t *testing.T) { } } +func TestBillingInterceptor_WithCacheUsage(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + var result llmproxy.BillingResult + lookup := func(provider, model string) (llmproxy.CostInfo, bool) { + return llmproxy.CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5}, true + } + + billing := NewBilling(lookup, func(r llmproxy.BillingResult) { + result = r + }) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + meta := llmproxy.BodyMetadata{Model: "gpt-4o"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + respMeta := llmproxy.ResponseMetadata{ + Usage: llmproxy.Usage{PromptTokens: 2000, CompletionTokens: 100, TotalTokens: 2100}, + Custom: map[string]any{"cache_usage": llmproxy.CacheUsage{CachedTokens: 1920}}, + } + return resp, respMeta, nil, nil + } + + _, _, _, err := billing.Intercept(req, meta, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if result.CachedTokens != 1920 { + t.Errorf("CachedTokens = %d, want 1920", result.CachedTokens) + } + + // 80 non-cached at $3/M, 1920 cached at $1.5/M + expectedInput := 3.0 * 80 / 1_000_000 + expectedCached := 1.5 * 1920 / 1_000_000 + expectedOutput := 15.0 * 100 / 1_000_000 + + if math.Abs(result.InputCost-expectedInput) > 1e-9 { + t.Errorf("InputCost = %f, want %f", result.InputCost, expectedInput) + } + if math.Abs(result.CachedInputCost-expectedCached) > 1e-9 { + t.Errorf("CachedInputCost = %f, want %f", result.CachedInputCost, expectedCached) + } + if math.Abs(result.TotalCost-(expectedInput+expectedCached+expectedOutput)) > 1e-9 { + t.Errorf("TotalCost = %f, want %f", result.TotalCost, expectedInput+expectedCached+expectedOutput) + } +} + +func TestBillingInterceptor_CacheUsageZeroTokens(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + var result llmproxy.BillingResult + lookup := func(provider, model string) (llmproxy.CostInfo, bool) { + return llmproxy.CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5}, true + } + + billing := NewBilling(lookup, func(r llmproxy.BillingResult) { + result = r + }) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + meta := llmproxy.BodyMetadata{Model: "gpt-4o"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + // Cache usage present but zero tokens + respMeta := llmproxy.ResponseMetadata{ + Usage: llmproxy.Usage{PromptTokens: 1000, CompletionTokens: 50}, + Custom: map[string]any{"cache_usage": llmproxy.CacheUsage{}}, + } + return resp, respMeta, nil, nil + } + + _, _, _, err := billing.Intercept(req, meta, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if result.CachedTokens != 0 { + t.Errorf("CachedTokens = %d, want 0", result.CachedTokens) + } + if result.CachedInputCost != 0 { + t.Errorf("CachedInputCost = %f, want 0", result.CachedInputCost) + } + // All tokens at full input rate + expectedInput := 3.0 * 1000 / 1_000_000 + if math.Abs(result.InputCost-expectedInput) > 1e-9 { + t.Errorf("InputCost = %f, want %f", result.InputCost, expectedInput) + } +} + +func TestBillingInterceptor_NoCacheUsageInCustom(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + var result llmproxy.BillingResult + lookup := func(provider, model string) (llmproxy.CostInfo, bool) { + return llmproxy.CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5}, true + } + + billing := NewBilling(lookup, func(r llmproxy.BillingResult) { + result = r + }) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + meta := llmproxy.BodyMetadata{Model: "gpt-4o"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + // No Custom map at all + respMeta := llmproxy.ResponseMetadata{ + Usage: llmproxy.Usage{PromptTokens: 1000, CompletionTokens: 50}, + } + return resp, respMeta, nil, nil + } + + _, _, _, err := billing.Intercept(req, meta, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if result.CachedTokens != 0 { + t.Errorf("CachedTokens = %d, want 0", result.CachedTokens) + } + expectedInput := 3.0 * 1000 / 1_000_000 + if math.Abs(result.InputCost-expectedInput) > 1e-9 { + t.Errorf("InputCost = %f, want %f", result.InputCost, expectedInput) + } +} + func TestDetectProvider(t *testing.T) { tests := []struct { model string From b984fbe3de07d6bc9409c9d8f0f2ee43daf9e9e0 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Sun, 12 Apr 2026 23:15:06 -0500 Subject: [PATCH 3/3] docs: clarify mutual-exclusivity assumption and clamping rationale in CalculateCost --- billing.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/billing.go b/billing.go index 4ec52b4..b53ddea 100644 --- a/billing.go +++ b/billing.go @@ -49,13 +49,15 @@ type BillingResult struct { func CalculateCost(provider, model string, costInfo CostInfo, promptTokens, completionTokens int, cacheUsage *CacheUsage) BillingResult { cachedTokens := 0 if cacheUsage != nil { - // Normalize cached token count across providers: - // - OpenAI/Fireworks/Bedrock: CachedTokens - // - Anthropic: CacheReadInputTokens + // Providers populate only one of these fields — OpenAI/Fireworks/Bedrock + // set CachedTokens while Anthropic sets CacheReadInputTokens. We sum them + // so the same code path works for any provider. The clamp below guards + // against overcounting if a future provider ever sets both fields. cachedTokens = cacheUsage.CachedTokens + cacheUsage.CacheReadInputTokens } - // Ensure cached tokens don't exceed prompt tokens + // Clamp to prompt tokens as a safety net — cached tokens can never + // exceed the total prompt tokens the provider reported. if cachedTokens > promptTokens { cachedTokens = promptTokens }