diff --git a/billing.go b/billing.go index d20888d..b53ddea 100644 --- a/billing.go +++ b/billing.go @@ -29,19 +29,54 @@ type BillingResult struct { PromptTokens int // CompletionTokens is the number of output tokens. CompletionTokens int + // CachedTokens is the number of prompt tokens served from cache. + CachedTokens int // TotalTokens is the sum of prompt and completion tokens. TotalTokens int - // InputCost is the calculated input cost in USD. + // InputCost is the calculated input cost in USD (non-cached prompt tokens). InputCost float64 + // CachedInputCost is the cost for cached prompt tokens in USD. + CachedInputCost float64 // OutputCost is the calculated output cost in USD. OutputCost float64 - // TotalCost is the sum of input and output cost in USD. + // TotalCost is the sum of all costs in USD. TotalCost float64 } -// CalculateCost computes the billing result from cost info and token usage. -func CalculateCost(provider, model string, costInfo CostInfo, promptTokens, completionTokens int) BillingResult { - inputCost := costInfo.Input * float64(promptTokens) / 1_000_000 +// CalculateCost computes the billing result from cost info, token usage, and cache usage. +// Cached tokens are billed at the CacheRead rate (if available), and non-cached prompt +// tokens are billed at the full Input rate. +func CalculateCost(provider, model string, costInfo CostInfo, promptTokens, completionTokens int, cacheUsage *CacheUsage) BillingResult { + cachedTokens := 0 + if cacheUsage != nil { + // Providers populate only one of these fields — OpenAI/Fireworks/Bedrock + // set CachedTokens while Anthropic sets CacheReadInputTokens. We sum them + // so the same code path works for any provider. The clamp below guards + // against overcounting if a future provider ever sets both fields. + cachedTokens = cacheUsage.CachedTokens + cacheUsage.CacheReadInputTokens + } + + // Clamp to prompt tokens as a safety net — cached tokens can never + // exceed the total prompt tokens the provider reported. + if cachedTokens > promptTokens { + cachedTokens = promptTokens + } + + nonCachedTokens := promptTokens - cachedTokens + + // Non-cached prompt tokens at full input rate + inputCost := costInfo.Input * float64(nonCachedTokens) / 1_000_000 + + // Cached tokens at cache read rate (falls back to full input rate if no cache pricing) + var cachedInputCost float64 + if cachedTokens > 0 { + cacheRate := costInfo.CacheRead + if cacheRate <= 0 { + cacheRate = costInfo.Input // fallback to full rate + } + cachedInputCost = cacheRate * float64(cachedTokens) / 1_000_000 + } + outputCost := costInfo.Output * float64(completionTokens) / 1_000_000 return BillingResult{ @@ -49,9 +84,11 @@ func CalculateCost(provider, model string, costInfo CostInfo, promptTokens, comp Model: model, PromptTokens: promptTokens, CompletionTokens: completionTokens, + CachedTokens: cachedTokens, TotalTokens: promptTokens + completionTokens, InputCost: inputCost, + CachedInputCost: cachedInputCost, OutputCost: outputCost, - TotalCost: inputCost + outputCost, + TotalCost: inputCost + cachedInputCost + outputCost, } } diff --git a/billing_test.go b/billing_test.go new file mode 100644 index 0000000..de448d2 --- /dev/null +++ b/billing_test.go @@ -0,0 +1,192 @@ +package llmproxy + +import ( + "math" + "testing" +) + +const epsilon = 1e-9 + +func assertFloat(t *testing.T, name string, got, want float64) { + t.Helper() + if math.Abs(got-want) > epsilon { + t.Errorf("%s = %f, want %f (diff: %e)", name, got, want, math.Abs(got-want)) + } +} + +func TestCalculateCost_NoCacheUsage(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + result := CalculateCost("openai", "gpt-4o", costInfo, 1000, 500, nil) + + if result.Provider != "openai" { + t.Errorf("Provider = %q, want %q", result.Provider, "openai") + } + if result.Model != "gpt-4o" { + t.Errorf("Model = %q, want %q", result.Model, "gpt-4o") + } + if result.PromptTokens != 1000 { + t.Errorf("PromptTokens = %d, want 1000", result.PromptTokens) + } + if result.CompletionTokens != 500 { + t.Errorf("CompletionTokens = %d, want 500", result.CompletionTokens) + } + if result.CachedTokens != 0 { + t.Errorf("CachedTokens = %d, want 0", result.CachedTokens) + } + if result.TotalTokens != 1500 { + t.Errorf("TotalTokens = %d, want 1500", result.TotalTokens) + } + + expectedInput := 3.0 * 1000 / 1_000_000 + expectedOutput := 15.0 * 500 / 1_000_000 + assertFloat(t, "InputCost", result.InputCost, expectedInput) + assertFloat(t, "CachedInputCost", result.CachedInputCost, 0) + assertFloat(t, "OutputCost", result.OutputCost, expectedOutput) + assertFloat(t, "TotalCost", result.TotalCost, expectedInput+expectedOutput) +} + +func TestCalculateCost_WithOpenAICacheHit(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + cacheUsage := &CacheUsage{CachedTokens: 800} + + result := CalculateCost("openai", "gpt-4o", costInfo, 1000, 500, cacheUsage) + + if result.CachedTokens != 800 { + t.Errorf("CachedTokens = %d, want 800", result.CachedTokens) + } + + // 200 non-cached at full rate, 800 cached at cache rate + expectedInput := 3.0 * 200 / 1_000_000 + expectedCached := 1.5 * 800 / 1_000_000 + expectedOutput := 15.0 * 500 / 1_000_000 + + assertFloat(t, "InputCost", result.InputCost, expectedInput) + assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) + assertFloat(t, "OutputCost", result.OutputCost, expectedOutput) + assertFloat(t, "TotalCost", result.TotalCost, expectedInput+expectedCached+expectedOutput) +} + +func TestCalculateCost_WithAnthropicCacheHit(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 0.3} + cacheUsage := &CacheUsage{CacheReadInputTokens: 2000} + + result := CalculateCost("anthropic", "claude-sonnet-4", costInfo, 2500, 100, cacheUsage) + + if result.CachedTokens != 2000 { + t.Errorf("CachedTokens = %d, want 2000", result.CachedTokens) + } + + // 500 non-cached at full rate, 2000 cached at cache rate + expectedInput := 3.0 * 500 / 1_000_000 + expectedCached := 0.3 * 2000 / 1_000_000 + expectedOutput := 15.0 * 100 / 1_000_000 + + assertFloat(t, "InputCost", result.InputCost, expectedInput) + assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) + assertFloat(t, "TotalCost", result.TotalCost, expectedInput+expectedCached+expectedOutput) +} + +func TestCalculateCost_CacheUsageWithZeroTokens(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + // CacheUsage present but all fields are zero + cacheUsage := &CacheUsage{} + + result := CalculateCost("openai", "gpt-4o", costInfo, 1000, 500, cacheUsage) + + // Should behave exactly like no cache usage + if result.CachedTokens != 0 { + t.Errorf("CachedTokens = %d, want 0", result.CachedTokens) + } + assertFloat(t, "CachedInputCost", result.CachedInputCost, 0) + + expectedInput := 3.0 * 1000 / 1_000_000 + expectedOutput := 15.0 * 500 / 1_000_000 + assertFloat(t, "InputCost", result.InputCost, expectedInput) + assertFloat(t, "TotalCost", result.TotalCost, expectedInput+expectedOutput) +} + +func TestCalculateCost_CacheUsageExceedsPromptTokens(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + // More cached tokens than prompt tokens (shouldn't happen, but defensive) + cacheUsage := &CacheUsage{CachedTokens: 5000} + + result := CalculateCost("openai", "gpt-4o", costInfo, 1000, 500, cacheUsage) + + // Cached tokens should be clamped to prompt tokens + if result.CachedTokens != 1000 { + t.Errorf("CachedTokens = %d, want 1000 (clamped)", result.CachedTokens) + } + + // All prompt tokens at cache rate, none at full rate + assertFloat(t, "InputCost", result.InputCost, 0) + expectedCached := 1.5 * 1000 / 1_000_000 + assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) +} + +func TestCalculateCost_NoCacheReadPrice(t *testing.T) { + // Provider doesn't have cache pricing — should fall back to full input rate + costInfo := CostInfo{Input: 3.0, Output: 15.0} + cacheUsage := &CacheUsage{CachedTokens: 800} + + result := CalculateCost("groq", "llama-3.3-70b", costInfo, 1000, 500, cacheUsage) + + if result.CachedTokens != 800 { + t.Errorf("CachedTokens = %d, want 800", result.CachedTokens) + } + + // Cached tokens should fall back to full input rate + expectedInput := 3.0 * 200 / 1_000_000 + expectedCached := 3.0 * 800 / 1_000_000 // same as input rate + expectedOutput := 15.0 * 500 / 1_000_000 + + assertFloat(t, "InputCost", result.InputCost, expectedInput) + assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) + assertFloat(t, "TotalCost", result.TotalCost, expectedInput+expectedCached+expectedOutput) +} + +func TestCalculateCost_AllTokensCached(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + cacheUsage := &CacheUsage{CachedTokens: 1000} + + result := CalculateCost("openai", "gpt-4o", costInfo, 1000, 500, cacheUsage) + + if result.CachedTokens != 1000 { + t.Errorf("CachedTokens = %d, want 1000", result.CachedTokens) + } + + // All prompt tokens cached — zero non-cached input cost + assertFloat(t, "InputCost", result.InputCost, 0) + expectedCached := 1.5 * 1000 / 1_000_000 + assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) +} + +func TestCalculateCost_ZeroTokens(t *testing.T) { + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + result := CalculateCost("openai", "gpt-4o", costInfo, 0, 0, nil) + + assertFloat(t, "InputCost", result.InputCost, 0) + assertFloat(t, "CachedInputCost", result.CachedInputCost, 0) + assertFloat(t, "OutputCost", result.OutputCost, 0) + assertFloat(t, "TotalCost", result.TotalCost, 0) +} + +func TestCalculateCost_MixedProviderCacheFields(t *testing.T) { + // Both CachedTokens and CacheReadInputTokens set (shouldn't happen, but test summing) + costInfo := CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5} + cacheUsage := &CacheUsage{ + CachedTokens: 300, + CacheReadInputTokens: 200, + } + + result := CalculateCost("test", "model", costInfo, 1000, 100, cacheUsage) + + // Should sum both fields: 300 + 200 = 500 + if result.CachedTokens != 500 { + t.Errorf("CachedTokens = %d, want 500", result.CachedTokens) + } + + expectedInput := 3.0 * 500 / 1_000_000 + expectedCached := 1.5 * 500 / 1_000_000 + assertFloat(t, "InputCost", result.InputCost, expectedInput) + assertFloat(t, "CachedInputCost", result.CachedInputCost, expectedCached) +} diff --git a/interceptors/billing.go b/interceptors/billing.go index ab93472..56d8419 100644 --- a/interceptors/billing.go +++ b/interceptors/billing.go @@ -36,7 +36,14 @@ func (i *BillingInterceptor) Intercept(req *http.Request, meta llmproxy.BodyMeta } if found && i.OnResult != nil { - result := llmproxy.CalculateCost(provider, meta.Model, costInfo, respMeta.Usage.PromptTokens, respMeta.Usage.CompletionTokens) + // Extract cache usage from response metadata if available + var cacheUsage *llmproxy.CacheUsage + if cu, ok := respMeta.Custom["cache_usage"]; ok { + if usage, ok := cu.(llmproxy.CacheUsage); ok { + cacheUsage = &usage + } + } + result := llmproxy.CalculateCost(provider, meta.Model, costInfo, respMeta.Usage.PromptTokens, respMeta.Usage.CompletionTokens, cacheUsage) i.OnResult(result) } diff --git a/interceptors/billing_test.go b/interceptors/billing_test.go index 2f3e174..c8f6470 100644 --- a/interceptors/billing_test.go +++ b/interceptors/billing_test.go @@ -119,6 +119,156 @@ func TestBillingInterceptor_ErrorPassthrough(t *testing.T) { } } +func TestBillingInterceptor_WithCacheUsage(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + var result llmproxy.BillingResult + lookup := func(provider, model string) (llmproxy.CostInfo, bool) { + return llmproxy.CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5}, true + } + + billing := NewBilling(lookup, func(r llmproxy.BillingResult) { + result = r + }) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + meta := llmproxy.BodyMetadata{Model: "gpt-4o"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + respMeta := llmproxy.ResponseMetadata{ + Usage: llmproxy.Usage{PromptTokens: 2000, CompletionTokens: 100, TotalTokens: 2100}, + Custom: map[string]any{"cache_usage": llmproxy.CacheUsage{CachedTokens: 1920}}, + } + return resp, respMeta, nil, nil + } + + _, _, _, err := billing.Intercept(req, meta, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if result.CachedTokens != 1920 { + t.Errorf("CachedTokens = %d, want 1920", result.CachedTokens) + } + + // 80 non-cached at $3/M, 1920 cached at $1.5/M + expectedInput := 3.0 * 80 / 1_000_000 + expectedCached := 1.5 * 1920 / 1_000_000 + expectedOutput := 15.0 * 100 / 1_000_000 + + if math.Abs(result.InputCost-expectedInput) > 1e-9 { + t.Errorf("InputCost = %f, want %f", result.InputCost, expectedInput) + } + if math.Abs(result.CachedInputCost-expectedCached) > 1e-9 { + t.Errorf("CachedInputCost = %f, want %f", result.CachedInputCost, expectedCached) + } + if math.Abs(result.TotalCost-(expectedInput+expectedCached+expectedOutput)) > 1e-9 { + t.Errorf("TotalCost = %f, want %f", result.TotalCost, expectedInput+expectedCached+expectedOutput) + } +} + +func TestBillingInterceptor_CacheUsageZeroTokens(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + var result llmproxy.BillingResult + lookup := func(provider, model string) (llmproxy.CostInfo, bool) { + return llmproxy.CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5}, true + } + + billing := NewBilling(lookup, func(r llmproxy.BillingResult) { + result = r + }) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + meta := llmproxy.BodyMetadata{Model: "gpt-4o"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + // Cache usage present but zero tokens + respMeta := llmproxy.ResponseMetadata{ + Usage: llmproxy.Usage{PromptTokens: 1000, CompletionTokens: 50}, + Custom: map[string]any{"cache_usage": llmproxy.CacheUsage{}}, + } + return resp, respMeta, nil, nil + } + + _, _, _, err := billing.Intercept(req, meta, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if result.CachedTokens != 0 { + t.Errorf("CachedTokens = %d, want 0", result.CachedTokens) + } + if result.CachedInputCost != 0 { + t.Errorf("CachedInputCost = %f, want 0", result.CachedInputCost) + } + // All tokens at full input rate + expectedInput := 3.0 * 1000 / 1_000_000 + if math.Abs(result.InputCost-expectedInput) > 1e-9 { + t.Errorf("InputCost = %f, want %f", result.InputCost, expectedInput) + } +} + +func TestBillingInterceptor_NoCacheUsageInCustom(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + var result llmproxy.BillingResult + lookup := func(provider, model string) (llmproxy.CostInfo, bool) { + return llmproxy.CostInfo{Input: 3.0, Output: 15.0, CacheRead: 1.5}, true + } + + billing := NewBilling(lookup, func(r llmproxy.BillingResult) { + result = r + }) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + meta := llmproxy.BodyMetadata{Model: "gpt-4o"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + // No Custom map at all + respMeta := llmproxy.ResponseMetadata{ + Usage: llmproxy.Usage{PromptTokens: 1000, CompletionTokens: 50}, + } + return resp, respMeta, nil, nil + } + + _, _, _, err := billing.Intercept(req, meta, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if result.CachedTokens != 0 { + t.Errorf("CachedTokens = %d, want 0", result.CachedTokens) + } + expectedInput := 3.0 * 1000 / 1_000_000 + if math.Abs(result.InputCost-expectedInput) > 1e-9 { + t.Errorf("InputCost = %f, want %f", result.InputCost, expectedInput) + } +} + func TestDetectProvider(t *testing.T) { tests := []struct { model string