diff --git a/autorouter.go b/autorouter.go index f6ebb53..f3e3d15 100644 --- a/autorouter.go +++ b/autorouter.go @@ -9,8 +9,24 @@ import ( "io" "net/http" "strings" + + "github.com/agentuity/go-common/slice" ) +var skipHeaders = []string{"Content-Encoding", "Content-Length"} + +func copyResponseHeaders(w http.ResponseWriter, headers http.Header) { + header := w.Header() + + for k, v := range headers { + if !slice.Contains(skipHeaders, k, slice.WithCaseInsensitive()) { + for _, val := range v { + header.Add(k, val) + } + } + } +} + type AutoRouter struct { registry Registry detector ProviderDetector @@ -303,6 +319,9 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w upstreamReq.Header[k] = v } + // FOR SSE, turn off compression explicitly + upstreamReq.Header["Accept-Encoding"] = []string{"identity"} + if err := provider.RequestEnricher().Enrich(upstreamReq, meta, body); err != nil { return ResponseMetadata{}, err } @@ -340,11 +359,7 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w w.Header().Set("Trailer", "X-Gateway-Cost,X-Gateway-Prompt-Tokens,X-Gateway-Completion-Tokens") } - for k, v := range upstreamResp.Header { - if k != "Content-Length" { - w.Header()[k] = v - } - } + copyResponseHeaders(w, upstreamResp.Header) w.WriteHeader(upstreamResp.StatusCode) @@ -469,9 +484,7 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer resp.Body.Close() - for k, v := range resp.Header { - w.Header()[k] = v - } + copyResponseHeaders(w, resp.Header) if billing, ok := meta.Custom["billing_result"].(BillingResult); ok { w.Header().Set("X-Gateway-Cost", fmt.Sprintf("%.6f", billing.TotalCost)) diff --git a/autorouter_test.go b/autorouter_test.go index 5af8a68..5ba7b4e 100644 --- a/autorouter_test.go +++ b/autorouter_test.go @@ -822,3 +822,76 @@ func TestAutoRouter_ResponsesAPIStreamingNoStreamOptions(t *testing.T) { } }) } + +func TestAutoRouter_copyResponseHeaders(t *testing.T) { + w := httptest.NewRecorder() + copyResponseHeaders(w, http.Header{}) + var sw strings.Builder + w.Header().Write(&sw) + if sw.Len() != 0 { + t.Errorf("headers should have been empty but was: %s", sw.String()) + } + sw.Reset() + w = httptest.NewRecorder() + + copyResponseHeaders(w, http.Header{"A": []string{"B"}}) + w.Header().Write(&sw) + if sw.Len() == 0 { + t.Error("headers should have content but was empty") + } + val := strings.TrimSpace(sw.String()) + if val != "A: B" { + t.Errorf("headers should have A: B but was %s", val) + } + sw.Reset() + w = httptest.NewRecorder() + + copyResponseHeaders(w, http.Header{"A": []string{"B"}, "Content-Encoding": []string{"gzip"}}) + w.Header().Write(&sw) + if sw.Len() == 0 { + t.Error("headers should have content but was empty") + } + val = strings.TrimSpace(sw.String()) + if val != "A: B" { + t.Errorf("headers should have A: B but was %s", val) + } + sw.Reset() + w = httptest.NewRecorder() + + copyResponseHeaders(w, http.Header{"A": []string{"B"}, "content-encoding": []string{"gzip"}}) + w.Header().Write(&sw) + if sw.Len() == 0 { + t.Error("headers should have content but was empty") + } + val = strings.TrimSpace(sw.String()) + if val != "A: B" { + t.Errorf("headers should have A: B but was %s", val) + } + sw.Reset() + w = httptest.NewRecorder() + + copyResponseHeaders(w, http.Header{"A": []string{"B"}, "Content-Length": []string{"1"}}) + w.Header().Write(&sw) + if sw.Len() == 0 { + t.Error("headers should have content but was empty") + } + val = strings.TrimSpace(sw.String()) + if val != "A: B" { + t.Errorf("headers should have A: B but was %s", val) + } + sw.Reset() + w = httptest.NewRecorder() + + copyResponseHeaders(w, http.Header{"A": []string{"B"}, "content-length": []string{"1"}}) + w.Header().Write(&sw) + if sw.Len() == 0 { + t.Error("headers should have content but was empty") + } + val = strings.TrimSpace(sw.String()) + if val != "A: B" { + t.Errorf("headers should have A: B but was %s", val) + } + sw.Reset() + w = httptest.NewRecorder() + +} diff --git a/go.mod b/go.mod index f867248..03fdb9b 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/agentuity/llmproxy go 1.26.2 require ( + github.com/agentuity/go-common v1.0.231 github.com/minio/simdjson-go v0.4.5 go.opentelemetry.io/otel/trace v1.43.0 ) @@ -12,5 +13,5 @@ require ( github.com/klauspost/compress v1.15.15 // indirect github.com/klauspost/cpuid/v2 v2.2.3 // indirect go.opentelemetry.io/otel v1.43.0 // indirect - golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e // indirect + golang.org/x/sys v0.42.0 // indirect ) diff --git a/go.sum b/go.sum index 6d4b853..d25c52f 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/agentuity/go-common v1.0.231 h1:t5CzJuA+yKv6U9lVSvxmiZoNM60ZeBo8U/Vf8P4ce4E= +github.com/agentuity/go-common v1.0.231/go.mod h1:/QxgG4qKu9Rik0084BargZ8wG13/3kdWYI+jIRJYUwI= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -18,7 +20,8 @@ go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= -golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e h1:CsOuNlbOuf0mzxJIefr6Q4uAUetRUwZE4qt7VfzP+xo= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/providers/anthropic/parser.go b/providers/anthropic/parser.go index 843bc4a..621b88c 100644 --- a/providers/anthropic/parser.go +++ b/providers/anthropic/parser.go @@ -45,6 +45,7 @@ func (p *Parser) Parse(body io.ReadCloser) (llmproxy.BodyMetadata, []byte, error Model: req.Model, Messages: make([]llmproxy.Message, len(req.Messages)), MaxTokens: req.MaxTokens, + Stream: req.Stream, Custom: make(map[string]any), } @@ -87,6 +88,7 @@ type Request struct { Model string `json:"model"` Messages []Message `json:"messages"` MaxTokens int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` System Content `json:"system,omitempty"` Custom map[string]interface{} `json:"-"` } diff --git a/providers/anthropic/parser_test.go b/providers/anthropic/parser_test.go index 943a1bf..15a409a 100644 --- a/providers/anthropic/parser_test.go +++ b/providers/anthropic/parser_test.go @@ -45,6 +45,19 @@ func TestParser(t *testing.T) { } }) + t.Run("parses stream flag", func(t *testing.T) { + body := `{"model":"claude-3-opus-20240229","max_tokens":1024,"stream":true,"messages":[{"role":"user","content":"hello"}]}` + parser := &Parser{} + + meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body)))) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !meta.Stream { + t.Error("expected stream flag to be true") + } + }) + t.Run("parses request with system prompt array", func(t *testing.T) { body := `{"model":"anthropic/claude-sonnet-4-6","max_tokens":1024,"system":[{"type":"text","text":"You are helpful."}],"messages":[{"role":"user","content":"hello"}]}` parser := &Parser{}