diff --git a/internal/api/server.go b/internal/api/server.go index 12205ad6e1..cedb819a76 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -362,7 +362,7 @@ func (s *Server) setupRoutes() { }, }) }) - s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler) + s.engine.POST("/v1internal:method", AuthMiddleware(s.accessManager), geminiCLIHandlers.CLIHandler) // OAuth callback endpoints (reuse main server port) // These endpoints receive provider redirects and persist diff --git a/internal/api/server_test.go b/internal/api/server_test.go index e224c90a32..72fb019931 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -1,7 +1,9 @@ package api import ( + "bytes" "encoding/json" + "io" "net/http" "net/http/httptest" "os" @@ -18,6 +20,22 @@ import ( sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" ) +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func stubDefaultTransport(t *testing.T, transport http.RoundTripper) { + t.Helper() + + original := http.DefaultTransport + http.DefaultTransport = transport + t.Cleanup(func() { + http.DefaultTransport = original + }) +} + func newTestServer(t *testing.T) *Server { t.Helper() @@ -135,6 +153,106 @@ func TestAmpProviderModelRoutes(t *testing.T) { } } +func TestGeminiCLIRouteRequiresAPIKey(t *testing.T) { + server := newTestServer(t) + + req := httptest.NewRequest(http.MethodPost, "/v1internal:generateContent", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "203.0.113.7:4567" + + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d; body=%s", rr.Code, http.StatusUnauthorized, rr.Body.String()) + } + if !strings.Contains(rr.Body.String(), "Missing API key") { + t.Fatalf("body = %q, want missing API key error", rr.Body.String()) + } +} + +func TestGeminiCLIRouteAllowsAuthenticatedNonLocalRequest(t *testing.T) { + server := newTestServer(t) + + req := httptest.NewRequest(http.MethodPost, "/v1internal:generateContent", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-key") + req.RemoteAddr = "203.0.113.7:4567" + + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code == http.StatusUnauthorized || rr.Code == http.StatusForbidden { + t.Fatalf("status = %d, want request to pass auth and localhost gate; body=%s", rr.Code, rr.Body.String()) + } +} + +func TestGeminiCLIPassthroughStripsAuthorizationUsedForProxyAuth(t *testing.T) { + server := newTestServer(t) + + var upstreamAuthorization string + stubDefaultTransport(t, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + upstreamAuthorization = req.Header.Get("Authorization") + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(`{"status":"ok"}`)), + Request: req, + }, nil + })) + + req := httptest.NewRequest(http.MethodPost, "/v1internal:countTokens", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-key") + + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String()) + } + if upstreamAuthorization != "" { + t.Fatalf("upstream Authorization = %q, want empty", upstreamAuthorization) + } +} + +func TestGeminiCLIPassthroughPreservesUpstreamAuthorizationWhenProxyAuthUsesGoogleAPIKey(t *testing.T) { + server := newTestServer(t) + + var ( + upstreamAuthorization string + upstreamGoogleAPIKey string + ) + stubDefaultTransport(t, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + upstreamAuthorization = req.Header.Get("Authorization") + upstreamGoogleAPIKey = req.Header.Get("X-Goog-Api-Key") + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(`{"status":"ok"}`)), + Request: req, + }, nil + })) + + req := httptest.NewRequest(http.MethodPost, "/v1internal:countTokens", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Goog-Api-Key", "test-key") + req.Header.Set("Authorization", "Bearer upstream-token") + + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String()) + } + if upstreamAuthorization != "Bearer upstream-token" { + t.Fatalf("upstream Authorization = %q, want %q", upstreamAuthorization, "Bearer upstream-token") + } + if upstreamGoogleAPIKey != "" { + t.Fatalf("upstream X-Goog-Api-Key = %q, want empty", upstreamGoogleAPIKey) + } +} + func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) { t.Setenv("WRITABLE_PATH", "") t.Setenv("writable_path", "") diff --git a/sdk/api/handlers/gemini/gemini-cli_handlers.go b/sdk/api/handlers/gemini/gemini-cli_handlers.go index b5fd494375..80afc6f918 100644 --- a/sdk/api/handlers/gemini/gemini-cli_handlers.go +++ b/sdk/api/handlers/gemini/gemini-cli_handlers.go @@ -47,18 +47,9 @@ func (h *GeminiCLIAPIHandler) Models() []map[string]any { } // CLIHandler handles CLI-specific requests for Gemini API operations. -// It restricts access to localhost only and routes requests to appropriate internal handlers. +// Access control is enforced at the router level so authenticated clients can +// use Gemini CLI-compatible routes through the public proxy. func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) { - if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") { - c.JSON(http.StatusForbidden, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: "CLI reply only allow local access", - Type: "forbidden", - }, - }) - return - } - rawJSON, _ := c.GetRawData() requestRawURI := c.Request.URL.Path @@ -78,9 +69,8 @@ func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) { }) return } - for key, value := range c.Request.Header { - req.Header[key] = value - } + req.Header = handlers.FilterUpstreamHeaders(c.Request.Header) + stripConsumedProxyCredential(req, c) httpClient := util.SetProxy(h.Cfg, &http.Client{}) @@ -130,6 +120,71 @@ func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) { } } +func stripConsumedProxyCredential(req *http.Request, c *gin.Context) { + if req == nil || c == nil { + return + } + + switch accessMetadataSource(c) { + case "authorization": + req.Header.Del("Authorization") + case "x-goog-api-key": + req.Header.Del("X-Goog-Api-Key") + case "x-api-key": + req.Header.Del("X-Api-Key") + case "query-key": + removeQueryValuesMatching(req, "key", strings.TrimSpace(c.GetString("apiKey"))) + case "query-auth-token": + removeQueryValuesMatching(req, "auth_token", strings.TrimSpace(c.GetString("apiKey"))) + } +} + +func accessMetadataSource(c *gin.Context) string { + if c == nil { + return "" + } + raw, exists := c.Get("accessMetadata") + if !exists || raw == nil { + return "" + } + switch typed := raw.(type) { + case map[string]string: + return strings.TrimSpace(typed["source"]) + case map[string]any: + if source, ok := typed["source"].(string); ok { + return strings.TrimSpace(source) + } + } + return "" +} + +func removeQueryValuesMatching(req *http.Request, key string, match string) { + if req == nil || req.URL == nil || match == "" { + return + } + + query := req.URL.Query() + values, ok := query[key] + if !ok || len(values) == 0 { + return + } + + kept := make([]string, 0, len(values)) + for _, value := range values { + if value == match { + continue + } + kept = append(kept, value) + } + + if len(kept) == 0 { + query.Del(key) + } else { + query[key] = kept + } + req.URL.RawQuery = query.Encode() +} + // handleInternalStreamGenerateContent handles streaming content generation requests. // It sets up a server-sent event stream and forwards the request to the backend client. // The function continuously proxies response chunks from the backend to the client.