-
Notifications
You must be signed in to change notification settings - Fork 49
fix(security): require API keys for Gemini CLI routes #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,7 +1,9 @@ | ||||||||||||||||||
| package api | ||||||||||||||||||
|
|
||||||||||||||||||
| import ( | ||||||||||||||||||
| "bytes" | ||||||||||||||||||
| "encoding/json" | ||||||||||||||||||
| "io" | ||||||||||||||||||
| "net/http" | ||||||||||||||||||
| "net/http/httptest" | ||||||||||||||||||
| "os" | ||||||||||||||||||
|
|
@@ -18,6 +20,22 @@ import ( | |||||||||||||||||
| sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| type roundTripperFunc func(*http.Request) (*http.Response, error) | ||||||||||||||||||
|
|
||||||||||||||||||
| func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { | ||||||||||||||||||
| return f(req) | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| func stubDefaultTransport(t *testing.T, transport http.RoundTripper) { | ||||||||||||||||||
| t.Helper() | ||||||||||||||||||
|
|
||||||||||||||||||
| original := http.DefaultTransport | ||||||||||||||||||
| http.DefaultTransport = transport | ||||||||||||||||||
| t.Cleanup(func() { | ||||||||||||||||||
| http.DefaultTransport = original | ||||||||||||||||||
| }) | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| func newTestServer(t *testing.T) *Server { | ||||||||||||||||||
| t.Helper() | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -135,6 +153,106 @@ func TestAmpProviderModelRoutes(t *testing.T) { | |||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| func TestGeminiCLIRouteRequiresAPIKey(t *testing.T) { | ||||||||||||||||||
| server := newTestServer(t) | ||||||||||||||||||
|
|
||||||||||||||||||
| req := httptest.NewRequest(http.MethodPost, "/v1internal:generateContent", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`)) | ||||||||||||||||||
| req.Header.Set("Content-Type", "application/json") | ||||||||||||||||||
| req.RemoteAddr = "203.0.113.7:4567" | ||||||||||||||||||
|
|
||||||||||||||||||
| rr := httptest.NewRecorder() | ||||||||||||||||||
| server.engine.ServeHTTP(rr, req) | ||||||||||||||||||
|
|
||||||||||||||||||
| if rr.Code != http.StatusUnauthorized { | ||||||||||||||||||
| t.Fatalf("status = %d, want %d; body=%s", rr.Code, http.StatusUnauthorized, rr.Body.String()) | ||||||||||||||||||
| } | ||||||||||||||||||
| if !strings.Contains(rr.Body.String(), "Missing API key") { | ||||||||||||||||||
| t.Fatalf("body = %q, want missing API key error", rr.Body.String()) | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| func TestGeminiCLIRouteAllowsAuthenticatedNonLocalRequest(t *testing.T) { | ||||||||||||||||||
| server := newTestServer(t) | ||||||||||||||||||
|
|
||||||||||||||||||
| req := httptest.NewRequest(http.MethodPost, "/v1internal:generateContent", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`)) | ||||||||||||||||||
| req.Header.Set("Content-Type", "application/json") | ||||||||||||||||||
| req.Header.Set("Authorization", "Bearer test-key") | ||||||||||||||||||
| req.RemoteAddr = "203.0.113.7:4567" | ||||||||||||||||||
|
|
||||||||||||||||||
| rr := httptest.NewRecorder() | ||||||||||||||||||
| server.engine.ServeHTTP(rr, req) | ||||||||||||||||||
|
|
||||||||||||||||||
| if rr.Code == http.StatusUnauthorized || rr.Code == http.StatusForbidden { | ||||||||||||||||||
| t.Fatalf("status = %d, want request to pass auth and localhost gate; body=%s", rr.Code, rr.Body.String()) | ||||||||||||||||||
| } | ||||||||||||||||||
|
||||||||||||||||||
| } | |
| } | |
| if rr.Code == http.StatusNotFound { | |
| t.Fatalf("status = %d, want request to reach the Gemini CLI route; body=%s", rr.Code, rr.Body.String()) | |
| } | |
| if strings.Contains(rr.Body.String(), "Missing API key") { | |
| t.Fatalf("body = %q, want request not to fail API key auth", rr.Body.String()) | |
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,18 +47,9 @@ func (h *GeminiCLIAPIHandler) Models() []map[string]any { | |
| } | ||
|
|
||
| // CLIHandler handles CLI-specific requests for Gemini API operations. | ||
| // It restricts access to localhost only and routes requests to appropriate internal handlers. | ||
| // Access control is enforced at the router level so authenticated clients can | ||
| // use Gemini CLI-compatible routes through the public proxy. | ||
|
Comment on lines
49
to
+51
|
||
| func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) { | ||
| if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") { | ||
| c.JSON(http.StatusForbidden, handlers.ErrorResponse{ | ||
| Error: handlers.ErrorDetail{ | ||
| Message: "CLI reply only allow local access", | ||
| Type: "forbidden", | ||
| }, | ||
| }) | ||
| return | ||
| } | ||
|
|
||
| rawJSON, _ := c.GetRawData() | ||
| requestRawURI := c.Request.URL.Path | ||
|
Comment on lines
49
to
54
|
||
|
|
||
|
|
@@ -78,9 +69,8 @@ func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) { | |
| }) | ||
| return | ||
| } | ||
| for key, value := range c.Request.Header { | ||
| req.Header[key] = value | ||
| } | ||
| req.Header = handlers.FilterUpstreamHeaders(c.Request.Header) | ||
| stripConsumedProxyCredential(req, c) | ||
|
|
||
| httpClient := util.SetProxy(h.Cfg, &http.Client{}) | ||
|
|
||
|
|
@@ -130,6 +120,71 @@ func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) { | |
| } | ||
| } | ||
|
|
||
| func stripConsumedProxyCredential(req *http.Request, c *gin.Context) { | ||
| if req == nil || c == nil { | ||
| return | ||
| } | ||
|
|
||
| switch accessMetadataSource(c) { | ||
| case "authorization": | ||
| req.Header.Del("Authorization") | ||
| case "x-goog-api-key": | ||
| req.Header.Del("X-Goog-Api-Key") | ||
| case "x-api-key": | ||
| req.Header.Del("X-Api-Key") | ||
| case "query-key": | ||
| removeQueryValuesMatching(req, "key", strings.TrimSpace(c.GetString("apiKey"))) | ||
| case "query-auth-token": | ||
| removeQueryValuesMatching(req, "auth_token", strings.TrimSpace(c.GetString("apiKey"))) | ||
| } | ||
| } | ||
|
|
||
| func accessMetadataSource(c *gin.Context) string { | ||
| if c == nil { | ||
| return "" | ||
| } | ||
| raw, exists := c.Get("accessMetadata") | ||
| if !exists || raw == nil { | ||
| return "" | ||
| } | ||
| switch typed := raw.(type) { | ||
| case map[string]string: | ||
| return strings.TrimSpace(typed["source"]) | ||
| case map[string]any: | ||
| if source, ok := typed["source"].(string); ok { | ||
| return strings.TrimSpace(source) | ||
| } | ||
| } | ||
| return "" | ||
| } | ||
|
|
||
| func removeQueryValuesMatching(req *http.Request, key string, match string) { | ||
| if req == nil || req.URL == nil || match == "" { | ||
| return | ||
| } | ||
|
|
||
| query := req.URL.Query() | ||
| values, ok := query[key] | ||
| if !ok || len(values) == 0 { | ||
| return | ||
| } | ||
|
|
||
| kept := make([]string, 0, len(values)) | ||
| for _, value := range values { | ||
| if value == match { | ||
| continue | ||
| } | ||
| kept = append(kept, value) | ||
| } | ||
|
|
||
| if len(kept) == 0 { | ||
| query.Del(key) | ||
| } else { | ||
| query[key] = kept | ||
| } | ||
| req.URL.RawQuery = query.Encode() | ||
| } | ||
|
|
||
| // handleInternalStreamGenerateContent handles streaming content generation requests. | ||
| // It sets up a server-sent event stream and forwards the request to the backend client. | ||
| // The function continuously proxies response chunks from the backend to the client. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AuthMiddleware() intentionally allows requests when the access manager has no providers configured (legacy behavior). With the localhost-only gate removed from the Gemini CLI handler, this means /v1internal:* becomes publicly reachable whenever API keys aren’t configured, which is a security regression. Consider using a dedicated middleware for /v1internal:* that fails closed when no access providers are configured (or preserves local-only access as a fallback).