From a05ad09e4abb60f98b4892f18d40e26fb87cb15a Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Sun, 12 Apr 2026 18:51:34 -0500 Subject: [PATCH 1/4] feat: enhance retry with rate limit headers, add comprehensive test coverage - Add rate limit header support to retry interceptor - Parse Retry-After header (seconds and HTTP date formats) - Parse X-RateLimit-Reset header as fallback - New constructor: NewRetryWithRateLimitHeaders() - Safety: ignore values >24h, fallback to default delay - Add comprehensive test coverage - interceptors: 51.6% -> 93.6% - openai_compatible: 52.6% -> 91.2% - New test files: addheader_test.go, billing_test.go, coverage_test.go - 123 total tests, all passing - Update DESIGN.md with retry rate limit header docs --- DESIGN.md | 16 + examples/basic/main.go | 2 + interceptors/addheader_test.go | 177 ++++++ interceptors/billing_test.go | 147 +++++ interceptors/coverage_test.go | 641 +++++++++++++++++++++ interceptors/retry.go | 81 +-- providers/openai_compatible/parser_test.go | 525 +++++++++++++++-- 7 files changed, 1490 insertions(+), 99 deletions(-) create mode 100644 interceptors/addheader_test.go create mode 100644 interceptors/billing_test.go create mode 100644 interceptors/coverage_test.go diff --git a/DESIGN.md b/DESIGN.md index 4ae5648..db9e653 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -314,6 +314,22 @@ All fields use `sync/atomic` operations for thread safety. The `Metrics` struct - **Body handling:** Reconstructs the request body from raw bytes on each retry attempt - **Custom predicate:** `NewRetryWithPredicate(maxAttempts, delay, predicate)` allows callers to supply a custom function that decides whether a given response should be retried +`NewRetryWithRateLimitHeaders(maxAttempts, defaultDelay)` β€” Retries with rate limit header support: + +- **Retry-After header:** Parses both seconds (integer) and HTTP date formats +- **X-RateLimit-Reset header:** Fallback if Retry-After not present +- **Max delay:** Values over 24 hours are ignored (fallback to defaultDelay) +- **Precedence:** Retry-After takes precedence over X-RateLimit-Reset + +Example: + +```go +// Use rate limit headers from provider +retry := interceptors.NewRetryWithRateLimitHeaders(3, time.Second) + +// If provider returns 429 with Retry-After: 30, waits 30s instead of 1s +``` + ### Billing `NewBilling(lookup, onResult)` β€” Calculates the cost of each request: diff --git a/examples/basic/main.go b/examples/basic/main.go index 9086d7a..f6db8a1 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -23,6 +23,7 @@ import ( "log" "net/http" "os" + "time" "github.com/agentuity/llmproxy" "github.com/agentuity/llmproxy/interceptors" @@ -230,6 +231,7 @@ func main() { http.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { provider := openaiProvider opts := []llmproxy.ProxyOption{ + llmproxy.WithInterceptor(interceptors.NewRetry(3, time.Millisecond*250)), llmproxy.WithInterceptor(tracingInterceptor), llmproxy.WithInterceptor(loggingInterceptor), llmproxy.WithInterceptor(interceptors.NewMetrics(metrics)), diff --git a/interceptors/addheader_test.go b/interceptors/addheader_test.go new file mode 100644 index 0000000..9f8079c --- /dev/null +++ b/interceptors/addheader_test.go @@ -0,0 +1,177 @@ +package interceptors + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/agentuity/llmproxy" +) + +func TestAddHeaderInterceptor_ResponseHeaders(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + add := NewAddResponseHeader( + NewHeader("X-Gateway-Version", "1.0"), + NewHeader("X-Served-By", "llmproxy"), + ) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + resp, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if got := resp.Header.Get("X-Gateway-Version"); got != "1.0" { + t.Errorf("X-Gateway-Version header = %q, want %q", got, "1.0") + } + if got := resp.Header.Get("X-Served-By"); got != "llmproxy" { + t.Errorf("X-Served-By header = %q, want %q", got, "llmproxy") + } + if got := resp.Header.Get("Content-Type"); got != "application/json" { + t.Errorf("Content-Type header should be preserved, got %q", got) + } +} + +func TestAddHeaderInterceptor_RequestHeaders(t *testing.T) { + var capturedReq *http.Request + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedReq = r + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + add := NewAddRequestHeader( + NewHeader("X-Client-ID", "my-app"), + NewHeader("X-Request-Source", "gateway"), + ) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + req.Header.Set("Content-Type", "application/json") + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + _, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if got := capturedReq.Header.Get("X-Client-ID"); got != "my-app" { + t.Errorf("X-Client-ID header = %q, want %q", got, "my-app") + } + if got := capturedReq.Header.Get("X-Request-Source"); got != "gateway" { + t.Errorf("X-Request-Source header = %q, want %q", got, "gateway") + } + if got := capturedReq.Header.Get("Content-Type"); got != "application/json" { + t.Errorf("Content-Type header should be preserved, got %q", got) + } +} + +func TestAddHeaderInterceptor_Both(t *testing.T) { + var capturedReq *http.Request + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedReq = r + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + add := NewAddHeader( + []Header{NewHeader("X-Request-ID", "req-123")}, + []Header{NewHeader("X-Response-Time", "50ms")}, + ) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + resp, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if got := capturedReq.Header.Get("X-Request-ID"); got != "req-123" { + t.Errorf("Request X-Request-ID header = %q, want %q", got, "req-123") + } + if got := resp.Header.Get("X-Response-Time"); got != "50ms" { + t.Errorf("Response X-Response-Time header = %q, want %q", got, "50ms") + } +} + +func TestAddHeaderInterceptor_Empty(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + add := &AddHeaderInterceptor{} + + req, _ := http.NewRequest("POST", upstream.URL, nil) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + _, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestAddHeaderInterceptor_ErrorPassthrough(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + add := NewAddResponseHeader(NewHeader("X-Test", "value")) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + expectedErr := http.ErrHandlerTimeout + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + return nil, llmproxy.ResponseMetadata{}, nil, expectedErr + } + + resp, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + if err != expectedErr { + t.Errorf("Error should pass through, got %v, want %v", err, expectedErr) + } + if resp != nil { + t.Errorf("Response should be nil on error, got %v", resp) + } +} + +func TestNewHeader(t *testing.T) { + h := NewHeader("X-Custom", "value") + if h.Key != "X-Custom" { + t.Errorf("Key = %q, want %q", h.Key, "X-Custom") + } + if h.Value != "value" { + t.Errorf("Value = %q, want %q", h.Value, "value") + } +} diff --git a/interceptors/billing_test.go b/interceptors/billing_test.go new file mode 100644 index 0000000..3cccacf --- /dev/null +++ b/interceptors/billing_test.go @@ -0,0 +1,147 @@ +package interceptors + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/agentuity/llmproxy" +) + +func TestBillingInterceptor_Success(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"chatcmpl-123","model":"gpt-4","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150}}`)) + })) + defer upstream.Close() + + var result llmproxy.BillingResult + lookup := func(provider, model string) (llmproxy.CostInfo, bool) { + if model == "gpt-4" { + return llmproxy.CostInfo{Input: 30, Output: 60}, true + } + return llmproxy.CostInfo{}, false + } + + billing := NewBilling(lookup, func(r llmproxy.BillingResult) { + result = r + }) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + body := []byte(`{"usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150}}`) + respMeta := llmproxy.ResponseMetadata{ + Usage: llmproxy.Usage{PromptTokens: 100, CompletionTokens: 50, TotalTokens: 150}, + } + return resp, respMeta, body, nil + } + + _, _, _, err := billing.Intercept(req, meta, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if result.Model != "gpt-4" { + t.Errorf("Model = %q, want %q", result.Model, "gpt-4") + } + if result.PromptTokens != 100 { + t.Errorf("PromptTokens = %d, want 100", result.PromptTokens) + } + if result.CompletionTokens != 50 { + t.Errorf("CompletionTokens = %d, want 50", result.CompletionTokens) + } + + expectedInputCost := 30.0 * 100 / 1_000_000 + expectedOutputCost := 60.0 * 50 / 1_000_000 + expectedTotal := expectedInputCost + expectedOutputCost + + if result.TotalCost != expectedTotal { + t.Errorf("TotalCost = %f, want %f", result.TotalCost, expectedTotal) + } +} + +func TestBillingInterceptor_ModelNotFound(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + called := false + lookup := func(provider, model string) (llmproxy.CostInfo, bool) { + return llmproxy.CostInfo{}, false + } + + billing := NewBilling(lookup, func(r llmproxy.BillingResult) { + called = true + }) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + meta := llmproxy.BodyMetadata{Model: "unknown-model"} + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{Usage: llmproxy.Usage{PromptTokens: 100, CompletionTokens: 50}}, nil, nil + } + + _, _, _, err := billing.Intercept(req, meta, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if called { + t.Error("OnResult should not be called when model not found") + } +} + +func TestBillingInterceptor_ErrorPassthrough(t *testing.T) { + billing := NewBilling(nil, nil) + + req, _ := http.NewRequest("POST", "http://example.com", nil) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + return nil, llmproxy.ResponseMetadata{}, nil, http.ErrHandlerTimeout + } + + _, _, _, err := billing.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + if err != http.ErrHandlerTimeout { + t.Errorf("Error should pass through, got %v", err) + } +} + +func TestDetectProvider(t *testing.T) { + tests := []struct { + model string + expected string + }{ + {"gpt-4", "openai"}, + {"gpt-3.5-turbo", "openai"}, + {"o1-preview", "openai"}, + {"o3-mini", "openai"}, + {"chatgpt-4o", "openai"}, + {"claude-3-opus", "anthropic"}, + {"claude-3-sonnet", "anthropic"}, + {"gemini-pro", "google"}, + {"gemini-1.5-flash", "google"}, + {"llama-3-70b", "groq"}, + {"mixtral-8x7b", "groq"}, + {"unknown-model", ""}, + } + + for _, tt := range tests { + t.Run(tt.model, func(t *testing.T) { + got := detectProvider(tt.model) + if got != tt.expected { + t.Errorf("detectProvider(%q) = %q, want %q", tt.model, got, tt.expected) + } + }) + } +} diff --git a/interceptors/coverage_test.go b/interceptors/coverage_test.go new file mode 100644 index 0000000..217e35a --- /dev/null +++ b/interceptors/coverage_test.go @@ -0,0 +1,641 @@ +package interceptors + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/agentuity/llmproxy" +) + +func TestMetricsInterceptor_Success(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + m := &Metrics{} + metrics := NewMetrics(m) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + respMeta := llmproxy.ResponseMetadata{ + Usage: llmproxy.Usage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + }, + } + return resp, respMeta, nil, nil + } + + _, _, _, err := metrics.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if m.TotalRequests != 1 { + t.Errorf("TotalRequests = %d, want 1", m.TotalRequests) + } + if m.Errors != 0 { + t.Errorf("Errors = %d, want 0", m.Errors) + } + if m.TotalTokens != 150 { + t.Errorf("TotalTokens = %d, want 150", m.TotalTokens) + } + if m.TotalPromptTokens != 100 { + t.Errorf("TotalPromptTokens = %d, want 100", m.TotalPromptTokens) + } + if m.TotalCompletionTokens != 50 { + t.Errorf("TotalCompletionTokens = %d, want 50", m.TotalCompletionTokens) + } + if m.TotalLatency <= 0 { + t.Errorf("TotalLatency = %d, want > 0", m.TotalLatency) + } +} + +func TestMetricsInterceptor_Error(t *testing.T) { + m := &Metrics{} + metrics := NewMetrics(m) + + req, _ := http.NewRequest("POST", "http://example.com", nil) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + return nil, llmproxy.ResponseMetadata{}, nil, http.ErrHandlerTimeout + } + + _, _, _, err := metrics.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + if err == nil { + t.Fatal("expected error") + } + + if m.TotalRequests != 1 { + t.Errorf("TotalRequests = %d, want 1", m.TotalRequests) + } + if m.Errors != 1 { + t.Errorf("Errors = %d, want 1", m.Errors) + } + if m.TotalTokens != 0 { + t.Errorf("TotalTokens = %d, want 0 (no tokens on error)", m.TotalTokens) + } +} + +func TestMetricsInterceptor_MultipleRequests(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + m := &Metrics{} + metrics := NewMetrics(m) + + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + respMeta := llmproxy.ResponseMetadata{ + Usage: llmproxy.Usage{TotalTokens: 100}, + } + return resp, respMeta, nil, nil + } + + for i := 0; i < 5; i++ { + req, _ := http.NewRequest("POST", upstream.URL, nil) + _, _, _, _ = metrics.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + } + + if m.TotalRequests != 5 { + t.Errorf("TotalRequests = %d, want 5", m.TotalRequests) + } + if m.TotalTokens != 500 { + t.Errorf("TotalTokens = %d, want 500", m.TotalTokens) + } +} + +func TestRetryInterceptor_Success(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + retry := NewRetry(3, time.Millisecond) + + req, _ := http.NewRequest("POST", upstream.URL, http.NoBody) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + _, _, _, err := retry.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestRetryInterceptor_RetriesOn5xx(t *testing.T) { + callCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + if callCount < 3 { + w.WriteHeader(http.StatusInternalServerError) + } else { + w.WriteHeader(http.StatusOK) + } + })) + defer upstream.Close() + + retry := NewRetry(3, time.Millisecond) + + req, _ := http.NewRequest("POST", upstream.URL, http.NoBody) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + resp, _, _, err := retry.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if callCount != 3 { + t.Errorf("callCount = %d, want 3", callCount) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", resp.StatusCode) + } +} + +func TestRetryInterceptor_RetriesOn429(t *testing.T) { + callCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + if callCount == 1 { + w.WriteHeader(http.StatusTooManyRequests) + } else { + w.WriteHeader(http.StatusOK) + } + })) + defer upstream.Close() + + retry := NewRetry(3, time.Millisecond) + + req, _ := http.NewRequest("POST", upstream.URL, http.NoBody) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + resp, _, _, err := retry.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if callCount != 2 { + t.Errorf("callCount = %d, want 2", callCount) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", resp.StatusCode) + } +} + +func TestRetryInterceptor_ExhaustedAttempts(t *testing.T) { + callCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusInternalServerError) + })) + defer upstream.Close() + + retry := NewRetry(3, time.Millisecond) + + req, _ := http.NewRequest("POST", upstream.URL, http.NoBody) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + resp, _, _, _ := retry.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + + if callCount != 3 { + t.Errorf("callCount = %d, want 3", callCount) + } + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("StatusCode = %d, want 500", resp.StatusCode) + } +} + +func TestRetryInterceptor_NoRetryOn200(t *testing.T) { + callCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + retry := NewRetry(3, time.Millisecond) + + req, _ := http.NewRequest("POST", upstream.URL, http.NoBody) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + _, _, _, _ = retry.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + + if callCount != 1 { + t.Errorf("callCount = %d, want 1 (no retry on success)", callCount) + } +} + +func TestRetryInterceptor_CustomPredicate(t *testing.T) { + callCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + if callCount == 1 { + w.WriteHeader(http.StatusBadGateway) + } else { + w.WriteHeader(http.StatusOK) + } + })) + defer upstream.Close() + + retry := NewRetryWithPredicate(3, time.Millisecond, func(resp *http.Response, err error) bool { + return resp.StatusCode == http.StatusBadGateway + }) + + req, _ := http.NewRequest("POST", upstream.URL, http.NoBody) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + resp, _, _, _ := retry.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + + if callCount != 2 { + t.Errorf("callCount = %d, want 2", callCount) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", resp.StatusCode) + } +} + +func TestRetryInterceptor_RetryAfterHeader(t *testing.T) { + callCount := 0 + var retryAfter int + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + if callCount == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(http.StatusTooManyRequests) + } else { + w.WriteHeader(http.StatusOK) + } + })) + defer upstream.Close() + + retry := NewRetryWithRateLimitHeaders(3, 5*time.Second) + + start := time.Now() + req, _ := http.NewRequest("POST", upstream.URL, http.NoBody) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + _, _, _, _ = retry.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + elapsed := time.Since(start) + + if callCount != 2 { + t.Errorf("callCount = %d, want 2", callCount) + } + if elapsed < time.Duration(retryAfter)*time.Second { + t.Errorf("elapsed = %v, should have waited at least %v", elapsed, time.Duration(retryAfter)*time.Second) + } +} + +func TestRetryInterceptor_RetryAfterDateHeader(t *testing.T) { + callCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + if callCount == 1 { + retryTime := time.Now().Add(500 * time.Millisecond) + w.Header().Set("Retry-After", retryTime.UTC().Format(http.TimeFormat)) + w.WriteHeader(http.StatusTooManyRequests) + } else { + w.WriteHeader(http.StatusOK) + } + })) + defer upstream.Close() + + retry := NewRetryWithRateLimitHeaders(3, 5*time.Second) + + start := time.Now() + req, _ := http.NewRequest("POST", upstream.URL, http.NoBody) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + _, _, _, _ = retry.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + elapsed := time.Since(start) + + if callCount != 2 { + t.Errorf("callCount = %d, want 2", callCount) + } + if elapsed < 400*time.Millisecond { + t.Errorf("elapsed = %v, should have waited for Retry-After date", elapsed) + } +} + +func TestRetryInterceptor_XRateLimitReset(t *testing.T) { + callCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + if callCount == 1 { + w.Header().Set("X-RateLimit-Reset", "1") + w.WriteHeader(http.StatusTooManyRequests) + } else { + w.WriteHeader(http.StatusOK) + } + })) + defer upstream.Close() + + retry := NewRetryWithRateLimitHeaders(3, 5*time.Second) + + start := time.Now() + req, _ := http.NewRequest("POST", upstream.URL, http.NoBody) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + _, _, _, _ = retry.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + elapsed := time.Since(start) + + if callCount != 2 { + t.Errorf("callCount = %d, want 2", callCount) + } + if elapsed < 900*time.Millisecond { + t.Errorf("elapsed = %v, should have used X-RateLimit-Reset header", elapsed) + } +} + +func TestRetryInterceptor_RateLimitHeadersFallback(t *testing.T) { + callCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + if callCount == 1 { + w.WriteHeader(http.StatusTooManyRequests) + } else { + w.WriteHeader(http.StatusOK) + } + })) + defer upstream.Close() + + retry := NewRetryWithRateLimitHeaders(3, 10*time.Millisecond) + + start := time.Now() + req, _ := http.NewRequest("POST", upstream.URL, http.NoBody) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + _, _, _, _ = retry.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + elapsed := time.Since(start) + + if callCount != 2 { + t.Errorf("callCount = %d, want 2", callCount) + } + if elapsed < 8*time.Millisecond { + t.Errorf("elapsed = %v, should have used default delay as fallback", elapsed) + } +} + +func TestParseRetryAfterHeader_Seconds(t *testing.T) { + resp := &http.Response{Header: make(http.Header)} + resp.Header.Set("Retry-After", "30") + + delay := parseRetryAfterHeader(resp) + if delay != 30*time.Second { + t.Errorf("delay = %v, want 30s", delay) + } +} + +func TestParseRetryAfterHeader_Date(t *testing.T) { + future := time.Now().Add(60 * time.Second) + resp := &http.Response{Header: make(http.Header)} + resp.Header.Set("Retry-After", future.UTC().Format(http.TimeFormat)) + + delay := parseRetryAfterHeader(resp) + if delay < 50*time.Second || delay > 70*time.Second { + t.Errorf("delay = %v, want ~60s", delay) + } +} + +func TestParseRetryAfterHeader_XRateLimitReset(t *testing.T) { + resp := &http.Response{Header: make(http.Header)} + resp.Header.Set("X-RateLimit-Reset", "45") + + delay := parseRetryAfterHeader(resp) + if delay != 45*time.Second { + t.Errorf("delay = %v, want 45s", delay) + } +} + +func TestParseRetryAfterHeader_Empty(t *testing.T) { + resp := &http.Response{Header: make(http.Header)} + + delay := parseRetryAfterHeader(resp) + if delay != 0 { + t.Errorf("delay = %v, want 0 (no header)", delay) + } +} + +func TestParseRetryAfterHeader_Invalid(t *testing.T) { + resp := &http.Response{Header: make(http.Header)} + resp.Header.Set("Retry-After", "invalid") + + delay := parseRetryAfterHeader(resp) + if delay != 0 { + t.Errorf("delay = %v, want 0 (invalid header)", delay) + } +} + +func TestParseRetryAfterHeader_TooLarge(t *testing.T) { + resp := &http.Response{Header: make(http.Header)} + resp.Header.Set("Retry-After", "86401") + + delay := parseRetryAfterHeader(resp) + if delay != 0 { + t.Errorf("delay = %v, want 0 (>24h is ignored)", delay) + } +} + +func TestParseRetryAfterHeader_RetryAfterPreferred(t *testing.T) { + resp := &http.Response{Header: make(http.Header)} + resp.Header.Set("Retry-After", "10") + resp.Header.Set("X-RateLimit-Reset", "20") + + delay := parseRetryAfterHeader(resp) + if delay != 10*time.Second { + t.Errorf("delay = %v, want 10s (Retry-After takes precedence)", delay) + } +} + +func TestNewRetryWithRateLimitHeaders(t *testing.T) { + retry := NewRetryWithRateLimitHeaders(5, time.Second) + if retry.MaxAttempts != 5 { + t.Errorf("MaxAttempts = %d, want 5", retry.MaxAttempts) + } + if retry.Delay != time.Second { + t.Errorf("Delay = %v, want 1s", retry.Delay) + } + if !retry.UseRateLimitHeaders { + t.Error("UseRateLimitHeaders should be true") + } +} + +func TestLoggingInterceptor_Success(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + var loggedMessages []string + logger := llmproxy.LoggerFunc(func(level, msg string, args ...interface{}) { + loggedMessages = append(loggedMessages, level+":"+msg) + }) + + logging := NewLogging(logger) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + respMeta := llmproxy.ResponseMetadata{ + Usage: llmproxy.Usage{PromptTokens: 100, CompletionTokens: 50}, + } + return resp, respMeta, nil, nil + } + + _, _, _, err := logging.Intercept(req, meta, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } + + if len(loggedMessages) < 2 { + t.Errorf("expected at least 2 log messages, got %d", len(loggedMessages)) + } +} + +func TestLoggingInterceptor_Error(t *testing.T) { + var loggedMessages []string + logger := llmproxy.LoggerFunc(func(level, msg string, args ...interface{}) { + loggedMessages = append(loggedMessages, level+":"+msg) + }) + + logging := NewLogging(logger) + + req, _ := http.NewRequest("POST", "http://example.com", nil) + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + return nil, llmproxy.ResponseMetadata{}, nil, http.ErrHandlerTimeout + } + + _, _, _, err := logging.Intercept(req, meta, nil, next) + if err == nil { + t.Fatal("expected error") + } + + var hasErrorLog bool + for _, msg := range loggedMessages { + if len(msg) > 5 && msg[:5] == "error" { + hasErrorLog = true + break + } + } + if !hasErrorLog { + t.Error("expected error log message") + } +} + +func TestLoggingInterceptor_NilLogger(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer upstream.Close() + + logging := NewLogging(nil) + + req, _ := http.NewRequest("POST", upstream.URL, nil) + next := func(req *http.Request) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, llmproxy.ResponseMetadata{}, nil, err + } + return resp, llmproxy.ResponseMetadata{}, nil, nil + } + + _, _, _, err := logging.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + if err != nil { + t.Fatalf("Intercept returned error: %v", err) + } +} + +func TestLoggerFunc(t *testing.T) { + var calls []string + logger := llmproxy.LoggerFunc(func(level, msg string, args ...interface{}) { + calls = append(calls, level) + }) + + logger.Debug("test") + logger.Info("test") + logger.Warn("test") + logger.Error("test") + + expected := []string{"debug", "info", "warn", "error"} + for i, exp := range expected { + if calls[i] != exp { + t.Errorf("call %d = %s, want %s", i, calls[i], exp) + } + } +} diff --git a/interceptors/retry.go b/interceptors/retry.go index 0b48c76..5305edd 100644 --- a/interceptors/retry.go +++ b/interceptors/retry.go @@ -5,30 +5,19 @@ import ( "context" "io" "net/http" + "strconv" "time" "github.com/agentuity/llmproxy" ) -// RetryInterceptor automatically retries failed requests. -// It handles transient failures like rate limits (429) and server errors (5xx). type RetryInterceptor struct { - // MaxAttempts is the maximum number of request attempts (including initial). - MaxAttempts int - // Delay is the wait time between retry attempts. - Delay time.Duration - // IsRetryable is a custom predicate to determine if a request should be retried. - // If nil, the default predicate is used (retries on 429 and 5xx responses, - // and on network errors except context cancellation). - IsRetryable func(*http.Response, error) bool + MaxAttempts int + Delay time.Duration + IsRetryable func(*http.Response, error) bool + UseRateLimitHeaders bool } -// Intercept attempts the request up to MaxAttempts times. -// Between retries, it waits for the configured Delay. -// Context cancellation (context.Canceled, context.DeadlineExceeded) is not retried. -// -// The rawBody is used to reconstruct the request body for each retry attempt, -// since HTTP request bodies can only be read once. func (i *RetryInterceptor) Intercept(req *http.Request, meta llmproxy.BodyMetadata, rawBody []byte, next llmproxy.RoundTripFunc) (*http.Response, llmproxy.ResponseMetadata, []byte, error) { var lastErr error var lastResp *http.Response @@ -42,8 +31,15 @@ func (i *RetryInterceptor) Intercept(req *http.Request, meta llmproxy.BodyMetada for attempt := 1; attempt <= i.MaxAttempts; attempt++ { if attempt > 1 { + delay := i.Delay + if i.UseRateLimitHeaders && lastResp != nil { + if headerDelay := parseRetryAfterHeader(lastResp); headerDelay > 0 { + delay = headerDelay + } + } + select { - case <-time.After(i.Delay): + case <-time.After(delay): case <-req.Context().Done(): return nil, lastMeta, lastRawRespBody, req.Context().Err() } @@ -75,21 +71,37 @@ func isContextError(err error) bool { return err == context.Canceled || err == context.DeadlineExceeded } +func parseRetryAfterHeader(resp *http.Response) time.Duration { + retryAfter := resp.Header.Get("Retry-After") + if retryAfter == "" { + retryAfter = resp.Header.Get("X-RateLimit-Reset") + } + if retryAfter == "" { + return 0 + } + + if seconds, err := strconv.Atoi(retryAfter); err == nil { + if seconds > 0 && seconds < 86400 { + return time.Duration(seconds) * time.Second + } + } + + if t, err := http.ParseTime(retryAfter); err == nil { + delay := time.Until(t) + if delay > 0 && delay < 24*time.Hour { + return delay + } + } + + return 0 +} + func cloneRequest(req *http.Request, body []byte) *http.Request { cloned := req.Clone(req.Context()) cloned.Body = io.NopCloser(bytes.NewReader(body)) return cloned } -// NewRetry creates a retry interceptor with the given configuration. -// -// Parameters: -// - maxAttempts: Maximum number of attempts (e.g., 3 = initial + 2 retries) -// - delay: Time to wait between retry attempts -// -// Example: -// -// retry := interceptors.NewRetry(3, time.Second) func NewRetry(maxAttempts int, delay time.Duration) *RetryInterceptor { return &RetryInterceptor{ MaxAttempts: maxAttempts, @@ -97,15 +109,14 @@ func NewRetry(maxAttempts int, delay time.Duration) *RetryInterceptor { } } -// NewRetryWithPredicate creates a retry interceptor with a custom retry predicate. -// Use this to customize which responses or errors should be retried. -// -// Example: -// -// retry := interceptors.NewRetryWithPredicate(3, time.Second, func(resp *http.Response, err error) bool { -// // Only retry on specific error conditions -// return err != nil || resp.StatusCode == 503 -// }) +func NewRetryWithRateLimitHeaders(maxAttempts int, defaultDelay time.Duration) *RetryInterceptor { + return &RetryInterceptor{ + MaxAttempts: maxAttempts, + Delay: defaultDelay, + UseRateLimitHeaders: true, + } +} + func NewRetryWithPredicate(maxAttempts int, delay time.Duration, isRetryable func(*http.Response, error) bool) *RetryInterceptor { return &RetryInterceptor{ MaxAttempts: maxAttempts, diff --git a/providers/openai_compatible/parser_test.go b/providers/openai_compatible/parser_test.go index ba83a93..2675cb0 100644 --- a/providers/openai_compatible/parser_test.go +++ b/providers/openai_compatible/parser_test.go @@ -4,84 +4,481 @@ import ( "bytes" "io" "net/http" + "net/http/httptest" + "strings" "testing" "github.com/agentuity/llmproxy" ) -func TestParser(t *testing.T) { - t.Run("parses basic request", func(t *testing.T) { - body := `{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}` - parser := &Parser{} +func TestParser_BasicRequest(t *testing.T) { + body := `{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}` + parser := &Parser{} - meta, raw, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if meta.Model != "gpt-4" { - t.Errorf("expected model gpt-4, got %s", meta.Model) - } - if len(meta.Messages) != 1 { - t.Errorf("expected 1 message, got %d", len(meta.Messages)) - } - if string(raw) != body { - t.Error("raw body mismatch") - } - }) + meta, raw, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if meta.Model != "gpt-4" { + t.Errorf("model = %q, want %q", meta.Model, "gpt-4") + } + if len(meta.Messages) != 1 { + t.Fatalf("messages count = %d, want 1", len(meta.Messages)) + } + if meta.Messages[0].Role != "user" { + t.Errorf("message role = %q, want %q", meta.Messages[0].Role, "user") + } + if meta.Messages[0].Content != "hello" { + t.Errorf("message content = %q, want %q", meta.Messages[0].Content, "hello") + } + if string(raw) != body { + t.Errorf("raw body mismatch") + } +} - t.Run("parses custom fields", func(t *testing.T) { - body := `{"model":"gpt-4","custom_field":"value"}` - parser := &Parser{} +func TestParser_AllFields(t *testing.T) { + body := `{"model":"gpt-4","messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"hi"},{"role":"assistant","content":"Hello!"}],"max_tokens":1000,"stream":true}` + parser := &Parser{} - meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if meta.Custom["custom_field"] != "value" { - t.Errorf("expected custom_field value, got %v", meta.Custom["custom_field"]) - } - }) + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } - t.Run("returns error for invalid JSON", func(t *testing.T) { - parser := &Parser{} - _, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte("invalid")))) - if err == nil { - t.Fatal("expected error") - } - }) + if meta.Model != "gpt-4" { + t.Errorf("model = %q, want %q", meta.Model, "gpt-4") + } + if len(meta.Messages) != 3 { + t.Errorf("messages count = %d, want 3", len(meta.Messages)) + } + if meta.MaxTokens != 1000 { + t.Errorf("max_tokens = %d, want 1000", meta.MaxTokens) + } + if !meta.Stream { + t.Errorf("stream = %v, want true", meta.Stream) + } } -func TestEnricher(t *testing.T) { - t.Run("sets authorization header", func(t *testing.T) { - enricher := NewEnricher("test-key") - req, _ := http.NewRequest("POST", "http://example.com", nil) - meta := llmproxy.BodyMetadata{} +func TestParser_CustomFields(t *testing.T) { + body := `{"model":"gpt-4","custom_field":"custom_value","another_custom":123,"provider_specific":{"nested":"data"}}` + parser := &Parser{} - err := enricher.Enrich(req, meta, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if req.Header.Get("Authorization") != "Bearer test-key" { - t.Errorf("expected Bearer token, got %s", req.Header.Get("Authorization")) - } - }) + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Custom["custom_field"] != "custom_value" { + t.Errorf("custom_field = %v, want custom_value", meta.Custom["custom_field"]) + } + if meta.Custom["another_custom"] != 123.0 { + t.Errorf("another_custom = %v, want 123", meta.Custom["another_custom"]) + } + if meta.Custom["provider_specific"] == nil { + t.Error("provider_specific should be in Custom") + } } -func TestResolver(t *testing.T) { - t.Run("resolves to correct endpoint", func(t *testing.T) { - resolver, err := NewResolver("https://api.example.com") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } +func TestParser_KnownFieldsNotInCustom(t *testing.T) { + body := `{"model":"gpt-4","temperature":0.7,"top_p":0.9,"frequency_penalty":0.5,"presence_penalty":0.3,"stop":["stop1","stop2"]}` + parser := &Parser{} - meta := llmproxy.BodyMetadata{Model: "gpt-4"} - u, err := resolver.Resolve(meta) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - expected := "https://api.example.com/v1/chat/completions" - if u.String() != expected { - t.Errorf("expected %s, got %s", expected, u.String()) + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + knownFields := []string{"model", "messages", "max_tokens", "stream", "temperature", "top_p", "n", "stop", "presence_penalty", "frequency_penalty", "logit_bias", "user"} + for _, field := range knownFields { + if _, ok := meta.Custom[field]; ok { + t.Errorf("known field %q should not be in Custom map", field) } - }) + } +} + +func TestParser_EmptyRequest(t *testing.T) { + body := `{}` + parser := &Parser{} + + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Model != "" { + t.Errorf("model should be empty, got %q", meta.Model) + } + if len(meta.Messages) != 0 { + t.Errorf("messages should be empty, got %d", len(meta.Messages)) + } +} + +func TestParser_InvalidJSON(t *testing.T) { + parser := &Parser{} + + _, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte("invalid json")))) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestParser_MultilineContent(t *testing.T) { + body := `{"model":"gpt-4","messages":[{"role":"user","content":"line1\nline2\nline3"}]}` + parser := &Parser{} + + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Messages[0].Content != "line1\nline2\nline3" { + t.Errorf("multiline content not preserved: %q", meta.Messages[0].Content) + } +} + +func TestParser_UnicodeContent(t *testing.T) { + body := `{"model":"gpt-4","messages":[{"role":"user","content":"Hello δΈ–η•Œ 🌍"}]}` + parser := &Parser{} + + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Messages[0].Content != "Hello δΈ–η•Œ 🌍" { + t.Errorf("unicode content not preserved: %q", meta.Messages[0].Content) + } +} + +func TestEnricher_SetsHeaders(t *testing.T) { + enricher := NewEnricher("test-api-key") + req := httptest.NewRequest("POST", "https://api.example.com/v1/chat/completions", nil) + + err := enricher.Enrich(req, llmproxy.BodyMetadata{}, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if auth := req.Header.Get("Authorization"); auth != "Bearer test-api-key" { + t.Errorf("Authorization = %q, want %q", auth, "Bearer test-api-key") + } + if ct := req.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type = %q, want %q", ct, "application/json") + } +} + +func TestEnricher_EmptyKey(t *testing.T) { + enricher := NewEnricher("") + req := httptest.NewRequest("POST", "https://example.com", nil) + + err := enricher.Enrich(req, llmproxy.BodyMetadata{}, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + auth := req.Header.Get("Authorization") + if auth != "Bearer " { + t.Errorf("Authorization = %q, want %q", auth, "Bearer ") + } +} + +func TestExtractor_BasicResponse(t *testing.T) { + body := `{"id":"chatcmpl-123","object":"chat.completion","created":1700000000,"model":"gpt-4","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150},"choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}]}` + extractor := NewExtractor() + + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } + + meta, rawBody, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.ID != "chatcmpl-123" { + t.Errorf("ID = %q, want %q", meta.ID, "chatcmpl-123") + } + if meta.Object != "chat.completion" { + t.Errorf("Object = %q, want %q", meta.Object, "chat.completion") + } + if meta.Model != "gpt-4" { + t.Errorf("Model = %q, want %q", meta.Model, "gpt-4") + } + if meta.Usage.PromptTokens != 100 { + t.Errorf("PromptTokens = %d, want 100", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 50 { + t.Errorf("CompletionTokens = %d, want 50", meta.Usage.CompletionTokens) + } + if meta.Usage.TotalTokens != 150 { + t.Errorf("TotalTokens = %d, want 150", meta.Usage.TotalTokens) + } + if len(meta.Choices) != 1 { + t.Fatalf("Choices count = %d, want 1", len(meta.Choices)) + } + if meta.Choices[0].Index != 0 { + t.Errorf("Choice index = %d, want 0", meta.Choices[0].Index) + } + if meta.Choices[0].Message == nil { + t.Fatal("Choice message is nil") + } + if meta.Choices[0].Message.Role != "assistant" { + t.Errorf("Choice message role = %q, want assistant", meta.Choices[0].Message.Role) + } + if meta.Choices[0].Message.Content != "Hello!" { + t.Errorf("Choice message content = %q, want Hello!", meta.Choices[0].Message.Content) + } + if meta.Choices[0].FinishReason != "stop" { + t.Errorf("FinishReason = %q, want stop", meta.Choices[0].FinishReason) + } + if string(rawBody) != body { + t.Error("raw body not preserved") + } +} + +func TestExtractor_MultipleChoices(t *testing.T) { + body := `{"id":"chatcmpl-123","model":"gpt-4","usage":{"prompt_tokens":10,"completion_tokens":20,"total_tokens":30},"choices":[{"index":0,"message":{"role":"assistant","content":"Option A"},"finish_reason":"stop"},{"index":1,"message":{"role":"assistant","content":"Option B"},"finish_reason":"stop"}]}` + extractor := NewExtractor() + + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(meta.Choices) != 2 { + t.Fatalf("Choices count = %d, want 2", len(meta.Choices)) + } + if meta.Choices[0].Message.Content != "Option A" { + t.Errorf("Choice 0 content = %q, want Option A", meta.Choices[0].Message.Content) + } + if meta.Choices[1].Message.Content != "Option B" { + t.Errorf("Choice 1 content = %q, want Option B", meta.Choices[1].Message.Content) + } +} + +func TestExtractor_DeltaForStreaming(t *testing.T) { + body := `{"id":"chatcmpl-123","model":"gpt-4","usage":{"prompt_tokens":10,"completion_tokens":0,"total_tokens":10},"choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":""}]}` + extractor := NewExtractor() + + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Choices[0].Delta == nil { + t.Fatal("Delta should not be nil") + } + if meta.Choices[0].Delta.Role != "assistant" { + t.Errorf("Delta role = %q, want assistant", meta.Choices[0].Delta.Role) + } + if meta.Choices[0].Delta.Content != "Hello" { + t.Errorf("Delta content = %q, want Hello", meta.Choices[0].Delta.Content) + } +} + +func TestExtractor_EmptyChoices(t *testing.T) { + body := `{"id":"chatcmpl-123","model":"gpt-4","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0},"choices":[]}` + extractor := NewExtractor() + + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(meta.Choices) != 0 { + t.Errorf("Choices count = %d, want 0", len(meta.Choices)) + } +} + +func TestExtractor_ZeroUsage(t *testing.T) { + body := `{"id":"chatcmpl-123","model":"gpt-4","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0},"choices":[]}` + extractor := NewExtractor() + + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } + + meta, _, err := extractor.Extract(resp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Usage.PromptTokens != 0 { + t.Errorf("PromptTokens = %d, want 0", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 0 { + t.Errorf("CompletionTokens = %d, want 0", meta.Usage.CompletionTokens) + } +} + +func TestExtractor_InvalidJSON(t *testing.T) { + extractor := NewExtractor() + + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("invalid json")), + } + + _, _, err := extractor.Extract(resp) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestResolver_BasicURL(t *testing.T) { + resolver, err := NewResolver("https://api.openai.com") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + meta := llmproxy.BodyMetadata{Model: "gpt-4"} + u, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := "https://api.openai.com/v1/chat/completions" + if u.String() != expected { + t.Errorf("URL = %q, want %q", u.String(), expected) + } +} + +func TestResolver_CustomBaseURL(t *testing.T) { + resolver, err := NewResolver("https://api.groq.com/openai") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + meta := llmproxy.BodyMetadata{Model: "llama-3-70b"} + u, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := "https://api.groq.com/openai/v1/chat/completions" + if u.String() != expected { + t.Errorf("URL = %q, want %q", u.String(), expected) + } +} + +func TestResolver_InvalidURL(t *testing.T) { + _, err := NewResolver("://invalid-url") + if err == nil { + t.Fatal("expected error for invalid URL") + } +} + +func TestResolver_TrailingSlash(t *testing.T) { + resolver, err := NewResolver("https://api.openai.com/") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + meta := llmproxy.BodyMetadata{} + u, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := "https://api.openai.com/v1/chat/completions" + if u.String() != expected { + t.Errorf("URL = %q, want %q", u.String(), expected) + } +} + +func TestProvider_New(t *testing.T) { + provider, err := New("test-provider", "test-key", "https://api.test.com") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if provider.Name() != "test-provider" { + t.Errorf("Name = %q, want %q", provider.Name(), "test-provider") + } + if provider.BodyParser() == nil { + t.Error("BodyParser should not be nil") + } + if provider.RequestEnricher() == nil { + t.Error("RequestEnricher should not be nil") + } + if provider.ResponseExtractor() == nil { + t.Error("ResponseExtractor should not be nil") + } + if provider.URLResolver() == nil { + t.Error("URLResolver should not be nil") + } +} + +func TestProvider_NewInvalidURL(t *testing.T) { + _, err := New("test", "key", "://invalid") + if err == nil { + t.Fatal("expected error for invalid URL") + } +} + +func TestProvider_NewWithProvider(t *testing.T) { + base := llmproxy.NewBaseProvider("custom", + llmproxy.WithBodyParser(&Parser{}), + ) + + provider := NewWithProvider("custom", base) + if provider.Name() != "custom" { + t.Errorf("Name = %q, want %q", provider.Name(), "custom") + } +} + +func TestParseOpenAIRequestBody(t *testing.T) { + data := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"test"}]}`) + + meta, err := ParseOpenAIRequestBody(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Model != "gpt-4" { + t.Errorf("Model = %q, want gpt-4", meta.Model) + } + if len(meta.Messages) != 1 { + t.Errorf("Messages count = %d, want 1", len(meta.Messages)) + } +} + +func TestNewExtractor(t *testing.T) { + extractor := NewExtractor() + if extractor == nil { + t.Error("NewExtractor returned nil") + } +} + +func TestNewEnricher(t *testing.T) { + enricher := NewEnricher("test-key") + if enricher == nil { + t.Error("NewEnricher returned nil") + } + if enricher.APIKey != "test-key" { + t.Errorf("APIKey = %q, want test-key", enricher.APIKey) + } } From c2ae1e2b86ef657468952c54fd012c3c49eb3939 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Sun, 12 Apr 2026 19:06:18 -0500 Subject: [PATCH 2/4] fix: address PR review feedback Inline comment fixes: - addheader_test.go: Close response body to prevent connection leaks - coverage_test.go: Fix flaky RetryAfter tests with proper timing tolerances - retry.go: Accept exactly 24h (86400s) in boundary checks, not just < 24h Nitpick fixes: - examples/basic/main.go: Use NewRetryWithRateLimitHeaders in example - billing_test.go: Use epsilon-based float comparison for TotalCost - parser_test.go: Add TotalTokens assertion in TestExtractor_ZeroUsage - Add test for exactly 24h boundary case --- DESIGN.md | 2 +- examples/basic/main.go | 2 +- interceptors/addheader_test.go | 2 ++ interceptors/billing_test.go | 6 ++-- interceptors/coverage_test.go | 39 ++++++++++++++++------ interceptors/retry.go | 4 +-- providers/openai_compatible/parser_test.go | 3 ++ 7 files changed, 42 insertions(+), 16 deletions(-) diff --git a/DESIGN.md b/DESIGN.md index db9e653..6d5ef96 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -318,7 +318,7 @@ All fields use `sync/atomic` operations for thread safety. The `Metrics` struct - **Retry-After header:** Parses both seconds (integer) and HTTP date formats - **X-RateLimit-Reset header:** Fallback if Retry-After not present -- **Max delay:** Values over 24 hours are ignored (fallback to defaultDelay) +- **Max delay:** Values over 24 hours are ignored (fallback to defaultDelay); exactly 24 hours is accepted - **Precedence:** Retry-After takes precedence over X-RateLimit-Reset Example: diff --git a/examples/basic/main.go b/examples/basic/main.go index f6db8a1..eb76627 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -231,7 +231,7 @@ func main() { http.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { provider := openaiProvider opts := []llmproxy.ProxyOption{ - llmproxy.WithInterceptor(interceptors.NewRetry(3, time.Millisecond*250)), + llmproxy.WithInterceptor(interceptors.NewRetryWithRateLimitHeaders(3, time.Millisecond*250)), llmproxy.WithInterceptor(tracingInterceptor), llmproxy.WithInterceptor(loggingInterceptor), llmproxy.WithInterceptor(interceptors.NewMetrics(metrics)), diff --git a/interceptors/addheader_test.go b/interceptors/addheader_test.go index 9f8079c..ed157fc 100644 --- a/interceptors/addheader_test.go +++ b/interceptors/addheader_test.go @@ -33,6 +33,7 @@ func TestAddHeaderInterceptor_ResponseHeaders(t *testing.T) { if err != nil { t.Fatalf("Intercept returned error: %v", err) } + defer resp.Body.Close() if got := resp.Header.Get("X-Gateway-Version"); got != "1.0" { t.Errorf("X-Gateway-Version header = %q, want %q", got, "1.0") @@ -111,6 +112,7 @@ func TestAddHeaderInterceptor_Both(t *testing.T) { if err != nil { t.Fatalf("Intercept returned error: %v", err) } + defer resp.Body.Close() if got := capturedReq.Header.Get("X-Request-ID"); got != "req-123" { t.Errorf("Request X-Request-ID header = %q, want %q", got, "req-123") diff --git a/interceptors/billing_test.go b/interceptors/billing_test.go index 3cccacf..2f3e174 100644 --- a/interceptors/billing_test.go +++ b/interceptors/billing_test.go @@ -1,6 +1,7 @@ package interceptors import ( + "math" "net/http" "net/http/httptest" "testing" @@ -62,8 +63,9 @@ func TestBillingInterceptor_Success(t *testing.T) { expectedOutputCost := 60.0 * 50 / 1_000_000 expectedTotal := expectedInputCost + expectedOutputCost - if result.TotalCost != expectedTotal { - t.Errorf("TotalCost = %f, want %f", result.TotalCost, expectedTotal) + epsilon := 1e-9 + if math.Abs(result.TotalCost-expectedTotal) > epsilon { + t.Errorf("TotalCost = %f, want %f (diff: %e)", result.TotalCost, expectedTotal, math.Abs(result.TotalCost-expectedTotal)) } } diff --git a/interceptors/coverage_test.go b/interceptors/coverage_test.go index 217e35a..c122add 100644 --- a/interceptors/coverage_test.go +++ b/interceptors/coverage_test.go @@ -3,6 +3,7 @@ package interceptors import ( "net/http" "net/http/httptest" + "strconv" "testing" "time" @@ -304,11 +305,11 @@ func TestRetryInterceptor_CustomPredicate(t *testing.T) { func TestRetryInterceptor_RetryAfterHeader(t *testing.T) { callCount := 0 - var retryAfter int + retryAfterSeconds := 1 upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ if callCount == 1 { - w.Header().Set("Retry-After", "1") + w.Header().Set("Retry-After", strconv.Itoa(retryAfterSeconds)) w.WriteHeader(http.StatusTooManyRequests) } else { w.WriteHeader(http.StatusOK) @@ -334,17 +335,22 @@ func TestRetryInterceptor_RetryAfterHeader(t *testing.T) { if callCount != 2 { t.Errorf("callCount = %d, want 2", callCount) } - if elapsed < time.Duration(retryAfter)*time.Second { - t.Errorf("elapsed = %v, should have waited at least %v", elapsed, time.Duration(retryAfter)*time.Second) + expectedMin := time.Duration(retryAfterSeconds) * time.Second + if elapsed < expectedMin { + t.Errorf("elapsed = %v, should have waited at least %v", elapsed, expectedMin) + } + if elapsed > expectedMin+200*time.Millisecond { + t.Errorf("elapsed = %v, should have waited no more than %v", elapsed, expectedMin+200*time.Millisecond) } } func TestRetryInterceptor_RetryAfterDateHeader(t *testing.T) { callCount := 0 + retryAfterSeconds := 2 upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ if callCount == 1 { - retryTime := time.Now().Add(500 * time.Millisecond) + retryTime := time.Now().Add(time.Duration(retryAfterSeconds) * time.Second) w.Header().Set("Retry-After", retryTime.UTC().Format(http.TimeFormat)) w.WriteHeader(http.StatusTooManyRequests) } else { @@ -371,17 +377,19 @@ func TestRetryInterceptor_RetryAfterDateHeader(t *testing.T) { if callCount != 2 { t.Errorf("callCount = %d, want 2", callCount) } - if elapsed < 400*time.Millisecond { - t.Errorf("elapsed = %v, should have waited for Retry-After date", elapsed) + expectedMin := time.Duration(retryAfterSeconds)*time.Second - 500*time.Millisecond + if elapsed < expectedMin { + t.Errorf("elapsed = %v, should have waited for Retry-After date (at least %v)", elapsed, expectedMin) } } func TestRetryInterceptor_XRateLimitReset(t *testing.T) { callCount := 0 + resetSeconds := 1 upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ if callCount == 1 { - w.Header().Set("X-RateLimit-Reset", "1") + w.Header().Set("X-RateLimit-Reset", strconv.Itoa(resetSeconds)) w.WriteHeader(http.StatusTooManyRequests) } else { w.WriteHeader(http.StatusOK) @@ -407,8 +415,9 @@ func TestRetryInterceptor_XRateLimitReset(t *testing.T) { if callCount != 2 { t.Errorf("callCount = %d, want 2", callCount) } - if elapsed < 900*time.Millisecond { - t.Errorf("elapsed = %v, should have used X-RateLimit-Reset header", elapsed) + expectedMin := time.Duration(resetSeconds) * time.Second + if elapsed < expectedMin { + t.Errorf("elapsed = %v, should have used X-RateLimit-Reset header (at least %v)", elapsed, expectedMin) } } @@ -507,6 +516,16 @@ func TestParseRetryAfterHeader_TooLarge(t *testing.T) { } } +func TestParseRetryAfterHeader_Exactly24h(t *testing.T) { + resp := &http.Response{Header: make(http.Header)} + resp.Header.Set("Retry-After", "86400") + + delay := parseRetryAfterHeader(resp) + if delay != 24*time.Hour { + t.Errorf("delay = %v, want 24h (exactly 24h should be accepted)", delay) + } +} + func TestParseRetryAfterHeader_RetryAfterPreferred(t *testing.T) { resp := &http.Response{Header: make(http.Header)} resp.Header.Set("Retry-After", "10") diff --git a/interceptors/retry.go b/interceptors/retry.go index 5305edd..b4e8e54 100644 --- a/interceptors/retry.go +++ b/interceptors/retry.go @@ -81,14 +81,14 @@ func parseRetryAfterHeader(resp *http.Response) time.Duration { } if seconds, err := strconv.Atoi(retryAfter); err == nil { - if seconds > 0 && seconds < 86400 { + if seconds > 0 && seconds <= 86400 { return time.Duration(seconds) * time.Second } } if t, err := http.ParseTime(retryAfter); err == nil { delay := time.Until(t) - if delay > 0 && delay < 24*time.Hour { + if delay > 0 && delay <= 24*time.Hour { return delay } } diff --git a/providers/openai_compatible/parser_test.go b/providers/openai_compatible/parser_test.go index 2675cb0..85d308d 100644 --- a/providers/openai_compatible/parser_test.go +++ b/providers/openai_compatible/parser_test.go @@ -331,6 +331,9 @@ func TestExtractor_ZeroUsage(t *testing.T) { if meta.Usage.CompletionTokens != 0 { t.Errorf("CompletionTokens = %d, want 0", meta.Usage.CompletionTokens) } + if meta.Usage.TotalTokens != 0 { + t.Errorf("TotalTokens = %d, want 0", meta.Usage.TotalTokens) + } } func TestExtractor_InvalidJSON(t *testing.T) { From 93978e422e9b762e461bec5e8233b399bba6a79c Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Sun, 12 Apr 2026 19:14:18 -0500 Subject: [PATCH 3/4] fix: address additional PR review feedback Inline comment fixes: - retry.go: Fix parseRetryAfterHeader to fall back to X-RateLimit-Reset when Retry-After is malformed - retry.go: Drain and close response body in retry loop to prevent connection leaks - enricher.go: Don't set Authorization header when API key is empty - parser_test.go: Verify body parser is preserved in TestProvider_NewWithProvider - parser_test.go: Update TestEnricher_EmptyKey to expect no Authorization header - coverage_test.go: Add test for malformed Retry-After fallback behavior Duplicate comment fixes: - addheader_test.go: Close response body in tests that discard it --- interceptors/addheader_test.go | 6 ++-- interceptors/coverage_test.go | 15 ++++++++-- interceptors/retry.go | 35 ++++++++++++++-------- providers/openai_compatible/enricher.go | 4 ++- providers/openai_compatible/parser_test.go | 13 ++++++-- 5 files changed, 52 insertions(+), 21 deletions(-) diff --git a/interceptors/addheader_test.go b/interceptors/addheader_test.go index ed157fc..4f9601c 100644 --- a/interceptors/addheader_test.go +++ b/interceptors/addheader_test.go @@ -70,10 +70,11 @@ func TestAddHeaderInterceptor_RequestHeaders(t *testing.T) { return resp, llmproxy.ResponseMetadata{}, nil, nil } - _, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + resp, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next) if err != nil { t.Fatalf("Intercept returned error: %v", err) } + resp.Body.Close() if got := capturedReq.Header.Get("X-Client-ID"); got != "my-app" { t.Errorf("X-Client-ID header = %q, want %q", got, "my-app") @@ -139,10 +140,11 @@ func TestAddHeaderInterceptor_Empty(t *testing.T) { return resp, llmproxy.ResponseMetadata{}, nil, nil } - _, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next) + resp, _, _, err := add.Intercept(req, llmproxy.BodyMetadata{}, nil, next) if err != nil { t.Fatalf("Intercept returned error: %v", err) } + resp.Body.Close() } func TestAddHeaderInterceptor_ErrorPassthrough(t *testing.T) { diff --git a/interceptors/coverage_test.go b/interceptors/coverage_test.go index c122add..4865c8b 100644 --- a/interceptors/coverage_test.go +++ b/interceptors/coverage_test.go @@ -346,7 +346,7 @@ func TestRetryInterceptor_RetryAfterHeader(t *testing.T) { func TestRetryInterceptor_RetryAfterDateHeader(t *testing.T) { callCount := 0 - retryAfterSeconds := 2 + retryAfterSeconds := 3 upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ if callCount == 1 { @@ -377,7 +377,7 @@ func TestRetryInterceptor_RetryAfterDateHeader(t *testing.T) { if callCount != 2 { t.Errorf("callCount = %d, want 2", callCount) } - expectedMin := time.Duration(retryAfterSeconds)*time.Second - 500*time.Millisecond + expectedMin := time.Duration(retryAfterSeconds)*time.Second - time.Second if elapsed < expectedMin { t.Errorf("elapsed = %v, should have waited for Retry-After date (at least %v)", elapsed, expectedMin) } @@ -537,6 +537,17 @@ func TestParseRetryAfterHeader_RetryAfterPreferred(t *testing.T) { } } +func TestParseRetryAfterHeader_MalformedRetryAfterFallsBack(t *testing.T) { + resp := &http.Response{Header: make(http.Header)} + resp.Header.Set("Retry-After", "invalid") + resp.Header.Set("X-RateLimit-Reset", "30") + + delay := parseRetryAfterHeader(resp) + if delay != 30*time.Second { + t.Errorf("delay = %v, want 30s (should fall back to X-RateLimit-Reset when Retry-After is malformed)", delay) + } +} + func TestNewRetryWithRateLimitHeaders(t *testing.T) { retry := NewRetryWithRateLimitHeaders(5, time.Second) if retry.MaxAttempts != 5 { diff --git a/interceptors/retry.go b/interceptors/retry.go index b4e8e54..b536750 100644 --- a/interceptors/retry.go +++ b/interceptors/retry.go @@ -52,6 +52,11 @@ func (i *RetryInterceptor) Intercept(req *http.Request, meta llmproxy.BodyMetada if !isRetryable(lastResp, lastErr) { return lastResp, lastMeta, lastRawRespBody, lastErr } + + if lastResp != nil && lastResp.Body != nil { + io.Copy(io.Discard, lastResp.Body) + lastResp.Body.Close() + } } return lastResp, lastMeta, lastRawRespBody, lastErr @@ -73,23 +78,27 @@ func isContextError(err error) bool { func parseRetryAfterHeader(resp *http.Response) time.Duration { retryAfter := resp.Header.Get("Retry-After") - if retryAfter == "" { - retryAfter = resp.Header.Get("X-RateLimit-Reset") - } - if retryAfter == "" { - return 0 - } + if retryAfter != "" { + if seconds, err := strconv.Atoi(retryAfter); err == nil { + if seconds > 0 && seconds <= 86400 { + return time.Duration(seconds) * time.Second + } + } - if seconds, err := strconv.Atoi(retryAfter); err == nil { - if seconds > 0 && seconds <= 86400 { - return time.Duration(seconds) * time.Second + if t, err := http.ParseTime(retryAfter); err == nil { + delay := time.Until(t) + if delay > 0 && delay <= 24*time.Hour { + return delay + } } } - if t, err := http.ParseTime(retryAfter); err == nil { - delay := time.Until(t) - if delay > 0 && delay <= 24*time.Hour { - return delay + xRateLimitReset := resp.Header.Get("X-RateLimit-Reset") + if xRateLimitReset != "" { + if seconds, err := strconv.Atoi(xRateLimitReset); err == nil { + if seconds > 0 && seconds <= 86400 { + return time.Duration(seconds) * time.Second + } } } diff --git a/providers/openai_compatible/enricher.go b/providers/openai_compatible/enricher.go index f11b684..f2b7d64 100644 --- a/providers/openai_compatible/enricher.go +++ b/providers/openai_compatible/enricher.go @@ -18,8 +18,10 @@ type Enricher struct { // - Authorization: Bearer // - Content-Type: application/json func (e *Enricher) Enrich(req *http.Request, meta llmproxy.BodyMetadata, rawBody []byte) error { - req.Header.Set("Authorization", "Bearer "+e.APIKey) req.Header.Set("Content-Type", "application/json") + if e.APIKey != "" { + req.Header.Set("Authorization", "Bearer "+e.APIKey) + } return nil } diff --git a/providers/openai_compatible/parser_test.go b/providers/openai_compatible/parser_test.go index 85d308d..03438f0 100644 --- a/providers/openai_compatible/parser_test.go +++ b/providers/openai_compatible/parser_test.go @@ -177,8 +177,11 @@ func TestEnricher_EmptyKey(t *testing.T) { } auth := req.Header.Get("Authorization") - if auth != "Bearer " { - t.Errorf("Authorization = %q, want %q", auth, "Bearer ") + if auth != "" { + t.Errorf("Authorization = %q, want empty (no header set for empty key)", auth) + } + if ct := req.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type = %q, want %q", ct, "application/json") } } @@ -443,14 +446,18 @@ func TestProvider_NewInvalidURL(t *testing.T) { } func TestProvider_NewWithProvider(t *testing.T) { + parser := &Parser{} base := llmproxy.NewBaseProvider("custom", - llmproxy.WithBodyParser(&Parser{}), + llmproxy.WithBodyParser(parser), ) provider := NewWithProvider("custom", base) if provider.Name() != "custom" { t.Errorf("Name = %q, want %q", provider.Name(), "custom") } + if provider.BodyParser() != parser { + t.Errorf("BodyParser not preserved: got %v, want %v", provider.BodyParser(), parser) + } } func TestParseOpenAIRequestBody(t *testing.T) { From c202b90c73e93cea729e67429b5135058177c315 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Sun, 12 Apr 2026 19:24:55 -0500 Subject: [PATCH 4/4] fix: address PR review feedback for data race and auth header - addheader_test.go: Fix data race in TestAddHeaderInterceptor_RequestHeaders and TestAddHeaderInterceptor_Both by using buffered channels for goroutine synchronization instead of shared variables - enricher.go: Delete Authorization header when API key is empty to prevent forwarding inbound auth tokens to upstream providers - parser_test.go: Update TestEnricher_EmptyKey to verify that existing Authorization header is removed when enricher has empty key --- interceptors/addheader_test.go | 10 ++++++---- providers/openai_compatible/enricher.go | 2 ++ providers/openai_compatible/parser_test.go | 3 ++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/interceptors/addheader_test.go b/interceptors/addheader_test.go index 4f9601c..bedb1d5 100644 --- a/interceptors/addheader_test.go +++ b/interceptors/addheader_test.go @@ -47,9 +47,9 @@ func TestAddHeaderInterceptor_ResponseHeaders(t *testing.T) { } func TestAddHeaderInterceptor_RequestHeaders(t *testing.T) { - var capturedReq *http.Request + reqCh := make(chan *http.Request, 1) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedReq = r + reqCh <- r w.WriteHeader(http.StatusOK) })) defer upstream.Close() @@ -76,6 +76,7 @@ func TestAddHeaderInterceptor_RequestHeaders(t *testing.T) { } resp.Body.Close() + capturedReq := <-reqCh if got := capturedReq.Header.Get("X-Client-ID"); got != "my-app" { t.Errorf("X-Client-ID header = %q, want %q", got, "my-app") } @@ -88,9 +89,9 @@ func TestAddHeaderInterceptor_RequestHeaders(t *testing.T) { } func TestAddHeaderInterceptor_Both(t *testing.T) { - var capturedReq *http.Request + reqCh := make(chan *http.Request, 1) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedReq = r + reqCh <- r w.WriteHeader(http.StatusOK) })) defer upstream.Close() @@ -115,6 +116,7 @@ func TestAddHeaderInterceptor_Both(t *testing.T) { } defer resp.Body.Close() + capturedReq := <-reqCh if got := capturedReq.Header.Get("X-Request-ID"); got != "req-123" { t.Errorf("Request X-Request-ID header = %q, want %q", got, "req-123") } diff --git a/providers/openai_compatible/enricher.go b/providers/openai_compatible/enricher.go index f2b7d64..d4bf251 100644 --- a/providers/openai_compatible/enricher.go +++ b/providers/openai_compatible/enricher.go @@ -21,6 +21,8 @@ func (e *Enricher) Enrich(req *http.Request, meta llmproxy.BodyMetadata, rawBody req.Header.Set("Content-Type", "application/json") if e.APIKey != "" { req.Header.Set("Authorization", "Bearer "+e.APIKey) + } else { + req.Header.Del("Authorization") } return nil } diff --git a/providers/openai_compatible/parser_test.go b/providers/openai_compatible/parser_test.go index 03438f0..206e49e 100644 --- a/providers/openai_compatible/parser_test.go +++ b/providers/openai_compatible/parser_test.go @@ -170,6 +170,7 @@ func TestEnricher_SetsHeaders(t *testing.T) { func TestEnricher_EmptyKey(t *testing.T) { enricher := NewEnricher("") req := httptest.NewRequest("POST", "https://example.com", nil) + req.Header.Set("Authorization", "Bearer incoming-token") err := enricher.Enrich(req, llmproxy.BodyMetadata{}, nil) if err != nil { @@ -178,7 +179,7 @@ func TestEnricher_EmptyKey(t *testing.T) { auth := req.Header.Get("Authorization") if auth != "" { - t.Errorf("Authorization = %q, want empty (no header set for empty key)", auth) + t.Errorf("Authorization = %q, want empty (header should be deleted for empty key)", auth) } if ct := req.Header.Get("Content-Type"); ct != "application/json" { t.Errorf("Content-Type = %q, want %q", ct, "application/json")