Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions DESIGN.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,22 @@ All fields use `sync/atomic` operations for thread safety. The `Metrics` struct
- **Body handling:** Reconstructs the request body from raw bytes on each retry attempt
- **Custom predicate:** `NewRetryWithPredicate(maxAttempts, delay, predicate)` allows callers to supply a custom function that decides whether a given response should be retried

`NewRetryWithRateLimitHeaders(maxAttempts, defaultDelay)` — Retries with rate limit header support:

- **Retry-After header:** Parses both seconds (integer) and HTTP date formats
- **X-RateLimit-Reset header:** Fallback if Retry-After not present
- **Max delay:** Values over 24 hours are ignored (fallback to defaultDelay); exactly 24 hours is accepted
- **Precedence:** Retry-After takes precedence over X-RateLimit-Reset

Example:

```go
// Use rate limit headers from provider
retry := interceptors.NewRetryWithRateLimitHeaders(3, time.Second)

// If provider returns 429 with Retry-After: 30, waits 30s instead of 1s
```

### Billing

`NewBilling(lookup, onResult)` — Calculates the cost of each request:
Expand Down
2 changes: 2 additions & 0 deletions examples/basic/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"log"
"net/http"
"os"
"time"

"github.com/agentuity/llmproxy"
"github.com/agentuity/llmproxy/interceptors"
Expand Down Expand Up @@ -230,6 +231,7 @@ func main() {
http.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
provider := openaiProvider
opts := []llmproxy.ProxyOption{
llmproxy.WithInterceptor(interceptors.NewRetryWithRateLimitHeaders(3, time.Millisecond*250)),
llmproxy.WithInterceptor(tracingInterceptor),
llmproxy.WithInterceptor(loggingInterceptor),
llmproxy.WithInterceptor(interceptors.NewMetrics(metrics)),
Expand Down
183 changes: 183 additions & 0 deletions interceptors/addheader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package interceptors

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/agentuity/llmproxy"
)

func TestAddHeaderInterceptor_ResponseHeaders(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
}))
defer upstream.Close()

add := NewAddResponseHeader(
NewHeader("X-Gateway-Version", "1.0"),
NewHeader("X-Served-By", "llmproxy"),
)

req, _ := http.NewRequest("POST", upstream.URL, nil)
next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, llmproxy.ResponseMetadata{}, nil, err
}
return resp, llmproxy.ResponseMetadata{}, nil, nil
}

resp, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next)
if err != nil {
t.Fatalf("Intercept returned error: %v", err)
}
defer resp.Body.Close()

if got := resp.Header.Get("X-Gateway-Version"); got != "1.0" {
t.Errorf("X-Gateway-Version header = %q, want %q", got, "1.0")
}
if got := resp.Header.Get("X-Served-By"); got != "llmproxy" {
t.Errorf("X-Served-By header = %q, want %q", got, "llmproxy")
}
if got := resp.Header.Get("Content-Type"); got != "application/json" {
t.Errorf("Content-Type header should be preserved, got %q", got)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

func TestAddHeaderInterceptor_RequestHeaders(t *testing.T) {
reqCh := make(chan *http.Request, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqCh <- r
w.WriteHeader(http.StatusOK)
}))
defer upstream.Close()

add := NewAddRequestHeader(
NewHeader("X-Client-ID", "my-app"),
NewHeader("X-Request-Source", "gateway"),
)

req, _ := http.NewRequest("POST", upstream.URL, nil)
req.Header.Set("Content-Type", "application/json")

next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, llmproxy.ResponseMetadata{}, nil, err
}
return resp, llmproxy.ResponseMetadata{}, nil, nil
}

resp, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next)
if err != nil {
t.Fatalf("Intercept returned error: %v", err)
}
resp.Body.Close()

capturedReq := <-reqCh
if got := capturedReq.Header.Get("X-Client-ID"); got != "my-app" {
t.Errorf("X-Client-ID header = %q, want %q", got, "my-app")
}
if got := capturedReq.Header.Get("X-Request-Source"); got != "gateway" {
t.Errorf("X-Request-Source header = %q, want %q", got, "gateway")
}
if got := capturedReq.Header.Get("Content-Type"); got != "application/json" {
t.Errorf("Content-Type header should be preserved, got %q", got)
}
}

func TestAddHeaderInterceptor_Both(t *testing.T) {
reqCh := make(chan *http.Request, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqCh <- r
w.WriteHeader(http.StatusOK)
}))
defer upstream.Close()

add := NewAddHeader(
[]Header{NewHeader("X-Request-ID", "req-123")},
[]Header{NewHeader("X-Response-Time", "50ms")},
)

req, _ := http.NewRequest("POST", upstream.URL, nil)
next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, llmproxy.ResponseMetadata{}, nil, err
}
return resp, llmproxy.ResponseMetadata{}, nil, nil
}

resp, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next)
if err != nil {
t.Fatalf("Intercept returned error: %v", err)
}
defer resp.Body.Close()

capturedReq := <-reqCh
if got := capturedReq.Header.Get("X-Request-ID"); got != "req-123" {
t.Errorf("Request X-Request-ID header = %q, want %q", got, "req-123")
}
if got := resp.Header.Get("X-Response-Time"); got != "50ms" {
t.Errorf("Response X-Response-Time header = %q, want %q", got, "50ms")
}
}

func TestAddHeaderInterceptor_Empty(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer upstream.Close()

add := &AddHeaderInterceptor{}

req, _ := http.NewRequest("POST", upstream.URL, nil)
next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, llmproxy.ResponseMetadata{}, nil, err
}
return resp, llmproxy.ResponseMetadata{}, nil, nil
}

resp, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next)
if err != nil {
t.Fatalf("Intercept returned error: %v", err)
}
resp.Body.Close()
}

func TestAddHeaderInterceptor_ErrorPassthrough(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer upstream.Close()

add := NewAddResponseHeader(NewHeader("X-Test", "value"))

req, _ := http.NewRequest("POST", upstream.URL, nil)
expectedErr := http.ErrHandlerTimeout
next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) {
return nil, llmproxy.ResponseMetadata{}, nil, expectedErr
}

resp, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next)
if err != expectedErr {
t.Errorf("Error should pass through, got %v, want %v", err, expectedErr)
}
if resp != nil {
t.Errorf("Response should be nil on error, got %v", resp)
}
}

func TestNewHeader(t *testing.T) {
h := NewHeader("X-Custom", "value")
if h.Key != "X-Custom" {
t.Errorf("Key = %q, want %q", h.Key, "X-Custom")
}
if h.Value != "value" {
t.Errorf("Value = %q, want %q", h.Value, "value")
}
}
149 changes: 149 additions & 0 deletions interceptors/billing_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package interceptors

import (
"math"
"net/http"
"net/http/httptest"
"testing"

"github.com/agentuity/llmproxy"
)

func TestBillingInterceptor_Success(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"chatcmpl-123","model":"gpt-4","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150}}`))
}))
defer upstream.Close()

var result llmproxy.BillingResult
lookup := func(provider, model string) (llmproxy.CostInfo, bool) {
if model == "gpt-4" {
return llmproxy.CostInfo{Input: 30, Output: 60}, true
}
return llmproxy.CostInfo{}, false
}

billing := NewBilling(lookup, func(r llmproxy.BillingResult) {
result = r
})

req, _ := http.NewRequest("POST", upstream.URL, nil)
meta := llmproxy.BodyMetadata{Model: "gpt-4"}

next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, llmproxy.ResponseMetadata{}, nil, err
}
body := []byte(`{"usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150}}`)
respMeta := llmproxy.ResponseMetadata{
Usage: llmproxy.Usage{PromptTokens: 100, CompletionTokens: 50, TotalTokens: 150},
}
return resp, respMeta, body, nil
}

_, _, _, err := billing.Intercept(req, meta, nil, next)
if err != nil {
t.Fatalf("Intercept returned error: %v", err)
}

if result.Model != "gpt-4" {
t.Errorf("Model = %q, want %q", result.Model, "gpt-4")
}
if result.PromptTokens != 100 {
t.Errorf("PromptTokens = %d, want 100", result.PromptTokens)
}
if result.CompletionTokens != 50 {
t.Errorf("CompletionTokens = %d, want 50", result.CompletionTokens)
}

expectedInputCost := 30.0 * 100 / 1_000_000
expectedOutputCost := 60.0 * 50 / 1_000_000
expectedTotal := expectedInputCost + expectedOutputCost

epsilon := 1e-9
if math.Abs(result.TotalCost-expectedTotal) > epsilon {
t.Errorf("TotalCost = %f, want %f (diff: %e)", result.TotalCost, expectedTotal, math.Abs(result.TotalCost-expectedTotal))
}
}

func TestBillingInterceptor_ModelNotFound(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer upstream.Close()

called := false
lookup := func(provider, model string) (llmproxy.CostInfo, bool) {
return llmproxy.CostInfo{}, false
}

billing := NewBilling(lookup, func(r llmproxy.BillingResult) {
called = true
})

req, _ := http.NewRequest("POST", upstream.URL, nil)
meta := llmproxy.BodyMetadata{Model: "unknown-model"}

next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, llmproxy.ResponseMetadata{}, nil, err
}
return resp, llmproxy.ResponseMetadata{Usage: llmproxy.Usage{PromptTokens: 100, CompletionTokens: 50}}, nil, nil
}

_, _, _, err := billing.Intercept(req, meta, nil, next)
if err != nil {
t.Fatalf("Intercept returned error: %v", err)
}

if called {
t.Error("OnResult should not be called when model not found")
}
}

func TestBillingInterceptor_ErrorPassthrough(t *testing.T) {
billing := NewBilling(nil, nil)

req, _ := http.NewRequest("POST", "http://example.com", nil)
next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) {
return nil, llmproxy.ResponseMetadata{}, nil, http.ErrHandlerTimeout
}

_, _, _, err := billing.Intercept(req, llmproxy.BodyMetadata{}, nil, next)
if err != http.ErrHandlerTimeout {
t.Errorf("Error should pass through, got %v", err)
}
}

func TestDetectProvider(t *testing.T) {
tests := []struct {
model string
expected string
}{
{"gpt-4", "openai"},
{"gpt-3.5-turbo", "openai"},
{"o1-preview", "openai"},
{"o3-mini", "openai"},
{"chatgpt-4o", "openai"},
{"claude-3-opus", "anthropic"},
{"claude-3-sonnet", "anthropic"},
{"gemini-pro", "google"},
{"gemini-1.5-flash", "google"},
{"llama-3-70b", "groq"},
{"mixtral-8x7b", "groq"},
{"unknown-model", ""},
}

for _, tt := range tests {
t.Run(tt.model, func(t *testing.T) {
got := detectProvider(tt.model)
if got != tt.expected {
t.Errorf("detectProvider(%q) = %q, want %q", tt.model, got, tt.expected)
}
})
}
}
Loading
Loading