Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ func (s *Server) setupRoutes() {
},
})
})
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
s.engine.POST("/v1internal:method", AuthMiddleware(s.accessManager), geminiCLIHandlers.CLIHandler)
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AuthMiddleware() intentionally allows requests when the access manager has no providers configured (legacy behavior). With the localhost-only gate removed from the Gemini CLI handler, this means /v1internal:* becomes publicly reachable whenever API keys aren’t configured, which is a security regression. Consider using a dedicated middleware for /v1internal:* that fails closed when no access providers are configured (or preserves local-only access as a fallback).

Copilot uses AI. Check for mistakes.

// OAuth callback endpoints (reuse main server port)
// These endpoints receive provider redirects and persist
Expand Down
118 changes: 118 additions & 0 deletions internal/api/server_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package api

import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
Expand All @@ -18,6 +20,22 @@ import (
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)

type roundTripperFunc func(*http.Request) (*http.Response, error)

func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

func stubDefaultTransport(t *testing.T, transport http.RoundTripper) {
t.Helper()

original := http.DefaultTransport
http.DefaultTransport = transport
t.Cleanup(func() {
http.DefaultTransport = original
})
}

func newTestServer(t *testing.T) *Server {
t.Helper()

Expand Down Expand Up @@ -135,6 +153,106 @@ func TestAmpProviderModelRoutes(t *testing.T) {
}
}

func TestGeminiCLIRouteRequiresAPIKey(t *testing.T) {
server := newTestServer(t)

req := httptest.NewRequest(http.MethodPost, "/v1internal:generateContent", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`))
req.Header.Set("Content-Type", "application/json")
req.RemoteAddr = "203.0.113.7:4567"

rr := httptest.NewRecorder()
server.engine.ServeHTTP(rr, req)

if rr.Code != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d; body=%s", rr.Code, http.StatusUnauthorized, rr.Body.String())
}
if !strings.Contains(rr.Body.String(), "Missing API key") {
t.Fatalf("body = %q, want missing API key error", rr.Body.String())
}
}

func TestGeminiCLIRouteAllowsAuthenticatedNonLocalRequest(t *testing.T) {
server := newTestServer(t)

req := httptest.NewRequest(http.MethodPost, "/v1internal:generateContent", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-key")
req.RemoteAddr = "203.0.113.7:4567"

rr := httptest.NewRecorder()
server.engine.ServeHTTP(rr, req)

if rr.Code == http.StatusUnauthorized || rr.Code == http.StatusForbidden {
t.Fatalf("status = %d, want request to pass auth and localhost gate; body=%s", rr.Code, rr.Body.String())
}
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestGeminiCLIRouteAllowsAuthenticatedNonLocalRequest only asserts the response is not 401/403, so it would still pass if the route were mis-registered and returned 404 (or if some other unrelated error occurred). Tighten the assertion to prove the request actually reached the intended route/handler (e.g., fail on 404, and/or assert the body is not the auth error).

Suggested change
}
}
if rr.Code == http.StatusNotFound {
t.Fatalf("status = %d, want request to reach the Gemini CLI route; body=%s", rr.Code, rr.Body.String())
}
if strings.Contains(rr.Body.String(), "Missing API key") {
t.Fatalf("body = %q, want request not to fail API key auth", rr.Body.String())
}

Copilot uses AI. Check for mistakes.
}

func TestGeminiCLIPassthroughStripsAuthorizationUsedForProxyAuth(t *testing.T) {
server := newTestServer(t)

var upstreamAuthorization string
stubDefaultTransport(t, roundTripperFunc(func(req *http.Request) (*http.Response, error) {
upstreamAuthorization = req.Header.Get("Authorization")
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(`{"status":"ok"}`)),
Request: req,
}, nil
}))

req := httptest.NewRequest(http.MethodPost, "/v1internal:countTokens", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer test-key")

rr := httptest.NewRecorder()
server.engine.ServeHTTP(rr, req)

if rr.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String())
}
if upstreamAuthorization != "" {
t.Fatalf("upstream Authorization = %q, want empty", upstreamAuthorization)
}
}

func TestGeminiCLIPassthroughPreservesUpstreamAuthorizationWhenProxyAuthUsesGoogleAPIKey(t *testing.T) {
server := newTestServer(t)

var (
upstreamAuthorization string
upstreamGoogleAPIKey string
)
stubDefaultTransport(t, roundTripperFunc(func(req *http.Request) (*http.Response, error) {
upstreamAuthorization = req.Header.Get("Authorization")
upstreamGoogleAPIKey = req.Header.Get("X-Goog-Api-Key")
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(`{"status":"ok"}`)),
Request: req,
}, nil
}))

req := httptest.NewRequest(http.MethodPost, "/v1internal:countTokens", bytes.NewBufferString(`{"model":"gemini-2.5-pro"}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Goog-Api-Key", "test-key")
req.Header.Set("Authorization", "Bearer upstream-token")

rr := httptest.NewRecorder()
server.engine.ServeHTTP(rr, req)

if rr.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String())
}
if upstreamAuthorization != "Bearer upstream-token" {
t.Fatalf("upstream Authorization = %q, want %q", upstreamAuthorization, "Bearer upstream-token")
}
if upstreamGoogleAPIKey != "" {
t.Fatalf("upstream X-Goog-Api-Key = %q, want empty", upstreamGoogleAPIKey)
}
}

func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
t.Setenv("WRITABLE_PATH", "")
t.Setenv("writable_path", "")
Expand Down
83 changes: 69 additions & 14 deletions sdk/api/handlers/gemini/gemini-cli_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,9 @@ func (h *GeminiCLIAPIHandler) Models() []map[string]any {
}

// CLIHandler handles CLI-specific requests for Gemini API operations.
// It restricts access to localhost only and routes requests to appropriate internal handlers.
// Access control is enforced at the router level so authenticated clients can
// use Gemini CLI-compatible routes through the public proxy.
Comment on lines 49 to +51
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The package-level comment at the top of this file still says the Gemini CLI handlers “restrict access to localhost only”, but this PR moves access control to router-level API key auth. Update/remove that file header comment to avoid misleading documentation.

Copilot uses AI. Check for mistakes.
func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) {
if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") {
c.JSON(http.StatusForbidden, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "CLI reply only allow local access",
Type: "forbidden",
},
})
return
}

rawJSON, _ := c.GetRawData()
requestRawURI := c.Request.URL.Path
Comment on lines 49 to 54
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that this handler is reachable through the public proxy, the fallback branch in CLIHandler that proxies to cloudcode-pa.googleapis.com copies all incoming headers to the upstream request. That includes the client’s Authorization header (which in this setup is the proxy API key), causing credential leakage to the upstream Google endpoint. Strip sensitive auth headers (Authorization / X-Api-Key / X-Goog-Api-Key, etc.) before forwarding, or explicitly whitelist safe headers to forward.

Copilot uses AI. Check for mistakes.

Expand All @@ -78,9 +69,8 @@ func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) {
})
return
}
for key, value := range c.Request.Header {
req.Header[key] = value
}
req.Header = handlers.FilterUpstreamHeaders(c.Request.Header)
stripConsumedProxyCredential(req, c)

httpClient := util.SetProxy(h.Cfg, &http.Client{})

Expand Down Expand Up @@ -130,6 +120,71 @@ func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) {
}
}

func stripConsumedProxyCredential(req *http.Request, c *gin.Context) {
if req == nil || c == nil {
return
}

switch accessMetadataSource(c) {
case "authorization":
req.Header.Del("Authorization")
case "x-goog-api-key":
req.Header.Del("X-Goog-Api-Key")
case "x-api-key":
req.Header.Del("X-Api-Key")
case "query-key":
removeQueryValuesMatching(req, "key", strings.TrimSpace(c.GetString("apiKey")))
case "query-auth-token":
removeQueryValuesMatching(req, "auth_token", strings.TrimSpace(c.GetString("apiKey")))
}
}

func accessMetadataSource(c *gin.Context) string {
if c == nil {
return ""
}
raw, exists := c.Get("accessMetadata")
if !exists || raw == nil {
return ""
}
switch typed := raw.(type) {
case map[string]string:
return strings.TrimSpace(typed["source"])
case map[string]any:
if source, ok := typed["source"].(string); ok {
return strings.TrimSpace(source)
}
}
return ""
}

func removeQueryValuesMatching(req *http.Request, key string, match string) {
if req == nil || req.URL == nil || match == "" {
return
}

query := req.URL.Query()
values, ok := query[key]
if !ok || len(values) == 0 {
return
}

kept := make([]string, 0, len(values))
for _, value := range values {
if value == match {
continue
}
kept = append(kept, value)
}

if len(kept) == 0 {
query.Del(key)
} else {
query[key] = kept
}
req.URL.RawQuery = query.Encode()
}

// handleInternalStreamGenerateContent handles streaming content generation requests.
// It sets up a server-sent event stream and forwards the request to the backend client.
// The function continuously proxies response chunks from the backend to the client.
Expand Down
Loading