From e4c78dd1f916517504f8cbec5ba1a652a43369c3 Mon Sep 17 00:00:00 2001 From: Danial Beg Date: Mon, 16 Mar 2026 18:18:22 -0700 Subject: [PATCH] Add 9 providers and update pricing for GPT-5 and Claude 4.5/4.6 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New providers (7→16 total): - Azure OpenAI (custom parser for /openai/deployments/ URL scheme) - Together AI, Fireworks AI, Perplexity, OpenRouter, xAI, Cerebras, SambaNova (all OpenAI-compatible with path-prefix routing) Pricing updates: - Add full GPT-5 family: gpt-5, 5-mini, 5-nano, 5-pro, 5-codex, 5-chat, 5.1–5.4 variants (19 models) - Add o3-pro and o1-pro ($150/$600 per MTok) - Fix Anthropic 4.5/4.6 pricing: Opus dropped from $15/$75 to $5/$25 - Add claude-opus-4.1, claude-3.7-sonnet, dot-notation model IDs - Add xAI Grok-3, Perplexity Sonar, Together/Fireworks/Cerebras/ SambaNova hosted model pricing Also: Azure api-key header support in ExtractAPIKey --- CLAUDE.md | 4 +- configs/agentledger.example.yaml | 40 ++++++++ internal/config/config.go | 32 +++++++ internal/meter/meter_test.go | 57 ++++++++++++ internal/meter/pricing.go | 84 +++++++++++++++-- internal/provider/azure.go | 119 ++++++++++++++++++++++++ internal/provider/azure_test.go | 76 +++++++++++++++ internal/provider/cerebras.go | 10 ++ internal/provider/fireworks.go | 10 ++ internal/provider/openai_compat_test.go | 81 ++++++++++++++-- internal/provider/openrouter.go | 10 ++ internal/provider/perplexity.go | 10 ++ internal/provider/provider.go | 5 +- internal/provider/provider_test.go | 10 ++ internal/provider/registry.go | 2 + internal/provider/sambanova.go | 10 ++ internal/provider/together.go | 10 ++ internal/provider/xai.go | 10 ++ 18 files changed, 562 insertions(+), 18 deletions(-) create mode 100644 internal/provider/azure.go create mode 100644 internal/provider/azure_test.go create mode 100644 internal/provider/cerebras.go create mode 100644 internal/provider/fireworks.go create mode 100644 internal/provider/openrouter.go create mode 100644 internal/provider/perplexity.go create mode 100644 internal/provider/sambanova.go create mode 100644 internal/provider/together.go create mode 100644 internal/provider/xai.go diff --git a/CLAUDE.md b/CLAUDE.md index 777e446..7a549ee 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -6,7 +6,7 @@ A Go-based open-source reverse proxy that provides real-time cost attribution, b ## Quick Context -- **What:** Transparent reverse proxy between AI agents and LLM APIs (OpenAI, Anthropic, Groq, Mistral, DeepSeek, Gemini, Cohere) +- **What:** Transparent reverse proxy between AI agents and LLM APIs (OpenAI, Anthropic, Azure OpenAI, Groq, Mistral, DeepSeek, Gemini, Cohere, Together AI, Fireworks AI, Perplexity, OpenRouter, xAI, Cerebras, SambaNova) - **How:** `export OPENAI_BASE_URL=http://localhost:8787/v1` — zero code changes - **Why:** No tool tracks per-agent-execution costs, detects loops, or meters MCP calls - **Language:** Go — single binary, zero runtime dependencies @@ -37,7 +37,7 @@ Agents → AgentLedger (Go proxy :8787) → LLM APIs (OpenAI, Anthropic, Groq, M |---------|---------| | `cmd/agentledger/` | CLI entrypoint (cobra): `serve`, `costs`, `version` | | `internal/proxy/` | Core reverse proxy (`httputil.ReverseProxy`), SSE streaming, middleware chain | -| `internal/provider/` | Provider interface + OpenAI/Anthropic/Gemini/Cohere parsers, OpenAI-compatible base type, path-prefix routing | +| `internal/provider/` | Provider interface + OpenAI/Anthropic/Azure/Gemini/Cohere parsers, OpenAI-compatible base type (Groq, Mistral, DeepSeek, Together, Fireworks, Perplexity, OpenRouter, xAI, Cerebras, SambaNova), path-prefix routing | | `internal/meter/` | Cost calculation engine, model pricing table, tiktoken-go fallback | | `internal/ledger/` | Storage interface, SQLite (modernc.org/sqlite, CGO-free) + Postgres impls, multi-tenant queries | | `internal/budget/` | Budget enforcement middleware, circuit breaker | diff --git a/configs/agentledger.example.yaml b/configs/agentledger.example.yaml index cff0358..c97ee0c 100644 --- a/configs/agentledger.example.yaml +++ b/configs/agentledger.example.yaml @@ -38,6 +38,46 @@ providers: # upstream: "https://api.cohere.com" # path_prefix: "/cohere" # enabled: true + # azure: + # type: "azure" # Azure OpenAI (custom URL scheme) + # upstream: "https://my-resource.openai.azure.com" + # path_prefix: "/azure" + # enabled: true + # together: + # type: "openai" + # upstream: "https://api.together.xyz" + # path_prefix: "/together" + # enabled: true + # fireworks: + # type: "openai" + # upstream: "https://api.fireworks.ai/inference" + # path_prefix: "/fireworks" + # enabled: true + # perplexity: + # type: "openai" + # upstream: "https://api.perplexity.ai" + # path_prefix: "/perplexity" + # enabled: true + # openrouter: + # type: "openai" + # upstream: "https://openrouter.ai/api" + # path_prefix: "/openrouter" + # enabled: true + # xai: + # type: "openai" # xAI (Grok) + # upstream: "https://api.x.ai" + # path_prefix: "/xai" + # enabled: true + # cerebras: + # type: "openai" + # upstream: "https://api.cerebras.ai" + # path_prefix: "/cerebras" + # enabled: true + # sambanova: + # type: "openai" + # upstream: "https://api.sambanova.ai" + # path_prefix: "/sambanova" + # enabled: true storage: driver: "sqlite" # "sqlite" or "postgres" diff --git a/internal/config/config.go b/internal/config/config.go index 83f386f..a9f6321 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -192,6 +192,38 @@ func Load(path string) (*Config, error) { v.SetDefault("providers.extra.cohere.upstream", "https://api.cohere.com") v.SetDefault("providers.extra.cohere.path_prefix", "/cohere") v.SetDefault("providers.extra.cohere.enabled", false) + v.SetDefault("providers.extra.azure.type", "azure") + v.SetDefault("providers.extra.azure.upstream", "") + v.SetDefault("providers.extra.azure.path_prefix", "/azure") + v.SetDefault("providers.extra.azure.enabled", false) + v.SetDefault("providers.extra.together.type", "openai") + v.SetDefault("providers.extra.together.upstream", "https://api.together.xyz") + v.SetDefault("providers.extra.together.path_prefix", "/together") + v.SetDefault("providers.extra.together.enabled", false) + v.SetDefault("providers.extra.fireworks.type", "openai") + v.SetDefault("providers.extra.fireworks.upstream", "https://api.fireworks.ai/inference") + v.SetDefault("providers.extra.fireworks.path_prefix", "/fireworks") + v.SetDefault("providers.extra.fireworks.enabled", false) + v.SetDefault("providers.extra.perplexity.type", "openai") + v.SetDefault("providers.extra.perplexity.upstream", "https://api.perplexity.ai") + v.SetDefault("providers.extra.perplexity.path_prefix", "/perplexity") + v.SetDefault("providers.extra.perplexity.enabled", false) + v.SetDefault("providers.extra.openrouter.type", "openai") + v.SetDefault("providers.extra.openrouter.upstream", "https://openrouter.ai/api") + v.SetDefault("providers.extra.openrouter.path_prefix", "/openrouter") + v.SetDefault("providers.extra.openrouter.enabled", false) + v.SetDefault("providers.extra.xai.type", "openai") + v.SetDefault("providers.extra.xai.upstream", "https://api.x.ai") + v.SetDefault("providers.extra.xai.path_prefix", "/xai") + v.SetDefault("providers.extra.xai.enabled", false) + v.SetDefault("providers.extra.cerebras.type", "openai") + v.SetDefault("providers.extra.cerebras.upstream", "https://api.cerebras.ai") + v.SetDefault("providers.extra.cerebras.path_prefix", "/cerebras") + v.SetDefault("providers.extra.cerebras.enabled", false) + v.SetDefault("providers.extra.sambanova.type", "openai") + v.SetDefault("providers.extra.sambanova.upstream", "https://api.sambanova.ai") + v.SetDefault("providers.extra.sambanova.path_prefix", "/sambanova") + v.SetDefault("providers.extra.sambanova.enabled", false) v.SetDefault("storage.driver", "sqlite") v.SetDefault("storage.dsn", "data/agentledger.db") diff --git a/internal/meter/meter_test.go b/internal/meter/meter_test.go index 888ab69..4c5e68a 100644 --- a/internal/meter/meter_test.go +++ b/internal/meter/meter_test.go @@ -86,3 +86,60 @@ func TestPrefixMatchLongestWins(t *testing.T) { t.Errorf("expected $0.15 (gpt-4o-mini pricing), got $%f", cost) } } + +func TestNewestModelsKnown(t *testing.T) { + m := New() + + models := []string{ + // OpenAI — GPT-5 family + "gpt-5", + "gpt-5-mini", + "gpt-5-nano", + "gpt-5-pro", + "gpt-5-codex", + "gpt-5.1", + "gpt-5.2", + "gpt-5.2-pro", + "gpt-5.4", + "gpt-5.4-pro", + // OpenAI — reasoning + "o3-pro", + "o1-pro", + "o4-mini", + // OpenAI — GPT-4.1 + "gpt-4.1", + "gpt-4.1-mini", + "gpt-4.1-nano", + // Anthropic 4.5/4.6 + "claude-opus-4.6", + "claude-sonnet-4.6", + "claude-haiku-4.5", + // Anthropic — dated variants (prefix match) + "claude-opus-4.6-20260101", + "claude-sonnet-4.6-20260101", + "claude-haiku-4.5-20251001", + // Anthropic — thinking variants (prefix match) + "claude-opus-4.6-thinking", + "claude-sonnet-4.5-thinking", + // xAI + "grok-3", + "grok-3-mini", + } + + for _, model := range models { + if !m.KnownModel(model) { + t.Errorf("%q should be a known model", model) + } + } +} + +func TestO3ProNotConfusedWithO3(t *testing.T) { + m := New() + + // o3-pro must NOT use o3 pricing ($10/$40), it should use its own ($150/$600) + cost := m.Calculate("o3-pro", 1_000_000, 0) + // o3-pro input: $150 per MTok + if math.Abs(cost-150.00) > 1e-9 { + t.Errorf("o3-pro: expected $150.00 input cost, got $%f (may have matched o3 instead)", cost) + } +} diff --git a/internal/meter/pricing.go b/internal/meter/pricing.go index e85bfac..bf19dbd 100644 --- a/internal/meter/pricing.go +++ b/internal/meter/pricing.go @@ -21,29 +21,69 @@ func DefaultPricing() map[string]ModelPricing { // OpenAI — reasoning models "o3": {InputPerMTok: 10.00, OutputPerMTok: 40.00}, + "o3-pro": {InputPerMTok: 150.00, OutputPerMTok: 600.00}, "o3-mini": {InputPerMTok: 1.10, OutputPerMTok: 4.40}, "o4-mini": {InputPerMTok: 1.10, OutputPerMTok: 4.40}, "o1": {InputPerMTok: 15.00, OutputPerMTok: 60.00}, + "o1-pro": {InputPerMTok: 150.00, OutputPerMTok: 600.00}, "o1-mini": {InputPerMTok: 3.00, OutputPerMTok: 12.00}, + // OpenAI — GPT-5.4 family + "gpt-5.4-pro": {InputPerMTok: 30.00, OutputPerMTok: 180.00}, + "gpt-5.4": {InputPerMTok: 2.50, OutputPerMTok: 15.00}, + + // OpenAI — GPT-5.3 family + "gpt-5.3-codex": {InputPerMTok: 1.75, OutputPerMTok: 14.00}, + "gpt-5.3-chat": {InputPerMTok: 1.75, OutputPerMTok: 14.00}, + + // OpenAI — GPT-5.2 family + "gpt-5.2-pro": {InputPerMTok: 10.50, OutputPerMTok: 84.00}, + "gpt-5.2-codex": {InputPerMTok: 1.75, OutputPerMTok: 14.00}, + "gpt-5.2-chat": {InputPerMTok: 0.875, OutputPerMTok: 7.00}, + "gpt-5.2": {InputPerMTok: 1.75, OutputPerMTok: 14.00}, + + // OpenAI — GPT-5.1 family + "gpt-5.1-codex-max": {InputPerMTok: 1.25, OutputPerMTok: 10.00}, + "gpt-5.1-codex-mini": {InputPerMTok: 0.25, OutputPerMTok: 2.00}, + "gpt-5.1-codex": {InputPerMTok: 1.25, OutputPerMTok: 10.00}, + "gpt-5.1-chat": {InputPerMTok: 0.625, OutputPerMTok: 5.00}, + "gpt-5.1": {InputPerMTok: 0.625, OutputPerMTok: 5.00}, + + // OpenAI — GPT-5 family + "gpt-5-pro": {InputPerMTok: 15.00, OutputPerMTok: 120.00}, + "gpt-5-codex": {InputPerMTok: 1.25, OutputPerMTok: 10.00}, + "gpt-5-chat": {InputPerMTok: 1.25, OutputPerMTok: 10.00}, + "gpt-5-mini": {InputPerMTok: 0.125, OutputPerMTok: 1.00}, + "gpt-5-nano": {InputPerMTok: 0.05, OutputPerMTok: 0.40}, + "gpt-5": {InputPerMTok: 1.25, OutputPerMTok: 10.00}, + // OpenAI — legacy "gpt-4-turbo": {InputPerMTok: 10.00, OutputPerMTok: 30.00}, "gpt-4": {InputPerMTok: 30.00, OutputPerMTok: 60.00}, "gpt-3.5-turbo": {InputPerMTok: 0.50, OutputPerMTok: 1.50}, - // Anthropic — Claude 4 family + // Anthropic — Claude 4.5/4.6 family (reduced pricing from 4.0) + "claude-opus-4.6": {InputPerMTok: 5.00, OutputPerMTok: 25.00}, + "claude-opus-4.5": {InputPerMTok: 5.00, OutputPerMTok: 25.00}, + "claude-sonnet-4.6": {InputPerMTok: 3.00, OutputPerMTok: 15.00}, + "claude-sonnet-4.5": {InputPerMTok: 3.00, OutputPerMTok: 15.00}, + "claude-haiku-4.5": {InputPerMTok: 1.00, OutputPerMTok: 5.00}, + + // Anthropic — Claude 4.0/4.1 family + "claude-opus-4.1": {InputPerMTok: 15.00, OutputPerMTok: 75.00}, "claude-opus-4": {InputPerMTok: 15.00, OutputPerMTok: 75.00}, "claude-sonnet-4": {InputPerMTok: 3.00, OutputPerMTok: 15.00}, - "claude-haiku-4": {InputPerMTok: 0.80, OutputPerMTok: 4.00}, + + // Anthropic — Claude 3.7 + "claude-3.7-sonnet": {InputPerMTok: 3.00, OutputPerMTok: 15.00}, // Anthropic — Claude 3.5 - "claude-3-5-sonnet": {InputPerMTok: 3.00, OutputPerMTok: 15.00}, - "claude-3-5-haiku": {InputPerMTok: 0.80, OutputPerMTok: 4.00}, + "claude-3.5-sonnet": {InputPerMTok: 3.00, OutputPerMTok: 15.00}, + "claude-3.5-haiku": {InputPerMTok: 0.80, OutputPerMTok: 4.00}, // Anthropic — Claude 3 - "claude-3-opus": {InputPerMTok: 15.00, OutputPerMTok: 75.00}, - "claude-3-sonnet": {InputPerMTok: 3.00, OutputPerMTok: 15.00}, - "claude-3-haiku": {InputPerMTok: 0.25, OutputPerMTok: 1.25}, + "claude-3-opus": {InputPerMTok: 15.00, OutputPerMTok: 75.00}, + "claude-3-haiku": {InputPerMTok: 0.25, OutputPerMTok: 1.25}, // Google Gemini "gemini-2.5-pro": {InputPerMTok: 1.25, OutputPerMTok: 10.00}, @@ -72,5 +112,35 @@ func DefaultPricing() map[string]ModelPricing { "command-r-plus": {InputPerMTok: 2.50, OutputPerMTok: 10.00}, "command-r": {InputPerMTok: 0.15, OutputPerMTok: 0.60}, "command-light": {InputPerMTok: 0.30, OutputPerMTok: 0.60}, + + // xAI (Grok) + "grok-3": {InputPerMTok: 3.00, OutputPerMTok: 15.00}, + "grok-3-mini": {InputPerMTok: 0.30, OutputPerMTok: 0.50}, + "grok-2": {InputPerMTok: 2.00, OutputPerMTok: 10.00}, + + // Perplexity + "sonar-pro": {InputPerMTok: 3.00, OutputPerMTok: 15.00}, + "sonar": {InputPerMTok: 1.00, OutputPerMTok: 1.00}, + "sonar-reasoning": {InputPerMTok: 1.00, OutputPerMTok: 5.00}, + + // Together AI (hosted open-source models) + "meta-llama/Llama-3.3-70B-Instruct-Turbo": {InputPerMTok: 0.88, OutputPerMTok: 0.88}, + "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {InputPerMTok: 3.50, OutputPerMTok: 3.50}, + "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {InputPerMTok: 0.18, OutputPerMTok: 0.18}, + "Qwen/Qwen2.5-72B-Instruct-Turbo": {InputPerMTok: 1.20, OutputPerMTok: 1.20}, + "deepseek-ai/DeepSeek-V3": {InputPerMTok: 0.90, OutputPerMTok: 0.90}, + + // Fireworks AI + "accounts/fireworks/models/llama-v3p3-70b-instruct": {InputPerMTok: 0.90, OutputPerMTok: 0.90}, + "accounts/fireworks/models/llama-v3p1-8b-instruct": {InputPerMTok: 0.20, OutputPerMTok: 0.20}, + "accounts/fireworks/models/qwen2p5-72b-instruct": {InputPerMTok: 0.90, OutputPerMTok: 0.90}, + + // Cerebras + "llama-3.3-70b": {InputPerMTok: 0.85, OutputPerMTok: 0.85}, + "llama-3.1-8b": {InputPerMTok: 0.10, OutputPerMTok: 0.10}, + + // SambaNova + "Meta-Llama-3.3-70B-Instruct": {InputPerMTok: 0.60, OutputPerMTok: 0.60}, + "Meta-Llama-3.1-8B-Instruct": {InputPerMTok: 0.10, OutputPerMTok: 0.10}, } } diff --git a/internal/provider/azure.go b/internal/provider/azure.go new file mode 100644 index 0000000..79e4056 --- /dev/null +++ b/internal/provider/azure.go @@ -0,0 +1,119 @@ +package provider + +import ( + "encoding/json" + "net/http" + "strings" +) + +// Azure implements the Provider interface for Azure OpenAI Service. +// Azure uses a different URL scheme: /openai/deployments/{deployment}/chat/completions?api-version=... +// Auth is via api-key header instead of Bearer token. +type Azure struct { + upstream string + pathPrefix string +} + +// NewAzure creates an Azure OpenAI provider. The upstream should be the Azure +// resource endpoint (e.g., https://my-resource.openai.azure.com). +// Requests arrive at /azure/openai/deployments/{deployment}/chat/completions. +func NewAzure(upstream string) *Azure { + return &Azure{ + upstream: upstream, + pathPrefix: "/azure", + } +} + +func (a *Azure) Name() string { return "azure" } //nolint:goconst +func (a *Azure) UpstreamURL() string { return a.upstream } +func (a *Azure) PathPrefix() string { return a.pathPrefix } + +func (a *Azure) Match(r *http.Request) bool { + p := r.URL.Path + return strings.HasPrefix(p, a.pathPrefix+"/openai/deployments/") +} + +// RewritePath strips the /azure prefix so upstream sees /openai/deployments/... +func (a *Azure) RewritePath(path string) string { + return strings.TrimPrefix(path, a.pathPrefix) +} + +// azureRequest is the minimal subset of an Azure OpenAI request. +type azureRequest struct { + MaxTokens int `json:"max_tokens"` + Stream bool `json:"stream"` +} + +func (a *Azure) ParseRequest(body []byte) (*RequestMeta, error) { + var req azureRequest + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + // Azure puts the model (deployment) in the URL path, not the body. + return &RequestMeta{ + Model: "azure-deployment", + MaxTokens: req.MaxTokens, + Stream: req.Stream, + }, nil +} + +// azureResponse matches the Azure OpenAI response (same as OpenAI format). +type azureResponse struct { + Model string `json:"model"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +func (a *Azure) ParseResponse(body []byte) (*ResponseMeta, error) { + var resp azureResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, err + } + return &ResponseMeta{ + Model: resp.Model, + InputTokens: resp.Usage.PromptTokens, + OutputTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + }, nil +} + +// azureStreamChunk matches the Azure OpenAI streaming chunk (same as OpenAI format). +type azureStreamChunk struct { + Model string `json:"model"` + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + } `json:"choices"` + Usage *struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage,omitempty"` +} + +func (a *Azure) ParseStreamChunk(_ string, data []byte) (*StreamChunkMeta, error) { + var chunk azureStreamChunk + if err := json.Unmarshal(data, &chunk); err != nil { + return nil, err + } + + meta := &StreamChunkMeta{ + Model: chunk.Model, + } + + if len(chunk.Choices) > 0 { + meta.Text = chunk.Choices[0].Delta.Content + } + + if chunk.Usage != nil { + meta.InputTokens = chunk.Usage.PromptTokens + meta.OutputTokens = chunk.Usage.CompletionTokens + meta.Done = true + } + + return meta, nil +} diff --git a/internal/provider/azure_test.go b/internal/provider/azure_test.go new file mode 100644 index 0000000..36301eb --- /dev/null +++ b/internal/provider/azure_test.go @@ -0,0 +1,76 @@ +package provider + +import ( + "net/http" + "testing" +) + +func TestAzureMatch(t *testing.T) { + a := NewAzure("https://my-resource.openai.azure.com") + + tests := []struct { + path string + want bool + }{ + {"/azure/openai/deployments/gpt-4o/chat/completions", true}, + {"/azure/openai/deployments/my-model/completions", true}, + {"/azure/openai/deployments/embed/embeddings", true}, + {"/openai/deployments/gpt-4o/chat/completions", false}, + {"/v1/chat/completions", false}, + } + + for _, tt := range tests { + r := &http.Request{URL: mustParseURL(tt.path), Header: http.Header{}} + if got := a.Match(r); got != tt.want { + t.Errorf("Match(%q) = %v, want %v", tt.path, got, tt.want) + } + } +} + +func TestAzureRewritePath(t *testing.T) { + a := NewAzure("") + got := a.RewritePath("/azure/openai/deployments/gpt-4o/chat/completions") + want := "/openai/deployments/gpt-4o/chat/completions" + if got != want { + t.Errorf("RewritePath = %q, want %q", got, want) + } +} + +func TestAzureName(t *testing.T) { + if NewAzure("").Name() != "azure" { + t.Error("name mismatch") + } +} + +func TestAzureParseResponse(t *testing.T) { + a := NewAzure("") + body := []byte(`{"model":"gpt-4o","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150}}`) + meta, err := a.ParseResponse(body) + if err != nil { + t.Fatal(err) + } + if meta.Model != "gpt-4o" { + t.Errorf("model = %q, want %q", meta.Model, "gpt-4o") + } + if meta.InputTokens != 100 { + t.Errorf("input = %d, want 100", meta.InputTokens) + } + if meta.OutputTokens != 50 { + t.Errorf("output = %d, want 50", meta.OutputTokens) + } +} + +func TestAzureParseRequest(t *testing.T) { + a := NewAzure("") + body := []byte(`{"max_tokens":1000,"stream":true}`) + meta, err := a.ParseRequest(body) + if err != nil { + t.Fatal(err) + } + if meta.MaxTokens != 1000 { + t.Errorf("max_tokens = %d, want 1000", meta.MaxTokens) + } + if !meta.Stream { + t.Error("stream = false, want true") + } +} diff --git a/internal/provider/cerebras.go b/internal/provider/cerebras.go new file mode 100644 index 0000000..6576031 --- /dev/null +++ b/internal/provider/cerebras.go @@ -0,0 +1,10 @@ +package provider + +// NewCerebras creates a Cerebras provider. Cerebras uses the OpenAI-compatible +// API format. Requests arrive at /cerebras/v1/chat/completions. +func NewCerebras(upstream string) *OpenAICompatible { + if upstream == "" { + upstream = "https://api.cerebras.ai" + } + return NewOpenAICompatible("cerebras", upstream, "/cerebras") +} diff --git a/internal/provider/fireworks.go b/internal/provider/fireworks.go new file mode 100644 index 0000000..1d9dedc --- /dev/null +++ b/internal/provider/fireworks.go @@ -0,0 +1,10 @@ +package provider + +// NewFireworks creates a Fireworks AI provider. Fireworks uses the OpenAI-compatible +// API format. Requests arrive at /fireworks/v1/chat/completions. +func NewFireworks(upstream string) *OpenAICompatible { + if upstream == "" { + upstream = "https://api.fireworks.ai/inference" + } + return NewOpenAICompatible("fireworks", upstream, "/fireworks") +} diff --git a/internal/provider/openai_compat_test.go b/internal/provider/openai_compat_test.go index 618e01e..ab66860 100644 --- a/internal/provider/openai_compat_test.go +++ b/internal/provider/openai_compat_test.go @@ -62,6 +62,13 @@ func TestOpenAICompatRewritePath(t *testing.T) { {"groq strip", NewGroq(""), "/groq/v1/chat/completions", "/v1/chat/completions"}, {"mistral strip", NewMistral(""), "/mistral/v1/chat/completions", "/v1/chat/completions"}, {"deepseek strip", NewDeepSeek(""), "/deepseek/v1/chat/completions", "/v1/chat/completions"}, + {"together strip", NewTogether(""), "/together/v1/chat/completions", "/v1/chat/completions"}, + {"fireworks strip", NewFireworks(""), "/fireworks/v1/chat/completions", "/v1/chat/completions"}, + {"perplexity strip", NewPerplexity(""), "/perplexity/v1/chat/completions", "/v1/chat/completions"}, + {"openrouter strip", NewOpenRouter(""), "/openrouter/v1/chat/completions", "/v1/chat/completions"}, + {"xai strip", NewXAI(""), "/xai/v1/chat/completions", "/v1/chat/completions"}, + {"cerebras strip", NewCerebras(""), "/cerebras/v1/chat/completions", "/v1/chat/completions"}, + {"sambanova strip", NewSambaNova(""), "/sambanova/v1/chat/completions", "/v1/chat/completions"}, } for _, tt := range tests { @@ -75,16 +82,74 @@ func TestOpenAICompatRewritePath(t *testing.T) { } func TestOpenAICompatProviderNames(t *testing.T) { - if NewOpenAI("").Name() != "openai" { - t.Error("OpenAI name mismatch") + providers := map[string]*OpenAICompatible{ + "openai": NewOpenAI(""), + "groq": NewGroq(""), + "mistral": NewMistral(""), + "deepseek": NewDeepSeek(""), + "together": NewTogether(""), + "fireworks": NewFireworks(""), + "perplexity": NewPerplexity(""), + "openrouter": NewOpenRouter(""), + "xai": NewXAI(""), + "cerebras": NewCerebras(""), + "sambanova": NewSambaNova(""), } - if NewGroq("").Name() != "groq" { - t.Error("Groq name mismatch") + for want, p := range providers { + if p.Name() != want { + t.Errorf("%s name = %q, want %q", want, p.Name(), want) + } + } +} + +func TestOpenAICompatNewProviderDefaults(t *testing.T) { + tests := []struct { + name string + prov *OpenAICompatible + upstream string + prefix string + }{ + {"together", NewTogether(""), "https://api.together.xyz", "/together"}, + {"fireworks", NewFireworks(""), "https://api.fireworks.ai/inference", "/fireworks"}, + {"perplexity", NewPerplexity(""), "https://api.perplexity.ai", "/perplexity"}, + {"openrouter", NewOpenRouter(""), "https://openrouter.ai/api", "/openrouter"}, + {"xai", NewXAI(""), "https://api.x.ai", "/xai"}, + {"cerebras", NewCerebras(""), "https://api.cerebras.ai", "/cerebras"}, + {"sambanova", NewSambaNova(""), "https://api.sambanova.ai", "/sambanova"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.prov.UpstreamURL() != tt.upstream { + t.Errorf("upstream = %q, want %q", tt.prov.UpstreamURL(), tt.upstream) + } + if tt.prov.PathPrefix() != tt.prefix { + t.Errorf("prefix = %q, want %q", tt.prov.PathPrefix(), tt.prefix) + } + }) } - if NewMistral("").Name() != "mistral" { - t.Error("Mistral name mismatch") +} + +func TestOpenAICompatMatch_NewProviders(t *testing.T) { + providers := []*OpenAICompatible{ + NewTogether(""), + NewFireworks(""), + NewPerplexity(""), + NewOpenRouter(""), + NewXAI(""), + NewCerebras(""), + NewSambaNova(""), } - if NewDeepSeek("").Name() != "deepseek" { - t.Error("DeepSeek name mismatch") + + for _, p := range providers { + prefix := p.PathPrefix() + r := &http.Request{URL: mustParseURL(prefix + "/v1/chat/completions"), Header: http.Header{}} + if !p.Match(r) { + t.Errorf("%s: should match %s/v1/chat/completions", p.Name(), prefix) + } + r = &http.Request{URL: mustParseURL("/v1/chat/completions"), Header: http.Header{}} + if p.Match(r) { + t.Errorf("%s: should not match /v1/chat/completions", p.Name()) + } } } diff --git a/internal/provider/openrouter.go b/internal/provider/openrouter.go new file mode 100644 index 0000000..7ab3a34 --- /dev/null +++ b/internal/provider/openrouter.go @@ -0,0 +1,10 @@ +package provider + +// NewOpenRouter creates an OpenRouter provider. OpenRouter uses the OpenAI-compatible +// API format to route to many models. Requests arrive at /openrouter/v1/chat/completions. +func NewOpenRouter(upstream string) *OpenAICompatible { + if upstream == "" { + upstream = "https://openrouter.ai/api" + } + return NewOpenAICompatible("openrouter", upstream, "/openrouter") +} diff --git a/internal/provider/perplexity.go b/internal/provider/perplexity.go new file mode 100644 index 0000000..ba351e5 --- /dev/null +++ b/internal/provider/perplexity.go @@ -0,0 +1,10 @@ +package provider + +// NewPerplexity creates a Perplexity provider. Perplexity uses the OpenAI-compatible +// API format. Requests arrive at /perplexity/v1/chat/completions. +func NewPerplexity(upstream string) *OpenAICompatible { + if upstream == "" { + upstream = "https://api.perplexity.ai" + } + return NewOpenAICompatible("perplexity", upstream, "/perplexity") +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go index ff2b230..91b1685 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -54,7 +54,7 @@ type StreamChunkMeta struct { // ExtractAPIKey reads the API key from request headers. // OpenAI uses Authorization: Bearer, Anthropic uses x-api-key, -// Google Gemini uses X-Goog-Api-Key. +// Google Gemini uses X-Goog-Api-Key, Azure OpenAI uses api-key. func ExtractAPIKey(r *http.Request) string { if auth := r.Header.Get("Authorization"); auth != "" { if strings.HasPrefix(auth, "Bearer ") { @@ -67,6 +67,9 @@ func ExtractAPIKey(r *http.Request) string { if key := r.Header.Get("X-Goog-Api-Key"); key != "" { return key } + if key := r.Header.Get("api-key"); key != "" { + return key + } return "" } diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index cb93d5f..32269b5 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -55,6 +55,16 @@ func TestExtractAPIKey(t *testing.T) { http.Header{"X-Api-Key": {"sk-ant-api03-xyz"}}, "sk-ant-api03-xyz", }, + { + "azure api-key", + http.Header{"Api-Key": {"abc123def456"}}, + "abc123def456", + }, + { + "google x-goog-api-key", + http.Header{"X-Goog-Api-Key": {"AIzaSyAbc123"}}, + "AIzaSyAbc123", + }, { "no key", http.Header{}, diff --git a/internal/provider/registry.go b/internal/provider/registry.go index 0434f83..d73e1c3 100644 --- a/internal/provider/registry.go +++ b/internal/provider/registry.go @@ -54,6 +54,8 @@ func NewProviderByType(name, typ, upstream, pathPrefix string) Provider { upstream = "https://api.anthropic.com" } return NewAnthropic(upstream) + case "azure": //nolint:goconst + return NewAzure(upstream) case "gemini": //nolint:goconst return NewGemini(upstream) case "cohere": //nolint:goconst diff --git a/internal/provider/sambanova.go b/internal/provider/sambanova.go new file mode 100644 index 0000000..27fb391 --- /dev/null +++ b/internal/provider/sambanova.go @@ -0,0 +1,10 @@ +package provider + +// NewSambaNova creates a SambaNova provider. SambaNova uses the OpenAI-compatible +// API format. Requests arrive at /sambanova/v1/chat/completions. +func NewSambaNova(upstream string) *OpenAICompatible { + if upstream == "" { + upstream = "https://api.sambanova.ai" + } + return NewOpenAICompatible("sambanova", upstream, "/sambanova") +} diff --git a/internal/provider/together.go b/internal/provider/together.go new file mode 100644 index 0000000..4d3e1dc --- /dev/null +++ b/internal/provider/together.go @@ -0,0 +1,10 @@ +package provider + +// NewTogether creates a Together AI provider. Together uses the OpenAI-compatible +// API format. Requests arrive at /together/v1/chat/completions. +func NewTogether(upstream string) *OpenAICompatible { + if upstream == "" { + upstream = "https://api.together.xyz" + } + return NewOpenAICompatible("together", upstream, "/together") +} diff --git a/internal/provider/xai.go b/internal/provider/xai.go new file mode 100644 index 0000000..f3a0b2f --- /dev/null +++ b/internal/provider/xai.go @@ -0,0 +1,10 @@ +package provider + +// NewXAI creates an xAI (Grok) provider. xAI uses the OpenAI-compatible +// API format. Requests arrive at /xai/v1/chat/completions. +func NewXAI(upstream string) *OpenAICompatible { + if upstream == "" { + upstream = "https://api.x.ai" + } + return NewOpenAICompatible("xai", upstream, "/xai") +}