From 26e86a9cfd64a0cbb7cc906de54564f6af2d03a4 Mon Sep 17 00:00:00 2001 From: benjamin Date: Mon, 6 Apr 2026 16:32:48 +0800 Subject: [PATCH 1/2] fix(security): require API keys for Gemini CLI routes --- internal/api/server.go | 2 +- internal/api/server_test.go | 35 +++++++++++++++++++ .../handlers/gemini/gemini-cli_handlers.go | 14 ++------ 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/internal/api/server.go b/internal/api/server.go index 6126ba2c12..8c259ed8dc 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -358,7 +358,7 @@ func (s *Server) setupRoutes() { }, }) }) - s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler) + s.engine.POST("/v1internal:method", AuthMiddleware(s.accessManager), geminiCLIHandlers.CLIHandler) // OAuth callback endpoints (reuse main server port) // These endpoints receive provider redirects and persist diff --git a/internal/api/server_test.go b/internal/api/server_test.go index f5c18aa167..6a4c8ef636 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -1,6 +1,7 @@ package api import ( + "bytes" "net/http" "net/http/httptest" "os" @@ -112,6 +113,40 @@ func TestAmpProviderModelRoutes(t *testing.T) { } } +func TestGeminiCLIRouteRequiresAPIKey(t *testing.T) { + server := newTestServer(t) + + req := httptest.NewRequest(http.MethodPost, "/v1internal:generateContent", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "203.0.113.7:4567" + + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d; body=%s", rr.Code, http.StatusUnauthorized, rr.Body.String()) + } + if !strings.Contains(rr.Body.String(), "Missing API key") { + t.Fatalf("body = %q, want missing API key error", rr.Body.String()) + } +} + +func TestGeminiCLIRouteAllowsAuthenticatedNonLocalRequest(t *testing.T) { + server := newTestServer(t) + + req := httptest.NewRequest(http.MethodPost, "/v1internal:generateContent", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-key") + req.RemoteAddr = "203.0.113.7:4567" + + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code == http.StatusUnauthorized || rr.Code == http.StatusForbidden { + t.Fatalf("status = %d, want request to pass auth and localhost gate; body=%s", rr.Code, rr.Body.String()) + } +} + func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) { t.Setenv("WRITABLE_PATH", "") t.Setenv("writable_path", "") diff --git a/sdk/api/handlers/gemini/gemini-cli_handlers.go b/sdk/api/handlers/gemini/gemini-cli_handlers.go index b5fd494375..0c61424336 100644 --- a/sdk/api/handlers/gemini/gemini-cli_handlers.go +++ b/sdk/api/handlers/gemini/gemini-cli_handlers.go @@ -10,7 +10,6 @@ import ( "fmt" "io" "net/http" - "strings" "time" "github.com/gin-gonic/gin" @@ -47,18 +46,9 @@ func (h *GeminiCLIAPIHandler) Models() []map[string]any { } // CLIHandler handles CLI-specific requests for Gemini API operations. -// It restricts access to localhost only and routes requests to appropriate internal handlers. +// Access control is enforced at the router level so authenticated clients can +// use Gemini CLI-compatible routes through the public proxy. func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) { - if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") { - c.JSON(http.StatusForbidden, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: "CLI reply only allow local access", - Type: "forbidden", - }, - }) - return - } - rawJSON, _ := c.GetRawData() requestRawURI := c.Request.URL.Path From 6edb5c52734b3bbd9fa2679088a443a9e125e76b Mon Sep 17 00:00:00 2001 From: Jlypx Date: Tue, 7 Apr 2026 15:28:11 +0800 Subject: [PATCH 2/2] fix(gemini-cli): strip proxy auth before passthrough --- internal/api/server_test.go | 83 +++++++++++++++++++ .../handlers/gemini/gemini-cli_handlers.go | 71 +++++++++++++++- 2 files changed, 151 insertions(+), 3 deletions(-) diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 6a4c8ef636..d4afe811ae 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -2,6 +2,7 @@ package api import ( "bytes" + "io" "net/http" "net/http/httptest" "os" @@ -18,6 +19,22 @@ import ( sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" ) +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func stubDefaultTransport(t *testing.T, transport http.RoundTripper) { + t.Helper() + + original := http.DefaultTransport + http.DefaultTransport = transport + t.Cleanup(func() { + http.DefaultTransport = original + }) +} + func newTestServer(t *testing.T) *Server { t.Helper() @@ -147,6 +164,72 @@ func TestGeminiCLIRouteAllowsAuthenticatedNonLocalRequest(t *testing.T) { } } +func TestGeminiCLIPassthroughStripsAuthorizationUsedForProxyAuth(t *testing.T) { + server := newTestServer(t) + + var upstreamAuthorization string + stubDefaultTransport(t, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + upstreamAuthorization = req.Header.Get("Authorization") + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(`{"status":"ok"}`)), + Request: req, + }, nil + })) + + req := httptest.NewRequest(http.MethodPost, "/v1internal:countTokens", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-key") + + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String()) + } + if upstreamAuthorization != "" { + t.Fatalf("upstream Authorization = %q, want empty", upstreamAuthorization) + } +} + +func TestGeminiCLIPassthroughPreservesUpstreamAuthorizationWhenProxyAuthUsesGoogleAPIKey(t *testing.T) { + server := newTestServer(t) + + var ( + upstreamAuthorization string + upstreamGoogleAPIKey string + ) + stubDefaultTransport(t, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + upstreamAuthorization = req.Header.Get("Authorization") + upstreamGoogleAPIKey = req.Header.Get("X-Goog-Api-Key") + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(`{"status":"ok"}`)), + Request: req, + }, nil + })) + + req := httptest.NewRequest(http.MethodPost, "/v1internal:countTokens", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Goog-Api-Key", "test-key") + req.Header.Set("Authorization", "Bearer upstream-token") + + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String()) + } + if upstreamAuthorization != "Bearer upstream-token" { + t.Fatalf("upstream Authorization = %q, want %q", upstreamAuthorization, "Bearer upstream-token") + } + if upstreamGoogleAPIKey != "" { + t.Fatalf("upstream X-Goog-Api-Key = %q, want empty", upstreamGoogleAPIKey) + } +} + func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) { t.Setenv("WRITABLE_PATH", "") t.Setenv("writable_path", "") diff --git a/sdk/api/handlers/gemini/gemini-cli_handlers.go b/sdk/api/handlers/gemini/gemini-cli_handlers.go index 0c61424336..80afc6f918 100644 --- a/sdk/api/handlers/gemini/gemini-cli_handlers.go +++ b/sdk/api/handlers/gemini/gemini-cli_handlers.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "net/http" + "strings" "time" "github.com/gin-gonic/gin" @@ -68,9 +69,8 @@ func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) { }) return } - for key, value := range c.Request.Header { - req.Header[key] = value - } + req.Header = handlers.FilterUpstreamHeaders(c.Request.Header) + stripConsumedProxyCredential(req, c) httpClient := util.SetProxy(h.Cfg, &http.Client{}) @@ -120,6 +120,71 @@ func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) { } } +func stripConsumedProxyCredential(req *http.Request, c *gin.Context) { + if req == nil || c == nil { + return + } + + switch accessMetadataSource(c) { + case "authorization": + req.Header.Del("Authorization") + case "x-goog-api-key": + req.Header.Del("X-Goog-Api-Key") + case "x-api-key": + req.Header.Del("X-Api-Key") + case "query-key": + removeQueryValuesMatching(req, "key", strings.TrimSpace(c.GetString("apiKey"))) + case "query-auth-token": + removeQueryValuesMatching(req, "auth_token", strings.TrimSpace(c.GetString("apiKey"))) + } +} + +func accessMetadataSource(c *gin.Context) string { + if c == nil { + return "" + } + raw, exists := c.Get("accessMetadata") + if !exists || raw == nil { + return "" + } + switch typed := raw.(type) { + case map[string]string: + return strings.TrimSpace(typed["source"]) + case map[string]any: + if source, ok := typed["source"].(string); ok { + return strings.TrimSpace(source) + } + } + return "" +} + +func removeQueryValuesMatching(req *http.Request, key string, match string) { + if req == nil || req.URL == nil || match == "" { + return + } + + query := req.URL.Query() + values, ok := query[key] + if !ok || len(values) == 0 { + return + } + + kept := make([]string, 0, len(values)) + for _, value := range values { + if value == match { + continue + } + kept = append(kept, value) + } + + if len(kept) == 0 { + query.Del(key) + } else { + query[key] = kept + } + req.URL.RawQuery = query.Encode() +} + // handleInternalStreamGenerateContent handles streaming content generation requests. // It sets up a server-sent event stream and forwards the request to the backend client. // The function continuously proxies response chunks from the backend to the client.