diff --git a/internal/api/handlers/management/api_tools.go b/internal/api/handlers/management/api_tools.go index de546ea820..cb4805e9ef 100644 --- a/internal/api/handlers/management/api_tools.go +++ b/internal/api/handlers/management/api_tools.go @@ -11,6 +11,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" @@ -636,6 +637,11 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper { if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" { proxyCandidates = append(proxyCandidates, proxyStr) } + if h != nil && h.cfg != nil { + if proxyStr := strings.TrimSpace(proxyURLFromAPIKeyConfig(h.cfg, auth)); proxyStr != "" { + proxyCandidates = append(proxyCandidates, proxyStr) + } + } } if h != nil && h.cfg != nil { if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" { @@ -658,6 +664,123 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper { return clone } +type apiKeyConfigEntry interface { + GetAPIKey() string + GetBaseURL() string +} + +func resolveAPIKeyConfig[T apiKeyConfigEntry](entries []T, auth *coreauth.Auth) *T { + if auth == nil || len(entries) == 0 { + return nil + } + attrKey, attrBase := "", "" + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range entries { + entry := &entries[i] + cfgKey := strings.TrimSpace((*entry).GetAPIKey()) + cfgBase := strings.TrimSpace((*entry).GetBaseURL()) + if attrKey != "" && attrBase != "" { + if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey != "" { + for i := range entries { + entry := &entries[i] + if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) { + return entry + } + } + } + return nil +} + +func proxyURLFromAPIKeyConfig(cfg *config.Config, auth *coreauth.Auth) string { + if cfg == nil || auth == nil { + return "" + } + authKind, authAccount := auth.AccountInfo() + if !strings.EqualFold(strings.TrimSpace(authKind), "api_key") { + return "" + } + + attrs := auth.Attributes + compatName := "" + providerKey := "" + if len(attrs) > 0 { + compatName = strings.TrimSpace(attrs["compat_name"]) + providerKey = strings.TrimSpace(attrs["provider_key"]) + } + if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") { + return resolveOpenAICompatAPIKeyProxyURL(cfg, auth, strings.TrimSpace(authAccount), providerKey, compatName) + } + + switch strings.ToLower(strings.TrimSpace(auth.Provider)) { + case "gemini": + if entry := resolveAPIKeyConfig(cfg.GeminiKey, auth); entry != nil { + return strings.TrimSpace(entry.ProxyURL) + } + case "claude": + if entry := resolveAPIKeyConfig(cfg.ClaudeKey, auth); entry != nil { + return strings.TrimSpace(entry.ProxyURL) + } + case "codex": + if entry := resolveAPIKeyConfig(cfg.CodexKey, auth); entry != nil { + return strings.TrimSpace(entry.ProxyURL) + } + } + return "" +} + +func resolveOpenAICompatAPIKeyProxyURL(cfg *config.Config, auth *coreauth.Auth, apiKey, providerKey, compatName string) string { + if cfg == nil || auth == nil { + return "" + } + apiKey = strings.TrimSpace(apiKey) + if apiKey == "" { + return "" + } + candidates := make([]string, 0, 3) + if v := strings.TrimSpace(compatName); v != "" { + candidates = append(candidates, v) + } + if v := strings.TrimSpace(providerKey); v != "" { + candidates = append(candidates, v) + } + if v := strings.TrimSpace(auth.Provider); v != "" { + candidates = append(candidates, v) + } + + for i := range cfg.OpenAICompatibility { + compat := &cfg.OpenAICompatibility[i] + for _, candidate := range candidates { + if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) { + for j := range compat.APIKeyEntries { + entry := &compat.APIKeyEntries[j] + if strings.EqualFold(strings.TrimSpace(entry.APIKey), apiKey) { + return strings.TrimSpace(entry.ProxyURL) + } + } + return "" + } + } + } + return "" +} + func buildProxyTransport(proxyStr string) *http.Transport { transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr) if errBuild != nil { diff --git a/internal/api/handlers/management/api_tools_test.go b/internal/api/handlers/management/api_tools_test.go index 6ed98c6e77..b27fe6395a 100644 --- a/internal/api/handlers/management/api_tools_test.go +++ b/internal/api/handlers/management/api_tools_test.go @@ -58,6 +58,105 @@ func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) { } } +func TestAPICallTransportAPIKeyAuthFallsBackToConfigProxyURL(t *testing.T) { + t.Parallel() + + h := &Handler{ + cfg: &config.Config{ + SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}, + GeminiKey: []config.GeminiKey{{ + APIKey: "gemini-key", + ProxyURL: "http://gemini-proxy.example.com:8080", + }}, + ClaudeKey: []config.ClaudeKey{{ + APIKey: "claude-key", + ProxyURL: "http://claude-proxy.example.com:8080", + }}, + CodexKey: []config.CodexKey{{ + APIKey: "codex-key", + ProxyURL: "http://codex-proxy.example.com:8080", + }}, + OpenAICompatibility: []config.OpenAICompatibility{{ + Name: "bohe", + BaseURL: "https://bohe.example.com", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{{ + APIKey: "compat-key", + ProxyURL: "http://compat-proxy.example.com:8080", + }}, + }}, + }, + } + + cases := []struct { + name string + auth *coreauth.Auth + wantProxy string + }{ + { + name: "gemini", + auth: &coreauth.Auth{ + Provider: "gemini", + Attributes: map[string]string{"api_key": "gemini-key"}, + }, + wantProxy: "http://gemini-proxy.example.com:8080", + }, + { + name: "claude", + auth: &coreauth.Auth{ + Provider: "claude", + Attributes: map[string]string{"api_key": "claude-key"}, + }, + wantProxy: "http://claude-proxy.example.com:8080", + }, + { + name: "codex", + auth: &coreauth.Auth{ + Provider: "codex", + Attributes: map[string]string{"api_key": "codex-key"}, + }, + wantProxy: "http://codex-proxy.example.com:8080", + }, + { + name: "openai-compatibility", + auth: &coreauth.Auth{ + Provider: "bohe", + Attributes: map[string]string{ + "api_key": "compat-key", + "compat_name": "bohe", + "provider_key": "bohe", + }, + }, + wantProxy: "http://compat-proxy.example.com:8080", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + transport := h.apiCallTransport(tc.auth) + httpTransport, ok := transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", transport) + } + + req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil) + if errRequest != nil { + t.Fatalf("http.NewRequest returned error: %v", errRequest) + } + + proxyURL, errProxy := httpTransport.Proxy(req) + if errProxy != nil { + t.Fatalf("httpTransport.Proxy returned error: %v", errProxy) + } + if proxyURL == nil || proxyURL.String() != tc.wantProxy { + t.Fatalf("proxy URL = %v, want %s", proxyURL, tc.wantProxy) + } + }) + } +} + func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) { t.Parallel() diff --git a/internal/api/server.go b/internal/api/server.go index 6126ba2c12..12205ad6e1 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -317,6 +317,10 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk // setupRoutes configures the API routes for the server. // It defines the endpoints and associates them with their respective handlers. func (s *Server) setupRoutes() { + s.engine.GET("/healthz", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + }) + s.engine.GET("/management.html", s.serveManagementControlPanel) openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers) geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers) diff --git a/internal/api/server_test.go b/internal/api/server_test.go index f5c18aa167..e224c90a32 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -1,6 +1,7 @@ package api import ( + "encoding/json" "net/http" "net/http/httptest" "os" @@ -46,6 +47,28 @@ func newTestServer(t *testing.T) *Server { return NewServer(cfg, authManager, accessManager, configPath) } +func TestHealthz(t *testing.T) { + server := newTestServer(t) + + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("unexpected status code: got %d want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String()) + } + + var resp struct { + Status string `json:"status"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse response JSON: %v; body=%s", err, rr.Body.String()) + } + if resp.Status != "ok" { + t.Fatalf("unexpected response status: got %q want %q", resp.Status, "ok") + } +} + func TestAmpProviderModelRoutes(t *testing.T) { testCases := []struct { name string diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index f930020303..c1f79cc9d7 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -246,6 +246,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } } +func websocketClientAddress(c *gin.Context) string { + if c == nil || c.Request == nil { + return "" + } + return strings.TrimSpace(c.ClientIP()) +} + func websocketUpgradeHeaders(req *http.Request) http.Header { headers := http.Header{} if req == nil { @@ -488,13 +495,6 @@ func normalizeResponseTranscriptReplacement(rawJSON []byte, lastRequest []byte) return bytes.Clone(normalized) } -func websocketClientAddress(c *gin.Context) string { - if c == nil || c.Request == nil { - return "" - } - return strings.TrimSpace(c.ClientIP()) -} - func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool { if len(attributes) > 0 { if raw := strings.TrimSpace(attributes["websockets"]); raw != "" { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index bd4ba62e2d..1095e610fc 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -968,6 +968,31 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) { } } +func TestWebsocketClientAddressUsesGinClientIP(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, engine := gin.CreateTestContext(recorder) + if err := engine.SetTrustedProxies([]string{"0.0.0.0/0", "::/0"}); err != nil { + t.Fatalf("SetTrustedProxies: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/v1/responses/ws", nil) + req.RemoteAddr = "172.18.0.1:34282" + req.Header.Set("X-Forwarded-For", "203.0.113.7") + c.Request = req + + if got := websocketClientAddress(c); got != strings.TrimSpace(c.ClientIP()) { + t.Fatalf("websocketClientAddress = %q, ClientIP = %q", got, c.ClientIP()) + } +} + +func TestWebsocketClientAddressReturnsEmptyForNilContext(t *testing.T) { + if got := websocketClientAddress(nil); got != "" { + t.Fatalf("websocketClientAddress(nil) = %q, want empty", got) + } +} + func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) { gin.SetMode(gin.TestMode)