diff --git a/DESIGN.md b/DESIGN.md index 6d5ef96..69b6759 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -119,6 +119,25 @@ The main entry point. `Forward(ctx, req)` orchestrates the full request lifecycl | completionTokens | int | Output tokens generated | | totalTokens | int | Sum of prompt + completion | +### CacheUsage + +| Field | Type | Description | +|-------|------|-------------| +| cachedTokens | int | Tokens served from cache (OpenAI, Azure) | +| cacheCreationInputTokens | int | Tokens written to cache (Anthropic) | +| cacheReadInputTokens | int | Tokens read from cache (Anthropic) | +| ephemeral5mInputTokens | int | 5-minute cache write tokens (Anthropic) | +| ephemeral1hInputTokens | int | 1-hour cache write tokens (Anthropic) | +| cacheWriteTokens | int | Tokens written to cache (Bedrock) | +| cacheDetails | []CacheDetail | TTL-based cache write breakdown (Bedrock) | + +### CacheDetail + +| Field | Type | Description | +|-------|------|-------------| +| ttl | string | Time-to-live for cache entry (e.g., "5m", "1h") | +| cacheWriteTokens | int | Tokens written to cache at this TTL | + ### Choice | Field | Type | Description | @@ -276,7 +295,7 @@ URL format: `https://{resource}.openai.azure.com/openai/deployments/{deployment} ## Interceptors -Seven built-in interceptors are provided in the `interceptors/` package. +Eight built-in interceptors are provided in the `interceptors/` package. ### Logging @@ -415,6 +434,220 @@ add := interceptors.NewAddResponseHeader( proxy := llmproxy.NewProxy(provider, llmproxy.WithInterceptor(add)) ``` +### PromptCaching + +Provider-specific prompt caching interceptors for Anthropic, OpenAI, xAI, Fireworks, and AWS Bedrock. + +#### Common Behavior + +- **Cache-Control header:** If the incoming request has `Cache-Control: no-cache`, the interceptor skips entirely — letting clients disable caching per-request +- **Provider detection:** Only applies to matching models: + - Anthropic: `claude-*` + - OpenAI: `gpt-*`, `o1-*`, `o3-*`, `o4-*`, `chatgpt-*` + - xAI: `grok-*` + - Fireworks: `accounts/fireworks/*`, `fireworks*` + - Bedrock: `anthropic.claude-*`, `amazon.nova-*`, `amazon.titan-*` +- **Cache usage tracking:** Response metadata includes `CacheUsage` in `Custom["cache_usage"]` + +#### Anthropic + +`NewAnthropicPromptCaching(retention)` — Enables Anthropic prompt caching: + +- **Automatic caching:** Adds `cache_control` at the top level of requests +- **Retention options:** + - `CacheRetentionDefault` (default, 5 min) — no TTL field, free, auto-refreshed on use + - `CacheRetention1h` — adds `ttl: "1h"`, costs more, longer cache lifetime +- **User-controlled caching:** If request already has `cache_control`, the interceptor skips entirely — letting you control caching explicitly via block-level breakpoints + +Example: + +```go +// Enable prompt caching for Anthropic (default 5 min, free) +caching := interceptors.NewAnthropicPromptCaching(interceptors.CacheRetentionDefault) +proxy := llmproxy.NewProxy(provider, llmproxy.WithInterceptor(caching)) + +// With 1h retention (costs more) and cache usage callback +caching := interceptors.NewAnthropicPromptCachingWithResult(interceptors.CacheRetention1h, func(u llmproxy.CacheUsage) { + log.Printf("Cache read: %d tokens, Cache write: %d tokens", u.CacheReadInputTokens, u.CacheCreationInputTokens) +}) +``` + +#### OpenAI + +`NewOpenAIPromptCaching(retention, cacheKey)` — Enables OpenAI prompt caching: + +- **Automatic caching:** OpenAI caches prompts ≥ 1024 tokens automatically +- **Cache routing:** Adds `prompt_cache_key` to improve cache hit rates for requests with common prefixes +- **Retention options:** + - `CacheRetentionDefault` (default, in-memory, 5-10 min) — no retention field + - `CacheRetention24h` — adds `prompt_cache_retention: "24h"` for GPT-5.x and GPT-4.1 +- **Cache key sources (in priority order):** + 1. `X-Cache-Key` header from incoming request + 2. Configured `CacheKey` in PromptCachingConfig + 3. Auto-derived from static content prefix via `DeriveCacheKeyFromPrefix()` +- **Tenant namespacing:** Cache keys are automatically prefixed with org/tenant ID from: + 1. Custom `OrgIDExtractor` function + 2. `OrgID` in `MetaContextValue` stored in request context + 3. `X-Org-ID` header + 4. `org_id` in `BodyMetadata.Custom` + 5. Configured `Namespace` fallback + +Example: + +```go +// Enable prompt caching for OpenAI with a cache key (default retention) +caching := interceptors.NewOpenAIPromptCaching(interceptors.CacheRetentionDefault, "my-app-session-123") +proxy := llmproxy.NewProxy(provider, llmproxy.WithInterceptor(caching)) + +// With 24h retention and cache usage callback +caching := interceptors.NewOpenAIPromptCachingWithResult(interceptors.CacheRetention24h, "my-key", func(u llmproxy.CacheUsage) { + log.Printf("Cached tokens: %d", u.CachedTokens) +}) + +// Auto-derive cache key from static content, namespace by tenant +caching := interceptors.NewOpenAIPromptCachingAuto("tenant-123", interceptors.CacheRetentionDefault) + +// Custom org ID extractor (e.g., from auth context) +caching := interceptors.NewOpenAIPromptCachingWithOrgExtractor( + interceptors.CacheRetentionDefault, + "my-key", + func(ctx context.Context, req *http.Request, meta llmproxy.BodyMetadata) string { + return getOrgFromAuthContext(ctx) + }, +) +``` + +#### xAI (Grok) + +`NewXAIPromptCaching(convID)` — Enables xAI/Grok prompt caching: + +- **Automatic prefix caching:** xAI caches from the start of the messages array automatically +- **Cache routing:** Adds `x-grok-conv-id` HTTP header to route requests to the same server where cache lives +- **Conversation ID:** Use a stable value (conversation ID, session ID, or deterministic hash of static content) +- **Key rule:** Never reorder or modify earlier messages — only append + +Example: + +```go +// Enable prompt caching for xAI with a conversation ID +caching := interceptors.NewXAIPromptCaching("conv-abc123-tenant456") +proxy := llmproxy.NewProxy(provider, llmproxy.WithInterceptor(caching)) + +// With cache usage callback +caching := interceptors.NewXAIPromptCachingWithResult("my-conv-id", func(u llmproxy.CacheUsage) { + log.Printf("Cached tokens: %d", u.CachedTokens) +}) +``` + +#### Fireworks + +`NewFireworksPromptCaching(sessionID)` — Enables Fireworks prompt caching: + +- **Automatic caching:** Fireworks caches prompts with shared prefixes automatically (enabled by default) +- **Cache routing:** Adds `x-session-affinity` HTTP header to route requests to the same replica +- **Tenant isolation:** Adds `x-prompt-cache-isolation-key` header set to org/tenant ID for multi-tenant isolation +- **Cache usage:** Reads `fireworks-cached-prompt-tokens` response header for cache hit tracking + +Example: + +```go +// Enable prompt caching for Fireworks with session affinity +caching := interceptors.NewFireworksPromptCaching("session-abc123") +proxy := llmproxy.NewProxy(provider, llmproxy.WithInterceptor(caching)) + +// With org ID extractor for tenant isolation +caching := interceptors.NewFireworksPromptCachingWithOrgExtractor("session-abc123", func(ctx context.Context, req *http.Request, meta llmproxy.BodyMetadata) string { + return getOrgFromAuthContext(ctx) +}) + +// With cache usage callback +caching := interceptors.NewFireworksPromptCachingWithResult("session-abc123", func(u llmproxy.CacheUsage) { + log.Printf("Cached tokens: %d", u.CachedTokens) +}) +``` + +#### AWS Bedrock + +`NewBedrockPromptCaching(retention)` — Enables AWS Bedrock prompt caching via the Converse API: + +- **Cache checkpoints:** Adds `cachePoint` objects to system, messages, and toolConfig +- **Retention options:** + - `CacheRetentionDefault` (default, 5 min) — no TTL field + - `CacheRetention1h` — adds `ttl: "1h"` for Claude Opus 4.5, Haiku 4.5, and Sonnet 4.5 +- **Minimum tokens:** 1,024 tokens per cache checkpoint (varies by model) +- **Maximum checkpoints:** 4 per request +- **Supported models:** Claude models (anthropic.claude-*), Nova models (amazon.nova-*), Titan models (amazon.titan-*) +- **Cache usage:** Reads `cacheReadInputTokens`, `cacheWriteInputTokens`, and `cacheDetails` from response + +Example: + +```go +// Enable prompt caching for Bedrock (default 5 min) +caching := interceptors.NewBedrockPromptCaching(interceptors.CacheRetentionDefault) +proxy := llmproxy.NewProxy(bedrockProvider, llmproxy.WithInterceptor(caching)) + +// With 1h retention for Claude Opus 4.5 +caching := interceptors.NewBedrockPromptCaching(interceptors.CacheRetention1h) + +// With cache usage callback +caching := interceptors.NewBedrockPromptCachingWithResult(interceptors.CacheRetentionDefault, func(u llmproxy.CacheUsage) { + log.Printf("Cache read: %d tokens, Cache write: %d tokens", u.CachedTokens, u.CacheWriteTokens) + for _, detail := range u.CacheDetails { + log.Printf(" TTL %s: %d tokens written", detail.TTL, detail.CacheWriteTokens) + } +}) +``` + +#### Azure OpenAI + +Azure OpenAI uses the same `prompt_cache_key` body parameter as OpenAI. **Use the OpenAI interceptor** for Azure OpenAI: + +```go +// Azure OpenAI prompt caching uses the OpenAI interceptor +caching := interceptors.NewOpenAIPromptCaching(interceptors.CacheRetentionDefault, "my-cache-key") +proxy := llmproxy.NewProxy(azureProvider, llmproxy.WithInterceptor(caching)) +``` + +**Note:** Azure OpenAI caches prompts ≥ 1,024 tokens automatically. The `prompt_cache_key` parameter is combined with the prefix hash to improve cache hit rates. Cache hits appear as `cached_tokens` in `prompt_tokens_details` in the response. + +#### Generic constructor + +`NewPromptCaching(provider, config)` — Creates a caching interceptor for any provider: + +```go +// Anthropic with 1h retention +caching := interceptors.NewPromptCaching("anthropic", interceptors.PromptCachingConfig{ + Enabled: true, + Retention: interceptors.CacheRetention1h, +}) + +// OpenAI with 24h retention +caching := interceptors.NewPromptCaching("openai", interceptors.PromptCachingConfig{ + Enabled: true, + Retention: interceptors.CacheRetention24h, + CacheKey: "my-cache-key", +}) + +// xAI with conversation ID +caching := interceptors.NewPromptCaching("xai", interceptors.PromptCachingConfig{ + Enabled: true, + CacheKey: "my-conv-id", +}) + +// Fireworks with session ID and org extractor +caching := interceptors.NewPromptCaching("fireworks", interceptors.PromptCachingConfig{ + Enabled: true, + CacheKey: "my-session-id", + OrgIDExtractor: interceptors.DefaultOrgIDExtractor, +}) + +// Bedrock with 1h retention +caching := interceptors.NewPromptCaching("bedrock", interceptors.PromptCachingConfig{ + Enabled: true, + Retention: interceptors.CacheRetention1h, +}) +``` + --- ## Pricing System @@ -503,6 +736,7 @@ llmproxy/ │ ├── headerban.go # HeaderBanInterceptor │ ├── logging.go # LoggingInterceptor │ ├── metrics.go # MetricsInterceptor, Metrics +│ ├── promptcaching.go # PromptCachingInterceptor │ ├── retry.go # RetryInterceptor │ └── tracing.go # TracingInterceptor ├── pricing/ diff --git a/README.md b/README.md index 322e133..c4d869e 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,9 @@ func main() { ## Features - **9 Provider Implementations**: OpenAI, Anthropic, Groq, Fireworks, x.AI, Google AI, AWS Bedrock, Azure OpenAI, OpenAI-compatible base -- **7 Built-in Interceptors**: Logging, Metrics, Retry, Billing, Tracing (OTel), HeaderBan, AddHeader +- **8 Built-in Interceptors**: Logging, Metrics, Retry, Billing, Tracing (OTel), HeaderBan, AddHeader, PromptCaching - **Pricing Integration**: models.dev adapter with markup support +- **Prompt Caching**: prompt caching support for Anthropic, OpenAI, xAI, Fireworks, and Bedrock - **Raw Body Preservation**: Custom JSON fields pass through unchanged ## Providers @@ -100,6 +101,24 @@ llmproxy.WithInterceptor(interceptors.NewResponseHeaderBan("Openai-Organization" llmproxy.WithInterceptor(interceptors.NewAddResponseHeader( interceptors.NewHeader("X-Gateway", "llmproxy"), )) + +// Anthropic prompt caching (default 5 min, free) +llmproxy.WithInterceptor(interceptors.NewAnthropicPromptCaching(interceptors.CacheRetentionDefault)) + +// Anthropic prompt caching with 1h retention (costs more) +llmproxy.WithInterceptor(interceptors.NewAnthropicPromptCaching(interceptors.CacheRetention1h)) + +// OpenAI prompt caching with explicit cache key +llmproxy.WithInterceptor(interceptors.NewOpenAIPromptCaching(interceptors.CacheRetention24h, "my-cache-key")) + +// OpenAI prompt caching with auto-derived key and tenant namespace +llmproxy.WithInterceptor(interceptors.NewOpenAIPromptCachingAuto("tenant-123", interceptors.CacheRetentionDefault)) + +// xAI/Grok prompt caching (uses x-grok-conv-id header) +llmproxy.WithInterceptor(interceptors.NewXAIPromptCaching("conv-abc123")) + +// Fireworks prompt caching (uses x-session-affinity and x-prompt-cache-isolation-key headers) +llmproxy.WithInterceptor(interceptors.NewFireworksPromptCaching("session-123")) ``` ## Architecture diff --git a/interceptor.go b/interceptor.go index 0d6aa94..cdf2844 100644 --- a/interceptor.go +++ b/interceptor.go @@ -70,6 +70,7 @@ type MetaContextKey struct{} type MetaContextValue struct { Meta BodyMetadata RawBody []byte + OrgID string } // GetMetaFromContext retrieves the metadata stored in a context. diff --git a/interceptors/promptcaching.go b/interceptors/promptcaching.go new file mode 100644 index 0000000..3b7be27 --- /dev/null +++ b/interceptors/promptcaching.go @@ -0,0 +1,921 @@ +package interceptors + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "io" + "net/http" + "strings" + + "github.com/agentuity/llmproxy" +) + +type CacheRetention string + +const ( + CacheRetentionDefault CacheRetention = "" + CacheRetention1h CacheRetention = "1h" + CacheRetention24h CacheRetention = "24h" +) + +const ( + HeaderCacheKey = "X-Cache-Key" + HeaderOrgID = "X-Org-ID" + HeaderFireworksSessionAffinity = "X-Session-Affinity" + HeaderFireworksPromptCacheIsolation = "X-Prompt-Cache-Isolation-Key" +) + +type CacheKeyFunc func(meta llmproxy.BodyMetadata, rawBody []byte) string + +type CacheKeyExtractor func(ctx context.Context, req *http.Request, meta llmproxy.BodyMetadata, rawBody []byte) string + +type OrgIDExtractor func(ctx context.Context, req *http.Request, meta llmproxy.BodyMetadata) string + +type PromptCachingConfig struct { + Enabled bool + Retention CacheRetention + CacheKey string + Namespace string + CacheKeyFn CacheKeyFunc + CacheKeyExtractor CacheKeyExtractor + OrgIDExtractor OrgIDExtractor +} + +type PromptCachingInterceptor struct { + provider string + config PromptCachingConfig + onResult func(llmproxy.CacheUsage) +} + +func (i *PromptCachingInterceptor) Intercept(req *http.Request, meta llmproxy.BodyMetadata, rawBody []byte, next llmproxy.RoundTripFunc) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + if !i.config.Enabled { + return next(req) + } + + if cacheControl := req.Header.Get("Cache-Control"); strings.Contains(cacheControl, "no-cache") { + return next(req) + } + + if i.provider != "" { + modelLower := strings.ToLower(meta.Model) + shouldApply := false + switch i.provider { + case "anthropic": + shouldApply = strings.Contains(modelLower, "claude") + case "openai": + shouldApply = isOpenAIModel(modelLower) + case "xai": + shouldApply = isXAIModel(modelLower) + case "fireworks": + shouldApply = isFireworksModel(modelLower) + case "bedrock": + shouldApply = isBedrockModel(modelLower) + default: + shouldApply = strings.Contains(modelLower, i.provider) + } + if !shouldApply { + return next(req) + } + } + + if i.provider == "xai" { + return i.interceptXAI(req, meta, rawBody, next) + } + + if i.provider == "fireworks" { + return i.interceptFireworks(req, meta, rawBody, next) + } + + if i.provider == "bedrock" { + return i.interceptBedrock(req, meta, rawBody, next) + } + + modifiedBody, shouldSkip := i.checkSkipOrModify(req, meta, rawBody) + if shouldSkip { + return next(req) + } + + if req.Body != nil { + req.Body.Close() + } + req = cloneRequestWithBody(req, modifiedBody) + + resp, respMeta, rawRespBody, err := next(req) + if err != nil { + return resp, respMeta, rawRespBody, err + } + + if i.onResult != nil { + if cacheUsage, ok := respMeta.Custom["cache_usage"].(llmproxy.CacheUsage); ok { + i.onResult(cacheUsage) + } + } + + return resp, respMeta, rawRespBody, err +} + +func (i *PromptCachingInterceptor) interceptXAI(req *http.Request, meta llmproxy.BodyMetadata, rawBody []byte, next llmproxy.RoundTripFunc) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + if req.Header.Get("x-grok-conv-id") != "" { + return next(req) + } + + cacheKey := i.resolveDynamicCacheKey(req, meta, rawBody) + if cacheKey == "" { + return next(req) + } + + req.Header.Set("x-grok-conv-id", cacheKey) + + resp, respMeta, rawRespBody, err := next(req) + if err != nil { + return resp, respMeta, rawRespBody, err + } + + if i.onResult != nil { + if cacheUsage, ok := respMeta.Custom["cache_usage"].(llmproxy.CacheUsage); ok { + i.onResult(cacheUsage) + } + } + + return resp, respMeta, rawRespBody, err +} + +func (i *PromptCachingInterceptor) interceptFireworks(req *http.Request, meta llmproxy.BodyMetadata, rawBody []byte, next llmproxy.RoundTripFunc) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + orgID := i.extractOrgID(req, meta) + + if req.Header.Get(HeaderFireworksPromptCacheIsolation) == "" && orgID != "" { + req.Header.Set(HeaderFireworksPromptCacheIsolation, orgID) + } + + if req.Header.Get(HeaderFireworksSessionAffinity) == "" { + if sessionID := i.resolveDynamicCacheKey(req, meta, rawBody); sessionID != "" { + req.Header.Set(HeaderFireworksSessionAffinity, sessionID) + } + } + + resp, respMeta, rawRespBody, err := next(req) + if err != nil { + return resp, respMeta, rawRespBody, err + } + + if cached := resp.Header.Get("fireworks-cached-prompt-tokens"); cached != "" { + if respMeta.Custom == nil { + respMeta.Custom = make(map[string]any) + } + var cachedTokens int + if err := json.Unmarshal([]byte(cached), &cachedTokens); err == nil && cachedTokens > 0 { + respMeta.Custom["cache_usage"] = llmproxy.CacheUsage{ + CachedTokens: cachedTokens, + } + } + } + + if i.onResult != nil { + if cacheUsage, ok := respMeta.Custom["cache_usage"].(llmproxy.CacheUsage); ok { + i.onResult(cacheUsage) + } + } + + return resp, respMeta, rawRespBody, err +} + +func (i *PromptCachingInterceptor) interceptBedrock(req *http.Request, meta llmproxy.BodyMetadata, rawBody []byte, next llmproxy.RoundTripFunc) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + modifiedBody, shouldSkip := i.checkBedrock(rawBody) + if shouldSkip { + return next(req) + } + + if req.Body != nil { + req.Body.Close() + } + req = cloneRequestWithBody(req, modifiedBody) + + resp, respMeta, rawRespBody, err := next(req) + if err != nil { + return resp, respMeta, rawRespBody, err + } + + if i.onResult != nil { + if cacheUsage, ok := respMeta.Custom["cache_usage"].(llmproxy.CacheUsage); ok { + i.onResult(cacheUsage) + } + } + + return resp, respMeta, rawRespBody, err +} + +func (i *PromptCachingInterceptor) checkSkipOrModify(req *http.Request, meta llmproxy.BodyMetadata, rawBody []byte) ([]byte, bool) { + switch i.provider { + case "anthropic": + return i.checkAnthropic(rawBody) + case "openai": + return i.checkOpenAI(req, meta, rawBody) + case "bedrock": + return i.checkBedrock(rawBody) + default: + return rawBody, true + } +} + +func (i *PromptCachingInterceptor) checkAnthropic(rawBody []byte) ([]byte, bool) { + var req map[string]interface{} + if err := json.Unmarshal(rawBody, &req); err != nil { + return rawBody, true + } + + if i.hasExistingCacheControl(req) { + return rawBody, true + } + + modified := false + + if system, exists := req["system"]; exists { + switch s := system.(type) { + case string: + if s != "" { + req["system"] = []interface{}{ + map[string]interface{}{ + "type": "text", + "text": s, + "cache_control": i.buildCacheControl(), + }, + } + modified = true + } + case []interface{}: + if len(s) > 0 { + lastIdx := len(s) - 1 + if block, ok := s[lastIdx].(map[string]interface{}); ok { + if _, hasCC := block["cache_control"]; !hasCC { + block["cache_control"] = i.buildCacheControl() + s[lastIdx] = block + modified = true + } + } + } + } + } + + if messages, exists := req["messages"]; exists { + if msgSlice, ok := messages.([]interface{}); ok && len(msgSlice) > 0 { + for idx := len(msgSlice) - 1; idx >= 0; idx-- { + if msg, ok := msgSlice[idx].(map[string]interface{}); ok { + if role, ok := msg["role"].(string); ok && (role == "user" || role == "assistant") { + if content := msg["content"]; content != nil { + switch c := content.(type) { + case string: + if c != "" { + msg["content"] = []interface{}{ + map[string]interface{}{ + "type": "text", + "text": c, + "cache_control": i.buildCacheControl(), + }, + } + modified = true + } + case []interface{}: + if len(c) > 0 { + lastBlock, ok := c[len(c)-1].(map[string]interface{}) + if ok { + if _, hasCC := lastBlock["cache_control"]; !hasCC { + lastBlock["cache_control"] = i.buildCacheControl() + c[len(c)-1] = lastBlock + modified = true + } + } + } + } + } + break + } + } + } + } + } + + if !modified { + return rawBody, true + } + + result, err := json.Marshal(req) + if err != nil { + return rawBody, true + } + + return result, false +} + +func (i *PromptCachingInterceptor) hasExistingCacheControl(req map[string]interface{}) bool { + if system, exists := req["system"]; exists { + if blocks, ok := system.([]interface{}); ok { + for _, b := range blocks { + if block, ok := b.(map[string]interface{}); ok { + if _, has := block["cache_control"]; has { + return true + } + } + } + } + } + + if messages, exists := req["messages"]; exists { + if msgSlice, ok := messages.([]interface{}); ok { + for _, m := range msgSlice { + if msg, ok := m.(map[string]interface{}); ok { + if content, ok := msg["content"].([]interface{}); ok { + for _, c := range content { + if block, ok := c.(map[string]interface{}); ok { + if _, has := block["cache_control"]; has { + return true + } + } + } + } + } + } + } + } + + return false +} + +func (i *PromptCachingInterceptor) buildCacheControl() map[string]interface{} { + cc := map[string]interface{}{ + "type": "ephemeral", + } + if i.config.Retention == CacheRetention1h { + cc["ttl"] = "1h" + } + return cc +} + +func (i *PromptCachingInterceptor) checkOpenAI(req *http.Request, meta llmproxy.BodyMetadata, rawBody []byte) ([]byte, bool) { + var body map[string]interface{} + if err := json.Unmarshal(rawBody, &body); err != nil { + return rawBody, true + } + + if _, exists := body["prompt_cache_key"]; exists { + return rawBody, true + } + + modified := false + + cacheKey := i.resolveCacheKey(req, meta, rawBody, body) + if cacheKey != "" { + body["prompt_cache_key"] = cacheKey + modified = true + } + + if i.config.Retention != "" { + if _, exists := body["prompt_cache_retention"]; !exists { + body["prompt_cache_retention"] = string(i.config.Retention) + modified = true + } + } + + if !modified { + return rawBody, true + } + + result, err := json.Marshal(body) + if err != nil { + return rawBody, true + } + + return result, false +} + +func (i *PromptCachingInterceptor) checkBedrock(rawBody []byte) ([]byte, bool) { + var req map[string]interface{} + if err := json.Unmarshal(rawBody, &req); err != nil { + return rawBody, true + } + + if i.hasExistingCachePoint(req) { + return rawBody, true + } + + modified := false + + if system, exists := req["system"]; exists { + if sysSlice, ok := system.([]interface{}); ok && len(sysSlice) > 0 { + lastIdx := len(sysSlice) - 1 + if block, ok := sysSlice[lastIdx].(map[string]interface{}); ok { + if _, hasCP := block["cachePoint"]; !hasCP { + sysSlice = append(sysSlice, map[string]interface{}{ + "cachePoint": i.buildCachePoint(), + }) + req["system"] = sysSlice + modified = true + } + } + } + } + + if messages, exists := req["messages"]; exists { + if msgSlice, ok := messages.([]interface{}); ok && len(msgSlice) > 0 { + for idx := len(msgSlice) - 1; idx >= 0; idx-- { + if msg, ok := msgSlice[idx].(map[string]interface{}); ok { + if role, ok := msg["role"].(string); ok && (role == "user" || role == "assistant") { + if content := msg["content"]; content != nil { + if contentSlice, ok := content.([]interface{}); ok && len(contentSlice) > 0 { + lastBlock, ok := contentSlice[len(contentSlice)-1].(map[string]interface{}) + if ok { + if _, hasCP := lastBlock["cachePoint"]; !hasCP { + contentSlice = append(contentSlice, map[string]interface{}{ + "cachePoint": i.buildCachePoint(), + }) + msg["content"] = contentSlice + modified = true + } + } + } + } + break + } + } + } + } + } + + if toolConfig, exists := req["toolConfig"]; exists { + if tc, ok := toolConfig.(map[string]interface{}); ok { + if tools, exists := tc["tools"]; exists { + if toolSlice, ok := tools.([]interface{}); ok && len(toolSlice) > 0 { + lastBlock, ok := toolSlice[len(toolSlice)-1].(map[string]interface{}) + if ok { + if _, hasCP := lastBlock["cachePoint"]; !hasCP { + toolSlice = append(toolSlice, map[string]interface{}{ + "cachePoint": i.buildCachePoint(), + }) + tc["tools"] = toolSlice + modified = true + } + } + } + } + } + } + + if !modified { + return rawBody, true + } + + result, err := json.Marshal(req) + if err != nil { + return rawBody, true + } + + return result, false +} + +func (i *PromptCachingInterceptor) hasExistingCachePoint(req map[string]interface{}) bool { + if system, exists := req["system"]; exists { + if blocks, ok := system.([]interface{}); ok { + for _, b := range blocks { + if block, ok := b.(map[string]interface{}); ok { + if _, has := block["cachePoint"]; has { + return true + } + } + } + } + } + + if messages, exists := req["messages"]; exists { + if msgSlice, ok := messages.([]interface{}); ok { + for _, m := range msgSlice { + if msg, ok := m.(map[string]interface{}); ok { + if content, ok := msg["content"].([]interface{}); ok { + for _, c := range content { + if block, ok := c.(map[string]interface{}); ok { + if _, has := block["cachePoint"]; has { + return true + } + } + } + } + } + } + } + } + + if toolConfig, exists := req["toolConfig"]; exists { + if tc, ok := toolConfig.(map[string]interface{}); ok { + if tools, ok := tc["tools"].([]interface{}); ok { + for _, t := range tools { + if tool, ok := t.(map[string]interface{}); ok { + if _, has := tool["cachePoint"]; has { + return true + } + } + } + } + } + } + + return false +} + +func (i *PromptCachingInterceptor) buildCachePoint() map[string]interface{} { + cp := map[string]interface{}{ + "type": "default", + } + if i.config.Retention == CacheRetention1h { + cp["ttl"] = "1h" + } + return cp +} + +func (i *PromptCachingInterceptor) resolveCacheKey(req *http.Request, meta llmproxy.BodyMetadata, rawBody []byte, body map[string]interface{}) string { + orgID := i.extractOrgID(req, meta) + + if headerKey := req.Header.Get(HeaderCacheKey); headerKey != "" { + return i.buildNamespacedKey(orgID, headerKey) + } + + if i.config.CacheKey != "" { + return i.buildNamespacedKey(orgID, i.config.CacheKey) + } + + if i.config.CacheKeyFn != nil { + derived := i.config.CacheKeyFn(meta, rawBody) + if derived != "" { + return i.buildNamespacedKey(orgID, derived) + } + } + + return "" +} + +func (i *PromptCachingInterceptor) extractOrgID(req *http.Request, meta llmproxy.BodyMetadata) string { + if i.config.OrgIDExtractor != nil { + if orgID := i.config.OrgIDExtractor(req.Context(), req, meta); orgID != "" { + return orgID + } + } + + if metaCtx := llmproxy.GetMetaFromContext(req.Context()); metaCtx.OrgID != "" { + return metaCtx.OrgID + } + + if orgID := req.Header.Get(HeaderOrgID); orgID != "" { + return orgID + } + + if orgID, ok := meta.Custom["org_id"].(string); ok && orgID != "" { + return orgID + } + + return i.config.Namespace +} + +func (i *PromptCachingInterceptor) buildNamespacedKey(orgID, key string) string { + if orgID != "" { + return orgID + ":" + key + } + if i.config.Namespace != "" { + return i.config.Namespace + ":" + key + } + return key +} + +func (i *PromptCachingInterceptor) resolveDynamicCacheKey(req *http.Request, meta llmproxy.BodyMetadata, rawBody []byte) string { + if headerKey := req.Header.Get(HeaderCacheKey); headerKey != "" { + return headerKey + } + + if i.config.CacheKeyExtractor != nil { + if key := i.config.CacheKeyExtractor(req.Context(), req, meta, rawBody); key != "" { + return key + } + } + + if i.config.CacheKeyFn != nil { + if key := i.config.CacheKeyFn(meta, rawBody); key != "" { + return key + } + } + + return i.config.CacheKey +} + +func DeriveCacheKeyFromPrefix(meta llmproxy.BodyMetadata, rawBody []byte) string { + var body struct { + System interface{} `json:"system"` + Messages []struct { + Role string `json:"role"` + Content interface{} `json:"content"` + } `json:"messages"` + Tools interface{} `json:"tools"` + } + json.Unmarshal(rawBody, &body) + + var prefix bytes.Buffer + if body.System != nil { + sysBytes, _ := json.Marshal(body.System) + prefix.Write(sysBytes) + } + if body.Tools != nil { + toolsBytes, _ := json.Marshal(body.Tools) + prefix.Write(toolsBytes) + } + for i, msg := range body.Messages { + if i < len(body.Messages)-1 { + msgBytes, _ := json.Marshal(msg) + prefix.Write(msgBytes) + } + } + + if prefix.Len() == 0 { + return "" + } + + hash := sha256.Sum256(prefix.Bytes()) + return hex.EncodeToString(hash[:16]) +} + +func isOpenAIModel(modelLower string) bool { + return strings.Contains(modelLower, "gpt-") || + strings.Contains(modelLower, "o1-") || + strings.Contains(modelLower, "o3-") || + strings.Contains(modelLower, "o4-") || + strings.Contains(modelLower, "chatgpt") +} + +func isXAIModel(modelLower string) bool { + return strings.Contains(modelLower, "grok") +} + +func isFireworksModel(modelLower string) bool { + return strings.Contains(modelLower, "fireworks") || + strings.Contains(modelLower, "accounts/fireworks") +} + +func isBedrockModel(modelLower string) bool { + return strings.Contains(modelLower, "anthropic.claude") || + strings.Contains(modelLower, "amazon.nova") || + strings.Contains(modelLower, "amazon.titan") +} + +func cloneRequestWithBody(req *http.Request, body []byte) *http.Request { + cloned := req.Clone(req.Context()) + cloned.Body = io.NopCloser(bytes.NewReader(body)) + cloned.ContentLength = int64(len(body)) + return cloned +} + +func NewPromptCaching(provider string, config PromptCachingConfig) *PromptCachingInterceptor { + return &PromptCachingInterceptor{ + provider: provider, + config: config, + } +} + +func NewPromptCachingWithResult(provider string, config PromptCachingConfig, onResult func(llmproxy.CacheUsage)) *PromptCachingInterceptor { + return &PromptCachingInterceptor{ + provider: provider, + config: config, + onResult: onResult, + } +} + +func NewAnthropicPromptCaching(retention CacheRetention) *PromptCachingInterceptor { + return NewPromptCaching("anthropic", PromptCachingConfig{ + Enabled: true, + Retention: retention, + }) +} + +func NewAnthropicPromptCachingWithResult(retention CacheRetention, onResult func(llmproxy.CacheUsage)) *PromptCachingInterceptor { + return NewPromptCachingWithResult("anthropic", PromptCachingConfig{ + Enabled: true, + Retention: retention, + }, onResult) +} + +func NewOpenAIPromptCaching(retention CacheRetention, cacheKey string) *PromptCachingInterceptor { + return NewPromptCaching("openai", PromptCachingConfig{ + Enabled: true, + Retention: retention, + CacheKey: cacheKey, + }) +} + +func NewOpenAIPromptCachingWithResult(retention CacheRetention, cacheKey string, onResult func(llmproxy.CacheUsage)) *PromptCachingInterceptor { + return NewPromptCachingWithResult("openai", PromptCachingConfig{ + Enabled: true, + Retention: retention, + CacheKey: cacheKey, + }, onResult) +} + +func NewOpenAIPromptCachingWithNamespace(namespace string, retention CacheRetention, cacheKey string) *PromptCachingInterceptor { + return NewPromptCaching("openai", PromptCachingConfig{ + Enabled: true, + Retention: retention, + CacheKey: cacheKey, + Namespace: namespace, + }) +} + +func NewOpenAIPromptCachingAuto(namespace string, retention CacheRetention) *PromptCachingInterceptor { + return NewPromptCaching("openai", PromptCachingConfig{ + Enabled: true, + Retention: retention, + Namespace: namespace, + CacheKeyFn: DeriveCacheKeyFromPrefix, + }) +} + +func NewOpenAIPromptCachingAutoWithResult(namespace string, retention CacheRetention, onResult func(llmproxy.CacheUsage)) *PromptCachingInterceptor { + return NewPromptCachingWithResult("openai", PromptCachingConfig{ + Enabled: true, + Retention: retention, + Namespace: namespace, + CacheKeyFn: DeriveCacheKeyFromPrefix, + }, onResult) +} + +func NewXAIPromptCaching(convID string) *PromptCachingInterceptor { + return NewPromptCaching("xai", PromptCachingConfig{ + Enabled: true, + CacheKey: convID, + }) +} + +func NewXAIPromptCachingWithResult(convID string, onResult func(llmproxy.CacheUsage)) *PromptCachingInterceptor { + return NewPromptCachingWithResult("xai", PromptCachingConfig{ + Enabled: true, + CacheKey: convID, + }, onResult) +} + +func NewXAIPromptCachingAuto() *PromptCachingInterceptor { + return NewPromptCaching("xai", PromptCachingConfig{ + Enabled: true, + CacheKeyFn: DeriveCacheKeyFromPrefix, + }) +} + +func NewXAIPromptCachingAutoWithResult(onResult func(llmproxy.CacheUsage)) *PromptCachingInterceptor { + return NewPromptCachingWithResult("xai", PromptCachingConfig{ + Enabled: true, + CacheKeyFn: DeriveCacheKeyFromPrefix, + }, onResult) +} + +func NewXAIPromptCachingWithExtractor(extractor CacheKeyExtractor) *PromptCachingInterceptor { + return NewPromptCaching("xai", PromptCachingConfig{ + Enabled: true, + CacheKeyExtractor: extractor, + }) +} + +func NewXAIPromptCachingWithExtractorAndResult(extractor CacheKeyExtractor, onResult func(llmproxy.CacheUsage)) *PromptCachingInterceptor { + return NewPromptCachingWithResult("xai", PromptCachingConfig{ + Enabled: true, + CacheKeyExtractor: extractor, + }, onResult) +} + +func NewXAIPromptCachingWithTraceID(traceExtractor TraceExtractor) *PromptCachingInterceptor { + return NewPromptCaching("xai", PromptCachingConfig{ + Enabled: true, + CacheKeyExtractor: TraceIDCacheKeyExtractor(traceExtractor), + }) +} + +func NewXAIPromptCachingWithTraceIDAndResult(traceExtractor TraceExtractor, onResult func(llmproxy.CacheUsage)) *PromptCachingInterceptor { + return NewPromptCachingWithResult("xai", PromptCachingConfig{ + Enabled: true, + CacheKeyExtractor: TraceIDCacheKeyExtractor(traceExtractor), + }, onResult) +} + +func NewFireworksPromptCaching(sessionID string) *PromptCachingInterceptor { + return NewPromptCaching("fireworks", PromptCachingConfig{ + Enabled: true, + CacheKey: sessionID, + }) +} + +func NewFireworksPromptCachingWithResult(sessionID string, onResult func(llmproxy.CacheUsage)) *PromptCachingInterceptor { + return NewPromptCachingWithResult("fireworks", PromptCachingConfig{ + Enabled: true, + CacheKey: sessionID, + }, onResult) +} + +func NewFireworksPromptCachingAuto() *PromptCachingInterceptor { + return NewPromptCaching("fireworks", PromptCachingConfig{ + Enabled: true, + CacheKeyFn: DeriveCacheKeyFromPrefix, + }) +} + +func NewFireworksPromptCachingAutoWithResult(onResult func(llmproxy.CacheUsage)) *PromptCachingInterceptor { + return NewPromptCachingWithResult("fireworks", PromptCachingConfig{ + Enabled: true, + CacheKeyFn: DeriveCacheKeyFromPrefix, + }, onResult) +} + +func NewFireworksPromptCachingWithExtractor(extractor CacheKeyExtractor) *PromptCachingInterceptor { + return NewPromptCaching("fireworks", PromptCachingConfig{ + Enabled: true, + CacheKeyExtractor: extractor, + }) +} + +func NewFireworksPromptCachingWithExtractorAndResult(extractor CacheKeyExtractor, onResult func(llmproxy.CacheUsage)) *PromptCachingInterceptor { + return NewPromptCachingWithResult("fireworks", PromptCachingConfig{ + Enabled: true, + CacheKeyExtractor: extractor, + }, onResult) +} + +func NewFireworksPromptCachingWithTraceID(traceExtractor TraceExtractor) *PromptCachingInterceptor { + return NewPromptCaching("fireworks", PromptCachingConfig{ + Enabled: true, + CacheKeyExtractor: TraceIDCacheKeyExtractor(traceExtractor), + }) +} + +func NewFireworksPromptCachingWithTraceIDAndResult(traceExtractor TraceExtractor, onResult func(llmproxy.CacheUsage)) *PromptCachingInterceptor { + return NewPromptCachingWithResult("fireworks", PromptCachingConfig{ + Enabled: true, + CacheKeyExtractor: TraceIDCacheKeyExtractor(traceExtractor), + }, onResult) +} + +func NewFireworksPromptCachingWithOrgExtractor(sessionID string, orgExtractor OrgIDExtractor) *PromptCachingInterceptor { + return NewPromptCaching("fireworks", PromptCachingConfig{ + Enabled: true, + CacheKey: sessionID, + OrgIDExtractor: orgExtractor, + }) +} + +func NewBedrockPromptCaching(retention CacheRetention) *PromptCachingInterceptor { + return NewPromptCaching("bedrock", PromptCachingConfig{ + Enabled: true, + Retention: retention, + }) +} + +func NewBedrockPromptCachingWithResult(retention CacheRetention, onResult func(llmproxy.CacheUsage)) *PromptCachingInterceptor { + return NewPromptCachingWithResult("bedrock", PromptCachingConfig{ + Enabled: true, + Retention: retention, + }, onResult) +} + +func NewOpenAIPromptCachingWithOrgExtractor(retention CacheRetention, cacheKey string, orgExtractor OrgIDExtractor) *PromptCachingInterceptor { + return NewPromptCaching("openai", PromptCachingConfig{ + Enabled: true, + Retention: retention, + CacheKey: cacheKey, + OrgIDExtractor: orgExtractor, + }) +} + +func NewOpenAIPromptCachingAutoWithOrgExtractor(retention CacheRetention, orgExtractor OrgIDExtractor) *PromptCachingInterceptor { + return NewPromptCaching("openai", PromptCachingConfig{ + Enabled: true, + Retention: retention, + CacheKeyFn: DeriveCacheKeyFromPrefix, + OrgIDExtractor: orgExtractor, + }) +} + +func DefaultOrgIDExtractor(ctx context.Context, req *http.Request, meta llmproxy.BodyMetadata) string { + if metaCtx := llmproxy.GetMetaFromContext(ctx); metaCtx.OrgID != "" { + return metaCtx.OrgID + } + if orgID := req.Header.Get(HeaderOrgID); orgID != "" { + return orgID + } + if orgID, ok := meta.Custom["org_id"].(string); ok { + return orgID + } + return "" +} + +func TraceIDCacheKeyExtractor(traceExtractor TraceExtractor) CacheKeyExtractor { + return func(ctx context.Context, req *http.Request, meta llmproxy.BodyMetadata, rawBody []byte) string { + if traceExtractor == nil { + return "" + } + traceInfo := traceExtractor(ctx) + if traceInfo.TraceID != [16]byte{} { + return hex.EncodeToString(traceInfo.TraceID[:]) + } + return "" + } +} diff --git a/interceptors/promptcaching_test.go b/interceptors/promptcaching_test.go new file mode 100644 index 0000000..b674de8 --- /dev/null +++ b/interceptors/promptcaching_test.go @@ -0,0 +1,2047 @@ +package interceptors + +import ( + "bytes" + "context" + "encoding/hex" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/agentuity/llmproxy" +) + +func TestPromptCachingInterceptor_AnthropicSystemString(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if !bytes.Contains(body, []byte(`"cache_control"`)) { + t.Error("Request body should contain cache_control") + } + if !bytes.Contains(body, []byte(`"ephemeral"`)) { + t.Error("Request body should contain type ephemeral") + } + if !bytes.Contains(body, []byte(`"system"`)) { + t.Error("Request body should contain system field") + } + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse request body: %v", err) + } + system, ok := req["system"].([]interface{}) + if !ok { + t.Fatal("System should be an array") + } else if len(system) != 1 { + t.Fatalf("System array should have 1 block, got %d", len(system)) + } + block, ok := system[0].(map[string]interface{}) + if !ok { + t.Fatal("System block should be an object") + } + if block["text"] != "You are helpful." { + t.Errorf("System block text = %q, want %q", block["text"], "You are helpful.") + } + if _, has := block["cache_control"]; !has { + t.Error("System block should have cache_control directly on the content block") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewAnthropicPromptCaching(CacheRetentionDefault) + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"claude-3-opus","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}]}`))) + meta := llmproxy.BodyMetadata{Model: "claude-3-opus"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"claude-3-opus","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_AnthropicSystemArray(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse: %v", err) + } + system, ok := req["system"].([]interface{}) + if !ok { + t.Fatal("System should be an array") + } + lastBlock, ok := system[len(system)-1].(map[string]interface{}) + if !ok { + t.Fatal("Last block should be an object") + } + if _, has := lastBlock["cache_control"]; !has { + t.Error("Last system block should have cache_control") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewAnthropicPromptCaching(CacheRetentionDefault) + + reqBody := `{"model":"claude-3-opus","system":[{"type":"text","text":"You are helpful."}],"messages":[{"role":"user","content":"Hello"}]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "claude-3-opus"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_AnthropicLastMessage(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse: %v", err) + } + messages, ok := req["messages"].([]interface{}) + if !ok { + t.Fatal("Messages should be an array") + } + lastMsg, ok := messages[len(messages)-1].(map[string]interface{}) + if !ok { + t.Fatal("Last message should be an object") + } + content, ok := lastMsg["content"].([]interface{}) + if !ok { + t.Fatal("Last message content should be an array") + } + lastBlock, ok := content[len(content)-1].(map[string]interface{}) + if !ok { + t.Fatal("Last content block should be an object") + } + if _, has := lastBlock["cache_control"]; !has { + t.Error("Last message content block should have cache_control") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewAnthropicPromptCaching(CacheRetentionDefault) + + reqBody := `{"model":"claude-3-opus","messages":[{"role":"user","content":"Hello"}]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "claude-3-opus"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_Anthropic1hRetention(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if !bytes.Contains(body, []byte(`"ttl":"1h"`)) { + t.Error("Request body should contain ttl 1h") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewAnthropicPromptCaching(CacheRetention1h) + + reqBody := `{"model":"claude-3-opus","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "claude-3-opus"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_OpenAIAddsCacheKey(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if !bytes.Contains(body, []byte(`"prompt_cache_key"`)) { + t.Error("Request body should contain prompt_cache_key") + } + if !bytes.Contains(body, []byte(`"my-cache-key"`)) { + t.Error("Request body should contain the cache key value") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCaching(CacheRetentionDefault, "my-cache-key") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_OpenAI24hRetention(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if !bytes.Contains(body, []byte(`"prompt_cache_retention":"24h"`)) { + t.Error("Request body should contain prompt_cache_retention 24h") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCaching(CacheRetention24h, "my-key") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-5.1","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "gpt-5.1"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-5.1","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_OpenAICacheUsage(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + var cacheUsage llmproxy.CacheUsage + caching := NewOpenAIPromptCachingWithResult(CacheRetentionDefault, "test-key", func(u llmproxy.CacheUsage) { + cacheUsage = u + }) + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + respMeta := llmproxy.ResponseMetadata{ + Usage: llmproxy.Usage{PromptTokens: 2006, CompletionTokens: 300}, + Custom: map[string]any{ + "cache_usage": llmproxy.CacheUsage{ + CachedTokens: 1920, + }, + }, + } + return resp, respMeta, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if cacheUsage.CachedTokens != 1920 { + t.Errorf("CachedTokens = %d, want 1920", cacheUsage.CachedTokens) + } +} + +func TestPromptCachingInterceptor_SkipsNonAnthropic(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if bytes.Contains(body, []byte(`"cache_control"`)) { + t.Error("Request body should NOT contain cache_control for non-Anthropic model") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewAnthropicPromptCaching(CacheRetentionDefault) + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_SkipsNonOpenAI(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if bytes.Contains(body, []byte(`"prompt_cache_key"`)) { + t.Error("Request body should NOT contain prompt_cache_key for non-OpenAI model") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCaching(CacheRetentionDefault, "test-key") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"claude-3-opus","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "claude-3-opus"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"claude-3-opus","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_AnthropicExistingCacheControl(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if bytes.Count(body, []byte(`"cache_control"`)) > 1 { + t.Error("Request body should not have additional cache_control added") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewAnthropicPromptCaching(CacheRetentionDefault) + + reqBody := `{"model":"claude-3-opus","system":[{"type":"text","text":"You are helpful.","cache_control":{"type":"ephemeral"}}],"messages":[{"role":"user","content":"Hello"}]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "claude-3-opus"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_OpenAIExistingCacheKey(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if bytes.Count(body, []byte(`"prompt_cache_key"`)) > 1 { + t.Error("Request body should not have duplicate prompt_cache_key") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCaching(CacheRetentionDefault, "new-key") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","prompt_cache_key":"existing-key","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","prompt_cache_key":"existing-key","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_Disabled(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if bytes.Contains(body, []byte(`"cache_control"`)) { + t.Error("Request body should NOT contain cache_control when disabled") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewPromptCaching("anthropic", PromptCachingConfig{Enabled: false}) + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"claude-3-opus","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "claude-3-opus"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"claude-3-opus","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_ErrorPassthrough(t *testing.T) { + caching := NewAnthropicPromptCaching(CacheRetentionDefault) + + req, _ := http.NewRequest("POST", "http://example.com", nil) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + return nil, llmproxy.ResponseMetadata{}, nil, http.ErrHandlerTimeout + } + + _, _, _, err := caching.Intercept(req, llmproxy.BodyMetadata{Model: "claude-3-opus"}, []byte(`{"model":"claude-3-opus"}`), next) + if err != http.ErrHandlerTimeout { + t.Errorf("Error should pass through, got %v", err) + } +} + +func TestPromptCachingInterceptor_CacheControlNoCache(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if bytes.Contains(body, []byte(`"cache_control"`)) { + t.Error("Request body should NOT contain cache_control when Cache-Control: no-cache is set") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewAnthropicPromptCaching(CacheRetentionDefault) + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"claude-3-opus","messages":[]}`))) + req.Header.Set("Cache-Control", "no-cache") + meta := llmproxy.BodyMetadata{Model: "claude-3-opus"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"claude-3-opus","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_CacheControlNoCacheOpenAI(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if bytes.Contains(body, []byte(`"prompt_cache_key"`)) { + t.Error("Request body should NOT contain prompt_cache_key when Cache-Control: no-cache is set") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCaching(CacheRetentionDefault, "my-key") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","messages":[]}`))) + req.Header.Set("Cache-Control", "no-cache") + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestIsOpenAIModel(t *testing.T) { + tests := []struct { + model string + expected bool + }{ + {"gpt-4", true}, + {"gpt-3.5-turbo", true}, + {"gpt-5.1", true}, + {"gpt-5-codex", true}, + {"o1-preview", true}, + {"o3-mini", true}, + {"chatgpt-4o", true}, + {"claude-3-opus", false}, + {"gemini-pro", false}, + {"llama-3", false}, + } + + for _, tt := range tests { + t.Run(tt.model, func(t *testing.T) { + result := isOpenAIModel(tt.model) + if result != tt.expected { + t.Errorf("isOpenAIModel(%q) = %v, want %v", tt.model, result, tt.expected) + } + }) + } +} + +func TestHasExistingCacheControl(t *testing.T) { + caching := &PromptCachingInterceptor{config: PromptCachingConfig{Enabled: true}} + + tests := []struct { + name string + req map[string]interface{} + expected bool + }{ + { + name: "no cache_control", + req: map[string]interface{}{"model": "claude-3-opus"}, + expected: false, + }, + { + name: "cache_control in system array", + req: map[string]interface{}{ + "system": []interface{}{ + map[string]interface{}{"type": "text", "text": "Hello", "cache_control": map[string]interface{}{"type": "ephemeral"}}, + }, + }, + expected: true, + }, + { + name: "cache_control in message content", + req: map[string]interface{}{ + "messages": []interface{}{ + map[string]interface{}{ + "role": "user", + "content": []interface{}{ + map[string]interface{}{"type": "text", "text": "Hi", "cache_control": map[string]interface{}{"type": "ephemeral"}}, + }, + }, + }, + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := caching.hasExistingCacheControl(tt.req) + if result != tt.expected { + t.Errorf("hasExistingCacheControl() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestBuildCacheControl(t *testing.T) { + tests := []struct { + name string + retention CacheRetention + wantTTL bool + }{ + { + name: "default retention", + retention: CacheRetentionDefault, + wantTTL: false, + }, + { + name: "1h retention", + retention: CacheRetention1h, + wantTTL: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + caching := &PromptCachingInterceptor{config: PromptCachingConfig{Enabled: true, Retention: tt.retention}} + cc := caching.buildCacheControl() + if cc["type"] != "ephemeral" { + t.Error("cache_control should have type ephemeral") + } + if tt.wantTTL { + if cc["ttl"] != "1h" { + t.Error("cache_control should have ttl 1h") + } + } else { + if _, has := cc["ttl"]; has { + t.Error("cache_control should not have ttl for default retention") + } + } + }) + } +} + +func TestCheckOpenAI(t *testing.T) { + tests := []struct { + name string + input string + cacheKey string + retention CacheRetention + wantKey bool + wantRet bool + }{ + { + name: "with cache key only", + input: `{"model":"gpt-4","messages":[]}`, + cacheKey: "my-key", + retention: "", + wantKey: true, + wantRet: false, + }, + { + name: "with retention only", + input: `{"model":"gpt-4","messages":[]}`, + cacheKey: "", + retention: CacheRetention24h, + wantKey: false, + wantRet: true, + }, + { + name: "with both", + input: `{"model":"gpt-4","messages":[]}`, + cacheKey: "my-key", + retention: CacheRetention24h, + wantKey: true, + wantRet: true, + }, + { + name: "with neither", + input: `{"model":"gpt-4","messages":[]}`, + cacheKey: "", + retention: "", + wantKey: false, + wantRet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + caching := &PromptCachingInterceptor{ + provider: "openai", + config: PromptCachingConfig{ + Enabled: true, + CacheKey: tt.cacheKey, + Retention: tt.retention, + }, + } + req, _ := http.NewRequest("POST", "http://example.com", bytes.NewReader([]byte(tt.input))) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + modified, shouldSkip := caching.checkOpenAI(req, meta, []byte(tt.input)) + + if tt.wantKey || tt.wantRet { + if shouldSkip { + t.Error("checkOpenAI should return false when modifications are needed (should not skip)") + } + } else { + if !shouldSkip { + t.Error("checkOpenAI should return true when no modifications needed (should skip)") + } + return + } + + if tt.wantKey { + if !bytes.Contains(modified, []byte(`"prompt_cache_key"`)) { + t.Error("Modified body should contain prompt_cache_key") + } + } + if tt.wantRet { + if !bytes.Contains(modified, []byte(`"prompt_cache_retention"`)) { + t.Error("Modified body should contain prompt_cache_retention") + } + } + }) + } +} + +func TestPromptCachingInterceptor_OpenAICacheKeyHeader(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse body: %v", err) + } + if req["prompt_cache_key"] != "header-key" { + t.Errorf("prompt_cache_key = %v, want header-key", req["prompt_cache_key"]) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCaching(CacheRetentionDefault, "config-key") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","messages":[]}`))) + req.Header.Set(HeaderCacheKey, "header-key") + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_OpenAINamespace(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse body: %v", err) + } + if req["prompt_cache_key"] != "tenant123:my-key" { + t.Errorf("prompt_cache_key = %v, want tenant123:my-key", req["prompt_cache_key"]) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCachingWithNamespace("tenant123", CacheRetentionDefault, "my-key") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_OpenAIAutoDerive(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse body: %v", err) + } + key, ok := req["prompt_cache_key"].(string) + if !ok || key == "" { + t.Error("prompt_cache_key should be auto-derived and not empty") + } + if !strings.HasPrefix(key, "tenant:") { + t.Errorf("prompt_cache_key should have namespace prefix, got %q", key) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCachingAuto("tenant", CacheRetentionDefault) + + reqBody := `{"model":"gpt-4","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestDeriveCacheKeyFromPrefix(t *testing.T) { + key1 := DeriveCacheKeyFromPrefix(llmproxy.BodyMetadata{}, []byte(`{"model":"gpt-4","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}]}`)) + key2 := DeriveCacheKeyFromPrefix(llmproxy.BodyMetadata{}, []byte(`{"model":"gpt-4","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}]}`)) + key3 := DeriveCacheKeyFromPrefix(llmproxy.BodyMetadata{}, []byte(`{"model":"gpt-4","system":"Different system.","messages":[{"role":"user","content":"Hello"}]}`)) + + if key1 == "" { + t.Error("DeriveCacheKeyFromPrefix should return non-empty key for valid input") + } + if key1 != key2 { + t.Error("Same prefix should derive same key") + } + if key1 == key3 { + t.Error("Different prefix should derive different key") + } +} + +func TestPromptCachingInterceptor_OpenAIExistingKeyInBody(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse body: %v", err) + } + if req["prompt_cache_key"] != "existing-key" { + t.Errorf("prompt_cache_key = %v, want existing-key (should not be modified)", req["prompt_cache_key"]) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCaching(CacheRetentionDefault, "new-key") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","prompt_cache_key":"existing-key","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","prompt_cache_key":"existing-key","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_XAIAddsHeader(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("x-grok-conv-id") != "my-conv-123" { + t.Errorf("x-grok-conv-id header = %q, want my-conv-123", r.Header.Get("x-grok-conv-id")) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewXAIPromptCaching("my-conv-123") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"grok-2-1212","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "grok-2-1212"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"grok-2-1212","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_XAISkipsNonXAI(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("x-grok-conv-id") != "" { + t.Error("x-grok-conv-id header should NOT be set for non-xAI model") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewXAIPromptCaching("my-conv-123") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_XAIExistingHeader(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("x-grok-conv-id") != "existing-conv-id" { + t.Errorf("x-grok-conv-id header = %q, want existing-conv-id", r.Header.Get("x-grok-conv-id")) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewXAIPromptCaching("new-conv-id") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"grok-2-1212","messages":[]}`))) + req.Header.Set("x-grok-conv-id", "existing-conv-id") + meta := llmproxy.BodyMetadata{Model: "grok-2-1212"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"grok-2-1212","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_XAINoCacheKey(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("x-grok-conv-id") != "" { + t.Error("x-grok-conv-id header should NOT be set when no cache key provided") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewXAIPromptCaching("") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"grok-2-1212","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "grok-2-1212"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"grok-2-1212","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_XAICacheControlNoCache(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("x-grok-conv-id") != "" { + t.Error("x-grok-conv-id header should NOT be set when Cache-Control: no-cache") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewXAIPromptCaching("my-conv-123") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"grok-2-1212","messages":[]}`))) + req.Header.Set("Cache-Control", "no-cache") + meta := llmproxy.BodyMetadata{Model: "grok-2-1212"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"grok-2-1212","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestIsXAIModel(t *testing.T) { + tests := []struct { + model string + expected bool + }{ + {"grok-2-1212", true}, + {"grok-3", true}, + {"grok-beta", true}, + {"grok-2-latest", true}, + {"gpt-4", false}, + {"claude-3-opus", false}, + {"gemini-pro", false}, + } + + for _, tt := range tests { + t.Run(tt.model, func(t *testing.T) { + result := isXAIModel(tt.model) + if result != tt.expected { + t.Errorf("isXAIModel(%q) = %v, want %v", tt.model, result, tt.expected) + } + }) + } +} + +func TestPromptCachingInterceptor_OpenAIOrgIDFromHeader(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse body: %v", err) + } + if req["prompt_cache_key"] != "org-abc:my-key" { + t.Errorf("prompt_cache_key = %v, want org-abc:my-key", req["prompt_cache_key"]) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCaching(CacheRetentionDefault, "my-key") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","messages":[]}`))) + req.Header.Set(HeaderOrgID, "org-abc") + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_OpenAIOrgIDFromMetaCustom(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse body: %v", err) + } + if req["prompt_cache_key"] != "tenant-xyz:my-key" { + t.Errorf("prompt_cache_key = %v, want tenant-xyz:my-key", req["prompt_cache_key"]) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCaching(CacheRetentionDefault, "my-key") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","messages":[]}`))) + meta := llmproxy.BodyMetadata{ + Model: "gpt-4", + Custom: map[string]any{ + "org_id": "tenant-xyz", + }, + } + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_OpenAIOrgIDExtractor(t *testing.T) { + customExtractor := func(ctx context.Context, req *http.Request, meta llmproxy.BodyMetadata) string { + return "custom-org-123" + } + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse body: %v", err) + } + if req["prompt_cache_key"] != "custom-org-123:my-key" { + t.Errorf("prompt_cache_key = %v, want custom-org-123:my-key", req["prompt_cache_key"]) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCachingWithOrgExtractor(CacheRetentionDefault, "my-key", customExtractor) + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_OpenAICacheKeyHeaderOverridesOrgID(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse body: %v", err) + } + if req["prompt_cache_key"] != "org-abc:header-key" { + t.Errorf("prompt_cache_key = %v, want org-abc:header-key", req["prompt_cache_key"]) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewOpenAIPromptCaching(CacheRetentionDefault, "config-key") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","messages":[]}`))) + req.Header.Set(HeaderOrgID, "org-abc") + req.Header.Set(HeaderCacheKey, "header-key") + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestDefaultOrgIDExtractor(t *testing.T) { + tests := []struct { + name string + setup func(*http.Request, *llmproxy.BodyMetadata) + expected string + }{ + { + name: "from header", + setup: func(req *http.Request, _ *llmproxy.BodyMetadata) { req.Header.Set(HeaderOrgID, "org-header") }, + expected: "org-header", + }, + { + name: "from meta custom", + setup: func(_ *http.Request, meta *llmproxy.BodyMetadata) { meta.Custom = map[string]any{"org_id": "org-meta"} }, + expected: "org-meta", + }, + { + name: "none", + setup: func(*http.Request, *llmproxy.BodyMetadata) {}, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("POST", "http://example.com", nil) + meta := llmproxy.BodyMetadata{} + tt.setup(req, &meta) + result := DefaultOrgIDExtractor(req.Context(), req, meta) + if result != tt.expected { + t.Errorf("DefaultOrgIDExtractor() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestPromptCachingInterceptor_FireworksAddsHeaders(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get(HeaderFireworksSessionAffinity) != "session-123" { + t.Errorf("x-session-affinity header = %q, want session-123", r.Header.Get(HeaderFireworksSessionAffinity)) + } + if r.Header.Get(HeaderFireworksPromptCacheIsolation) != "org-abc" { + t.Errorf("x-prompt-cache-isolation-key header = %q, want org-abc", r.Header.Get(HeaderFireworksPromptCacheIsolation)) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewFireworksPromptCachingWithOrgExtractor("session-123", func(ctx context.Context, req *http.Request, meta llmproxy.BodyMetadata) string { + return "org-abc" + }) + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"accounts/fireworks/models/llama-v3-70b-instruct","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "accounts/fireworks/models/llama-v3-70b-instruct"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"accounts/fireworks/models/llama-v3-70b-instruct","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_FireworksSkipsNonFireworks(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get(HeaderFireworksSessionAffinity) != "" { + t.Error("x-session-affinity header should NOT be set for non-Fireworks model") + } + if r.Header.Get(HeaderFireworksPromptCacheIsolation) != "" { + t.Error("x-prompt-cache-isolation-key header should NOT be set for non-Fireworks model") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewFireworksPromptCaching("session-123") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_FireworksExistingHeaders(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get(HeaderFireworksSessionAffinity) != "existing-session" { + t.Errorf("x-session-affinity header = %q, want existing-session", r.Header.Get(HeaderFireworksSessionAffinity)) + } + if r.Header.Get(HeaderFireworksPromptCacheIsolation) != "existing-org" { + t.Errorf("x-prompt-cache-isolation-key header = %q, want existing-org", r.Header.Get(HeaderFireworksPromptCacheIsolation)) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewFireworksPromptCachingWithOrgExtractor("new-session", func(ctx context.Context, req *http.Request, meta llmproxy.BodyMetadata) string { + return "new-org" + }) + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"accounts/fireworks/models/llama-v3-70b-instruct","messages":[]}`))) + req.Header.Set(HeaderFireworksSessionAffinity, "existing-session") + req.Header.Set(HeaderFireworksPromptCacheIsolation, "existing-org") + meta := llmproxy.BodyMetadata{Model: "accounts/fireworks/models/llama-v3-70b-instruct"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"accounts/fireworks/models/llama-v3-70b-instruct","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_FireworksNoSessionID(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get(HeaderFireworksSessionAffinity) != "" { + t.Error("x-session-affinity header should NOT be set when no session ID provided") + } + if r.Header.Get(HeaderFireworksPromptCacheIsolation) != "" { + t.Error("x-prompt-cache-isolation-key should NOT be set when no org ID") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewFireworksPromptCaching("") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"accounts/fireworks/models/llama-v3-70b-instruct","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "accounts/fireworks/models/llama-v3-70b-instruct"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"accounts/fireworks/models/llama-v3-70b-instruct","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_FireworksCacheUsage(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("fireworks-prompt-tokens", "2006") + w.Header().Set("fireworks-cached-prompt-tokens", "1920") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + var cacheUsage llmproxy.CacheUsage + caching := NewFireworksPromptCachingWithResult("session-123", func(u llmproxy.CacheUsage) { + cacheUsage = u + }) + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"accounts/fireworks/models/llama-v3-70b-instruct","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "accounts/fireworks/models/llama-v3-70b-instruct"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"accounts/fireworks/models/llama-v3-70b-instruct","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if cacheUsage.CachedTokens != 1920 { + t.Errorf("CachedTokens = %d, want 1920", cacheUsage.CachedTokens) + } +} + +func TestPromptCachingInterceptor_FireworksCacheControlNoCache(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get(HeaderFireworksSessionAffinity) != "" { + t.Error("x-session-affinity header should NOT be set when Cache-Control: no-cache") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewFireworksPromptCaching("session-123") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"accounts/fireworks/models/llama-v3-70b-instruct","messages":[]}`))) + req.Header.Set("Cache-Control", "no-cache") + meta := llmproxy.BodyMetadata{Model: "accounts/fireworks/models/llama-v3-70b-instruct"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"accounts/fireworks/models/llama-v3-70b-instruct","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestIsFireworksModel(t *testing.T) { + tests := []struct { + model string + expected bool + }{ + {"accounts/fireworks/models/llama-v3-70b-instruct", true}, + {"accounts/fireworks/models/qwen2p5-72b", true}, + {"fireworks-model", true}, + {"gpt-4", false}, + {"claude-3-opus", false}, + {"grok-2", false}, + } + + for _, tt := range tests { + t.Run(tt.model, func(t *testing.T) { + result := isFireworksModel(tt.model) + if result != tt.expected { + t.Errorf("isFireworksModel(%q) = %v, want %v", tt.model, result, tt.expected) + } + }) + } +} + +func TestIsBedrockModel(t *testing.T) { + tests := []struct { + model string + expected bool + }{ + {"anthropic.claude-3-sonnet-20240229-v1:0", true}, + {"anthropic.claude-3-opus-20240229-v1:0", true}, + {"anthropic.claude-opus-4-5-20251101-v1:0", true}, + {"amazon.nova-micro-v1:0", true}, + {"amazon.nova-lite-v1:0", true}, + {"amazon.nova-pro-v1:0", true}, + {"amazon.titan-text-express-v1", true}, + {"gpt-4", false}, + {"claude-3-opus", false}, + {"grok-2", false}, + } + + for _, tt := range tests { + t.Run(tt.model, func(t *testing.T) { + result := isBedrockModel(tt.model) + if result != tt.expected { + t.Errorf("isBedrockModel(%q) = %v, want %v", tt.model, result, tt.expected) + } + }) + } +} + +func TestPromptCachingInterceptor_BedrockSystemCachePoint(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse: %v", err) + } + system, ok := req["system"].([]interface{}) + if !ok { + t.Fatal("System should be an array") + } + lastBlock, ok := system[len(system)-1].(map[string]interface{}) + if !ok { + t.Fatal("Last block should be an object") + } + if cp, ok := lastBlock["cachePoint"].(map[string]interface{}); ok { + if cp["type"] != "default" { + t.Errorf("cachePoint type = %v, want default", cp["type"]) + } + } else { + t.Error("Last system block should have cachePoint") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewBedrockPromptCaching(CacheRetentionDefault) + + reqBody := `{"modelId":"anthropic.claude-3-sonnet-20240229-v1:0","system":[{"text":"You are helpful."}],"messages":[{"role":"user","content":[{"text":"Hello"}]}]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "anthropic.claude-3-sonnet-20240229-v1:0"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_BedrockMessagesCachePoint(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse: %v", err) + } + messages, ok := req["messages"].([]interface{}) + if !ok { + t.Fatal("Messages should be an array") + } + lastMsg, ok := messages[len(messages)-1].(map[string]interface{}) + if !ok { + t.Fatal("Last message should be an object") + } + content, ok := lastMsg["content"].([]interface{}) + if !ok { + t.Fatal("Last message content should be an array") + } + lastBlock, ok := content[len(content)-1].(map[string]interface{}) + if !ok { + t.Fatal("Last content block should be an object") + } + if _, has := lastBlock["cachePoint"]; !has { + t.Error("Last message content block should have cachePoint") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewBedrockPromptCaching(CacheRetentionDefault) + + reqBody := `{"modelId":"anthropic.claude-3-sonnet-20240229-v1:0","messages":[{"role":"user","content":[{"text":"Hello"}]}]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "anthropic.claude-3-sonnet-20240229-v1:0"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_BedrockToolsCachePoint(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req map[string]interface{} + if err := json.Unmarshal(body, &req); err != nil { + t.Fatalf("Failed to parse: %v", err) + } + toolConfig, ok := req["toolConfig"].(map[string]interface{}) + if !ok { + t.Fatal("toolConfig should be an object") + } + tools, ok := toolConfig["tools"].([]interface{}) + if !ok { + t.Fatal("tools should be an array") + } + lastBlock, ok := tools[len(tools)-1].(map[string]interface{}) + if !ok { + t.Fatal("Last tool block should be an object") + } + if _, has := lastBlock["cachePoint"]; !has { + t.Error("Last tool block should have cachePoint") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewBedrockPromptCaching(CacheRetentionDefault) + + reqBody := `{"modelId":"anthropic.claude-3-sonnet-20240229-v1:0","messages":[{"role":"user","content":[{"text":"Hello"}]}],"toolConfig":{"tools":[{"toolSpec":{"name":"get_weather","description":"Get weather","inputSchema":{"json":{"type":"object"}}}}]}}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "anthropic.claude-3-sonnet-20240229-v1:0"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_Bedrock1hRetention(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if !bytes.Contains(body, []byte(`"ttl":"1h"`)) { + t.Error("Request body should contain ttl 1h") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewBedrockPromptCaching(CacheRetention1h) + + reqBody := `{"modelId":"anthropic.claude-opus-4-5-20251101-v1:0","system":[{"text":"You are helpful."}],"messages":[{"role":"user","content":[{"text":"Hello"}]}]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "anthropic.claude-opus-4-5-20251101-v1:0"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_BedrockExistingCachePoint(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if bytes.Count(body, []byte(`"cachePoint"`)) > 1 { + t.Error("Request body should not have additional cachePoint added") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewBedrockPromptCaching(CacheRetentionDefault) + + reqBody := `{"modelId":"anthropic.claude-3-sonnet-20240229-v1:0","system":[{"text":"You are helpful.","cachePoint":{"type":"default"}}],"messages":[{"role":"user","content":[{"text":"Hello"}]}]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "anthropic.claude-3-sonnet-20240229-v1:0"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_BedrockSkipsNonBedrock(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if bytes.Contains(body, []byte(`"cachePoint"`)) { + t.Error("Request body should NOT contain cachePoint for non-Bedrock model") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewBedrockPromptCaching(CacheRetentionDefault) + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"gpt-4","messages":[]}`))) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"gpt-4","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_BedrockCacheControlNoCache(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if bytes.Contains(body, []byte(`"cachePoint"`)) { + t.Error("Request body should NOT contain cachePoint when Cache-Control: no-cache is set") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewBedrockPromptCaching(CacheRetentionDefault) + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"modelId":"anthropic.claude-3-sonnet-20240229-v1:0","messages":[]}`))) + req.Header.Set("Cache-Control", "no-cache") + meta := llmproxy.BodyMetadata{Model: "anthropic.claude-3-sonnet-20240229-v1:0"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"modelId":"anthropic.claude-3-sonnet-20240229-v1:0","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_BedrockCacheUsage(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"usage":{"inputTokens":100,"outputTokens":50,"totalTokens":150,"cacheReadInputTokens":80,"cacheWriteInputTokens":20,"cacheDetails":[{"ttl":"5m","cacheWriteInputTokens":20}]}}`)) + })) + defer upstream.Close() + + var cacheUsage llmproxy.CacheUsage + caching := NewBedrockPromptCachingWithResult(CacheRetentionDefault, func(u llmproxy.CacheUsage) { + cacheUsage = u + }) + + reqBody := `{"modelId":"anthropic.claude-3-sonnet-20240229-v1:0","system":[{"text":"You are helpful."}],"messages":[{"role":"user","content":[{"text":"Hello"}]}]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "anthropic.claude-3-sonnet-20240229-v1:0"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + respMeta := llmproxy.ResponseMetadata{ + Custom: map[string]any{ + "cache_usage": llmproxy.CacheUsage{ + CachedTokens: 80, + CacheWriteTokens: 20, + CacheDetails: []llmproxy.CacheDetail{ + {TTL: "5m", CacheWriteTokens: 20}, + }, + }, + }, + } + return resp, respMeta, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if cacheUsage.CachedTokens != 80 { + t.Errorf("CachedTokens = %d, want 80", cacheUsage.CachedTokens) + } + if cacheUsage.CacheWriteTokens != 20 { + t.Errorf("CacheWriteTokens = %d, want 20", cacheUsage.CacheWriteTokens) + } +} + +func TestPromptCachingInterceptor_XAIAutoDerive(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + convID := r.Header.Get("x-grok-conv-id") + if convID == "" { + t.Error("x-grok-conv-id header should be auto-derived") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewXAIPromptCachingAuto() + + reqBody := `{"model":"grok-2-1212","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "grok-2-1212"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_XAIWithTraceID(t *testing.T) { + traceID := [16]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10} + expectedTraceIDHex := hex.EncodeToString(traceID[:]) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + convID := r.Header.Get("x-grok-conv-id") + if convID != expectedTraceIDHex { + t.Errorf("x-grok-conv-id = %q, want %q", convID, expectedTraceIDHex) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + traceExtractor := func(ctx context.Context) TraceInfo { + return TraceInfo{TraceID: traceID} + } + + caching := NewXAIPromptCachingWithTraceID(traceExtractor) + + reqBody := `{"model":"grok-2-1212","messages":[]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "grok-2-1212"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_FireworksAutoDerive(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sessionID := r.Header.Get(HeaderFireworksSessionAffinity) + if sessionID == "" { + t.Error("x-session-affinity header should be auto-derived") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewFireworksPromptCachingAuto() + + reqBody := `{"model":"accounts/fireworks/models/llama-v3-70b-instruct","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "accounts/fireworks/models/llama-v3-70b-instruct"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_FireworksWithTraceID(t *testing.T) { + traceID := [16]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10} + expectedTraceIDHex := hex.EncodeToString(traceID[:]) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sessionID := r.Header.Get(HeaderFireworksSessionAffinity) + if sessionID != expectedTraceIDHex { + t.Errorf("x-session-affinity = %q, want %q", sessionID, expectedTraceIDHex) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + traceExtractor := func(ctx context.Context) TraceInfo { + return TraceInfo{TraceID: traceID} + } + + caching := NewFireworksPromptCachingWithTraceID(traceExtractor) + + reqBody := `{"model":"accounts/fireworks/models/llama-v3-70b-instruct","messages":[]}` + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(reqBody))) + meta := llmproxy.BodyMetadata{Model: "accounts/fireworks/models/llama-v3-70b-instruct"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(reqBody), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_XAICacheKeyHeader(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + convID := r.Header.Get("x-grok-conv-id") + if convID != "header-key" { + t.Errorf("x-grok-conv-id = %q, want header-key", convID) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewXAIPromptCaching("config-key") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"grok-2-1212","messages":[]}`))) + req.Header.Set(HeaderCacheKey, "header-key") + meta := llmproxy.BodyMetadata{Model: "grok-2-1212"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"grok-2-1212","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestPromptCachingInterceptor_FireworksCacheKeyHeader(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sessionID := r.Header.Get(HeaderFireworksSessionAffinity) + if sessionID != "header-key" { + t.Errorf("x-session-affinity = %q, want header-key", sessionID) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + })) + defer upstream.Close() + + caching := NewFireworksPromptCaching("config-key") + + req, _ := http.NewRequest("POST", upstream.URL, bytes.NewReader([]byte(`{"model":"accounts/fireworks/models/llama-v3-70b-instruct","messages":[]}`))) + req.Header.Set(HeaderCacheKey, "header-key") + meta := llmproxy.BodyMetadata{Model: "accounts/fireworks/models/llama-v3-70b-instruct"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body, _ := io.ReadAll(resp.Body) + return resp, llmproxy.ResponseMetadata{}, body, nil + } + + _, _, _, err := caching.Intercept(req, meta, []byte(`{"model":"accounts/fireworks/models/llama-v3-70b-instruct","messages":[]}`), next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} diff --git a/metadata.go b/metadata.go index eefefee..15bef47 100644 --- a/metadata.go +++ b/metadata.go @@ -56,6 +56,32 @@ type Usage struct { TotalTokens int `json:"total_tokens"` } +// CacheUsage tracks prompt caching token consumption. +type CacheUsage struct { + // CachedTokens is the number of tokens served from cache (OpenAI). + CachedTokens int `json:"cached_tokens,omitempty"` + // CacheCreationInputTokens is the number of tokens written to cache (Anthropic). + CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` + // CacheReadInputTokens is the number of tokens read from cache (Anthropic). + CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` + // Ephemeral5mInputTokens is the number of 5-minute cache write tokens (Anthropic). + Ephemeral5mInputTokens int `json:"ephemeral_5m_input_tokens,omitempty"` + // Ephemeral1hInputTokens is the number of 1-hour cache write tokens (Anthropic). + Ephemeral1hInputTokens int `json:"ephemeral_1h_input_tokens,omitempty"` + // CacheWriteTokens is the number of tokens written to cache (Bedrock). + CacheWriteTokens int `json:"cache_write_tokens,omitempty"` + // CacheDetails contains TTL-based cache write breakdown (Bedrock). + CacheDetails []CacheDetail `json:"cache_details,omitempty"` +} + +// CacheDetail contains cache details for a checkpoint (Bedrock). +type CacheDetail struct { + // TTL is the time-to-live for the cache entry (e.g., "5m", "1h"). + TTL string `json:"ttl,omitempty"` + // CacheWriteTokens is the number of tokens written to cache at this TTL. + CacheWriteTokens int `json:"cache_write_tokens,omitempty"` +} + // Choice represents a single completion choice in the response. type Choice struct { // Index is the position of this choice in the choices array. diff --git a/providers/anthropic/extractor.go b/providers/anthropic/extractor.go index a3295a1..67cd48c 100644 --- a/providers/anthropic/extractor.go +++ b/providers/anthropic/extractor.go @@ -42,6 +42,18 @@ func (e *Extractor) Extract(resp *http.Response) (llmproxy.ResponseMetadata, []b Custom: make(map[string]any), } + cacheUsage := llmproxy.CacheUsage{ + CacheCreationInputTokens: anthropicResp.Usage.CacheCreationInputTokens, + CacheReadInputTokens: anthropicResp.Usage.CacheReadInputTokens, + } + if anthropicResp.CacheCreation != nil { + cacheUsage.Ephemeral5mInputTokens = anthropicResp.CacheCreation.Ephemeral5mInputTokens + cacheUsage.Ephemeral1hInputTokens = anthropicResp.CacheCreation.Ephemeral1hInputTokens + } + if cacheUsage.CacheCreationInputTokens > 0 || cacheUsage.CacheReadInputTokens > 0 { + meta.Custom["cache_usage"] = cacheUsage + } + if len(anthropicResp.Content) > 0 { var content string var role string @@ -70,14 +82,15 @@ func (e *Extractor) Extract(resp *http.Response) (llmproxy.ResponseMetadata, []b // Response represents an Anthropic messages API response. type Response struct { - ID string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - Model string `json:"model"` - Content []ContentBlock `json:"content"` - StopReason string `json:"stop_reason"` - StopSequence string `json:"stop_sequence,omitempty"` - Usage UsageInfo `json:"usage"` + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Model string `json:"model"` + Content []ContentBlock `json:"content"` + StopReason string `json:"stop_reason"` + StopSequence string `json:"stop_sequence,omitempty"` + Usage UsageInfo `json:"usage"` + CacheCreation *CacheCreationInfo `json:"cache_creation,omitempty"` } // ContentBlock represents a content block in an Anthropic response. @@ -88,8 +101,16 @@ type ContentBlock struct { // UsageInfo tracks token usage in an Anthropic response. type UsageInfo struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` +} + +// CacheCreationInfo tracks cache creation token breakdown. +type CacheCreationInfo struct { + Ephemeral5mInputTokens int `json:"ephemeral_5m_input_tokens,omitempty"` + Ephemeral1hInputTokens int `json:"ephemeral_1h_input_tokens,omitempty"` } // NewExtractor creates a new Anthropic response extractor. diff --git a/providers/anthropic/parser_test.go b/providers/anthropic/parser_test.go index d74c65c..aed02cd 100644 --- a/providers/anthropic/parser_test.go +++ b/providers/anthropic/parser_test.go @@ -154,4 +154,50 @@ func TestExtractor(t *testing.T) { t.Error("raw body mismatch") } }) + + t.Run("extracts cache usage", func(t *testing.T) { + extractor := &Extractor{} + respBody := `{"id":"msg_123","type":"message","role":"assistant","model":"claude-3-opus-20240229","content":[{"type":"text","text":"Hello!"}],"stop_reason":"end_turn","usage":{"input_tokens":50,"output_tokens":5,"cache_creation_input_tokens":500,"cache_read_input_tokens":1000},"cache_creation":{"ephemeral_5m_input_tokens":500,"ephemeral_1h_input_tokens":0}}` + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte(respBody))), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cacheUsage, ok := meta.Custom["cache_usage"].(llmproxy.CacheUsage) + if !ok { + t.Fatal("expected cache_usage in Custom map") + } + if cacheUsage.CacheCreationInputTokens != 500 { + t.Errorf("expected 500 cache creation tokens, got %d", cacheUsage.CacheCreationInputTokens) + } + if cacheUsage.CacheReadInputTokens != 1000 { + t.Errorf("expected 1000 cache read tokens, got %d", cacheUsage.CacheReadInputTokens) + } + if cacheUsage.Ephemeral5mInputTokens != 500 { + t.Errorf("expected 500 5m tokens, got %d", cacheUsage.Ephemeral5mInputTokens) + } + }) + + t.Run("no cache usage when not present", func(t *testing.T) { + extractor := &Extractor{} + respBody := `{"id":"msg_123","type":"message","role":"assistant","model":"claude-3-opus-20240229","content":[{"type":"text","text":"Hello!"}],"stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}` + + resp := &http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte(respBody))), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if _, ok := meta.Custom["cache_usage"]; ok { + t.Error("expected no cache_usage in Custom map when not present in response") + } + }) } diff --git a/providers/bedrock/extractor.go b/providers/bedrock/extractor.go index 070dddf..43afbfe 100644 --- a/providers/bedrock/extractor.go +++ b/providers/bedrock/extractor.go @@ -52,14 +52,35 @@ func (e *Extractor) Extract(resp *http.Response) (llmproxy.ResponseMetadata, []b }) } - // Store additional metrics if bedrockResp.Metrics != nil { meta.Custom["latency_ms"] = bedrockResp.Metrics.LatencyMs } + if bedrockResp.Usage.CacheReadInputTokens > 0 || bedrockResp.Usage.CacheWriteInputTokens > 0 || len(bedrockResp.Usage.CacheDetails) > 0 { + meta.Custom["cache_usage"] = llmproxy.CacheUsage{ + CachedTokens: bedrockResp.Usage.CacheReadInputTokens, + CacheWriteTokens: bedrockResp.Usage.CacheWriteInputTokens, + CacheDetails: extractCacheDetails(bedrockResp.Usage.CacheDetails), + } + } + return meta, body, nil } +func extractCacheDetails(details []CacheDetail) []llmproxy.CacheDetail { + if len(details) == 0 { + return nil + } + result := make([]llmproxy.CacheDetail, len(details)) + for i, d := range details { + result[i] = llmproxy.CacheDetail{ + TTL: d.TTL, + CacheWriteTokens: d.CacheWriteInputTokens, + } + } + return result +} + func extractOutputText(content []ContentBlock) string { var text string for _, block := range content { @@ -93,9 +114,18 @@ type OutputMessage struct { // ResponseUsage contains token usage information. type ResponseUsage struct { - InputTokens int `json:"inputTokens"` - OutputTokens int `json:"outputTokens"` - TotalTokens int `json:"totalTokens"` + InputTokens int `json:"inputTokens"` + OutputTokens int `json:"outputTokens"` + TotalTokens int `json:"totalTokens"` + CacheReadInputTokens int `json:"cacheReadInputTokens,omitempty"` + CacheWriteInputTokens int `json:"cacheWriteInputTokens,omitempty"` + CacheDetails []CacheDetail `json:"cacheDetails,omitempty"` +} + +// CacheDetail contains cache details for a checkpoint. +type CacheDetail struct { + TTL string `json:"ttl"` + CacheWriteInputTokens int `json:"cacheWriteInputTokens"` } // ResponseMetrics contains performance metrics. diff --git a/providers/bedrock/parser.go b/providers/bedrock/parser.go index e842ba8..c9d54f7 100644 --- a/providers/bedrock/parser.go +++ b/providers/bedrock/parser.go @@ -105,6 +105,13 @@ type ContentBlock struct { Image *ImageSource `json:"image,omitempty"` ToolUse *ToolUse `json:"toolUse,omitempty"` ToolResult *ToolResult `json:"toolResult,omitempty"` + CachePoint *CachePoint `json:"cachePoint,omitempty"` +} + +// CachePoint represents a cache checkpoint for prompt caching. +type CachePoint struct { + Type string `json:"type"` + TTL string `json:"ttl,omitempty"` } // ImageSource represents an image in a content block. @@ -140,7 +147,8 @@ type ToolResult struct { // SystemBlock represents a system message block. type SystemBlock struct { - Text string `json:"text"` + Text string `json:"text"` + CachePoint *CachePoint `json:"cachePoint,omitempty"` } // InferenceConfig contains inference parameters. @@ -159,7 +167,8 @@ type ToolConfig struct { // Tool represents a tool definition. type Tool struct { - ToolSpec *ToolSpec `json:"toolSpec,omitempty"` + ToolSpec *ToolSpec `json:"toolSpec,omitempty"` + CachePoint *CachePoint `json:"cachePoint,omitempty"` } // ToolSpec contains tool specification. diff --git a/providers/openai_compatible/extractor.go b/providers/openai_compatible/extractor.go index 7e6893f..fc92a9b 100644 --- a/providers/openai_compatible/extractor.go +++ b/providers/openai_compatible/extractor.go @@ -46,6 +46,12 @@ func (e *Extractor) Extract(resp *http.Response) (llmproxy.ResponseMetadata, []b Custom: make(map[string]any), } + if openaiResp.Usage.PromptTokensDetails != nil && openaiResp.Usage.PromptTokensDetails.CachedTokens > 0 { + meta.Custom["cache_usage"] = llmproxy.CacheUsage{ + CachedTokens: openaiResp.Usage.PromptTokensDetails.CachedTokens, + } + } + for i, c := range openaiResp.Choices { meta.Choices[i] = llmproxy.Choice{ Index: c.Index, @@ -86,9 +92,25 @@ type OpenAIResponse struct { // UsageInfo tracks token usage in an OpenAI-compatible response. type UsageInfo struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` +} + +// PromptTokensDetails contains detailed prompt token breakdown. +type PromptTokensDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` + AudioTokens int `json:"audio_tokens,omitempty"` +} + +// CompletionTokensDetails contains detailed completion token breakdown. +type CompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` + AudioTokens int `json:"audio_tokens,omitempty"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"` } // ResponseChoice represents a single completion choice. diff --git a/providers/openai_compatible/parser_test.go b/providers/openai_compatible/parser_test.go index 206e49e..7be75dc 100644 --- a/providers/openai_compatible/parser_test.go +++ b/providers/openai_compatible/parser_test.go @@ -493,3 +493,67 @@ func TestNewEnricher(t *testing.T) { t.Errorf("APIKey = %q, want test-key", enricher.APIKey) } } + +func TestExtractor_CacheUsage(t *testing.T) { + body := `{"id":"chatcmpl-123","model":"gpt-4","usage":{"prompt_tokens":2006,"completion_tokens":300,"total_tokens":2306,"prompt_tokens_details":{"cached_tokens":1920}},"choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}]}` + extractor := NewExtractor() + + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cacheUsage, ok := meta.Custom["cache_usage"].(llmproxy.CacheUsage) + if !ok { + t.Fatal("expected cache_usage in Custom map") + } + if cacheUsage.CachedTokens != 1920 { + t.Errorf("CachedTokens = %d, want 1920", cacheUsage.CachedTokens) + } +} + +func TestExtractor_NoCacheUsage(t *testing.T) { + body := `{"id":"chatcmpl-123","model":"gpt-4","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150},"choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}]}` + extractor := NewExtractor() + + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if _, ok := meta.Custom["cache_usage"]; ok { + t.Error("expected no cache_usage in Custom map when not present in response") + } +} + +func TestExtractor_ZeroCachedTokens(t *testing.T) { + body := `{"id":"chatcmpl-123","model":"gpt-4","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150,"prompt_tokens_details":{"cached_tokens":0}},"choices":[]}` + extractor := NewExtractor() + + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if _, ok := meta.Custom["cache_usage"]; ok { + t.Error("expected no cache_usage when cached_tokens is 0") + } +}