Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions internal/api/handlers/management/api_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
Expand Down Expand Up @@ -636,6 +637,11 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(proxyURLFromAPIKeyConfig(h.cfg, auth)); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
}
if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
Expand All @@ -658,6 +664,123 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
return clone
}

type apiKeyConfigEntry interface {
GetAPIKey() string
GetBaseURL() string
}

func resolveAPIKeyConfig[T apiKeyConfigEntry](entries []T, auth *coreauth.Auth) *T {
if auth == nil || len(entries) == 0 {
return nil
}
attrKey, attrBase := "", ""
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range entries {
entry := &entries[i]
cfgKey := strings.TrimSpace((*entry).GetAPIKey())
cfgBase := strings.TrimSpace((*entry).GetBaseURL())
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range entries {
entry := &entries[i]
if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) {
return entry
}
}
}
return nil
}

func proxyURLFromAPIKeyConfig(cfg *config.Config, auth *coreauth.Auth) string {
if cfg == nil || auth == nil {
return ""
}
authKind, authAccount := auth.AccountInfo()
if !strings.EqualFold(strings.TrimSpace(authKind), "api_key") {
return ""
}

attrs := auth.Attributes
compatName := ""
providerKey := ""
if len(attrs) > 0 {
compatName = strings.TrimSpace(attrs["compat_name"])
providerKey = strings.TrimSpace(attrs["provider_key"])
}
if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
return resolveOpenAICompatAPIKeyProxyURL(cfg, auth, strings.TrimSpace(authAccount), providerKey, compatName)
}

switch strings.ToLower(strings.TrimSpace(auth.Provider)) {
case "gemini":
if entry := resolveAPIKeyConfig(cfg.GeminiKey, auth); entry != nil {
return strings.TrimSpace(entry.ProxyURL)
}
case "claude":
if entry := resolveAPIKeyConfig(cfg.ClaudeKey, auth); entry != nil {
return strings.TrimSpace(entry.ProxyURL)
}
case "codex":
if entry := resolveAPIKeyConfig(cfg.CodexKey, auth); entry != nil {
return strings.TrimSpace(entry.ProxyURL)
}
}
return ""
}

func resolveOpenAICompatAPIKeyProxyURL(cfg *config.Config, auth *coreauth.Auth, apiKey, providerKey, compatName string) string {
if cfg == nil || auth == nil {
return ""
}
apiKey = strings.TrimSpace(apiKey)
if apiKey == "" {
return ""
}
candidates := make([]string, 0, 3)
if v := strings.TrimSpace(compatName); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(providerKey); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(auth.Provider); v != "" {
candidates = append(candidates, v)
}

for i := range cfg.OpenAICompatibility {
compat := &cfg.OpenAICompatibility[i]
for _, candidate := range candidates {
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
for j := range compat.APIKeyEntries {
entry := &compat.APIKeyEntries[j]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), apiKey) {
return strings.TrimSpace(entry.ProxyURL)
}
}
return ""
}
}
}
return ""
}

func buildProxyTransport(proxyStr string) *http.Transport {
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
if errBuild != nil {
Expand Down
99 changes: 99 additions & 0 deletions internal/api/handlers/management/api_tools_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,105 @@ func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
}
}

func TestAPICallTransportAPIKeyAuthFallsBackToConfigProxyURL(t *testing.T) {
t.Parallel()

h := &Handler{
cfg: &config.Config{
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
GeminiKey: []config.GeminiKey{{
APIKey: "gemini-key",
ProxyURL: "http://gemini-proxy.example.com:8080",
}},
ClaudeKey: []config.ClaudeKey{{
APIKey: "claude-key",
ProxyURL: "http://claude-proxy.example.com:8080",
}},
CodexKey: []config.CodexKey{{
APIKey: "codex-key",
ProxyURL: "http://codex-proxy.example.com:8080",
}},
OpenAICompatibility: []config.OpenAICompatibility{{
Name: "bohe",
BaseURL: "https://bohe.example.com",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{
APIKey: "compat-key",
ProxyURL: "http://compat-proxy.example.com:8080",
}},
}},
},
}

cases := []struct {
name string
auth *coreauth.Auth
wantProxy string
}{
{
name: "gemini",
auth: &coreauth.Auth{
Provider: "gemini",
Attributes: map[string]string{"api_key": "gemini-key"},
},
wantProxy: "http://gemini-proxy.example.com:8080",
},
{
name: "claude",
auth: &coreauth.Auth{
Provider: "claude",
Attributes: map[string]string{"api_key": "claude-key"},
},
wantProxy: "http://claude-proxy.example.com:8080",
},
{
name: "codex",
auth: &coreauth.Auth{
Provider: "codex",
Attributes: map[string]string{"api_key": "codex-key"},
},
wantProxy: "http://codex-proxy.example.com:8080",
},
{
name: "openai-compatibility",
auth: &coreauth.Auth{
Provider: "bohe",
Attributes: map[string]string{
"api_key": "compat-key",
"compat_name": "bohe",
"provider_key": "bohe",
},
},
wantProxy: "http://compat-proxy.example.com:8080",
},
}

for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

transport := h.apiCallTransport(tc.auth)
httpTransport, ok := transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
}

req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
if errRequest != nil {
t.Fatalf("http.NewRequest returned error: %v", errRequest)
}

proxyURL, errProxy := httpTransport.Proxy(req)
if errProxy != nil {
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
}
if proxyURL == nil || proxyURL.String() != tc.wantProxy {
t.Fatalf("proxy URL = %v, want %s", proxyURL, tc.wantProxy)
}
})
}
}

func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
t.Parallel()

Expand Down
4 changes: 4 additions & 0 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
// setupRoutes configures the API routes for the server.
// It defines the endpoints and associates them with their respective handlers.
func (s *Server) setupRoutes() {
s.engine.GET("/healthz", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})

s.engine.GET("/management.html", s.serveManagementControlPanel)
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
Expand Down
23 changes: 23 additions & 0 deletions internal/api/server_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -46,6 +47,28 @@ func newTestServer(t *testing.T) *Server {
return NewServer(cfg, authManager, accessManager, configPath)
}

func TestHealthz(t *testing.T) {
server := newTestServer(t)

req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
rr := httptest.NewRecorder()
server.engine.ServeHTTP(rr, req)

if rr.Code != http.StatusOK {
t.Fatalf("unexpected status code: got %d want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String())
}

var resp struct {
Status string `json:"status"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response JSON: %v; body=%s", err, rr.Body.String())
}
if resp.Status != "ok" {
t.Fatalf("unexpected response status: got %q want %q", resp.Status, "ok")
}
}

func TestAmpProviderModelRoutes(t *testing.T) {
testCases := []struct {
name string
Expand Down
14 changes: 7 additions & 7 deletions sdk/api/handlers/openai/openai_responses_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
}

func websocketClientAddress(c *gin.Context) string {
if c == nil || c.Request == nil {
return ""
}
return strings.TrimSpace(c.ClientIP())
}

func websocketUpgradeHeaders(req *http.Request) http.Header {
headers := http.Header{}
if req == nil {
Expand Down Expand Up @@ -488,13 +495,6 @@ func normalizeResponseTranscriptReplacement(rawJSON []byte, lastRequest []byte)
return bytes.Clone(normalized)
}

func websocketClientAddress(c *gin.Context) string {
if c == nil || c.Request == nil {
return ""
}
return strings.TrimSpace(c.ClientIP())
}

func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
if len(attributes) > 0 {
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
Expand Down
25 changes: 25 additions & 0 deletions sdk/api/handlers/openai/openai_responses_websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,31 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
}
}

func TestWebsocketClientAddressUsesGinClientIP(t *testing.T) {
gin.SetMode(gin.TestMode)

recorder := httptest.NewRecorder()
c, engine := gin.CreateTestContext(recorder)
if err := engine.SetTrustedProxies([]string{"0.0.0.0/0", "::/0"}); err != nil {
t.Fatalf("SetTrustedProxies: %v", err)
}

req := httptest.NewRequest(http.MethodGet, "/v1/responses/ws", nil)
req.RemoteAddr = "172.18.0.1:34282"
req.Header.Set("X-Forwarded-For", "203.0.113.7")
c.Request = req

if got := websocketClientAddress(c); got != strings.TrimSpace(c.ClientIP()) {
t.Fatalf("websocketClientAddress = %q, ClientIP = %q", got, c.ClientIP())
}
}

func TestWebsocketClientAddressReturnsEmptyForNilContext(t *testing.T) {
if got := websocketClientAddress(nil); got != "" {
t.Fatalf("websocketClientAddress(nil) = %q, want empty", got)
}
}

func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
gin.SetMode(gin.TestMode)

Expand Down
Loading