From 8d0a002e8da5dc75d0bc430bc9531daa47a881dd Mon Sep 17 00:00:00 2001 From: Thomas Date: Tue, 7 Apr 2026 05:09:36 +0800 Subject: [PATCH 1/3] feat: add global and per-provider upstream proxy settings --- docs/en/config-reference.md | 6 + docs/zh/config-reference.md | 6 + internal/config/config.go | 87 ++++++++ internal/config/config_test.go | 134 +++++++++++++ internal/config/proxy_url.go | 34 ++++ internal/proxy/failover_forward.go | 4 +- internal/proxy/failover_manual.go | 2 +- internal/proxy/proxy.go | 135 ++++++++++--- internal/proxy/proxy_test.go | 122 +++++++++++ internal/proxy/reload_state_test.go | 103 ++++++++++ internal/web/api.go | 107 +++++++++- internal/web/api_test.go | 300 ++++++++++++++++++++++++---- internal/web/static/app.js | 91 ++++++++- internal/web/static/app.test.js | 52 +++++ internal/web/static/index.html | 40 ++++ internal/web/types.go | 58 ++++-- internal/web/yaml_format.go | 10 + internal/web/yaml_format_test.go | 31 +++ 18 files changed, 1234 insertions(+), 88 deletions(-) create mode 100644 internal/config/proxy_url.go diff --git a/docs/en/config-reference.md b/docs/en/config-reference.md index bacce62..fc63566 100644 --- a/docs/en/config-reference.md +++ b/docs/en/config-reference.md @@ -60,6 +60,8 @@ providers: | `reactivate_after` | duration | `1h` | Auto-reactivation delay for temporarily deactivated providers; set `0` to disable temporary deactivation for auth, billing, and quota failures | | `upstream_idle_timeout` | duration | `3m` | Abort the current upstream attempt if no response body bytes arrive for too long | | `response_header_timeout` | duration | `2m` | Timeout while waiting for upstream response headers | +| `upstream_proxy_mode` | string | `inherit` | Default upstream proxy mode for providers that use `proxy_mode: inherit`; `inherit` / `direct` / `custom` | +| `upstream_proxy_url` | string | empty | Required when `upstream_proxy_mode: custom`; supports `http://`, `https://`, `socks5://`, and `socks5h://` proxy URLs | | `max_request_body_bytes` | int | `33554432` | Request body size limit, default 32 MiB | | `log_dir` | string | `/logs` | Log directory | | `log_retention_days` | int | `7` | Log retention days; `0` keeps logs forever; default is 7 days | @@ -179,6 +181,8 @@ providers: | `base_url` | string | yes | Upstream API base URL | | `api_key` | string | one of two | Single API key | | `api_keys` | array | one of two | Multiple API keys, used in order | +| `proxy_mode` | string | no | Upstream proxy mode for this provider; `inherit` follows the global default | +| `proxy_url` | string | no | Required when `proxy_mode: custom`; supports `http://`, `https://`, `socks5://`, and `socks5h://` proxy URLs | | `priority` | int | no | Lower number = higher priority; omitted or `0` is treated as `1` | | `enabled` | bool | no | Defaults to `true` | | `model` | string | no | Force this provider to use a specific upstream model name for supported OpenAI and Claude requests | @@ -189,6 +193,8 @@ providers: - Use `api_key` when you only have one key - Use `api_keys` when you want retries across multiple keys within the same provider +- Use global `upstream_proxy_mode` / `upstream_proxy_url` to define the default proxy for inherited providers +- Use provider `proxy_mode: direct` to bypass both the global default proxy and environment proxy settings - Use `model` when different upstream providers expose the same family under different model IDs - Use `reasoning_effort` and `thinking_budget_tokens` only when you want Clipal to override the client-sent defaults for that provider - For long-running background setups, this is a good default: diff --git a/docs/zh/config-reference.md b/docs/zh/config-reference.md index dff57d7..55a67c8 100644 --- a/docs/zh/config-reference.md +++ b/docs/zh/config-reference.md @@ -60,6 +60,8 @@ providers: | `reactivate_after` | duration | `1h` | provider 临时禁用后的自动恢复时间;设为 `0` 可禁用基于鉴权、计费、额度错误的临时禁用 | | `upstream_idle_timeout` | duration | `3m` | 上游响应 body 长时间无字节时中断当前尝试 | | `response_header_timeout` | duration | `2m` | 等待上游响应头的超时 | +| `upstream_proxy_mode` | string | `inherit` | 作为默认值应用到 `proxy_mode: inherit` 的 provider;可选 `inherit` / `direct` / `custom` | +| `upstream_proxy_url` | string | 空 | 当 `upstream_proxy_mode: custom` 时必填;支持 `http://`、`https://`、`socks5://` 和 `socks5h://` 代理 URL | | `max_request_body_bytes` | int | `33554432` | 请求体大小上限,默认 32 MiB | | `log_dir` | string | `/logs` | 日志目录 | | `log_retention_days` | int | `7` | 日志保留天数;`0` 表示永久保留;默认保留 7 天 | @@ -179,6 +181,8 @@ providers: | `base_url` | string | 是 | 上游 API Base URL | | `api_key` | string | 二选一 | 单个 API Key | | `api_keys` | array | 二选一 | 多个 API Key,按顺序使用 | +| `proxy_mode` | string | 否 | 该 provider 的上游代理模式;`inherit` 表示继承全局默认代理 | +| `proxy_url` | string | 否 | 当 `proxy_mode: custom` 时必填;支持 `http://`、`https://`、`socks5://` 和 `socks5h://` 代理 URL | | `priority` | int | 否 | 数字越小优先级越高;省略或 `0` 时按 `1` 处理 | | `enabled` | bool | 否 | 是否启用,默认 `true` | | `model` | string | 否 | 对支持的 OpenAI / Claude 请求强制改写为这个上游模型名 | @@ -189,6 +193,8 @@ providers: - 只有一个 key 时用 `api_key` - 需要同 provider 多 key 轮转时用 `api_keys` +- 需要统一默认代理时,优先配置全局 `upstream_proxy_mode` / `upstream_proxy_url` +- 需要让某个 provider 绕过全局默认代理和环境代理时,用 `proxy_mode: direct` - 不同上游对同一模型族使用不同模型 ID 时,可为该 provider 配置 `model` - 只有在你希望 Clipal 按 provider 覆盖客户端默认思考参数时,才配置 `reasoning_effort` 或 `thinking_budget_tokens` - 常驻后台运行时,建议: diff --git a/internal/config/config.go b/internal/config/config.go index 34f21d9..b9992c2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -93,6 +93,8 @@ type GlobalConfig struct { // ResponseHeaderTimeout controls how long we wait for the upstream to return // response headers after the request is fully written. Set to "0" to disable. ResponseHeaderTimeout string `yaml:"response_header_timeout"` + UpstreamProxyMode ProviderProxyMode `yaml:"upstream_proxy_mode,omitempty"` + UpstreamProxyURL string `yaml:"upstream_proxy_url,omitempty"` MaxRequestBody int64 `yaml:"max_request_body_bytes"` LogDir string `yaml:"log_dir"` LogRetentionDays int `yaml:"log_retention_days"` @@ -133,11 +135,21 @@ type ClaudeOverrides struct { ThinkingBudgetTokens *int `yaml:"thinking_budget_tokens,omitempty"` } +type ProviderProxyMode string + +const ( + ProviderProxyModeInherit ProviderProxyMode = "inherit" + ProviderProxyModeDirect ProviderProxyMode = "direct" + ProviderProxyModeCustom ProviderProxyMode = "custom" +) + type providerYAML struct { Name string `yaml:"name"` BaseURL string `yaml:"base_url"` APIKey string `yaml:"api_key,omitempty"` APIKeys []string `yaml:"api_keys,omitempty"` + ProxyMode ProviderProxyMode `yaml:"proxy_mode,omitempty"` + ProxyURL string `yaml:"proxy_url,omitempty"` Priority int `yaml:"priority"` Enabled *bool `yaml:"enabled,omitempty"` Overrides *ProviderOverrides `yaml:"overrides,omitempty"` @@ -152,6 +164,8 @@ type Provider struct { BaseURL string `yaml:"base_url"` APIKey string `yaml:"api_key,omitempty"` APIKeys []string `yaml:"api_keys,omitempty"` + ProxyMode ProviderProxyMode `yaml:"proxy_mode,omitempty"` + ProxyURL string `yaml:"proxy_url,omitempty"` Priority int `yaml:"priority"` Enabled *bool `yaml:"enabled,omitempty"` Overrides *ProviderOverrides `yaml:"-"` @@ -190,19 +204,32 @@ func (p *Provider) UnmarshalYAML(value *yaml.Node) error { BaseURL: raw.BaseURL, APIKey: raw.APIKey, APIKeys: append([]string(nil), raw.APIKeys...), + ProxyMode: raw.ProxyMode, + ProxyURL: raw.ProxyURL, Priority: raw.Priority, Enabled: raw.Enabled, Overrides: NormalizeProviderOverrides(overrides), } + NormalizeProviderProxySettings(p) return nil } func (p Provider) MarshalYAML() (any, error) { + proxyMode := p.NormalizedProxyMode() + proxyURL := p.NormalizedProxyURL() + if proxyMode == ProviderProxyModeInherit { + proxyMode = "" + } + if proxyMode != ProviderProxyModeCustom { + proxyURL = "" + } return providerYAML{ Name: p.Name, BaseURL: p.BaseURL, APIKey: p.APIKey, APIKeys: append([]string(nil), p.APIKeys...), + ProxyMode: proxyMode, + ProxyURL: proxyURL, Priority: p.Priority, Enabled: p.Enabled, Overrides: NormalizeProviderOverrides(p.Overrides), @@ -263,6 +290,57 @@ func NormalizeProviderOverrides(overrides *ProviderOverrides) *ProviderOverrides return &normalized } +func NormalizeProviderProxySettings(provider *Provider) { + if provider == nil { + return + } + provider.ProxyMode = provider.NormalizedProxyMode() + provider.ProxyURL = provider.NormalizedProxyURL() +} + +func (g GlobalConfig) NormalizedUpstreamProxyMode() ProviderProxyMode { + mode := strings.ToLower(strings.TrimSpace(string(g.UpstreamProxyMode))) + if mode == "" { + return ProviderProxyModeInherit + } + return ProviderProxyMode(mode) +} + +func (g GlobalConfig) NormalizedUpstreamProxyURL() string { + return strings.TrimSpace(g.UpstreamProxyURL) +} + +func (p Provider) NormalizedProxyMode() ProviderProxyMode { + mode := strings.ToLower(strings.TrimSpace(string(p.ProxyMode))) + if mode == "" { + return ProviderProxyModeInherit + } + return ProviderProxyMode(mode) +} + +func (p Provider) NormalizedProxyURL() string { + return strings.TrimSpace(p.ProxyURL) +} + +func validateProxySettings(scope string, mode ProviderProxyMode, rawURL string) error { + switch mode { + case ProviderProxyModeInherit, ProviderProxyModeDirect: + if rawURL != "" { + return fmt.Errorf("%s: proxy_url requires proxy_mode custom", scope) + } + case ProviderProxyModeCustom: + if rawURL == "" { + return fmt.Errorf("%s: proxy_url is required when proxy_mode=custom", scope) + } + if err := ValidateProxyURL(rawURL); err != nil { + return fmt.Errorf("%s: %w", scope, err) + } + default: + return fmt.Errorf("%s: invalid proxy_mode %q", scope, mode) + } + return nil +} + // IsEnabled returns whether the provider is enabled (default true) func (p *Provider) IsEnabled() bool { if p.Enabled == nil { @@ -363,6 +441,8 @@ func DefaultGlobalConfig() GlobalConfig { ReactivateAfter: "1h", UpstreamIdleTimeout: "3m", ResponseHeaderTimeout: "2m", + UpstreamProxyMode: ProviderProxyModeInherit, + UpstreamProxyURL: "", // Default body limit: 32 MiB. clipal buffers request bodies to support retries, // so a hard cap prevents unbounded memory usage. MaxRequestBody: 32 * 1024 * 1024, @@ -602,6 +682,7 @@ func applyClientDefaults(cc *ClientConfig) { } cc.Providers[i].APIKey = strings.TrimSpace(cc.Providers[i].APIKey) cc.Providers[i].APIKeys = cc.Providers[i].NormalizedAPIKeys() + NormalizeProviderProxySettings(&cc.Providers[i]) cc.Providers[i].Overrides = NormalizeProviderOverrides(cc.Providers[i].Overrides) if len(cc.Providers[i].APIKeys) == 1 { cc.Providers[i].APIKey = cc.Providers[i].APIKeys[0] @@ -700,6 +781,9 @@ func (c *Config) Validate() error { if c.Global.LogRetentionDays < 0 { return fmt.Errorf("invalid log_retention_days: %d", c.Global.LogRetentionDays) } + if err := validateProxySettings("global upstream proxy", c.Global.NormalizedUpstreamProxyMode(), c.Global.NormalizedUpstreamProxyURL()); err != nil { + return err + } // Circuit breaker: // - failure_threshold == 0 disables the circuit breaker entirely. @@ -809,6 +893,9 @@ func validateProviders(clientName string, providers []Provider) error { if p.Priority < 1 { return fmt.Errorf("%s provider %s: priority must be >= 1", clientName, p.Name) } + if err := validateProxySettings(fmt.Sprintf("%s provider %s", clientName, p.Name), p.NormalizedProxyMode(), p.NormalizedProxyURL()); err != nil { + return err + } if !providerOverridesSupportedForClient(clientName, p.Overrides) { return fmt.Errorf("%s provider %s: unsupported overrides for client", clientName, p.Name) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index d40d25e..ee90b84 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -705,3 +705,137 @@ func TestGetConfigDir_RespectsEnvironmentOverride(t *testing.T) { t.Fatalf("GetConfigDir = %q, want %q", got, want) } } + +func TestLoad_ProviderProxyModeDefaultsToInherit(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + writeClientConfigFile(t, dir, "openai.yaml", ` +providers: + - name: p1 + base_url: https://example.com + api_key: key + priority: 1 +`) + + cfg, err := Load(dir) + if err != nil { + t.Fatalf("Load: %v", err) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != ProviderProxyModeInherit { + t.Fatalf("proxy mode = %q, want %q", got, ProviderProxyModeInherit) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyURL(); got != "" { + t.Fatalf("proxy url = %q, want empty", got) + } +} + +func TestValidate_ProviderProxySettings(t *testing.T) { + t.Parallel() + + base := &Config{ + Global: DefaultGlobalConfig(), + Claude: ClientConfig{Mode: ClientModeAuto}, + OpenAI: ClientConfig{Mode: ClientModeAuto}, + Gemini: ClientConfig{Mode: ClientModeAuto}, + } + + makeProvider := func(mode ProviderProxyMode, proxyURL string) Provider { + return Provider{ + Name: "p1", + BaseURL: "https://example.com", + APIKey: "key", + ProxyMode: mode, + ProxyURL: proxyURL, + Priority: 1, + } + } + + t.Run("accepts direct", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeDirect, "")} + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + }) + + t.Run("accepts custom http proxy", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeCustom, "http://127.0.0.1:7890")} + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + }) + + t.Run("accepts custom socks5 proxy", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeCustom, "socks5://127.0.0.1:1080")} + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + }) + + t.Run("accepts custom socks5h proxy", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeCustom, "socks5h://127.0.0.1:1080")} + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + }) + + t.Run("rejects custom proxy without url", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeCustom, "")} + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), "proxy_url is required") { + t.Fatalf("Validate err = %v", err) + } + }) + + t.Run("rejects unsupported proxy scheme", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeCustom, "ftp://127.0.0.1:21")} + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), "proxy_url scheme must be http, https, socks5, or socks5h") { + t.Fatalf("Validate err = %v", err) + } + }) + + t.Run("rejects proxy url without custom mode", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeDirect, "http://127.0.0.1:7890")} + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), "proxy_url requires proxy_mode custom") { + t.Fatalf("Validate err = %v", err) + } + }) +} + +func TestValidate_GlobalUpstreamProxySettings(t *testing.T) { + t.Parallel() + + cfg := &Config{ + Global: DefaultGlobalConfig(), + Claude: ClientConfig{Mode: ClientModeAuto}, + OpenAI: ClientConfig{Mode: ClientModeAuto}, + Gemini: ClientConfig{Mode: ClientModeAuto}, + } + + cfg.Global.UpstreamProxyMode = ProviderProxyModeCustom + cfg.Global.UpstreamProxyURL = "http://127.0.0.1:7890" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + + cfg.Global.UpstreamProxyURL = "socks5://127.0.0.1:1080" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + + cfg.Global.UpstreamProxyURL = "socks5h://127.0.0.1:1080" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + + cfg.Global.UpstreamProxyURL = "ftp://127.0.0.1:21" + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), "proxy_url scheme must be http, https, socks5, or socks5h") { + t.Fatalf("Validate err = %v", err) + } +} diff --git a/internal/config/proxy_url.go b/internal/config/proxy_url.go new file mode 100644 index 0000000..dbbbc43 --- /dev/null +++ b/internal/config/proxy_url.go @@ -0,0 +1,34 @@ +package config + +import ( + "fmt" + "net/url" + "strings" +) + +const ( + supportedProxyURLSchemeList = "http, https, socks5, or socks5h" + supportedProxyURLPrefixList = "http://, https://, socks5://, or socks5h://" +) + +// ParseProxyURL validates a configured proxy URL and normalizes its scheme. +func ParseProxyURL(raw string) (*url.URL, error) { + parsed, err := url.Parse(strings.TrimSpace(raw)) + if err != nil || !parsed.IsAbs() || parsed.Host == "" { + return nil, fmt.Errorf("proxy_url must be an absolute %s URL", supportedProxyURLPrefixList) + } + + parsed.Scheme = strings.ToLower(parsed.Scheme) + switch parsed.Scheme { + case "http", "https", "socks5", "socks5h": + return parsed, nil + default: + return nil, fmt.Errorf("proxy_url scheme must be %s", supportedProxyURLSchemeList) + } +} + +// ValidateProxyURL reports whether a configured proxy URL is supported. +func ValidateProxyURL(raw string) error { + _, err := ParseProxyURL(raw) + return err +} diff --git a/internal/proxy/failover_forward.go b/internal/proxy/failover_forward.go index a4c39cc..ced20d1 100644 --- a/internal/proxy/failover_forward.go +++ b/internal/proxy/failover_forward.go @@ -236,7 +236,7 @@ func (cp *ClientProxy) forwardWithFailover(w http.ResponseWriter, req *http.Requ } //nolint:gosec // Clipal is a user-configured reverse proxy and intentionally forwards to configured upstream base URLs. - resp, err := cp.httpClient.Do(proxyReq) + resp, err := cp.upstreamHTTPClient(index).Do(proxyReq) if err != nil { if req.Context().Err() != nil { if busyProbeHeld { @@ -565,7 +565,7 @@ func (cp *ClientProxy) forwardCountTokensSingleShot(w http.ResponseWriter, req * } //nolint:gosec // Clipal is a user-configured reverse proxy and intentionally forwards to configured upstream base URLs. - resp, err := cp.httpClient.Do(proxyReq) + resp, err := cp.upstreamHTTPClient(index).Do(proxyReq) if err != nil { if req.Context().Err() != nil { return diff --git a/internal/proxy/failover_manual.go b/internal/proxy/failover_manual.go index eb4a605..73e41ae 100644 --- a/internal/proxy/failover_manual.go +++ b/internal/proxy/failover_manual.go @@ -74,7 +74,7 @@ func (cp *ClientProxy) forwardManual(w http.ResponseWriter, req *http.Request, p } //nolint:gosec // Clipal is a user-configured reverse proxy and intentionally forwards to the pinned upstream provider. - resp, err := cp.httpClient.Do(proxyReq) + resp, err := cp.upstreamHTTPClient(index).Do(proxyReq) if err != nil { if req.Context().Err() != nil { cancelAttempt(nil) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 528145e..5f2902e 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -145,6 +145,9 @@ type ClientProxy struct { geminiStreamKeyIndex []int mu sync.RWMutex httpClient *http.Client + providerHTTPClients []*http.Client + providerProxyModes []config.ProviderProxyMode + providerProxyURLs []string deactivated []providerDeactivation keyDeactivated [][]providerDeactivation providerBusy []providerBusyState @@ -163,8 +166,20 @@ type ClientProxy struct { // Close releases resources held by the ClientProxy. func (cp *ClientProxy) Close() { - if cp.httpClient != nil { - cp.httpClient.CloseIdleConnections() + seen := make(map[*http.Client]struct{}, len(cp.providerHTTPClients)+1) + closeClient := func(client *http.Client) { + if client == nil { + return + } + if _, ok := seen[client]; ok { + return + } + seen[client] = struct{}{} + client.CloseIdleConnections() + } + closeClient(cp.httpClient) + for _, client := range cp.providerHTTPClients { + closeClient(client) } } @@ -193,19 +208,19 @@ func NewRouter(cfg *config.Config) *Router { // Initialize client proxies claudeProviders := config.GetEnabledProviders(cfg.Claude) if len(claudeProviders) > 0 { - r.proxies[ClientClaude] = newClientProxy(ClientClaude, cfg.Claude.Mode, cfg.Claude.PinnedProvider, claudeProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, telemetryStore) + r.proxies[ClientClaude] = newClientProxyWithGlobalProxy(ClientClaude, cfg.Claude.Mode, cfg.Claude.PinnedProvider, claudeProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, cfg.Global.NormalizedUpstreamProxyMode(), cfg.Global.NormalizedUpstreamProxyURL(), telemetryStore) r.proxies[ClientClaude].applyRoutingRuntimeSettings(routingCfg) } codexProviders := config.GetEnabledProviders(cfg.OpenAI) if len(codexProviders) > 0 { - r.proxies[ClientOpenAI] = newClientProxy(ClientOpenAI, cfg.OpenAI.Mode, cfg.OpenAI.PinnedProvider, codexProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, telemetryStore) + r.proxies[ClientOpenAI] = newClientProxyWithGlobalProxy(ClientOpenAI, cfg.OpenAI.Mode, cfg.OpenAI.PinnedProvider, codexProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, cfg.Global.NormalizedUpstreamProxyMode(), cfg.Global.NormalizedUpstreamProxyURL(), telemetryStore) r.proxies[ClientOpenAI].applyRoutingRuntimeSettings(routingCfg) } geminiProviders := config.GetEnabledProviders(cfg.Gemini) if len(geminiProviders) > 0 { - r.proxies[ClientGemini] = newClientProxy(ClientGemini, cfg.Gemini.Mode, cfg.Gemini.PinnedProvider, geminiProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, telemetryStore) + r.proxies[ClientGemini] = newClientProxyWithGlobalProxy(ClientGemini, cfg.Gemini.Mode, cfg.Gemini.PinnedProvider, geminiProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, cfg.Global.NormalizedUpstreamProxyMode(), cfg.Global.NormalizedUpstreamProxyURL(), telemetryStore) r.proxies[ClientGemini].applyRoutingRuntimeSettings(routingCfg) } @@ -213,6 +228,10 @@ func NewRouter(cfg *config.Config) *Router { } func newClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvider string, providers []config.Provider, reactivateAfter time.Duration, upstreamIdle time.Duration, responseHeaderTimeout time.Duration, cbCfg circuitBreakerConfig, telemetryStore ...*telemetry.Store) *ClientProxy { + return newClientProxyWithGlobalProxy(clientType, mode, pinnedProvider, providers, reactivateAfter, upstreamIdle, responseHeaderTimeout, cbCfg, config.ProviderProxyModeInherit, "", telemetryStore...) +} + +func newClientProxyWithGlobalProxy(clientType ClientType, mode config.ClientMode, pinnedProvider string, providers []config.Provider, reactivateAfter time.Duration, upstreamIdle time.Duration, responseHeaderTimeout time.Duration, cbCfg circuitBreakerConfig, globalProxyMode config.ProviderProxyMode, globalProxyURL string, telemetryStore ...*telemetry.Store) *ClientProxy { var store *telemetry.Store if len(telemetryStore) > 0 { store = telemetryStore[0] @@ -221,6 +240,7 @@ func newClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvide Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, } + sharedClient := newUpstreamHTTPClient(dialer, responseHeaderTimeout, http.ProxyFromEnvironment) pinnedIndex := -1 pinnedProvider = strings.TrimSpace(pinnedProvider) if mode == config.ClientModeManual && pinnedProvider != "" { @@ -238,10 +258,15 @@ func newClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvide countTokensKeyIndex := make([]int, len(providers)) responsesKeyIndex := make([]int, len(providers)) geminiStreamKeyIndex := make([]int, len(providers)) + providerHTTPClients := make([]*http.Client, len(providers)) + providerProxyModes := make([]config.ProviderProxyMode, len(providers)) + providerProxyURLs := make([]string, len(providers)) keyDeactivated := make([][]providerDeactivation, len(providers)) for i := range providers { breakers[i] = newCircuitBreaker(cbCfg) providerKeys[i] = providers[i].NormalizedAPIKeys() + providerProxyModes[i], providerProxyURLs[i] = effectiveProviderProxySettings(providers[i], globalProxyMode, globalProxyURL) + providerHTTPClients[i] = newProviderHTTPClient(providerProxyModes[i], providerProxyURLs[i], providers[i].Name, sharedClient, dialer, responseHeaderTimeout) if len(providerKeys[i]) == 0 { providerKeys[i] = []string{""} } @@ -283,6 +308,9 @@ func newClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvide responsesKeyIndex: responsesKeyIndex, geminiStreamKeyIndex: geminiStreamKeyIndex, telemetry: store, + providerHTTPClients: providerHTTPClients, + providerProxyModes: providerProxyModes, + providerProxyURLs: providerProxyURLs, deactivated: make([]providerDeactivation, len(providers)), keyDeactivated: keyDeactivated, providerBusy: make([]providerBusyState, len(providers)), @@ -293,24 +321,72 @@ func newClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvide dynamicFeatureBindings: make(map[string]stickyLookupEntry), routing: defaultRoutingRuntimeSettings(), breakers: breakers, - httpClient: &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: dialer.DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: responseHeaderTimeout, - ExpectContinueTimeout: 1 * time.Second, - // Keep response bytes unchanged unless the client explicitly asks for compression. - DisableCompression: true, - }, + httpClient: sharedClient, + } +} + +func newUpstreamHTTPClient(dialer *net.Dialer, responseHeaderTimeout time.Duration, proxy func(*http.Request) (*url.URL, error)) *http.Client { + return &http.Client{ + Transport: &http.Transport{ + Proxy: proxy, + DialContext: dialer.DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: responseHeaderTimeout, + ExpectContinueTimeout: 1 * time.Second, + // Keep response bytes unchanged unless the client explicitly asks for compression. + DisableCompression: true, }, } } +func effectiveProviderProxySettings(provider config.Provider, globalMode config.ProviderProxyMode, globalURL string) (config.ProviderProxyMode, string) { + switch provider.NormalizedProxyMode() { + case config.ProviderProxyModeDirect: + return config.ProviderProxyModeDirect, "" + case config.ProviderProxyModeCustom: + return config.ProviderProxyModeCustom, provider.NormalizedProxyURL() + default: + switch globalMode { + case config.ProviderProxyModeDirect: + return config.ProviderProxyModeDirect, "" + case config.ProviderProxyModeCustom: + return config.ProviderProxyModeCustom, strings.TrimSpace(globalURL) + default: + return config.ProviderProxyModeInherit, "" + } + } +} + +func newProviderHTTPClient(mode config.ProviderProxyMode, proxyURLRaw string, providerName string, sharedClient *http.Client, dialer *net.Dialer, responseHeaderTimeout time.Duration) *http.Client { + switch mode { + case config.ProviderProxyModeDirect: + return newUpstreamHTTPClient(dialer, responseHeaderTimeout, nil) + case config.ProviderProxyModeCustom: + proxyURL, err := config.ParseProxyURL(proxyURLRaw) + if err != nil { + logger.Warn("invalid custom proxy for provider %s; falling back to inherited proxy settings", providerName) + return sharedClient + } + return newUpstreamHTTPClient(dialer, responseHeaderTimeout, http.ProxyURL(proxyURL)) + default: + return sharedClient + } +} + +func (cp *ClientProxy) upstreamHTTPClient(providerIndex int) *http.Client { + if cp == nil { + return nil + } + if providerIndex >= 0 && providerIndex < len(cp.providerHTTPClients) && cp.providerHTTPClients[providerIndex] != nil { + return cp.providerHTTPClients[providerIndex] + } + return cp.httpClient +} + func (cp *ClientProxy) applyRoutingRuntimeSettings(settings routingRuntimeSettings) { if cp == nil { return @@ -616,13 +692,13 @@ func (r *Router) reloadProviderConfigsLocked() error { newProxies := make(map[ClientType]*ClientProxy) if ps := config.GetEnabledProviders(newCfg.Claude); len(ps) > 0 { - newProxies[ClientClaude] = newReloadedClientProxy(ClientClaude, newCfg.Claude.Mode, newCfg.Claude.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), oldProxies[ClientClaude], r.telemetry) + newProxies[ClientClaude] = newReloadedClientProxy(ClientClaude, newCfg.Claude.Mode, newCfg.Claude.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), newCfg.Global.NormalizedUpstreamProxyMode(), newCfg.Global.NormalizedUpstreamProxyURL(), oldProxies[ClientClaude], r.telemetry) } if ps := config.GetEnabledProviders(newCfg.OpenAI); len(ps) > 0 { - newProxies[ClientOpenAI] = newReloadedClientProxy(ClientOpenAI, newCfg.OpenAI.Mode, newCfg.OpenAI.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), oldProxies[ClientOpenAI], r.telemetry) + newProxies[ClientOpenAI] = newReloadedClientProxy(ClientOpenAI, newCfg.OpenAI.Mode, newCfg.OpenAI.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), newCfg.Global.NormalizedUpstreamProxyMode(), newCfg.Global.NormalizedUpstreamProxyURL(), oldProxies[ClientOpenAI], r.telemetry) } if ps := config.GetEnabledProviders(newCfg.Gemini); len(ps) > 0 { - newProxies[ClientGemini] = newReloadedClientProxy(ClientGemini, newCfg.Gemini.Mode, newCfg.Gemini.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), oldProxies[ClientGemini], r.telemetry) + newProxies[ClientGemini] = newReloadedClientProxy(ClientGemini, newCfg.Gemini.Mode, newCfg.Gemini.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), newCfg.Global.NormalizedUpstreamProxyMode(), newCfg.Global.NormalizedUpstreamProxyURL(), oldProxies[ClientGemini], r.telemetry) } r.reconcileTelemetryUsage(oldCfg, newCfg) @@ -644,8 +720,8 @@ func (r *Router) reloadProviderConfigsLocked() error { return nil } -func newReloadedClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvider string, providers []config.Provider, reactivateAfter time.Duration, upstreamIdle time.Duration, responseHeaderTimeout time.Duration, cbCfg circuitBreakerConfig, routing routingRuntimeSettings, old *ClientProxy, telemetryStore *telemetry.Store) *ClientProxy { - cp := newClientProxy(clientType, mode, pinnedProvider, providers, reactivateAfter, upstreamIdle, responseHeaderTimeout, cbCfg, telemetryStore) +func newReloadedClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvider string, providers []config.Provider, reactivateAfter time.Duration, upstreamIdle time.Duration, responseHeaderTimeout time.Duration, cbCfg circuitBreakerConfig, routing routingRuntimeSettings, globalProxyMode config.ProviderProxyMode, globalProxyURL string, old *ClientProxy, telemetryStore *telemetry.Store) *ClientProxy { + cp := newClientProxyWithGlobalProxy(clientType, mode, pinnedProvider, providers, reactivateAfter, upstreamIdle, responseHeaderTimeout, cbCfg, globalProxyMode, globalProxyURL, telemetryStore) cp.applyRoutingRuntimeSettings(routing) if old != nil { cp.inheritRuntimeState(old) @@ -676,7 +752,7 @@ func (cp *ClientProxy) inheritRuntimeState(old *ClientProxy) { if !ok { continue } - if !sameProviderRuntimeIdentity(cp.providers[newIdx], old.providers[oldIdx]) { + if !sameProviderRuntimeIdentity(cp.providers[newIdx], cp.providerProxyModes[newIdx], cp.providerProxyURLs[newIdx], old.providers[oldIdx], old.providerProxyModes[oldIdx], old.providerProxyURLs[oldIdx]) { continue } cp.deactivated[newIdx] = old.deactivated[oldIdx] @@ -691,7 +767,7 @@ func (cp *ClientProxy) inheritRuntimeState(old *ClientProxy) { if !ok { continue } - if !sameProviderRuntimeIdentity(cp.providers[newIdx], old.providers[oldIdx]) { + if !sameProviderRuntimeIdentity(cp.providers[newIdx], cp.providerProxyModes[newIdx], cp.providerProxyURLs[newIdx], old.providers[oldIdx], old.providerProxyModes[oldIdx], old.providerProxyURLs[oldIdx]) { continue } newByOldIndex[oldIdx] = newIdx @@ -830,8 +906,11 @@ func inheritStickyRuntimeState(dst *ClientProxy, src *ClientProxy, indexMap map[ } } -func sameProviderRuntimeIdentity(a, b config.Provider) bool { - return a.Name == b.Name && strings.TrimSpace(a.BaseURL) == strings.TrimSpace(b.BaseURL) +func sameProviderRuntimeIdentity(a config.Provider, aMode config.ProviderProxyMode, aURL string, b config.Provider, bMode config.ProviderProxyMode, bURL string) bool { + return a.Name == b.Name && + strings.TrimSpace(a.BaseURL) == strings.TrimSpace(b.BaseURL) && + aMode == bMode && + strings.TrimSpace(aURL) == strings.TrimSpace(bURL) } func providerIndexByName(providers []config.Provider, name string) int { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index d6042a0..04505d6 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -150,6 +150,128 @@ func TestBuildTargetURL(t *testing.T) { } } +func TestNewClientProxy_UsesProviderSpecificProxyModes(t *testing.T) { + t.Setenv("HTTP_PROXY", "http://env-proxy:8080") + + cp := newClientProxy(ClientOpenAI, config.ClientModeAuto, "", []config.Provider{ + {Name: "inherit", BaseURL: "http://inherit.example", APIKey: "k1", Priority: 1}, + {Name: "direct", BaseURL: "http://direct.example", APIKey: "k2", ProxyMode: config.ProviderProxyModeDirect, Priority: 2}, + {Name: "custom", BaseURL: "http://custom.example", APIKey: "k3", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "http://custom-proxy:9090", Priority: 3}, + }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}) + + req, err := http.NewRequest(http.MethodGet, "http://upstream.example/v1/test", nil) + if err != nil { + t.Fatalf("http.NewRequest: %v", err) + } + + inheritTransport, ok := cp.upstreamHTTPClient(0).Transport.(*http.Transport) + if !ok { + t.Fatalf("inherit transport type = %T", cp.upstreamHTTPClient(0).Transport) + } + inheritProxy, err := inheritTransport.Proxy(req) + if err != nil { + t.Fatalf("inherit proxy: %v", err) + } + if inheritProxy == nil || inheritProxy.String() != "http://env-proxy:8080" { + t.Fatalf("inherit proxy = %v, want http://env-proxy:8080", inheritProxy) + } + + directTransport, ok := cp.upstreamHTTPClient(1).Transport.(*http.Transport) + if !ok { + t.Fatalf("direct transport type = %T", cp.upstreamHTTPClient(1).Transport) + } + if directTransport.Proxy != nil { + directProxy, err := directTransport.Proxy(req) + if err != nil { + t.Fatalf("direct proxy: %v", err) + } + if directProxy != nil { + t.Fatalf("direct proxy = %v, want nil", directProxy) + } + } + + customTransport, ok := cp.upstreamHTTPClient(2).Transport.(*http.Transport) + if !ok { + t.Fatalf("custom transport type = %T", cp.upstreamHTTPClient(2).Transport) + } + customProxy, err := customTransport.Proxy(req) + if err != nil { + t.Fatalf("custom proxy: %v", err) + } + if customProxy == nil || customProxy.String() != "http://custom-proxy:9090" { + t.Fatalf("custom proxy = %v, want http://custom-proxy:9090", customProxy) + } +} + +func TestNewClientProxyWithGlobalProxy_UsesGlobalDefaultForInheritedProviders(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodGet, "http://upstream.example/v1/test", nil) + if err != nil { + t.Fatalf("http.NewRequest: %v", err) + } + + directCP := newClientProxyWithGlobalProxy(ClientOpenAI, config.ClientModeAuto, "", []config.Provider{ + {Name: "inherit", BaseURL: "http://inherit.example", APIKey: "k1", Priority: 1}, + }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}, config.ProviderProxyModeDirect, "") + + directTransport, ok := directCP.upstreamHTTPClient(0).Transport.(*http.Transport) + if !ok { + t.Fatalf("direct transport type = %T", directCP.upstreamHTTPClient(0).Transport) + } + if directTransport.Proxy != nil { + directProxy, err := directTransport.Proxy(req) + if err != nil { + t.Fatalf("direct proxy: %v", err) + } + if directProxy != nil { + t.Fatalf("direct proxy = %v, want nil", directProxy) + } + } + + customCP := newClientProxyWithGlobalProxy(ClientOpenAI, config.ClientModeAuto, "", []config.Provider{ + {Name: "inherit", BaseURL: "http://inherit.example", APIKey: "k1", Priority: 1}, + }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}, config.ProviderProxyModeCustom, "http://global-proxy:8081") + + customTransport, ok := customCP.upstreamHTTPClient(0).Transport.(*http.Transport) + if !ok { + t.Fatalf("custom transport type = %T", customCP.upstreamHTTPClient(0).Transport) + } + customProxy, err := customTransport.Proxy(req) + if err != nil { + t.Fatalf("custom proxy: %v", err) + } + if customProxy == nil || customProxy.String() != "http://global-proxy:8081" { + t.Fatalf("custom proxy = %v, want http://global-proxy:8081", customProxy) + } +} + +func TestNewClientProxy_UsesCustomSocksProxy(t *testing.T) { + t.Parallel() + + cp := newClientProxy(ClientOpenAI, config.ClientModeAuto, "", []config.Provider{ + {Name: "custom-socks", BaseURL: "http://custom.example", APIKey: "k1", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "socks5://custom-proxy:1080", Priority: 1}, + }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}) + + req, err := http.NewRequest(http.MethodGet, "https://upstream.example/v1/test", nil) + if err != nil { + t.Fatalf("http.NewRequest: %v", err) + } + + transport, ok := cp.upstreamHTTPClient(0).Transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T", cp.upstreamHTTPClient(0).Transport) + } + + customProxy, err := transport.Proxy(req) + if err != nil { + t.Fatalf("custom proxy: %v", err) + } + if customProxy == nil || customProxy.String() != "socks5://custom-proxy:1080" { + t.Fatalf("custom proxy = %v, want socks5://custom-proxy:1080", customProxy) + } +} + func TestAddForwardedHeaders(t *testing.T) { t.Parallel() diff --git a/internal/proxy/reload_state_test.go b/internal/proxy/reload_state_test.go index b9773ab..15c5f79 100644 --- a/internal/proxy/reload_state_test.go +++ b/internal/proxy/reload_state_test.go @@ -475,6 +475,109 @@ func TestReloadProviderConfigsLocked_ReconcilesTelemetryFromYAMLChanges(t *testi } } +func TestReloadProviderConfigsLocked_DoesNotPreserveSuppressionStateWhenProxyChanges(t *testing.T) { + router, dir := newReloadTestRouter(t) + oldProxy := router.proxies[ClientOpenAI] + now := time.Now() + + oldProxy.deactivated[0] = providerDeactivation{ + at: now.Add(-time.Second), + until: now.Add(30 * time.Second), + reason: "rate_limit", + status: http.StatusTooManyRequests, + message: "slow down", + } + oldProxy.keyDeactivated[0][0] = providerDeactivation{ + at: now.Add(-time.Second), + until: now.Add(20 * time.Second), + reason: "rate_limit", + status: http.StatusTooManyRequests, + message: "key cooldown", + } + oldProxy.breakers[0].state = circuitOpen + oldProxy.breakers[0].openedAt = now.Add(-5 * time.Second) + + global := config.DefaultGlobalConfig() + global.ListenAddr = "127.0.0.1" + global.Port = 3333 + writeProxyReloadFixture(t, dir, global, config.ClientConfig{ + Mode: config.ClientModeAuto, + Providers: []config.Provider{ + { + Name: "p1", + BaseURL: "https://p1.example", + APIKey: "k1", + ProxyMode: config.ProviderProxyModeDirect, + Priority: 1, + }, + }, + }) + + if err := router.reloadProviderConfigsLocked(); err != nil { + t.Fatalf("reloadProviderConfigsLocked: %v", err) + } + + newProxy := router.proxies[ClientOpenAI] + if !newProxy.deactivated[0].until.IsZero() || newProxy.deactivated[0].reason != "" { + t.Fatalf("provider cooldown should not carry across proxy change: %#v", newProxy.deactivated[0]) + } + if !newProxy.keyDeactivated[0][0].until.IsZero() || newProxy.keyDeactivated[0][0].reason != "" { + t.Fatalf("key cooldown should not carry across proxy change: %#v", newProxy.keyDeactivated[0][0]) + } + if newProxy.breakers[0].state != circuitClosed { + t.Fatalf("breaker state = %s, want closed", newProxy.breakers[0].state) + } +} + +func TestReloadProviderConfigsLocked_DoesNotPreserveSuppressionStateWhenGlobalProxyChangesForInheritedProvider(t *testing.T) { + router, dir := newReloadTestRouter(t) + oldProxy := router.proxies[ClientOpenAI] + now := time.Now() + + oldProxy.deactivated[0] = providerDeactivation{ + at: now.Add(-time.Second), + until: now.Add(30 * time.Second), + reason: "rate_limit", + status: http.StatusTooManyRequests, + message: "slow down", + } + oldProxy.keyDeactivated[0][0] = providerDeactivation{ + at: now.Add(-time.Second), + until: now.Add(20 * time.Second), + reason: "rate_limit", + status: http.StatusTooManyRequests, + message: "key cooldown", + } + oldProxy.breakers[0].state = circuitOpen + oldProxy.breakers[0].openedAt = now.Add(-5 * time.Second) + + global := config.DefaultGlobalConfig() + global.ListenAddr = "127.0.0.1" + global.Port = 3333 + global.UpstreamProxyMode = config.ProviderProxyModeDirect + writeProxyReloadFixture(t, dir, global, config.ClientConfig{ + Mode: config.ClientModeAuto, + Providers: []config.Provider{ + {Name: "p1", BaseURL: "https://p1.example", APIKey: "k1", Priority: 1}, + }, + }) + + if err := router.reloadProviderConfigsLocked(); err != nil { + t.Fatalf("reloadProviderConfigsLocked: %v", err) + } + + newProxy := router.proxies[ClientOpenAI] + if !newProxy.deactivated[0].until.IsZero() || newProxy.deactivated[0].reason != "" { + t.Fatalf("provider cooldown should not carry across global proxy change: %#v", newProxy.deactivated[0]) + } + if !newProxy.keyDeactivated[0][0].until.IsZero() || newProxy.keyDeactivated[0][0].reason != "" { + t.Fatalf("key cooldown should not carry across global proxy change: %#v", newProxy.keyDeactivated[0][0]) + } + if newProxy.breakers[0].state != circuitClosed { + t.Fatalf("breaker state = %s, want closed", newProxy.breakers[0].state) + } +} + func TestTimeUntilNextAvailable_PicksEarliestBlockedSource(t *testing.T) { cbCfg := circuitBreakerConfig{ enabled: true, diff --git a/internal/web/api.go b/internal/web/api.go index 6f00c3d..fd66c6e 100644 --- a/internal/web/api.go +++ b/internal/web/api.go @@ -111,6 +111,14 @@ func (a *API) HandleUpdateGlobalConfig(w http.ResponseWriter, r *http.Request) { if strings.TrimSpace(req.ResponseHeaderTimeout) != "" { cfg.Global.ResponseHeaderTimeout = req.ResponseHeaderTimeout } + if strings.TrimSpace(req.UpstreamProxyMode) != "" { + cfg.Global.UpstreamProxyMode = config.ProviderProxyMode(strings.ToLower(strings.TrimSpace(req.UpstreamProxyMode))) + if cfg.Global.UpstreamProxyMode == config.ProviderProxyModeCustom { + cfg.Global.UpstreamProxyURL = strings.TrimSpace(req.UpstreamProxyURL) + } else { + cfg.Global.UpstreamProxyURL = "" + } + } cfg.Global.MaxRequestBody = req.MaxRequestBodyBytes cfg.Global.LogDir = req.LogDir cfg.Global.LogRetentionDays = req.LogRetentionDays @@ -266,6 +274,7 @@ func (a *API) HandleAddProvider(w http.ResponseWriter, r *http.Request) { return } req.Overrides = normalizeProviderOverrideRequest(req.Overrides) + normalizeProviderProxyRequest(&req) if err := validateProviderOverrideRequest(clientType, req.Overrides); err != nil { writeError(w, err.Error(), http.StatusBadRequest) return @@ -332,6 +341,10 @@ func (a *API) HandleAddProvider(w http.ResponseWriter, r *http.Request) { Enabled: req.Enabled, } applyProviderOverrides(&provider, req) + if err := applyProviderProxySettings(&provider, req, false); err != nil { + writeError(w, err.Error(), http.StatusBadRequest) + return + } assignProviderKeys(&provider, keys) cc.Providers = append(cc.Providers, provider) @@ -362,6 +375,7 @@ func (a *API) HandleUpdateProvider(w http.ResponseWriter, r *http.Request) { return } req.Overrides = normalizeProviderOverrideRequest(req.Overrides) + normalizeProviderProxyRequest(&req) if err := validateProviderOverrideRequest(clientType, req.Overrides); err != nil { writeError(w, err.Error(), http.StatusBadRequest) return @@ -416,7 +430,11 @@ func (a *API) HandleUpdateProvider(w http.ResponseWriter, r *http.Request) { } } - updated := updateProviderInList(cc.Providers, providerName, req, keys) + updated, err := updateProviderInList(cc.Providers, providerName, req, keys) + if err != nil { + writeError(w, err.Error(), http.StatusBadRequest) + return + } if updated { if !a.saveClientConfigOrWriteError(w, clientType, cfg) { return @@ -963,6 +981,14 @@ func normalizeProviderKeys(req ProviderRequest) ([]string, error) { return keys, nil } +func lowerTrimStringPtr(v *string) *string { + if v == nil { + return nil + } + trimmed := strings.ToLower(strings.TrimSpace(*v)) + return &trimmed +} + func trimStringPtr(v *string) *string { if v == nil { return nil @@ -998,6 +1024,14 @@ func providerOverrideSupportForClient(clientType string) providerOverrideSupport return support } +func normalizeProviderProxyRequest(req *ProviderRequest) { + if req == nil { + return + } + req.ProxyMode = lowerTrimStringPtr(req.ProxyMode) + req.ProxyURL = trimStringPtr(req.ProxyURL) +} + func normalizeProviderOverrideRequest(overrides *ProviderOverridesRequest) *ProviderOverridesRequest { if overrides == nil { return nil @@ -1047,6 +1081,68 @@ func validateProviderOverrideRequest(clientType string, overrides *ProviderOverr return nil } +func validateProviderProxyModeValue(mode string) error { + switch config.ProviderProxyMode(strings.TrimSpace(mode)) { + case config.ProviderProxyModeInherit, config.ProviderProxyModeDirect, config.ProviderProxyModeCustom: + return nil + default: + return fmt.Errorf("proxy_mode must be one of inherit, direct, custom") + } +} + +func validateProviderProxyURLValue(raw string) error { + return config.ValidateProxyURL(raw) +} + +func applyProviderProxySettings(provider *config.Provider, req ProviderRequest, isUpdate bool) error { + if provider == nil { + return nil + } + if !isUpdate && req.ProxyMode == nil && req.ProxyURL == nil { + provider.ProxyMode = config.ProviderProxyModeInherit + provider.ProxyURL = "" + return nil + } + if req.ProxyMode == nil { + if req.ProxyURL != nil { + return fmt.Errorf("proxy_url requires proxy_mode") + } + return nil + } + + modeValue := strings.TrimSpace(*req.ProxyMode) + if err := validateProviderProxyModeValue(modeValue); err != nil { + return err + } + + mode := config.ProviderProxyMode(modeValue) + proxyURL := provider.NormalizedProxyURL() + if req.ProxyURL != nil { + proxyURL = strings.TrimSpace(*req.ProxyURL) + } + switch mode { + case config.ProviderProxyModeInherit, config.ProviderProxyModeDirect: + if proxyURL != "" && req.ProxyURL != nil { + return fmt.Errorf("proxy_url requires proxy_mode custom") + } + provider.ProxyMode = mode + provider.ProxyURL = "" + return nil + case config.ProviderProxyModeCustom: + if proxyURL == "" { + return fmt.Errorf("proxy_url is required when proxy_mode=custom") + } + if err := validateProviderProxyURLValue(proxyURL); err != nil { + return err + } + provider.ProxyMode = mode + provider.ProxyURL = proxyURL + return nil + default: + return fmt.Errorf("proxy_mode must be one of inherit, direct, custom") + } +} + func applyProviderOverrides(provider *config.Provider, req ProviderRequest) { if provider == nil || req.Overrides == nil { return @@ -1107,7 +1203,7 @@ func extractClientAndProvider(path string) (string, string) { return "", "" } -func updateProviderInList(providers []config.Provider, name string, req ProviderRequest, keys []string) bool { +func updateProviderInList(providers []config.Provider, name string, req ProviderRequest, keys []string) (bool, error) { for i := range providers { if providers[i].Name == name { if req.Name != "" { @@ -1126,10 +1222,13 @@ func updateProviderInList(providers []config.Provider, name string, req Provider providers[i].Enabled = req.Enabled } applyProviderOverrides(&providers[i], req) - return true + if err := applyProviderProxySettings(&providers[i], req, true); err != nil { + return false, err + } + return true, nil } } - return false + return false, nil } func deleteProviderFromList(providers []config.Provider, name string) ([]config.Provider, bool) { diff --git a/internal/web/api_test.go b/internal/web/api_test.go index 9448b0b..3dcfdbb 100644 --- a/internal/web/api_test.go +++ b/internal/web/api_test.go @@ -77,6 +77,8 @@ providers: - name: p1 base_url: https://example.com api_key: secret + proxy_mode: custom + proxy_url: http://proxy.internal:7890 model: gpt-5.4 reasoning_effort: high priority: 2 @@ -117,6 +119,9 @@ providers: if _, ok := got[0]["base_url"]; !ok { t.Fatalf("expected base_url in provider listing, got keys=%v", keys(got[0])) } + if _, ok := got[0]["proxy_url"]; ok { + t.Fatalf("did not expect proxy_url in provider listing") + } if got[0]["key_count"] != float64(1) { t.Fatalf("expected key_count=1, got %v", got[0]["key_count"]) } @@ -130,6 +135,12 @@ providers: if first == nil { t.Fatalf("expected provider p1 in listing, got %#v", got) } + if first["proxy_mode"] != "custom" { + t.Fatalf("expected proxy_mode=custom, got %v", first["proxy_mode"]) + } + if first["proxy_url_hint"] != "http://proxy.internal:7890" { + t.Fatalf("expected proxy_url_hint redaction, got %v", first["proxy_url_hint"]) + } overrides, ok := first["overrides"].(map[string]any) if !ok { t.Fatalf("expected overrides object in listing, got %T", first["overrides"]) @@ -324,15 +335,20 @@ providers: } func TestHandleUpdateGlobalConfig_AcceptsSnakeCaseNotifications(t *testing.T) { - dir := t.TempDir() - api := NewAPI(dir, "test", nil) + for _, proxyURL := range []string{"http://127.0.0.1:7890", "socks5://127.0.0.1:1080"} { + t.Run(proxyURL, func(t *testing.T) { + dir := t.TempDir() + api := NewAPI(dir, "test", nil) - body := []byte(`{ + body := []byte(`{ "listen_addr": "127.0.0.1", "port": 3333, "log_level": "info", "reactivate_after": "10m", "upstream_idle_timeout": "1m", + "response_header_timeout": "30s", + "upstream_proxy_mode": "custom", + "upstream_proxy_url": "` + proxyURL + `", "max_request_body_bytes": 12345, "log_dir": "", "log_retention_days": 7, @@ -361,41 +377,49 @@ func TestHandleUpdateGlobalConfig_AcceptsSnakeCaseNotifications(t *testing.T) { } }`) - req := httptest.NewRequest(http.MethodPut, "/api/config/global/update", bytes.NewReader(body)) - w := httptest.NewRecorder() - api.HandleUpdateGlobalConfig(w, req) - res := w.Result() - if res.StatusCode != http.StatusOK { - t.Fatalf("status=%d body=%s", res.StatusCode, w.Body.String()) - } + req := httptest.NewRequest(http.MethodPut, "/api/config/global/update", bytes.NewReader(body)) + w := httptest.NewRecorder() + api.HandleUpdateGlobalConfig(w, req) + res := w.Result() + if res.StatusCode != http.StatusOK { + t.Fatalf("status=%d body=%s", res.StatusCode, w.Body.String()) + } - cfg, err := config.Load(dir) - if err != nil { - t.Fatalf("reload config: %v", err) - } - if !cfg.Global.Notifications.Enabled { - t.Fatalf("expected notifications.enabled=true") - } - if cfg.Global.Notifications.MinLevel != config.LogLevelWarn { - t.Fatalf("expected notifications.min_level=warn, got %q", cfg.Global.Notifications.MinLevel) - } - if cfg.Global.Notifications.ProviderSwitch == nil || *cfg.Global.Notifications.ProviderSwitch { - t.Fatalf("expected notifications.provider_switch=false, got %v", cfg.Global.Notifications.ProviderSwitch) - } - if !cfg.Global.Routing.StickySessions.Enabled { - t.Fatalf("expected routing.sticky_sessions.enabled=true") - } - if cfg.Global.Routing.StickySessions.ExplicitTTL != "45m" { - t.Fatalf("expected routing.sticky_sessions.explicit_ttl=45m, got %q", cfg.Global.Routing.StickySessions.ExplicitTTL) - } - if !cfg.Global.Routing.BusyBackpressure.Enabled { - t.Fatalf("expected routing.busy_backpressure.enabled=true") - } - if cfg.Global.Routing.BusyBackpressure.ShortRetryAfterMax != "5s" { - t.Fatalf("expected routing.busy_backpressure.short_retry_after_max=5s, got %q", cfg.Global.Routing.BusyBackpressure.ShortRetryAfterMax) - } - if cfg.Global.Routing.BusyBackpressure.MaxInlineWait != "12s" { - t.Fatalf("expected routing.busy_backpressure.max_inline_wait=12s, got %q", cfg.Global.Routing.BusyBackpressure.MaxInlineWait) + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("reload config: %v", err) + } + if !cfg.Global.Notifications.Enabled { + t.Fatalf("expected notifications.enabled=true") + } + if cfg.Global.Notifications.MinLevel != config.LogLevelWarn { + t.Fatalf("expected notifications.min_level=warn, got %q", cfg.Global.Notifications.MinLevel) + } + if cfg.Global.Notifications.ProviderSwitch == nil || *cfg.Global.Notifications.ProviderSwitch { + t.Fatalf("expected notifications.provider_switch=false, got %v", cfg.Global.Notifications.ProviderSwitch) + } + if !cfg.Global.Routing.StickySessions.Enabled { + t.Fatalf("expected routing.sticky_sessions.enabled=true") + } + if cfg.Global.Routing.StickySessions.ExplicitTTL != "45m" { + t.Fatalf("expected routing.sticky_sessions.explicit_ttl=45m, got %q", cfg.Global.Routing.StickySessions.ExplicitTTL) + } + if !cfg.Global.Routing.BusyBackpressure.Enabled { + t.Fatalf("expected routing.busy_backpressure.enabled=true") + } + if cfg.Global.Routing.BusyBackpressure.ShortRetryAfterMax != "5s" { + t.Fatalf("expected routing.busy_backpressure.short_retry_after_max=5s, got %q", cfg.Global.Routing.BusyBackpressure.ShortRetryAfterMax) + } + if cfg.Global.Routing.BusyBackpressure.MaxInlineWait != "12s" { + t.Fatalf("expected routing.busy_backpressure.max_inline_wait=12s, got %q", cfg.Global.Routing.BusyBackpressure.MaxInlineWait) + } + if cfg.Global.NormalizedUpstreamProxyMode() != config.ProviderProxyModeCustom { + t.Fatalf("expected upstream_proxy_mode=custom, got %q", cfg.Global.NormalizedUpstreamProxyMode()) + } + if cfg.Global.NormalizedUpstreamProxyURL() != proxyURL { + t.Fatalf("expected upstream_proxy_url to be saved, got %q", cfg.Global.NormalizedUpstreamProxyURL()) + } + }) } } @@ -418,6 +442,8 @@ func TestHandleUpdateGlobalConfig_AllowsClearingRoutingStrings(t *testing.T) { "reactivate_after": "10m", "upstream_idle_timeout": "1m", "response_header_timeout": "30s", + "upstream_proxy_mode": "direct", + "upstream_proxy_url": "", "max_request_body_bytes": 12345, "log_dir": "", "log_retention_days": 7, @@ -467,6 +493,12 @@ func TestHandleUpdateGlobalConfig_AllowsClearingRoutingStrings(t *testing.T) { if cfg.Global.Routing.BusyBackpressure.MaxInlineWait != "" { t.Fatalf("expected routing.busy_backpressure.max_inline_wait to be cleared, got %q", cfg.Global.Routing.BusyBackpressure.MaxInlineWait) } + if cfg.Global.NormalizedUpstreamProxyMode() != config.ProviderProxyModeDirect { + t.Fatalf("expected upstream_proxy_mode to be direct, got %q", cfg.Global.NormalizedUpstreamProxyMode()) + } + if cfg.Global.NormalizedUpstreamProxyURL() != "" { + t.Fatalf("expected upstream_proxy_url to be cleared, got %q", cfg.Global.NormalizedUpstreamProxyURL()) + } } func TestHandleAddProvider_AcceptsAPIKeys(t *testing.T) { @@ -504,6 +536,9 @@ func TestHandleAddProvider_AcceptsAPIKeys(t *testing.T) { if cfg.OpenAI.Providers[0].APIKey != "" { t.Fatalf("expected multi-key provider to be persisted via api_keys") } + if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != config.ProviderProxyModeInherit { + t.Fatalf("proxy_mode = %q, want %q", got, config.ProviderProxyModeInherit) + } if got := cfg.OpenAI.Providers[0].ModelOverride(); got != "gpt-5.4" { t.Fatalf("model = %q", got) } @@ -584,6 +619,197 @@ providers: } } +func TestHandleAddProvider_AcceptsCustomProxy(t *testing.T) { + for _, proxyURL := range []string{"http://127.0.0.1:7890", "socks5://127.0.0.1:1080"} { + t.Run(proxyURL, func(t *testing.T) { + dir := t.TempDir() + api := NewAPI(dir, "test", nil) + + body := []byte(`{ + "name": "p1", + "base_url": "https://example.com", + "api_key": "key1", + "proxy_mode": "custom", + "proxy_url": "` + proxyURL + `", + "priority": 1, + "enabled": true +}`) + + req := httptest.NewRequest(http.MethodPost, "/api/providers/codex", bytes.NewReader(body)) + w := httptest.NewRecorder() + api.HandleAddProvider(w, req) + if w.Result().StatusCode != http.StatusOK { + t.Fatalf("status=%d body=%s", w.Result().StatusCode, w.Body.String()) + } + + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("load config: %v", err) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != config.ProviderProxyModeCustom { + t.Fatalf("proxy_mode = %q", got) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyURL(); got != proxyURL { + t.Fatalf("proxy_url = %q", got) + } + }) + } +} + +func TestHandleAddProvider_RejectsProxyURLWithoutCustomMode(t *testing.T) { + for _, mode := range []string{"direct", "inherit"} { + t.Run(mode, func(t *testing.T) { + dir := t.TempDir() + api := NewAPI(dir, "test", nil) + + body := []byte(`{ + "name": "p1", + "base_url": "https://example.com", + "api_key": "key1", + "proxy_mode": "` + mode + `", + "proxy_url": "http://127.0.0.1:7890", + "priority": 1, + "enabled": true +}`) + + req := httptest.NewRequest(http.MethodPost, "/api/providers/codex", bytes.NewReader(body)) + w := httptest.NewRecorder() + api.HandleAddProvider(w, req) + if w.Result().StatusCode != http.StatusBadRequest { + t.Fatalf("status=%d body=%s", w.Result().StatusCode, w.Body.String()) + } + + got := testutil.DecodeJSONMap(t, w.Body.Bytes()) + if got["error"] != "proxy_url requires proxy_mode custom" { + t.Fatalf("error = %#v", got["error"]) + } + + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("load config: %v", err) + } + if len(cfg.OpenAI.Providers) != 0 { + t.Fatalf("providers len = %d, want 0", len(cfg.OpenAI.Providers)) + } + }) + } +} + +func TestHandleUpdateProvider_ProxySettings(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "openai.yaml"), []byte(` +providers: + - name: p1 + base_url: https://example.com + api_key: key1 + proxy_mode: custom + proxy_url: http://127.0.0.1:7890 + priority: 1 +`), 0o600); err != nil { + t.Fatal(err) + } + + api := NewAPI(dir, "test", nil) + + t.Run("retain existing custom proxy url", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/api/providers/codex/p1", bytes.NewReader([]byte(`{ + "proxy_mode": "custom" +}`))) + w := httptest.NewRecorder() + api.HandleUpdateProvider(w, req) + if w.Result().StatusCode != http.StatusOK { + t.Fatalf("status=%d body=%s", w.Result().StatusCode, w.Body.String()) + } + + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("load config: %v", err) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyURL(); got != "http://127.0.0.1:7890" { + t.Fatalf("proxy_url = %q", got) + } + }) + + t.Run("switch to direct clears proxy url", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/api/providers/codex/p1", bytes.NewReader([]byte(`{ + "proxy_mode": "direct" +}`))) + w := httptest.NewRecorder() + api.HandleUpdateProvider(w, req) + if w.Result().StatusCode != http.StatusOK { + t.Fatalf("status=%d body=%s", w.Result().StatusCode, w.Body.String()) + } + + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("load config: %v", err) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != config.ProviderProxyModeDirect { + t.Fatalf("proxy_mode = %q", got) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyURL(); got != "" { + t.Fatalf("proxy_url = %q, want empty", got) + } + }) + + t.Run("reject proxy url without mode", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/api/providers/codex/p1", bytes.NewReader([]byte(`{ + "proxy_url": "http://127.0.0.1:8899" +}`))) + w := httptest.NewRecorder() + api.HandleUpdateProvider(w, req) + if w.Result().StatusCode != http.StatusBadRequest { + t.Fatalf("status=%d body=%s", w.Result().StatusCode, w.Body.String()) + } + }) + + t.Run("reject proxy url without custom mode", func(t *testing.T) { + for _, mode := range []string{"direct", "inherit"} { + t.Run(mode, func(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "openai.yaml"), []byte(` +providers: + - name: p1 + base_url: https://example.com + api_key: key1 + proxy_mode: custom + proxy_url: http://127.0.0.1:7890 + priority: 1 +`), 0o600); err != nil { + t.Fatal(err) + } + + api := NewAPI(dir, "test", nil) + req := httptest.NewRequest(http.MethodPut, "/api/providers/codex/p1", bytes.NewReader([]byte(`{ + "proxy_mode": "`+mode+`", + "proxy_url": "http://127.0.0.1:8899" +}`))) + w := httptest.NewRecorder() + api.HandleUpdateProvider(w, req) + if w.Result().StatusCode != http.StatusBadRequest { + t.Fatalf("status=%d body=%s", w.Result().StatusCode, w.Body.String()) + } + + got := testutil.DecodeJSONMap(t, w.Body.Bytes()) + if got["error"] != "proxy_url requires proxy_mode custom" { + t.Fatalf("error = %#v", got["error"]) + } + + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("load config: %v", err) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != config.ProviderProxyModeCustom { + t.Fatalf("proxy_mode = %q, want %q", got, config.ProviderProxyModeCustom) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyURL(); got != "http://127.0.0.1:7890" { + t.Fatalf("proxy_url = %q", got) + } + }) + } + }) +} + func TestHandleGetClientConfig_ReturnsConfiguredModeAndPin(t *testing.T) { dir := t.TempDir() if err := os.WriteFile(filepath.Join(dir, "openai.yaml"), []byte(` diff --git a/internal/web/static/app.js b/internal/web/static/app.js index f179ed9..cb0a86c 100644 --- a/internal/web/static/app.js +++ b/internal/web/static/app.js @@ -66,6 +66,7 @@ function app() { empty: 'No providers configured for {client}', pinBadge: 'Pinned', baseUrl: 'Base URL', + proxy: 'Proxy', apiKeys: 'API Keys', usageTotal: 'Usage', usageInOut: 'Input / Output', @@ -104,7 +105,9 @@ function app() { deleteConfirm: 'Are you sure you want to delete provider "{name}"?', deletedTitle: 'Deleted provider {name}', deletedMessage: 'It has been removed from {client}\'s provider list.', - clientTypeLabel: 'Client Type' + clientTypeLabel: 'Client Type', + proxyDirect: 'Direct', + proxyCustom: 'Custom' }, modal: { provider: { @@ -114,6 +117,14 @@ function app() { name: 'Name *', nameHint: 'Letters, numbers, dot (.), underscore (_), and hyphen (-).', baseUrl: 'Base URL *', + proxyMode: 'Proxy Mode', + proxyModeInherit: 'Inherit Environment', + proxyModeDirect: 'Direct', + proxyModeCustom: 'Custom Proxy', + proxyUrl: 'Proxy URL', + proxyUrlHint: 'http://127.0.0.1:7890', + proxyUrlHelp: 'Supports http://, https://, socks5://, and socks5h:// proxy URLs.', + keepExistingProxy: 'Leave empty to keep the current proxy ({proxy}).', model: 'Model', modelHint: 'model-id', reasoningEffort: 'Reasoning Effort', @@ -154,6 +165,14 @@ function app() { upstreamIdleTimeoutHint: 'Set to 0 to disable stalled-stream protection.', responseHeaderTimeout: 'Response Header Timeout', responseHeaderTimeoutHint: 'Set to 0 to wait indefinitely for headers.', + upstreamProxyMode: 'Default Upstream Proxy', + upstreamProxyUrl: 'Default Proxy URL', + upstreamProxyHint: 'Used by providers whose proxy mode is Inherit.', + upstreamProxyUrlHelp: 'Supports http://, https://, socks5://, and socks5h:// proxy URLs.', + proxyModeInherit: 'Inherit Environment', + proxyModeDirect: 'Direct', + proxyModeCustom: 'Custom Proxy', + upstreamProxyUrlHint: 'http://127.0.0.1:7890', failureThreshold: 'Failure Threshold', failureThresholdHint: '0 disables the circuit breaker.', successThreshold: 'Success Threshold', @@ -356,6 +375,7 @@ function app() { empty: '{client} 还没有配置任何 Provider', pinBadge: '已固定', baseUrl: 'Base URL', + proxy: '代理', apiKeys: 'API Keys', usageTotal: '用量', usageInOut: '输入 / 输出', @@ -394,7 +414,9 @@ function app() { deletedTitle: '已删除 Provider {name}', deletedMessage: '它已从 {client} 的 Provider 列表中移除。', clientTypeLabel: '客户端类型', - dragToReorder: '拖拽调整优先级' + dragToReorder: '拖拽调整优先级', + proxyDirect: '直连', + proxyCustom: '自定义' }, modal: { provider: { @@ -404,6 +426,14 @@ function app() { name: '名称 *', nameHint: '允许字母、数字、点号 (.)、下划线 (_) 和连字符 (-)。', baseUrl: 'Base URL *', + proxyMode: '代理模式', + proxyModeInherit: '继承环境变量', + proxyModeDirect: '直连', + proxyModeCustom: '自定义代理', + proxyUrl: '代理 URL', + proxyUrlHint: 'http://127.0.0.1:7890', + proxyUrlHelp: '支持 http://、https://、socks5:// 和 socks5h:// 代理 URL。', + keepExistingProxy: '留空则保留当前代理({proxy})。', model: '模型', modelHint: 'model-id', reasoningEffort: '思考强度', @@ -443,6 +473,14 @@ function app() { upstreamIdleTimeoutHint: '设为 0 可关闭流式响应停滞保护。', responseHeaderTimeout: '响应头超时', responseHeaderTimeoutHint: '设为 0 表示无限等待响应头。', + upstreamProxyMode: '默认上游代理', + upstreamProxyUrl: '默认代理 URL', + upstreamProxyHint: '对代理模式为“继承”的 Provider 生效。', + upstreamProxyUrlHelp: '支持 http://、https://、socks5:// 和 socks5h:// 代理 URL。', + proxyModeInherit: '继承环境变量', + proxyModeDirect: '直连', + proxyModeCustom: '自定义代理', + upstreamProxyUrlHint: 'http://127.0.0.1:7890', failureThreshold: '失败阈值', failureThresholdHint: '设为 0 可关闭熔断器。', successThreshold: '成功阈值', @@ -612,6 +650,8 @@ function app() { reactivate_after: '', upstream_idle_timeout: '', response_header_timeout: '', + upstream_proxy_mode: 'inherit', + upstream_proxy_url: '', max_request_body_bytes: 0, log_dir: '', log_retention_days: 7, @@ -672,6 +712,9 @@ function app() { providerForm: { name: '', base_url: '', + proxy_mode: 'inherit', + proxy_url: '', + proxy_url_hint: '', model: '', reasoning_effort: '', thinking_budget_tokens: 0, @@ -1472,6 +1515,34 @@ function app() { return this.tf('modal.provider.keepExistingKeys', { count }); }, + normalizeProviderProxyMode(mode) { + const value = String(mode || '').trim().toLowerCase(); + return ['inherit', 'direct', 'custom'].includes(value) ? value : 'inherit'; + }, + + providerProxySummary(provider) { + const mode = this.normalizeProviderProxyMode(provider && provider.proxy_mode); + if (mode === 'direct') { + return this.t('providers.proxyDirect'); + } + if (mode === 'custom') { + return String((provider && provider.proxy_url_hint) || '').trim() || this.t('providers.proxyCustom'); + } + return ''; + }, + + providerFormUsesCustomProxy() { + return this.normalizeProviderProxyMode(this.providerForm.proxy_mode) === 'custom'; + }, + + providerEditProxyHint() { + const proxy = String(this.providerForm.proxy_url_hint || '').trim(); + if (!proxy) { + return ''; + } + return this.tf('modal.provider.keepExistingProxy', { proxy }); + }, + providerOverrideSupport() { const support = this.clientConfig && this.clientConfig.override_support; if (!support || typeof support !== 'object') { @@ -1877,9 +1948,16 @@ function app() { const payload = { name: this.providerForm.name, base_url: this.providerForm.base_url, + proxy_mode: this.normalizeProviderProxyMode(this.providerForm.proxy_mode), priority: this.providerForm.priority, enabled: this.providerForm.enabled }; + if (payload.proxy_mode === 'custom') { + const proxyURL = String(this.providerForm.proxy_url || '').trim(); + if (proxyURL) { + payload.proxy_url = proxyURL; + } + } const overrides = {}; if (this.providerSupportsModelOverride()) { overrides.model = String(this.providerForm.model || ''); @@ -1948,6 +2026,9 @@ function app() { this.providerForm = { name: provider.name, base_url: provider.base_url, + proxy_mode: this.normalizeProviderProxyMode(provider.proxy_mode), + proxy_url: '', + proxy_url_hint: String(provider.proxy_url_hint || ''), model: String((provider.overrides && provider.overrides.model) || ''), reasoning_effort: String((provider.overrides && provider.overrides.openai && provider.overrides.openai.reasoning_effort) || ''), thinking_budget_tokens: Number((provider.overrides && provider.overrides.claude && provider.overrides.claude.thinking_budget_tokens) || 0), @@ -1990,6 +2071,9 @@ function app() { this.providerForm = { name: '', base_url: '', + proxy_mode: 'inherit', + proxy_url: '', + proxy_url_hint: '', model: '', reasoning_effort: '', thinking_budget_tokens: 0, @@ -2145,6 +2229,9 @@ function app() { this.providerForm = { name: '', base_url: '', + proxy_mode: 'inherit', + proxy_url: '', + proxy_url_hint: '', model: '', reasoning_effort: '', thinking_budget_tokens: 0, diff --git a/internal/web/static/app.test.js b/internal/web/static/app.test.js index 711602e..379441c 100644 --- a/internal/web/static/app.test.js +++ b/internal/web/static/app.test.js @@ -141,6 +141,7 @@ test('saveProvider includes OpenAI override fields in payload', async () => { assert.deepEqual(calls[0].options, { name: 'openai-primary', base_url: 'https://example.com', + proxy_mode: 'inherit', priority: 1, enabled: true, overrides: { @@ -192,6 +193,7 @@ test('saveProvider includes Claude thinking budget override in payload', async ( assert.deepEqual(calls[0].options, { name: 'claude-primary', base_url: 'https://example.com', + proxy_mode: 'inherit', priority: 2, enabled: false, overrides: { @@ -287,6 +289,7 @@ test('openAddProviderModal sets next priority for provider form', () => { assert.equal(state.showAddProviderModal, true); assert.equal(state.providerForm.priority, 4); + assert.equal(state.providerForm.proxy_mode, 'inherit'); }); test('editProvider hydrates override fields directly into the form', () => { @@ -295,6 +298,8 @@ test('editProvider hydrates override fields directly into the form', () => { state.editProvider({ name: 'openai-primary', base_url: 'https://example.com', + proxy_mode: 'custom', + proxy_url_hint: 'http://127.0.0.1:7890', overrides: { model: 'gpt-5.4', openai: { @@ -307,11 +312,54 @@ test('editProvider hydrates override fields directly into the form', () => { }); assert.equal(state.showEditProviderModal, true); + assert.equal(state.providerForm.proxy_mode, 'custom'); + assert.equal(state.providerForm.proxy_url, ''); + assert.equal(state.providerForm.proxy_url_hint, 'http://127.0.0.1:7890'); assert.equal(state.providerForm.model, 'gpt-5.4'); assert.equal(state.providerForm.reasoning_effort, 'high'); assert.equal(state.providerForm.thinking_budget_tokens, 0); }); +test('saveProvider includes custom proxy settings when configured', async () => { + const state = loadApp(); + const calls = []; + state.selectedClient = 'openai'; + state.providerForm = { + name: 'openai-proxy', + base_url: 'https://example.com', + proxy_mode: 'custom', + proxy_url: 'http://127.0.0.1:7890', + proxy_url_hint: '', + model: '', + reasoning_effort: '', + thinking_budget_tokens: 0, + api_keys_text: 'key-1', + priority: 1, + enabled: true + }; + state.apiCall = async (url, options) => { + calls.push({ url, options: JSON.parse(options.body) }); + return {}; + }; + state.showAlert = () => {}; + state.closeModals = () => {}; + state.loadProviders = async () => {}; + state.refreshStatus = async () => {}; + + await state.saveProvider(); + + assert.equal(calls.length, 1); + assert.deepEqual(calls[0].options, { + name: 'openai-proxy', + base_url: 'https://example.com', + proxy_mode: 'custom', + proxy_url: 'http://127.0.0.1:7890', + priority: 1, + enabled: true, + api_key: 'key-1' + }); +}); + test('saveProvider omits unsupported override fields for gemini', async () => { const state = loadApp(); const calls = []; @@ -319,6 +367,9 @@ test('saveProvider omits unsupported override fields for gemini', async () => { state.providerForm = { name: 'gemini-primary', base_url: 'https://example.com', + proxy_mode: 'inherit', + proxy_url: '', + proxy_url_hint: '', model: 'gemini-2.5-pro', reasoning_effort: 'high', thinking_budget_tokens: 2048, @@ -342,6 +393,7 @@ test('saveProvider omits unsupported override fields for gemini', async () => { assert.deepEqual(calls[0].options, { name: 'gemini-primary', base_url: 'https://example.com', + proxy_mode: 'inherit', priority: 1, enabled: true, api_key: 'key-1' diff --git a/internal/web/static/index.html b/internal/web/static/index.html index 4e7d74d..f35263a 100644 --- a/internal/web/static/index.html +++ b/internal/web/static/index.html @@ -214,6 +214,11 @@

+
+ + +
required>
+
+ + +
+
+
+ + +
+
placeholder="https://api.example.com">
+ +
+ +
+ + +
+ +
+
+
+
diff --git a/internal/web/types.go b/internal/web/types.go index 7adb100..3c2401a 100644 --- a/internal/web/types.go +++ b/internal/web/types.go @@ -5,6 +5,8 @@ package web // and lets us redact sensitive fields like API keys. import ( + "net/url" + "strings" "time" "github.com/lansespirit/Clipal/internal/config" @@ -20,6 +22,8 @@ type GlobalConfigRequest struct { ReactivateAfter string `json:"reactivate_after"` UpstreamIdleTimeout string `json:"upstream_idle_timeout"` ResponseHeaderTimeout string `json:"response_header_timeout"` + UpstreamProxyMode string `json:"upstream_proxy_mode"` + UpstreamProxyURL string `json:"upstream_proxy_url"` MaxRequestBodyBytes int64 `json:"max_request_body_bytes"` LogDir string `json:"log_dir"` LogRetentionDays int `json:"log_retention_days"` @@ -66,6 +70,8 @@ type GlobalConfigResponse struct { ReactivateAfter string `json:"reactivate_after"` UpstreamIdleTimeout string `json:"upstream_idle_timeout"` ResponseHeaderTimeout string `json:"response_header_timeout"` + UpstreamProxyMode string `json:"upstream_proxy_mode"` + UpstreamProxyURL string `json:"upstream_proxy_url"` MaxRequestBodyBytes int64 `json:"max_request_body_bytes"` LogDir string `json:"log_dir"` LogRetentionDays int `json:"log_retention_days"` @@ -163,6 +169,8 @@ type ProviderRequest struct { BaseURL string `json:"base_url"` APIKey string `json:"api_key,omitempty"` APIKeys []string `json:"api_keys,omitempty"` + ProxyMode *string `json:"proxy_mode,omitempty"` + ProxyURL *string `json:"proxy_url,omitempty"` Overrides *ProviderOverridesRequest `json:"overrides,omitempty"` // Priority is 1-based. Omit to keep existing value (on updates) or to // auto-assign the next priority (on create). @@ -172,13 +180,15 @@ type ProviderRequest struct { // ProviderResponse is returned for provider listings (never includes api_key). type ProviderResponse struct { - Name string `json:"name"` - BaseURL string `json:"base_url"` - Priority int `json:"priority"` - Enabled bool `json:"enabled"` - KeyCount int `json:"key_count"` - Usage *ProviderUsageResponse `json:"usage,omitempty"` - Overrides *ProviderOverridesResponse `json:"overrides,omitempty"` + Name string `json:"name"` + BaseURL string `json:"base_url"` + ProxyMode string `json:"proxy_mode"` + ProxyURLHint string `json:"proxy_url_hint,omitempty"` + Priority int `json:"priority"` + Enabled bool `json:"enabled"` + KeyCount int `json:"key_count"` + Usage *ProviderUsageResponse `json:"usage,omitempty"` + Overrides *ProviderOverridesResponse `json:"overrides,omitempty"` } type ProviderUsageResponse struct { @@ -216,6 +226,8 @@ type ProviderExport struct { BaseURL string `json:"base_url"` APIKey string `json:"api_key,omitempty"` APIKeys []string `json:"api_keys,omitempty"` + ProxyMode string `json:"proxy_mode,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` Priority int `json:"priority"` Enabled *bool `json:"enabled,omitempty"` Overrides *ProviderOverridesResponse `json:"overrides,omitempty"` @@ -360,6 +372,8 @@ func toGlobalConfigResponse(gc config.GlobalConfig) GlobalConfigResponse { ReactivateAfter: gc.ReactivateAfter, UpstreamIdleTimeout: gc.UpstreamIdleTimeout, ResponseHeaderTimeout: gc.ResponseHeaderTimeout, + UpstreamProxyMode: string(gc.NormalizedUpstreamProxyMode()), + UpstreamProxyURL: gc.NormalizedUpstreamProxyURL(), MaxRequestBodyBytes: gc.MaxRequestBody, LogDir: gc.LogDir, LogRetentionDays: gc.LogRetentionDays, @@ -418,13 +432,15 @@ func toProviderResponses(providers []config.Provider, usageByProvider map[string out := make([]ProviderResponse, 0, len(providers)) for _, p := range providers { out = append(out, ProviderResponse{ - Name: p.Name, - BaseURL: p.BaseURL, - Priority: p.Priority, - Enabled: p.IsEnabled(), - KeyCount: p.KeyCount(), - Usage: mapProviderUsageResponse(usageByProvider[p.Name]), - Overrides: mapProviderOverridesResponse(p), + Name: p.Name, + BaseURL: p.BaseURL, + ProxyMode: string(p.NormalizedProxyMode()), + ProxyURLHint: proxyURLHint(p.NormalizedProxyURL()), + Priority: p.Priority, + Enabled: p.IsEnabled(), + KeyCount: p.KeyCount(), + Usage: mapProviderUsageResponse(usageByProvider[p.Name]), + Overrides: mapProviderOverridesResponse(p), }) } return out @@ -461,6 +477,8 @@ func toClientConfigExport(cc config.ClientConfig) ClientConfigExport { BaseURL: p.BaseURL, APIKey: p.APIKey, APIKeys: append([]string(nil), p.APIKeys...), + ProxyMode: string(p.NormalizedProxyMode()), + ProxyURL: p.NormalizedProxyURL(), Priority: p.Priority, Enabled: p.Enabled, Overrides: mapProviderOverridesResponse(p), @@ -473,6 +491,18 @@ func toClientConfigExport(cc config.ClientConfig) ClientConfigExport { } } +func proxyURLHint(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + parsed, err := url.Parse(trimmed) + if err != nil || parsed.Scheme == "" || parsed.Host == "" { + return "" + } + return parsed.Scheme + "://" + parsed.Host +} + func toProviderOverrideSupport(s providerOverrideSupport) ProviderOverrideSupport { return ProviderOverrideSupport{ Model: s.Model, diff --git a/internal/web/yaml_format.go b/internal/web/yaml_format.go index 5dd9a2d..4b0f2ec 100644 --- a/internal/web/yaml_format.go +++ b/internal/web/yaml_format.go @@ -62,6 +62,12 @@ func formatClientConfigYAML(clientType string, cc config.ClientConfig) []byte { writeBufferString(&b, fmt.Sprintf(" - name: %s\n", yamlDoubleQuote(p.Name))) writeBufferString(&b, fmt.Sprintf(" base_url: %s\n", yamlDoubleQuote(p.BaseURL))) + if p.NormalizedProxyMode() != config.ProviderProxyModeInherit { + writeBufferString(&b, fmt.Sprintf(" proxy_mode: %s\n", yamlDoubleQuote(string(p.NormalizedProxyMode())))) + } + if p.NormalizedProxyMode() == config.ProviderProxyModeCustom && p.NormalizedProxyURL() != "" { + writeBufferString(&b, fmt.Sprintf(" proxy_url: %s\n", yamlDoubleQuote(p.NormalizedProxyURL()))) + } keys := p.NormalizedAPIKeys() if len(keys) <= 1 { writeBufferString(&b, fmt.Sprintf(" api_key: %s\n", yamlDoubleQuote(p.PrimaryAPIKey()))) @@ -111,6 +117,10 @@ func formatGlobalConfigYAML(gc config.GlobalConfig) []byte { writeBufferString(&b, "# How long to wait for the upstream to return response headers.\n") writeBufferString(&b, "# Set to 0 to disable.\n") writeBufferString(&b, fmt.Sprintf("response_header_timeout: %s\n", yamlDoubleQuote(strings.TrimSpace(gc.ResponseHeaderTimeout)))) + writeBufferString(&b, "# Default upstream proxy for providers whose proxy_mode is inherit.\n") + writeBufferString(&b, fmt.Sprintf("upstream_proxy_mode: %s # inherit | direct | custom\n", yamlDoubleQuote(string(gc.NormalizedUpstreamProxyMode())))) + writeBufferString(&b, "# Supported proxy URLs: http://, https://, socks5://, socks5h://\n") + writeBufferString(&b, fmt.Sprintf("upstream_proxy_url: %s\n", yamlDoubleQuote(gc.NormalizedUpstreamProxyURL()))) writeBufferString(&b, "# Max request body size in bytes (clipal buffers request bodies for retries).\n") writeBufferString(&b, fmt.Sprintf("max_request_body_bytes: %d\n\n", gc.MaxRequestBody)) diff --git a/internal/web/yaml_format_test.go b/internal/web/yaml_format_test.go index 843d391..39415e3 100644 --- a/internal/web/yaml_format_test.go +++ b/internal/web/yaml_format_test.go @@ -150,6 +150,27 @@ func TestFormatClientConfigYAML_SingleNormalizedKeyUsesAPIKeyField(t *testing.T) } } +func TestFormatClientConfigYAML_WritesProxySettingsOnlyWhenNeeded(t *testing.T) { + cc := config.ClientConfig{ + Providers: []config.Provider{ + {Name: "inherit", BaseURL: "https://a.example", APIKey: "k1", Priority: 1, Enabled: boolPtr(true)}, + {Name: "direct", BaseURL: "https://b.example", APIKey: "k2", ProxyMode: config.ProviderProxyModeDirect, Priority: 2, Enabled: boolPtr(true)}, + {Name: "custom", BaseURL: "https://c.example", APIKey: "k3", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "http://127.0.0.1:7890", Priority: 3, Enabled: boolPtr(true)}, + }, + } + + got := string(formatClientConfigYAML("codex", cc)) + if strings.Count(got, "proxy_mode:") != 2 { + t.Fatalf("expected exactly two proxy_mode entries, got:\n%s", got) + } + if !strings.Contains(got, `name: "direct"`) || !strings.Contains(got, `proxy_mode: "direct"`) { + t.Fatalf("expected direct proxy_mode, got:\n%s", got) + } + if !strings.Contains(got, `name: "custom"`) || !strings.Contains(got, `proxy_mode: "custom"`) || !strings.Contains(got, `proxy_url: "http://127.0.0.1:7890"`) { + t.Fatalf("expected custom proxy settings, got:\n%s", got) + } +} + func TestFormatClientConfigYAML_RoundTripAndEscapesSpecialCharacters(t *testing.T) { cc := config.ClientConfig{ Providers: []config.Provider{ @@ -331,12 +352,16 @@ func TestFormatGlobalConfigYAML_RoundTripAndEscapesSpecialCharacters(t *testing. gc.LogDir = "logs\r\nfolder\tcontrol\x01" gc.LogRetentionDays = 7 gc.LogStdout = boolPtr(false) + gc.UpstreamProxyMode = config.ProviderProxyModeCustom + gc.UpstreamProxyURL = "http://127.0.0.1:7890" gc.Notifications.ProviderSwitch = boolPtr(false) got := string(formatGlobalConfigYAML(gc)) for _, want := range []string{ `listen_addr: "host\"quoted\"\\path"`, `log_dir: "logs\r\nfolder\tcontrol\x01"`, + `upstream_proxy_mode: "custom" # inherit | direct | custom`, + `upstream_proxy_url: "http://127.0.0.1:7890"`, `log_retention_days: 7 # default 7 days`, `log_stdout: false`, `provider_switch: false`, @@ -360,6 +385,12 @@ func TestFormatGlobalConfigYAML_RoundTripAndEscapesSpecialCharacters(t *testing. if loaded.Global.LogDir != gc.LogDir { t.Fatalf("log_dir = %q, want %q", loaded.Global.LogDir, gc.LogDir) } + if loaded.Global.NormalizedUpstreamProxyMode() != config.ProviderProxyModeCustom { + t.Fatalf("upstream_proxy_mode = %q, want custom", loaded.Global.NormalizedUpstreamProxyMode()) + } + if loaded.Global.NormalizedUpstreamProxyURL() != "http://127.0.0.1:7890" { + t.Fatalf("upstream_proxy_url = %q", loaded.Global.NormalizedUpstreamProxyURL()) + } if loaded.Global.LogRetentionDays != 7 { t.Fatalf("log_retention_days = %d, want 7", loaded.Global.LogRetentionDays) } From 6250c8b7ab3e18948f9476c061975d35fb6e1b33 Mon Sep 17 00:00:00 2001 From: Thomas Date: Sat, 11 Apr 2026 22:25:45 +0800 Subject: [PATCH 2/3] refactor(proxy): split proxy mode semantics and centralize config logic Address PR #11 review feedback: - Split global/proxy mode enums: global uses `environment|direct|custom`, provider uses `default|direct|custom` (no shared `inherit` value) - Move proxy apply/normalize/validate logic from web API into config layer - Share http.Client by effective proxy policy key for connection pool reuse --- docs/en/config-reference.md | 6 +- docs/zh/config-reference.md | 6 +- internal/config/config.go | 207 +++++++++++++++++++++++++--- internal/config/config_test.go | 130 ++++++++++++++++- internal/proxy/proxy.go | 122 +++++++++------- internal/proxy/proxy_test.go | 54 ++++++-- internal/proxy/reload_state_test.go | 4 +- internal/web/api.go | 103 ++------------ internal/web/api_test.go | 12 +- internal/web/static/app.js | 39 ++++-- internal/web/static/app.test.js | 33 ++++- internal/web/static/index.html | 8 +- internal/web/types.go | 4 +- internal/web/yaml_format.go | 6 +- internal/web/yaml_format_test.go | 8 +- 15 files changed, 520 insertions(+), 222 deletions(-) diff --git a/docs/en/config-reference.md b/docs/en/config-reference.md index fc63566..9ac33c4 100644 --- a/docs/en/config-reference.md +++ b/docs/en/config-reference.md @@ -60,7 +60,7 @@ providers: | `reactivate_after` | duration | `1h` | Auto-reactivation delay for temporarily deactivated providers; set `0` to disable temporary deactivation for auth, billing, and quota failures | | `upstream_idle_timeout` | duration | `3m` | Abort the current upstream attempt if no response body bytes arrive for too long | | `response_header_timeout` | duration | `2m` | Timeout while waiting for upstream response headers | -| `upstream_proxy_mode` | string | `inherit` | Default upstream proxy mode for providers that use `proxy_mode: inherit`; `inherit` / `direct` / `custom` | +| `upstream_proxy_mode` | string | `environment` | Default upstream proxy mode for providers that use `proxy_mode: default`; `environment` / `direct` / `custom` | | `upstream_proxy_url` | string | empty | Required when `upstream_proxy_mode: custom`; supports `http://`, `https://`, `socks5://`, and `socks5h://` proxy URLs | | `max_request_body_bytes` | int | `33554432` | Request body size limit, default 32 MiB | | `log_dir` | string | `/logs` | Log directory | @@ -181,7 +181,7 @@ providers: | `base_url` | string | yes | Upstream API base URL | | `api_key` | string | one of two | Single API key | | `api_keys` | array | one of two | Multiple API keys, used in order | -| `proxy_mode` | string | no | Upstream proxy mode for this provider; `inherit` follows the global default | +| `proxy_mode` | string | no | Upstream proxy mode for this provider; `default` follows the global default | | `proxy_url` | string | no | Required when `proxy_mode: custom`; supports `http://`, `https://`, `socks5://`, and `socks5h://` proxy URLs | | `priority` | int | no | Lower number = higher priority; omitted or `0` is treated as `1` | | `enabled` | bool | no | Defaults to `true` | @@ -193,7 +193,7 @@ providers: - Use `api_key` when you only have one key - Use `api_keys` when you want retries across multiple keys within the same provider -- Use global `upstream_proxy_mode` / `upstream_proxy_url` to define the default proxy for inherited providers +- Use global `upstream_proxy_mode` / `upstream_proxy_url` to define the default proxy for providers that use `proxy_mode: default` - Use provider `proxy_mode: direct` to bypass both the global default proxy and environment proxy settings - Use `model` when different upstream providers expose the same family under different model IDs - Use `reasoning_effort` and `thinking_budget_tokens` only when you want Clipal to override the client-sent defaults for that provider diff --git a/docs/zh/config-reference.md b/docs/zh/config-reference.md index 55a67c8..f626dfb 100644 --- a/docs/zh/config-reference.md +++ b/docs/zh/config-reference.md @@ -60,7 +60,7 @@ providers: | `reactivate_after` | duration | `1h` | provider 临时禁用后的自动恢复时间;设为 `0` 可禁用基于鉴权、计费、额度错误的临时禁用 | | `upstream_idle_timeout` | duration | `3m` | 上游响应 body 长时间无字节时中断当前尝试 | | `response_header_timeout` | duration | `2m` | 等待上游响应头的超时 | -| `upstream_proxy_mode` | string | `inherit` | 作为默认值应用到 `proxy_mode: inherit` 的 provider;可选 `inherit` / `direct` / `custom` | +| `upstream_proxy_mode` | string | `environment` | 作为默认值应用到 `proxy_mode: default` 的 provider;可选 `environment` / `direct` / `custom` | | `upstream_proxy_url` | string | 空 | 当 `upstream_proxy_mode: custom` 时必填;支持 `http://`、`https://`、`socks5://` 和 `socks5h://` 代理 URL | | `max_request_body_bytes` | int | `33554432` | 请求体大小上限,默认 32 MiB | | `log_dir` | string | `/logs` | 日志目录 | @@ -181,7 +181,7 @@ providers: | `base_url` | string | 是 | 上游 API Base URL | | `api_key` | string | 二选一 | 单个 API Key | | `api_keys` | array | 二选一 | 多个 API Key,按顺序使用 | -| `proxy_mode` | string | 否 | 该 provider 的上游代理模式;`inherit` 表示继承全局默认代理 | +| `proxy_mode` | string | 否 | 该 provider 的上游代理模式;`default` 表示使用全局默认代理 | | `proxy_url` | string | 否 | 当 `proxy_mode: custom` 时必填;支持 `http://`、`https://`、`socks5://` 和 `socks5h://` 代理 URL | | `priority` | int | 否 | 数字越小优先级越高;省略或 `0` 时按 `1` 处理 | | `enabled` | bool | 否 | 是否启用,默认 `true` | @@ -193,7 +193,7 @@ providers: - 只有一个 key 时用 `api_key` - 需要同 provider 多 key 轮转时用 `api_keys` -- 需要统一默认代理时,优先配置全局 `upstream_proxy_mode` / `upstream_proxy_url` +- 需要统一默认代理时,优先配置全局 `upstream_proxy_mode` / `upstream_proxy_url`,并让 provider 使用 `proxy_mode: default` - 需要让某个 provider 绕过全局默认代理和环境代理时,用 `proxy_mode: direct` - 不同上游对同一模型族使用不同模型 ID 时,可为该 provider 配置 `model` - 只有在你希望 Clipal 按 provider 覆盖客户端默认思考参数时,才配置 `reasoning_effort` 或 `thinking_budget_tokens` diff --git a/internal/config/config.go b/internal/config/config.go index b9992c2..10b7469 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -92,16 +92,16 @@ type GlobalConfig struct { UpstreamIdleTimeout string `yaml:"upstream_idle_timeout"` // ResponseHeaderTimeout controls how long we wait for the upstream to return // response headers after the request is fully written. Set to "0" to disable. - ResponseHeaderTimeout string `yaml:"response_header_timeout"` - UpstreamProxyMode ProviderProxyMode `yaml:"upstream_proxy_mode,omitempty"` - UpstreamProxyURL string `yaml:"upstream_proxy_url,omitempty"` - MaxRequestBody int64 `yaml:"max_request_body_bytes"` - LogDir string `yaml:"log_dir"` - LogRetentionDays int `yaml:"log_retention_days"` - LogStdout *bool `yaml:"log_stdout"` - Notifications NotificationsConfig `yaml:"notifications"` - CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"` - Routing RoutingConfig `yaml:"routing"` + ResponseHeaderTimeout string `yaml:"response_header_timeout"` + UpstreamProxyMode GlobalUpstreamProxyMode `yaml:"upstream_proxy_mode,omitempty"` + UpstreamProxyURL string `yaml:"upstream_proxy_url,omitempty"` + MaxRequestBody int64 `yaml:"max_request_body_bytes"` + LogDir string `yaml:"log_dir"` + LogRetentionDays int `yaml:"log_retention_days"` + LogStdout *bool `yaml:"log_stdout"` + Notifications NotificationsConfig `yaml:"notifications"` + CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"` + Routing RoutingConfig `yaml:"routing"` // Deprecated: retained only so older config.yaml files still load under // strict KnownFields decoding. Runtime no longer reads this field. IgnoreCountTokensFailover bool `yaml:"ignore_count_tokens_failover"` @@ -135,14 +135,32 @@ type ClaudeOverrides struct { ThinkingBudgetTokens *int `yaml:"thinking_budget_tokens,omitempty"` } +type GlobalUpstreamProxyMode string + +const ( + GlobalUpstreamProxyModeEnvironment GlobalUpstreamProxyMode = "environment" + GlobalUpstreamProxyModeDirect GlobalUpstreamProxyMode = "direct" + GlobalUpstreamProxyModeCustom GlobalUpstreamProxyMode = "custom" +) + type ProviderProxyMode string const ( - ProviderProxyModeInherit ProviderProxyMode = "inherit" + ProviderProxyModeDefault ProviderProxyMode = "default" ProviderProxyModeDirect ProviderProxyMode = "direct" ProviderProxyModeCustom ProviderProxyMode = "custom" ) +type UpstreamProxySettingsPatch struct { + Mode *string + URL *string +} + +type ProviderProxySettingsPatch struct { + Mode *string + URL *string +} + type providerYAML struct { Name string `yaml:"name"` BaseURL string `yaml:"base_url"` @@ -217,7 +235,7 @@ func (p *Provider) UnmarshalYAML(value *yaml.Node) error { func (p Provider) MarshalYAML() (any, error) { proxyMode := p.NormalizedProxyMode() proxyURL := p.NormalizedProxyURL() - if proxyMode == ProviderProxyModeInherit { + if proxyMode == ProviderProxyModeDefault { proxyMode = "" } if proxyMode != ProviderProxyModeCustom { @@ -298,12 +316,36 @@ func NormalizeProviderProxySettings(provider *Provider) { provider.ProxyURL = provider.NormalizedProxyURL() } -func (g GlobalConfig) NormalizedUpstreamProxyMode() ProviderProxyMode { +func NormalizeUpstreamProxySettings(global *GlobalConfig) { + if global == nil { + return + } + global.UpstreamProxyMode = global.NormalizedUpstreamProxyMode() + global.UpstreamProxyURL = global.NormalizedUpstreamProxyURL() +} + +func NormalizeUpstreamProxySettingsPatch(patch *UpstreamProxySettingsPatch) { + if patch == nil { + return + } + patch.Mode = normalizeLowerTrimStringPtr(patch.Mode) + patch.URL = normalizeTrimStringPtr(patch.URL) +} + +func NormalizeProviderProxySettingsPatch(patch *ProviderProxySettingsPatch) { + if patch == nil { + return + } + patch.Mode = normalizeLowerTrimStringPtr(patch.Mode) + patch.URL = normalizeTrimStringPtr(patch.URL) +} + +func (g GlobalConfig) NormalizedUpstreamProxyMode() GlobalUpstreamProxyMode { mode := strings.ToLower(strings.TrimSpace(string(g.UpstreamProxyMode))) if mode == "" { - return ProviderProxyModeInherit + return GlobalUpstreamProxyModeEnvironment } - return ProviderProxyMode(mode) + return GlobalUpstreamProxyMode(mode) } func (g GlobalConfig) NormalizedUpstreamProxyURL() string { @@ -313,7 +355,7 @@ func (g GlobalConfig) NormalizedUpstreamProxyURL() string { func (p Provider) NormalizedProxyMode() ProviderProxyMode { mode := strings.ToLower(strings.TrimSpace(string(p.ProxyMode))) if mode == "" { - return ProviderProxyModeInherit + return ProviderProxyModeDefault } return ProviderProxyMode(mode) } @@ -322,9 +364,115 @@ func (p Provider) NormalizedProxyURL() string { return strings.TrimSpace(p.ProxyURL) } -func validateProxySettings(scope string, mode ProviderProxyMode, rawURL string) error { +func ApplyUpstreamProxySettings(global *GlobalConfig, patch UpstreamProxySettingsPatch) error { + if global == nil { + return nil + } + NormalizeUpstreamProxySettingsPatch(&patch) + if patch.Mode == nil { + if patch.URL != nil { + return fmt.Errorf("upstream_proxy_url requires upstream_proxy_mode") + } + return nil + } + + mode := GlobalUpstreamProxyMode(*patch.Mode) + proxyURL := global.NormalizedUpstreamProxyURL() + if patch.URL != nil { + proxyURL = *patch.URL + } + + switch mode { + case GlobalUpstreamProxyModeEnvironment, GlobalUpstreamProxyModeDirect: + if patch.URL != nil && proxyURL != "" { + return fmt.Errorf("upstream_proxy_url requires upstream_proxy_mode custom") + } + global.UpstreamProxyMode = mode + global.UpstreamProxyURL = "" + return nil + case GlobalUpstreamProxyModeCustom: + if proxyURL == "" { + return fmt.Errorf("upstream_proxy_url is required when upstream_proxy_mode=custom") + } + if err := ValidateProxyURL(proxyURL); err != nil { + return err + } + global.UpstreamProxyMode = mode + global.UpstreamProxyURL = proxyURL + return nil + default: + return fmt.Errorf("upstream_proxy_mode must be one of environment, direct, custom") + } +} + +func ApplyProviderProxySettings(provider *Provider, patch ProviderProxySettingsPatch, isUpdate bool) error { + if provider == nil { + return nil + } + NormalizeProviderProxySettingsPatch(&patch) + if !isUpdate && patch.Mode == nil && patch.URL == nil { + provider.ProxyMode = ProviderProxyModeDefault + provider.ProxyURL = "" + return nil + } + if patch.Mode == nil { + if patch.URL != nil { + return fmt.Errorf("proxy_url requires proxy_mode") + } + return nil + } + + mode := ProviderProxyMode(*patch.Mode) + proxyURL := provider.NormalizedProxyURL() + if patch.URL != nil { + proxyURL = *patch.URL + } + + switch mode { + case ProviderProxyModeDefault, ProviderProxyModeDirect: + if patch.URL != nil && proxyURL != "" { + return fmt.Errorf("proxy_url requires proxy_mode custom") + } + provider.ProxyMode = mode + provider.ProxyURL = "" + return nil + case ProviderProxyModeCustom: + if proxyURL == "" { + return fmt.Errorf("proxy_url is required when proxy_mode=custom") + } + if err := ValidateProxyURL(proxyURL); err != nil { + return err + } + provider.ProxyMode = mode + provider.ProxyURL = proxyURL + return nil + default: + return fmt.Errorf("proxy_mode must be one of default, direct, custom") + } +} + +func validateGlobalProxySettings(scope string, mode GlobalUpstreamProxyMode, rawURL string) error { switch mode { - case ProviderProxyModeInherit, ProviderProxyModeDirect: + case GlobalUpstreamProxyModeEnvironment, GlobalUpstreamProxyModeDirect: + if rawURL != "" { + return fmt.Errorf("%s: upstream_proxy_url requires upstream_proxy_mode custom", scope) + } + case GlobalUpstreamProxyModeCustom: + if rawURL == "" { + return fmt.Errorf("%s: upstream_proxy_url is required when upstream_proxy_mode=custom", scope) + } + if err := ValidateProxyURL(rawURL); err != nil { + return fmt.Errorf("%s: %w", scope, err) + } + default: + return fmt.Errorf("%s: invalid upstream_proxy_mode %q", scope, mode) + } + return nil +} + +func validateProviderProxySettings(scope string, mode ProviderProxyMode, rawURL string) error { + switch mode { + case ProviderProxyModeDefault, ProviderProxyModeDirect: if rawURL != "" { return fmt.Errorf("%s: proxy_url requires proxy_mode custom", scope) } @@ -341,6 +489,22 @@ func validateProxySettings(scope string, mode ProviderProxyMode, rawURL string) return nil } +func normalizeLowerTrimStringPtr(v *string) *string { + if v == nil { + return nil + } + trimmed := strings.ToLower(strings.TrimSpace(*v)) + return &trimmed +} + +func normalizeTrimStringPtr(v *string) *string { + if v == nil { + return nil + } + trimmed := strings.TrimSpace(*v) + return &trimmed +} + // IsEnabled returns whether the provider is enabled (default true) func (p *Provider) IsEnabled() bool { if p.Enabled == nil { @@ -441,7 +605,7 @@ func DefaultGlobalConfig() GlobalConfig { ReactivateAfter: "1h", UpstreamIdleTimeout: "3m", ResponseHeaderTimeout: "2m", - UpstreamProxyMode: ProviderProxyModeInherit, + UpstreamProxyMode: GlobalUpstreamProxyModeEnvironment, UpstreamProxyURL: "", // Default body limit: 32 MiB. clipal buffers request bodies to support retries, // so a hard cap prevents unbounded memory usage. @@ -502,6 +666,7 @@ func Load(configDir string) (*Config, error) { if err := loadYAML(globalPath, &cfg.Global); err != nil && !os.IsNotExist(err) { return nil, fmt.Errorf("failed to load global config: %w", err) } + NormalizeUpstreamProxySettings(&cfg.Global) if err := migrateLegacyClientConfigFiles(configDir); err != nil { return nil, err @@ -781,7 +946,7 @@ func (c *Config) Validate() error { if c.Global.LogRetentionDays < 0 { return fmt.Errorf("invalid log_retention_days: %d", c.Global.LogRetentionDays) } - if err := validateProxySettings("global upstream proxy", c.Global.NormalizedUpstreamProxyMode(), c.Global.NormalizedUpstreamProxyURL()); err != nil { + if err := validateGlobalProxySettings("global upstream proxy", c.Global.NormalizedUpstreamProxyMode(), c.Global.NormalizedUpstreamProxyURL()); err != nil { return err } @@ -893,7 +1058,7 @@ func validateProviders(clientName string, providers []Provider) error { if p.Priority < 1 { return fmt.Errorf("%s provider %s: priority must be >= 1", clientName, p.Name) } - if err := validateProxySettings(fmt.Sprintf("%s provider %s", clientName, p.Name), p.NormalizedProxyMode(), p.NormalizedProxyURL()); err != nil { + if err := validateProviderProxySettings(fmt.Sprintf("%s provider %s", clientName, p.Name), p.NormalizedProxyMode(), p.NormalizedProxyURL()); err != nil { return err } if !providerOverridesSupportedForClient(clientName, p.Overrides) { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ee90b84..1c01c2e 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -706,7 +706,7 @@ func TestGetConfigDir_RespectsEnvironmentOverride(t *testing.T) { } } -func TestLoad_ProviderProxyModeDefaultsToInherit(t *testing.T) { +func TestLoad_ProviderProxyModeDefaultsToDefault(t *testing.T) { t.Parallel() dir := t.TempDir() @@ -722,14 +722,26 @@ providers: if err != nil { t.Fatalf("Load: %v", err) } - if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != ProviderProxyModeInherit { - t.Fatalf("proxy mode = %q, want %q", got, ProviderProxyModeInherit) + if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != ProviderProxyModeDefault { + t.Fatalf("proxy mode = %q, want %q", got, ProviderProxyModeDefault) } if got := cfg.OpenAI.Providers[0].NormalizedProxyURL(); got != "" { t.Fatalf("proxy url = %q, want empty", got) } } +func TestLoad_GlobalUpstreamProxyModeDefaultsToEnvironment(t *testing.T) { + t.Parallel() + + cfg, err := Load(t.TempDir()) + if err != nil { + t.Fatalf("Load: %v", err) + } + if got := cfg.Global.NormalizedUpstreamProxyMode(); got != GlobalUpstreamProxyModeEnvironment { + t.Fatalf("upstream proxy mode = %q, want %q", got, GlobalUpstreamProxyModeEnvironment) + } +} + func TestValidate_ProviderProxySettings(t *testing.T) { t.Parallel() @@ -806,6 +818,14 @@ func TestValidate_ProviderProxySettings(t *testing.T) { t.Fatalf("Validate err = %v", err) } }) + + t.Run("rejects removed inherit mode", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyMode("inherit"), "")} + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), `invalid proxy_mode "inherit"`) { + t.Fatalf("Validate err = %v", err) + } + }) } func TestValidate_GlobalUpstreamProxySettings(t *testing.T) { @@ -818,7 +838,7 @@ func TestValidate_GlobalUpstreamProxySettings(t *testing.T) { Gemini: ClientConfig{Mode: ClientModeAuto}, } - cfg.Global.UpstreamProxyMode = ProviderProxyModeCustom + cfg.Global.UpstreamProxyMode = GlobalUpstreamProxyModeCustom cfg.Global.UpstreamProxyURL = "http://127.0.0.1:7890" if err := cfg.Validate(); err != nil { t.Fatalf("Validate: %v", err) @@ -838,4 +858,106 @@ func TestValidate_GlobalUpstreamProxySettings(t *testing.T) { if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), "proxy_url scheme must be http, https, socks5, or socks5h") { t.Fatalf("Validate err = %v", err) } + + cfg.Global.UpstreamProxyMode = GlobalUpstreamProxyMode("inherit") + cfg.Global.UpstreamProxyURL = "" + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), `invalid upstream_proxy_mode "inherit"`) { + t.Fatalf("Validate err = %v", err) + } +} + +func TestApplyProviderProxySettings(t *testing.T) { + t.Parallel() + + t.Run("create defaults to default mode", func(t *testing.T) { + var provider Provider + if err := ApplyProviderProxySettings(&provider, ProviderProxySettingsPatch{}, false); err != nil { + t.Fatalf("ApplyProviderProxySettings: %v", err) + } + if provider.ProxyMode != ProviderProxyModeDefault || provider.ProxyURL != "" { + t.Fatalf("provider = %#v", provider) + } + }) + + t.Run("update retains existing custom url", func(t *testing.T) { + provider := Provider{ProxyMode: ProviderProxyModeCustom, ProxyURL: "http://127.0.0.1:7890"} + mode := "custom" + if err := ApplyProviderProxySettings(&provider, ProviderProxySettingsPatch{Mode: &mode}, true); err != nil { + t.Fatalf("ApplyProviderProxySettings: %v", err) + } + if provider.ProxyMode != ProviderProxyModeCustom || provider.ProxyURL != "http://127.0.0.1:7890" { + t.Fatalf("provider = %#v", provider) + } + }) + + t.Run("update clears url on direct", func(t *testing.T) { + provider := Provider{ProxyMode: ProviderProxyModeCustom, ProxyURL: "http://127.0.0.1:7890"} + mode := "direct" + if err := ApplyProviderProxySettings(&provider, ProviderProxySettingsPatch{Mode: &mode}, true); err != nil { + t.Fatalf("ApplyProviderProxySettings: %v", err) + } + if provider.ProxyMode != ProviderProxyModeDirect || provider.ProxyURL != "" { + t.Fatalf("provider = %#v", provider) + } + }) + + t.Run("rejects default with url", func(t *testing.T) { + mode := "default" + rawURL := "http://127.0.0.1:7890" + if err := ApplyProviderProxySettings(&Provider{}, ProviderProxySettingsPatch{Mode: &mode, URL: &rawURL}, false); err == nil || !strings.Contains(err.Error(), "proxy_url requires proxy_mode custom") { + t.Fatalf("ApplyProviderProxySettings err = %v", err) + } + }) + + t.Run("rejects removed inherit mode", func(t *testing.T) { + mode := "inherit" + if err := ApplyProviderProxySettings(&Provider{}, ProviderProxySettingsPatch{Mode: &mode}, false); err == nil || !strings.Contains(err.Error(), "proxy_mode must be one of default, direct, custom") { + t.Fatalf("ApplyProviderProxySettings err = %v", err) + } + }) +} + +func TestApplyUpstreamProxySettings(t *testing.T) { + t.Parallel() + + t.Run("omitted patch preserves existing values", func(t *testing.T) { + global := GlobalConfig{ + UpstreamProxyMode: GlobalUpstreamProxyModeCustom, + UpstreamProxyURL: "http://127.0.0.1:7890", + } + if err := ApplyUpstreamProxySettings(&global, UpstreamProxySettingsPatch{}); err != nil { + t.Fatalf("ApplyUpstreamProxySettings: %v", err) + } + if global.UpstreamProxyMode != GlobalUpstreamProxyModeCustom || global.UpstreamProxyURL != "http://127.0.0.1:7890" { + t.Fatalf("global = %#v", global) + } + }) + + t.Run("switches to environment and clears url", func(t *testing.T) { + global := GlobalConfig{ + UpstreamProxyMode: GlobalUpstreamProxyModeCustom, + UpstreamProxyURL: "http://127.0.0.1:7890", + } + mode := "environment" + if err := ApplyUpstreamProxySettings(&global, UpstreamProxySettingsPatch{Mode: &mode}); err != nil { + t.Fatalf("ApplyUpstreamProxySettings: %v", err) + } + if global.UpstreamProxyMode != GlobalUpstreamProxyModeEnvironment || global.UpstreamProxyURL != "" { + t.Fatalf("global = %#v", global) + } + }) + + t.Run("rejects url without mode", func(t *testing.T) { + rawURL := "http://127.0.0.1:7890" + if err := ApplyUpstreamProxySettings(&GlobalConfig{}, UpstreamProxySettingsPatch{URL: &rawURL}); err == nil || !strings.Contains(err.Error(), "upstream_proxy_url requires upstream_proxy_mode") { + t.Fatalf("ApplyUpstreamProxySettings err = %v", err) + } + }) + + t.Run("rejects removed inherit mode", func(t *testing.T) { + mode := "inherit" + if err := ApplyUpstreamProxySettings(&GlobalConfig{}, UpstreamProxySettingsPatch{Mode: &mode}); err == nil || !strings.Contains(err.Error(), "upstream_proxy_mode must be one of environment, direct, custom") { + t.Fatalf("ApplyUpstreamProxySettings err = %v", err) + } + }) } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 5f2902e..4c9fa06 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -94,6 +94,19 @@ type routingRuntimeSettings struct { maxInlineWait time.Duration } +type upstreamProxyPolicyMode string + +const ( + upstreamProxyPolicyEnvironment upstreamProxyPolicyMode = "environment" + upstreamProxyPolicyDirect upstreamProxyPolicyMode = "direct" + upstreamProxyPolicyCustom upstreamProxyPolicyMode = "custom" +) + +type upstreamProxyPolicyKey struct { + mode upstreamProxyPolicyMode + url string +} + // Router manages multiple client proxies type Router struct { cfg *config.Config @@ -129,30 +142,29 @@ func (r *Router) TelemetryStore() *telemetry.Store { // ClientProxy handles requests for a specific client type type ClientProxy struct { - clientType ClientType - mode config.ClientMode - pinnedProvider string - pinnedIndex int - providers []config.Provider - providerKeys [][]string - currentIndex int - countTokensIndex int - responsesIndex int - geminiStreamIndex int - currentKeyIndex []int - countTokensKeyIndex []int - responsesKeyIndex []int - geminiStreamKeyIndex []int - mu sync.RWMutex - httpClient *http.Client - providerHTTPClients []*http.Client - providerProxyModes []config.ProviderProxyMode - providerProxyURLs []string - deactivated []providerDeactivation - keyDeactivated [][]providerDeactivation - providerBusy []providerBusyState - reactivateAfter time.Duration - upstreamIdle time.Duration + clientType ClientType + mode config.ClientMode + pinnedProvider string + pinnedIndex int + providers []config.Provider + providerKeys [][]string + currentIndex int + countTokensIndex int + responsesIndex int + geminiStreamIndex int + currentKeyIndex []int + countTokensKeyIndex []int + responsesKeyIndex []int + geminiStreamKeyIndex []int + mu sync.RWMutex + httpClient *http.Client + providerHTTPClients []*http.Client + providerProxyPolicies []upstreamProxyPolicyKey + deactivated []providerDeactivation + keyDeactivated [][]providerDeactivation + providerBusy []providerBusyState + reactivateAfter time.Duration + upstreamIdle time.Duration stickyBindings map[string]stickyBinding responseLookup map[string]stickyLookupEntry @@ -228,10 +240,10 @@ func NewRouter(cfg *config.Config) *Router { } func newClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvider string, providers []config.Provider, reactivateAfter time.Duration, upstreamIdle time.Duration, responseHeaderTimeout time.Duration, cbCfg circuitBreakerConfig, telemetryStore ...*telemetry.Store) *ClientProxy { - return newClientProxyWithGlobalProxy(clientType, mode, pinnedProvider, providers, reactivateAfter, upstreamIdle, responseHeaderTimeout, cbCfg, config.ProviderProxyModeInherit, "", telemetryStore...) + return newClientProxyWithGlobalProxy(clientType, mode, pinnedProvider, providers, reactivateAfter, upstreamIdle, responseHeaderTimeout, cbCfg, config.GlobalUpstreamProxyModeEnvironment, "", telemetryStore...) } -func newClientProxyWithGlobalProxy(clientType ClientType, mode config.ClientMode, pinnedProvider string, providers []config.Provider, reactivateAfter time.Duration, upstreamIdle time.Duration, responseHeaderTimeout time.Duration, cbCfg circuitBreakerConfig, globalProxyMode config.ProviderProxyMode, globalProxyURL string, telemetryStore ...*telemetry.Store) *ClientProxy { +func newClientProxyWithGlobalProxy(clientType ClientType, mode config.ClientMode, pinnedProvider string, providers []config.Provider, reactivateAfter time.Duration, upstreamIdle time.Duration, responseHeaderTimeout time.Duration, cbCfg circuitBreakerConfig, globalProxyMode config.GlobalUpstreamProxyMode, globalProxyURL string, telemetryStore ...*telemetry.Store) *ClientProxy { var store *telemetry.Store if len(telemetryStore) > 0 { store = telemetryStore[0] @@ -259,14 +271,22 @@ func newClientProxyWithGlobalProxy(clientType ClientType, mode config.ClientMode responsesKeyIndex := make([]int, len(providers)) geminiStreamKeyIndex := make([]int, len(providers)) providerHTTPClients := make([]*http.Client, len(providers)) - providerProxyModes := make([]config.ProviderProxyMode, len(providers)) - providerProxyURLs := make([]string, len(providers)) + providerProxyPolicies := make([]upstreamProxyPolicyKey, len(providers)) + policyClients := map[upstreamProxyPolicyKey]*http.Client{ + {mode: upstreamProxyPolicyEnvironment}: sharedClient, + } keyDeactivated := make([][]providerDeactivation, len(providers)) for i := range providers { breakers[i] = newCircuitBreaker(cbCfg) providerKeys[i] = providers[i].NormalizedAPIKeys() - providerProxyModes[i], providerProxyURLs[i] = effectiveProviderProxySettings(providers[i], globalProxyMode, globalProxyURL) - providerHTTPClients[i] = newProviderHTTPClient(providerProxyModes[i], providerProxyURLs[i], providers[i].Name, sharedClient, dialer, responseHeaderTimeout) + providerProxyPolicies[i] = effectiveProviderProxyPolicy(providers[i], globalProxyMode, globalProxyURL) + if client, ok := policyClients[providerProxyPolicies[i]]; ok { + providerHTTPClients[i] = client + } else { + client = newProviderHTTPClient(providerProxyPolicies[i], providers[i].Name, sharedClient, dialer, responseHeaderTimeout) + policyClients[providerProxyPolicies[i]] = client + providerHTTPClients[i] = client + } if len(providerKeys[i]) == 0 { providerKeys[i] = []string{""} } @@ -309,8 +329,7 @@ func newClientProxyWithGlobalProxy(clientType ClientType, mode config.ClientMode geminiStreamKeyIndex: geminiStreamKeyIndex, telemetry: store, providerHTTPClients: providerHTTPClients, - providerProxyModes: providerProxyModes, - providerProxyURLs: providerProxyURLs, + providerProxyPolicies: providerProxyPolicies, deactivated: make([]providerDeactivation, len(providers)), keyDeactivated: keyDeactivated, providerBusy: make([]providerBusyState, len(providers)), @@ -343,32 +362,32 @@ func newUpstreamHTTPClient(dialer *net.Dialer, responseHeaderTimeout time.Durati } } -func effectiveProviderProxySettings(provider config.Provider, globalMode config.ProviderProxyMode, globalURL string) (config.ProviderProxyMode, string) { +func effectiveProviderProxyPolicy(provider config.Provider, globalMode config.GlobalUpstreamProxyMode, globalURL string) upstreamProxyPolicyKey { switch provider.NormalizedProxyMode() { case config.ProviderProxyModeDirect: - return config.ProviderProxyModeDirect, "" + return upstreamProxyPolicyKey{mode: upstreamProxyPolicyDirect} case config.ProviderProxyModeCustom: - return config.ProviderProxyModeCustom, provider.NormalizedProxyURL() + return upstreamProxyPolicyKey{mode: upstreamProxyPolicyCustom, url: provider.NormalizedProxyURL()} default: switch globalMode { - case config.ProviderProxyModeDirect: - return config.ProviderProxyModeDirect, "" - case config.ProviderProxyModeCustom: - return config.ProviderProxyModeCustom, strings.TrimSpace(globalURL) + case config.GlobalUpstreamProxyModeDirect: + return upstreamProxyPolicyKey{mode: upstreamProxyPolicyDirect} + case config.GlobalUpstreamProxyModeCustom: + return upstreamProxyPolicyKey{mode: upstreamProxyPolicyCustom, url: strings.TrimSpace(globalURL)} default: - return config.ProviderProxyModeInherit, "" + return upstreamProxyPolicyKey{mode: upstreamProxyPolicyEnvironment} } } } -func newProviderHTTPClient(mode config.ProviderProxyMode, proxyURLRaw string, providerName string, sharedClient *http.Client, dialer *net.Dialer, responseHeaderTimeout time.Duration) *http.Client { - switch mode { - case config.ProviderProxyModeDirect: +func newProviderHTTPClient(policy upstreamProxyPolicyKey, providerName string, sharedClient *http.Client, dialer *net.Dialer, responseHeaderTimeout time.Duration) *http.Client { + switch policy.mode { + case upstreamProxyPolicyDirect: return newUpstreamHTTPClient(dialer, responseHeaderTimeout, nil) - case config.ProviderProxyModeCustom: - proxyURL, err := config.ParseProxyURL(proxyURLRaw) + case upstreamProxyPolicyCustom: + proxyURL, err := config.ParseProxyURL(policy.url) if err != nil { - logger.Warn("invalid custom proxy for provider %s; falling back to inherited proxy settings", providerName) + logger.Warn("invalid custom proxy for provider %s; falling back to environment proxy settings", providerName) return sharedClient } return newUpstreamHTTPClient(dialer, responseHeaderTimeout, http.ProxyURL(proxyURL)) @@ -720,7 +739,7 @@ func (r *Router) reloadProviderConfigsLocked() error { return nil } -func newReloadedClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvider string, providers []config.Provider, reactivateAfter time.Duration, upstreamIdle time.Duration, responseHeaderTimeout time.Duration, cbCfg circuitBreakerConfig, routing routingRuntimeSettings, globalProxyMode config.ProviderProxyMode, globalProxyURL string, old *ClientProxy, telemetryStore *telemetry.Store) *ClientProxy { +func newReloadedClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvider string, providers []config.Provider, reactivateAfter time.Duration, upstreamIdle time.Duration, responseHeaderTimeout time.Duration, cbCfg circuitBreakerConfig, routing routingRuntimeSettings, globalProxyMode config.GlobalUpstreamProxyMode, globalProxyURL string, old *ClientProxy, telemetryStore *telemetry.Store) *ClientProxy { cp := newClientProxyWithGlobalProxy(clientType, mode, pinnedProvider, providers, reactivateAfter, upstreamIdle, responseHeaderTimeout, cbCfg, globalProxyMode, globalProxyURL, telemetryStore) cp.applyRoutingRuntimeSettings(routing) if old != nil { @@ -752,7 +771,7 @@ func (cp *ClientProxy) inheritRuntimeState(old *ClientProxy) { if !ok { continue } - if !sameProviderRuntimeIdentity(cp.providers[newIdx], cp.providerProxyModes[newIdx], cp.providerProxyURLs[newIdx], old.providers[oldIdx], old.providerProxyModes[oldIdx], old.providerProxyURLs[oldIdx]) { + if !sameProviderRuntimeIdentity(cp.providers[newIdx], cp.providerProxyPolicies[newIdx], old.providers[oldIdx], old.providerProxyPolicies[oldIdx]) { continue } cp.deactivated[newIdx] = old.deactivated[oldIdx] @@ -767,7 +786,7 @@ func (cp *ClientProxy) inheritRuntimeState(old *ClientProxy) { if !ok { continue } - if !sameProviderRuntimeIdentity(cp.providers[newIdx], cp.providerProxyModes[newIdx], cp.providerProxyURLs[newIdx], old.providers[oldIdx], old.providerProxyModes[oldIdx], old.providerProxyURLs[oldIdx]) { + if !sameProviderRuntimeIdentity(cp.providers[newIdx], cp.providerProxyPolicies[newIdx], old.providers[oldIdx], old.providerProxyPolicies[oldIdx]) { continue } newByOldIndex[oldIdx] = newIdx @@ -906,11 +925,10 @@ func inheritStickyRuntimeState(dst *ClientProxy, src *ClientProxy, indexMap map[ } } -func sameProviderRuntimeIdentity(a config.Provider, aMode config.ProviderProxyMode, aURL string, b config.Provider, bMode config.ProviderProxyMode, bURL string) bool { +func sameProviderRuntimeIdentity(a config.Provider, aPolicy upstreamProxyPolicyKey, b config.Provider, bPolicy upstreamProxyPolicyKey) bool { return a.Name == b.Name && strings.TrimSpace(a.BaseURL) == strings.TrimSpace(b.BaseURL) && - aMode == bMode && - strings.TrimSpace(aURL) == strings.TrimSpace(bURL) + aPolicy == bPolicy } func providerIndexByName(providers []config.Provider, name string) int { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 04505d6..2bf016b 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -154,7 +154,7 @@ func TestNewClientProxy_UsesProviderSpecificProxyModes(t *testing.T) { t.Setenv("HTTP_PROXY", "http://env-proxy:8080") cp := newClientProxy(ClientOpenAI, config.ClientModeAuto, "", []config.Provider{ - {Name: "inherit", BaseURL: "http://inherit.example", APIKey: "k1", Priority: 1}, + {Name: "default", BaseURL: "http://default.example", APIKey: "k1", Priority: 1}, {Name: "direct", BaseURL: "http://direct.example", APIKey: "k2", ProxyMode: config.ProviderProxyModeDirect, Priority: 2}, {Name: "custom", BaseURL: "http://custom.example", APIKey: "k3", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "http://custom-proxy:9090", Priority: 3}, }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}) @@ -164,16 +164,16 @@ func TestNewClientProxy_UsesProviderSpecificProxyModes(t *testing.T) { t.Fatalf("http.NewRequest: %v", err) } - inheritTransport, ok := cp.upstreamHTTPClient(0).Transport.(*http.Transport) + defaultTransport, ok := cp.upstreamHTTPClient(0).Transport.(*http.Transport) if !ok { - t.Fatalf("inherit transport type = %T", cp.upstreamHTTPClient(0).Transport) + t.Fatalf("default transport type = %T", cp.upstreamHTTPClient(0).Transport) } - inheritProxy, err := inheritTransport.Proxy(req) + defaultProxy, err := defaultTransport.Proxy(req) if err != nil { - t.Fatalf("inherit proxy: %v", err) + t.Fatalf("default proxy: %v", err) } - if inheritProxy == nil || inheritProxy.String() != "http://env-proxy:8080" { - t.Fatalf("inherit proxy = %v, want http://env-proxy:8080", inheritProxy) + if defaultProxy == nil || defaultProxy.String() != "http://env-proxy:8080" { + t.Fatalf("default proxy = %v, want http://env-proxy:8080", defaultProxy) } directTransport, ok := cp.upstreamHTTPClient(1).Transport.(*http.Transport) @@ -203,7 +203,7 @@ func TestNewClientProxy_UsesProviderSpecificProxyModes(t *testing.T) { } } -func TestNewClientProxyWithGlobalProxy_UsesGlobalDefaultForInheritedProviders(t *testing.T) { +func TestNewClientProxyWithGlobalProxy_UsesGlobalDefaultForDefaultProviders(t *testing.T) { t.Parallel() req, err := http.NewRequest(http.MethodGet, "http://upstream.example/v1/test", nil) @@ -212,8 +212,8 @@ func TestNewClientProxyWithGlobalProxy_UsesGlobalDefaultForInheritedProviders(t } directCP := newClientProxyWithGlobalProxy(ClientOpenAI, config.ClientModeAuto, "", []config.Provider{ - {Name: "inherit", BaseURL: "http://inherit.example", APIKey: "k1", Priority: 1}, - }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}, config.ProviderProxyModeDirect, "") + {Name: "default", BaseURL: "http://default.example", APIKey: "k1", Priority: 1}, + }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}, config.GlobalUpstreamProxyModeDirect, "") directTransport, ok := directCP.upstreamHTTPClient(0).Transport.(*http.Transport) if !ok { @@ -230,8 +230,8 @@ func TestNewClientProxyWithGlobalProxy_UsesGlobalDefaultForInheritedProviders(t } customCP := newClientProxyWithGlobalProxy(ClientOpenAI, config.ClientModeAuto, "", []config.Provider{ - {Name: "inherit", BaseURL: "http://inherit.example", APIKey: "k1", Priority: 1}, - }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}, config.ProviderProxyModeCustom, "http://global-proxy:8081") + {Name: "default", BaseURL: "http://default.example", APIKey: "k1", Priority: 1}, + }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}, config.GlobalUpstreamProxyModeCustom, "http://global-proxy:8081") customTransport, ok := customCP.upstreamHTTPClient(0).Transport.(*http.Transport) if !ok { @@ -246,6 +246,36 @@ func TestNewClientProxyWithGlobalProxy_UsesGlobalDefaultForInheritedProviders(t } } +func TestNewClientProxyWithGlobalProxy_SharesHTTPClientsByEffectivePolicy(t *testing.T) { + t.Parallel() + + cp := newClientProxyWithGlobalProxy(ClientOpenAI, config.ClientModeAuto, "", []config.Provider{ + {Name: "default-a", BaseURL: "http://default-a.example", APIKey: "k1", Priority: 1}, + {Name: "default-b", BaseURL: "http://default-b.example", APIKey: "k2", Priority: 2}, + {Name: "direct-a", BaseURL: "http://direct-a.example", APIKey: "k3", ProxyMode: config.ProviderProxyModeDirect, Priority: 3}, + {Name: "direct-b", BaseURL: "http://direct-b.example", APIKey: "k4", ProxyMode: config.ProviderProxyModeDirect, Priority: 4}, + {Name: "custom-a", BaseURL: "http://custom-a.example", APIKey: "k5", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "http://custom-proxy:9090", Priority: 5}, + {Name: "custom-b", BaseURL: "http://custom-b.example", APIKey: "k6", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "http://custom-proxy:9090", Priority: 6}, + {Name: "custom-c", BaseURL: "http://custom-c.example", APIKey: "k7", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "http://other-proxy:9090", Priority: 7}, + }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}, config.GlobalUpstreamProxyModeEnvironment, "") + + if cp.upstreamHTTPClient(0) != cp.upstreamHTTPClient(1) { + t.Fatalf("default providers should share the environment client") + } + if cp.upstreamHTTPClient(2) != cp.upstreamHTTPClient(3) { + t.Fatalf("direct providers should share the direct client") + } + if cp.upstreamHTTPClient(4) != cp.upstreamHTTPClient(5) { + t.Fatalf("matching custom proxy URLs should share the same client") + } + if cp.upstreamHTTPClient(4) == cp.upstreamHTTPClient(6) { + t.Fatalf("different custom proxy URLs should not share the same client") + } + if cp.upstreamHTTPClient(0) == cp.upstreamHTTPClient(2) { + t.Fatalf("environment and direct policies should not share the same client") + } +} + func TestNewClientProxy_UsesCustomSocksProxy(t *testing.T) { t.Parallel() diff --git a/internal/proxy/reload_state_test.go b/internal/proxy/reload_state_test.go index 15c5f79..44dfa2c 100644 --- a/internal/proxy/reload_state_test.go +++ b/internal/proxy/reload_state_test.go @@ -529,7 +529,7 @@ func TestReloadProviderConfigsLocked_DoesNotPreserveSuppressionStateWhenProxyCha } } -func TestReloadProviderConfigsLocked_DoesNotPreserveSuppressionStateWhenGlobalProxyChangesForInheritedProvider(t *testing.T) { +func TestReloadProviderConfigsLocked_DoesNotPreserveSuppressionStateWhenGlobalProxyChangesForDefaultProvider(t *testing.T) { router, dir := newReloadTestRouter(t) oldProxy := router.proxies[ClientOpenAI] now := time.Now() @@ -554,7 +554,7 @@ func TestReloadProviderConfigsLocked_DoesNotPreserveSuppressionStateWhenGlobalPr global := config.DefaultGlobalConfig() global.ListenAddr = "127.0.0.1" global.Port = 3333 - global.UpstreamProxyMode = config.ProviderProxyModeDirect + global.UpstreamProxyMode = config.GlobalUpstreamProxyModeDirect writeProxyReloadFixture(t, dir, global, config.ClientConfig{ Mode: config.ClientModeAuto, Providers: []config.Provider{ diff --git a/internal/web/api.go b/internal/web/api.go index fd66c6e..eb338d3 100644 --- a/internal/web/api.go +++ b/internal/web/api.go @@ -111,13 +111,12 @@ func (a *API) HandleUpdateGlobalConfig(w http.ResponseWriter, r *http.Request) { if strings.TrimSpace(req.ResponseHeaderTimeout) != "" { cfg.Global.ResponseHeaderTimeout = req.ResponseHeaderTimeout } - if strings.TrimSpace(req.UpstreamProxyMode) != "" { - cfg.Global.UpstreamProxyMode = config.ProviderProxyMode(strings.ToLower(strings.TrimSpace(req.UpstreamProxyMode))) - if cfg.Global.UpstreamProxyMode == config.ProviderProxyModeCustom { - cfg.Global.UpstreamProxyURL = strings.TrimSpace(req.UpstreamProxyURL) - } else { - cfg.Global.UpstreamProxyURL = "" - } + if err := config.ApplyUpstreamProxySettings(&cfg.Global, config.UpstreamProxySettingsPatch{ + Mode: req.UpstreamProxyMode, + URL: req.UpstreamProxyURL, + }); err != nil { + writeError(w, err.Error(), http.StatusBadRequest) + return } cfg.Global.MaxRequestBody = req.MaxRequestBodyBytes cfg.Global.LogDir = req.LogDir @@ -274,7 +273,6 @@ func (a *API) HandleAddProvider(w http.ResponseWriter, r *http.Request) { return } req.Overrides = normalizeProviderOverrideRequest(req.Overrides) - normalizeProviderProxyRequest(&req) if err := validateProviderOverrideRequest(clientType, req.Overrides); err != nil { writeError(w, err.Error(), http.StatusBadRequest) return @@ -341,7 +339,10 @@ func (a *API) HandleAddProvider(w http.ResponseWriter, r *http.Request) { Enabled: req.Enabled, } applyProviderOverrides(&provider, req) - if err := applyProviderProxySettings(&provider, req, false); err != nil { + if err := config.ApplyProviderProxySettings(&provider, config.ProviderProxySettingsPatch{ + Mode: req.ProxyMode, + URL: req.ProxyURL, + }, false); err != nil { writeError(w, err.Error(), http.StatusBadRequest) return } @@ -375,7 +376,6 @@ func (a *API) HandleUpdateProvider(w http.ResponseWriter, r *http.Request) { return } req.Overrides = normalizeProviderOverrideRequest(req.Overrides) - normalizeProviderProxyRequest(&req) if err := validateProviderOverrideRequest(clientType, req.Overrides); err != nil { writeError(w, err.Error(), http.StatusBadRequest) return @@ -981,14 +981,6 @@ func normalizeProviderKeys(req ProviderRequest) ([]string, error) { return keys, nil } -func lowerTrimStringPtr(v *string) *string { - if v == nil { - return nil - } - trimmed := strings.ToLower(strings.TrimSpace(*v)) - return &trimmed -} - func trimStringPtr(v *string) *string { if v == nil { return nil @@ -1024,14 +1016,6 @@ func providerOverrideSupportForClient(clientType string) providerOverrideSupport return support } -func normalizeProviderProxyRequest(req *ProviderRequest) { - if req == nil { - return - } - req.ProxyMode = lowerTrimStringPtr(req.ProxyMode) - req.ProxyURL = trimStringPtr(req.ProxyURL) -} - func normalizeProviderOverrideRequest(overrides *ProviderOverridesRequest) *ProviderOverridesRequest { if overrides == nil { return nil @@ -1081,68 +1065,6 @@ func validateProviderOverrideRequest(clientType string, overrides *ProviderOverr return nil } -func validateProviderProxyModeValue(mode string) error { - switch config.ProviderProxyMode(strings.TrimSpace(mode)) { - case config.ProviderProxyModeInherit, config.ProviderProxyModeDirect, config.ProviderProxyModeCustom: - return nil - default: - return fmt.Errorf("proxy_mode must be one of inherit, direct, custom") - } -} - -func validateProviderProxyURLValue(raw string) error { - return config.ValidateProxyURL(raw) -} - -func applyProviderProxySettings(provider *config.Provider, req ProviderRequest, isUpdate bool) error { - if provider == nil { - return nil - } - if !isUpdate && req.ProxyMode == nil && req.ProxyURL == nil { - provider.ProxyMode = config.ProviderProxyModeInherit - provider.ProxyURL = "" - return nil - } - if req.ProxyMode == nil { - if req.ProxyURL != nil { - return fmt.Errorf("proxy_url requires proxy_mode") - } - return nil - } - - modeValue := strings.TrimSpace(*req.ProxyMode) - if err := validateProviderProxyModeValue(modeValue); err != nil { - return err - } - - mode := config.ProviderProxyMode(modeValue) - proxyURL := provider.NormalizedProxyURL() - if req.ProxyURL != nil { - proxyURL = strings.TrimSpace(*req.ProxyURL) - } - switch mode { - case config.ProviderProxyModeInherit, config.ProviderProxyModeDirect: - if proxyURL != "" && req.ProxyURL != nil { - return fmt.Errorf("proxy_url requires proxy_mode custom") - } - provider.ProxyMode = mode - provider.ProxyURL = "" - return nil - case config.ProviderProxyModeCustom: - if proxyURL == "" { - return fmt.Errorf("proxy_url is required when proxy_mode=custom") - } - if err := validateProviderProxyURLValue(proxyURL); err != nil { - return err - } - provider.ProxyMode = mode - provider.ProxyURL = proxyURL - return nil - default: - return fmt.Errorf("proxy_mode must be one of inherit, direct, custom") - } -} - func applyProviderOverrides(provider *config.Provider, req ProviderRequest) { if provider == nil || req.Overrides == nil { return @@ -1222,7 +1144,10 @@ func updateProviderInList(providers []config.Provider, name string, req Provider providers[i].Enabled = req.Enabled } applyProviderOverrides(&providers[i], req) - if err := applyProviderProxySettings(&providers[i], req, true); err != nil { + if err := config.ApplyProviderProxySettings(&providers[i], config.ProviderProxySettingsPatch{ + Mode: req.ProxyMode, + URL: req.ProxyURL, + }, true); err != nil { return false, err } return true, nil diff --git a/internal/web/api_test.go b/internal/web/api_test.go index 3dcfdbb..af3e6cc 100644 --- a/internal/web/api_test.go +++ b/internal/web/api_test.go @@ -413,7 +413,7 @@ func TestHandleUpdateGlobalConfig_AcceptsSnakeCaseNotifications(t *testing.T) { if cfg.Global.Routing.BusyBackpressure.MaxInlineWait != "12s" { t.Fatalf("expected routing.busy_backpressure.max_inline_wait=12s, got %q", cfg.Global.Routing.BusyBackpressure.MaxInlineWait) } - if cfg.Global.NormalizedUpstreamProxyMode() != config.ProviderProxyModeCustom { + if cfg.Global.NormalizedUpstreamProxyMode() != config.GlobalUpstreamProxyModeCustom { t.Fatalf("expected upstream_proxy_mode=custom, got %q", cfg.Global.NormalizedUpstreamProxyMode()) } if cfg.Global.NormalizedUpstreamProxyURL() != proxyURL { @@ -493,7 +493,7 @@ func TestHandleUpdateGlobalConfig_AllowsClearingRoutingStrings(t *testing.T) { if cfg.Global.Routing.BusyBackpressure.MaxInlineWait != "" { t.Fatalf("expected routing.busy_backpressure.max_inline_wait to be cleared, got %q", cfg.Global.Routing.BusyBackpressure.MaxInlineWait) } - if cfg.Global.NormalizedUpstreamProxyMode() != config.ProviderProxyModeDirect { + if cfg.Global.NormalizedUpstreamProxyMode() != config.GlobalUpstreamProxyModeDirect { t.Fatalf("expected upstream_proxy_mode to be direct, got %q", cfg.Global.NormalizedUpstreamProxyMode()) } if cfg.Global.NormalizedUpstreamProxyURL() != "" { @@ -536,8 +536,8 @@ func TestHandleAddProvider_AcceptsAPIKeys(t *testing.T) { if cfg.OpenAI.Providers[0].APIKey != "" { t.Fatalf("expected multi-key provider to be persisted via api_keys") } - if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != config.ProviderProxyModeInherit { - t.Fatalf("proxy_mode = %q, want %q", got, config.ProviderProxyModeInherit) + if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != config.ProviderProxyModeDefault { + t.Fatalf("proxy_mode = %q, want %q", got, config.ProviderProxyModeDefault) } if got := cfg.OpenAI.Providers[0].ModelOverride(); got != "gpt-5.4" { t.Fatalf("model = %q", got) @@ -657,7 +657,7 @@ func TestHandleAddProvider_AcceptsCustomProxy(t *testing.T) { } func TestHandleAddProvider_RejectsProxyURLWithoutCustomMode(t *testing.T) { - for _, mode := range []string{"direct", "inherit"} { + for _, mode := range []string{"direct", "default"} { t.Run(mode, func(t *testing.T) { dir := t.TempDir() api := NewAPI(dir, "test", nil) @@ -764,7 +764,7 @@ providers: }) t.Run("reject proxy url without custom mode", func(t *testing.T) { - for _, mode := range []string{"direct", "inherit"} { + for _, mode := range []string{"direct", "default"} { t.Run(mode, func(t *testing.T) { dir := t.TempDir() if err := os.WriteFile(filepath.Join(dir, "openai.yaml"), []byte(` diff --git a/internal/web/static/app.js b/internal/web/static/app.js index cb0a86c..f9c9ac7 100644 --- a/internal/web/static/app.js +++ b/internal/web/static/app.js @@ -118,7 +118,7 @@ function app() { nameHint: 'Letters, numbers, dot (.), underscore (_), and hyphen (-).', baseUrl: 'Base URL *', proxyMode: 'Proxy Mode', - proxyModeInherit: 'Inherit Environment', + proxyModeDefault: 'Use Default', proxyModeDirect: 'Direct', proxyModeCustom: 'Custom Proxy', proxyUrl: 'Proxy URL', @@ -167,9 +167,9 @@ function app() { responseHeaderTimeoutHint: 'Set to 0 to wait indefinitely for headers.', upstreamProxyMode: 'Default Upstream Proxy', upstreamProxyUrl: 'Default Proxy URL', - upstreamProxyHint: 'Used by providers whose proxy mode is Inherit.', + upstreamProxyHint: 'Used by providers whose proxy mode is Default.', upstreamProxyUrlHelp: 'Supports http://, https://, socks5://, and socks5h:// proxy URLs.', - proxyModeInherit: 'Inherit Environment', + proxyModeEnvironment: 'Use Environment', proxyModeDirect: 'Direct', proxyModeCustom: 'Custom Proxy', upstreamProxyUrlHint: 'http://127.0.0.1:7890', @@ -427,7 +427,7 @@ function app() { nameHint: '允许字母、数字、点号 (.)、下划线 (_) 和连字符 (-)。', baseUrl: 'Base URL *', proxyMode: '代理模式', - proxyModeInherit: '继承环境变量', + proxyModeDefault: '使用默认值', proxyModeDirect: '直连', proxyModeCustom: '自定义代理', proxyUrl: '代理 URL', @@ -475,9 +475,9 @@ function app() { responseHeaderTimeoutHint: '设为 0 表示无限等待响应头。', upstreamProxyMode: '默认上游代理', upstreamProxyUrl: '默认代理 URL', - upstreamProxyHint: '对代理模式为“继承”的 Provider 生效。', + upstreamProxyHint: '对代理模式为“使用默认值”的 Provider 生效。', upstreamProxyUrlHelp: '支持 http://、https://、socks5:// 和 socks5h:// 代理 URL。', - proxyModeInherit: '继承环境变量', + proxyModeEnvironment: '使用环境变量', proxyModeDirect: '直连', proxyModeCustom: '自定义代理', upstreamProxyUrlHint: 'http://127.0.0.1:7890', @@ -650,7 +650,7 @@ function app() { reactivate_after: '', upstream_idle_timeout: '', response_header_timeout: '', - upstream_proxy_mode: 'inherit', + upstream_proxy_mode: 'environment', upstream_proxy_url: '', max_request_body_bytes: 0, log_dir: '', @@ -712,7 +712,7 @@ function app() { providerForm: { name: '', base_url: '', - proxy_mode: 'inherit', + proxy_mode: 'default', proxy_url: '', proxy_url_hint: '', model: '', @@ -1515,9 +1515,14 @@ function app() { return this.tf('modal.provider.keepExistingKeys', { count }); }, + normalizeGlobalProxyMode(mode) { + const value = String(mode || '').trim().toLowerCase(); + return value || 'environment'; + }, + normalizeProviderProxyMode(mode) { const value = String(mode || '').trim().toLowerCase(); - return ['inherit', 'direct', 'custom'].includes(value) ? value : 'inherit'; + return value || 'default'; }, providerProxySummary(provider) { @@ -1528,6 +1533,9 @@ function app() { if (mode === 'custom') { return String((provider && provider.proxy_url_hint) || '').trim() || this.t('providers.proxyCustom'); } + if (mode !== 'default') { + return mode; + } return ''; }, @@ -2071,7 +2079,7 @@ function app() { this.providerForm = { name: '', base_url: '', - proxy_mode: 'inherit', + proxy_mode: 'default', proxy_url: '', proxy_url_hint: '', model: '', @@ -2098,9 +2106,16 @@ function app() { async saveGlobalConfig() { try { + const payload = { + ...this.globalConfig, + upstream_proxy_mode: this.normalizeGlobalProxyMode(this.globalConfig.upstream_proxy_mode) + }; + if (payload.upstream_proxy_mode !== 'custom') { + payload.upstream_proxy_url = ''; + } await this.apiCall('/api/config/global/update', { method: 'PUT', - body: JSON.stringify(this.globalConfig) + body: JSON.stringify(payload) }); this.showAlert('success', this.t('settings.saveSuccess')); await this.refreshStatus(); @@ -2229,7 +2244,7 @@ function app() { this.providerForm = { name: '', base_url: '', - proxy_mode: 'inherit', + proxy_mode: 'default', proxy_url: '', proxy_url_hint: '', model: '', diff --git a/internal/web/static/app.test.js b/internal/web/static/app.test.js index 379441c..6bbad92 100644 --- a/internal/web/static/app.test.js +++ b/internal/web/static/app.test.js @@ -141,7 +141,7 @@ test('saveProvider includes OpenAI override fields in payload', async () => { assert.deepEqual(calls[0].options, { name: 'openai-primary', base_url: 'https://example.com', - proxy_mode: 'inherit', + proxy_mode: 'default', priority: 1, enabled: true, overrides: { @@ -193,7 +193,7 @@ test('saveProvider includes Claude thinking budget override in payload', async ( assert.deepEqual(calls[0].options, { name: 'claude-primary', base_url: 'https://example.com', - proxy_mode: 'inherit', + proxy_mode: 'default', priority: 2, enabled: false, overrides: { @@ -289,7 +289,7 @@ test('openAddProviderModal sets next priority for provider form', () => { assert.equal(state.showAddProviderModal, true); assert.equal(state.providerForm.priority, 4); - assert.equal(state.providerForm.proxy_mode, 'inherit'); + assert.equal(state.providerForm.proxy_mode, 'default'); }); test('editProvider hydrates override fields directly into the form', () => { @@ -367,7 +367,7 @@ test('saveProvider omits unsupported override fields for gemini', async () => { state.providerForm = { name: 'gemini-primary', base_url: 'https://example.com', - proxy_mode: 'inherit', + proxy_mode: 'default', proxy_url: '', proxy_url_hint: '', model: 'gemini-2.5-pro', @@ -393,9 +393,32 @@ test('saveProvider omits unsupported override fields for gemini', async () => { assert.deepEqual(calls[0].options, { name: 'gemini-primary', base_url: 'https://example.com', - proxy_mode: 'inherit', + proxy_mode: 'default', priority: 1, enabled: true, api_key: 'key-1' }); }); + +test('saveGlobalConfig normalizes and clears non-custom upstream proxy settings', async () => { + const state = loadApp(); + const calls = []; + state.globalConfig = { + ...state.globalConfig, + upstream_proxy_mode: 'DIRECT', + upstream_proxy_url: 'http://127.0.0.1:7890' + }; + state.apiCall = async (url, options) => { + calls.push({ url, options: JSON.parse(options.body) }); + return {}; + }; + state.showAlert = () => {}; + state.refreshStatus = async () => {}; + + await state.saveGlobalConfig(); + + assert.equal(calls.length, 1); + assert.equal(calls[0].url, '/api/config/global/update'); + assert.equal(calls[0].options.upstream_proxy_mode, 'direct'); + assert.equal(calls[0].options.upstream_proxy_url, ''); +}); diff --git a/internal/web/static/index.html b/internal/web/static/index.html index f35263a..5f1f255 100644 --- a/internal/web/static/index.html +++ b/internal/web/static/index.html @@ -215,7 +215,7 @@

+ x-show="normalizeProviderProxyMode(provider.proxy_mode) !== 'default'">
@@ -323,13 +323,13 @@

-
+
@@ -859,7 +859,7 @@

diff --git a/internal/web/types.go b/internal/web/types.go index 3c2401a..a64ae19 100644 --- a/internal/web/types.go +++ b/internal/web/types.go @@ -22,8 +22,8 @@ type GlobalConfigRequest struct { ReactivateAfter string `json:"reactivate_after"` UpstreamIdleTimeout string `json:"upstream_idle_timeout"` ResponseHeaderTimeout string `json:"response_header_timeout"` - UpstreamProxyMode string `json:"upstream_proxy_mode"` - UpstreamProxyURL string `json:"upstream_proxy_url"` + UpstreamProxyMode *string `json:"upstream_proxy_mode,omitempty"` + UpstreamProxyURL *string `json:"upstream_proxy_url,omitempty"` MaxRequestBodyBytes int64 `json:"max_request_body_bytes"` LogDir string `json:"log_dir"` LogRetentionDays int `json:"log_retention_days"` diff --git a/internal/web/yaml_format.go b/internal/web/yaml_format.go index 4b0f2ec..f24f8b0 100644 --- a/internal/web/yaml_format.go +++ b/internal/web/yaml_format.go @@ -62,7 +62,7 @@ func formatClientConfigYAML(clientType string, cc config.ClientConfig) []byte { writeBufferString(&b, fmt.Sprintf(" - name: %s\n", yamlDoubleQuote(p.Name))) writeBufferString(&b, fmt.Sprintf(" base_url: %s\n", yamlDoubleQuote(p.BaseURL))) - if p.NormalizedProxyMode() != config.ProviderProxyModeInherit { + if p.NormalizedProxyMode() != config.ProviderProxyModeDefault { writeBufferString(&b, fmt.Sprintf(" proxy_mode: %s\n", yamlDoubleQuote(string(p.NormalizedProxyMode())))) } if p.NormalizedProxyMode() == config.ProviderProxyModeCustom && p.NormalizedProxyURL() != "" { @@ -117,8 +117,8 @@ func formatGlobalConfigYAML(gc config.GlobalConfig) []byte { writeBufferString(&b, "# How long to wait for the upstream to return response headers.\n") writeBufferString(&b, "# Set to 0 to disable.\n") writeBufferString(&b, fmt.Sprintf("response_header_timeout: %s\n", yamlDoubleQuote(strings.TrimSpace(gc.ResponseHeaderTimeout)))) - writeBufferString(&b, "# Default upstream proxy for providers whose proxy_mode is inherit.\n") - writeBufferString(&b, fmt.Sprintf("upstream_proxy_mode: %s # inherit | direct | custom\n", yamlDoubleQuote(string(gc.NormalizedUpstreamProxyMode())))) + writeBufferString(&b, "# Default upstream proxy for providers whose proxy_mode is default.\n") + writeBufferString(&b, fmt.Sprintf("upstream_proxy_mode: %s # environment | direct | custom\n", yamlDoubleQuote(string(gc.NormalizedUpstreamProxyMode())))) writeBufferString(&b, "# Supported proxy URLs: http://, https://, socks5://, socks5h://\n") writeBufferString(&b, fmt.Sprintf("upstream_proxy_url: %s\n", yamlDoubleQuote(gc.NormalizedUpstreamProxyURL()))) writeBufferString(&b, "# Max request body size in bytes (clipal buffers request bodies for retries).\n") diff --git a/internal/web/yaml_format_test.go b/internal/web/yaml_format_test.go index 39415e3..06f9130 100644 --- a/internal/web/yaml_format_test.go +++ b/internal/web/yaml_format_test.go @@ -153,7 +153,7 @@ func TestFormatClientConfigYAML_SingleNormalizedKeyUsesAPIKeyField(t *testing.T) func TestFormatClientConfigYAML_WritesProxySettingsOnlyWhenNeeded(t *testing.T) { cc := config.ClientConfig{ Providers: []config.Provider{ - {Name: "inherit", BaseURL: "https://a.example", APIKey: "k1", Priority: 1, Enabled: boolPtr(true)}, + {Name: "default", BaseURL: "https://a.example", APIKey: "k1", Priority: 1, Enabled: boolPtr(true)}, {Name: "direct", BaseURL: "https://b.example", APIKey: "k2", ProxyMode: config.ProviderProxyModeDirect, Priority: 2, Enabled: boolPtr(true)}, {Name: "custom", BaseURL: "https://c.example", APIKey: "k3", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "http://127.0.0.1:7890", Priority: 3, Enabled: boolPtr(true)}, }, @@ -352,7 +352,7 @@ func TestFormatGlobalConfigYAML_RoundTripAndEscapesSpecialCharacters(t *testing. gc.LogDir = "logs\r\nfolder\tcontrol\x01" gc.LogRetentionDays = 7 gc.LogStdout = boolPtr(false) - gc.UpstreamProxyMode = config.ProviderProxyModeCustom + gc.UpstreamProxyMode = config.GlobalUpstreamProxyModeCustom gc.UpstreamProxyURL = "http://127.0.0.1:7890" gc.Notifications.ProviderSwitch = boolPtr(false) @@ -360,7 +360,7 @@ func TestFormatGlobalConfigYAML_RoundTripAndEscapesSpecialCharacters(t *testing. for _, want := range []string{ `listen_addr: "host\"quoted\"\\path"`, `log_dir: "logs\r\nfolder\tcontrol\x01"`, - `upstream_proxy_mode: "custom" # inherit | direct | custom`, + `upstream_proxy_mode: "custom" # environment | direct | custom`, `upstream_proxy_url: "http://127.0.0.1:7890"`, `log_retention_days: 7 # default 7 days`, `log_stdout: false`, @@ -385,7 +385,7 @@ func TestFormatGlobalConfigYAML_RoundTripAndEscapesSpecialCharacters(t *testing. if loaded.Global.LogDir != gc.LogDir { t.Fatalf("log_dir = %q, want %q", loaded.Global.LogDir, gc.LogDir) } - if loaded.Global.NormalizedUpstreamProxyMode() != config.ProviderProxyModeCustom { + if loaded.Global.NormalizedUpstreamProxyMode() != config.GlobalUpstreamProxyModeCustom { t.Fatalf("upstream_proxy_mode = %q, want custom", loaded.Global.NormalizedUpstreamProxyMode()) } if loaded.Global.NormalizedUpstreamProxyURL() != "http://127.0.0.1:7890" { From 463dc07413b0caf0befc049a1aeacc0659c41732 Mon Sep 17 00:00:00 2001 From: Thomas Date: Mon, 13 Apr 2026 02:14:53 +0800 Subject: [PATCH 3/3] refactor(proxy): canonicalize proxy URLs for policy key and identity comparison Add CanonicalProxyURL in config layer to normalize host case and strip default ports, so semantically equivalent proxy URLs produce identical policy keys. This ensures connection pool reuse and runtime state inheritance work correctly across reload even when URL text differs (e.g. host case, default port presence). --- internal/config/canon_edge_test.go | 17 +++++ internal/config/config.go | 12 ++++ internal/config/proxy_url.go | 74 ++++++++++++++++++++ internal/config/proxy_url_test.go | 102 ++++++++++++++++++++++++++++ internal/proxy/proxy.go | 18 ++--- internal/proxy/proxy_test.go | 4 +- internal/proxy/reload_state_test.go | 68 +++++++++++++++++++ 7 files changed, 285 insertions(+), 10 deletions(-) create mode 100644 internal/config/canon_edge_test.go create mode 100644 internal/config/proxy_url_test.go diff --git a/internal/config/canon_edge_test.go b/internal/config/canon_edge_test.go new file mode 100644 index 0000000..89bd554 --- /dev/null +++ b/internal/config/canon_edge_test.go @@ -0,0 +1,17 @@ +package config + +import "testing" + +func TestXXXCanonTrailingSlash(t *testing.T) { + cases := []struct{ in, want string }{ + {"http://proxy:8080/", "http://proxy:8080/"}, + {"http://proxy:8080", "http://proxy:8080"}, + {"http://user:pass@proxy:8080", "http://user:pass@proxy:8080"}, + {"http://USER:PASS@PROXY:8080", "http://USER:PASS@proxy:8080"}, + {"http://proxy:8080/path", "http://proxy:8080/path"}, + } + for _, c := range cases { + got := CanonicalProxyURL(c.in) + t.Logf("CanonicalProxyURL(%q) = %q (want %q, match=%v)", c.in, got, c.want, got == c.want) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 10b7469..44f5448 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -352,6 +352,12 @@ func (g GlobalConfig) NormalizedUpstreamProxyURL() string { return strings.TrimSpace(g.UpstreamProxyURL) } +// CanonicalUpstreamProxyURL returns a canonicalized upstream proxy URL suitable +// for policy key construction and identity comparison. +func (g GlobalConfig) CanonicalUpstreamProxyURL() string { + return CanonicalProxyURL(g.UpstreamProxyURL) +} + func (p Provider) NormalizedProxyMode() ProviderProxyMode { mode := strings.ToLower(strings.TrimSpace(string(p.ProxyMode))) if mode == "" { @@ -364,6 +370,12 @@ func (p Provider) NormalizedProxyURL() string { return strings.TrimSpace(p.ProxyURL) } +// CanonicalProxyURL returns a canonicalized provider proxy URL suitable +// for policy key construction and identity comparison. +func (p Provider) CanonicalProxyURL() string { + return CanonicalProxyURL(p.ProxyURL) +} + func ApplyUpstreamProxySettings(global *GlobalConfig, patch UpstreamProxySettingsPatch) error { if global == nil { return nil diff --git a/internal/config/proxy_url.go b/internal/config/proxy_url.go index dbbbc43..435ee5e 100644 --- a/internal/config/proxy_url.go +++ b/internal/config/proxy_url.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "net" "net/url" "strings" ) @@ -27,6 +28,79 @@ func ParseProxyURL(raw string) (*url.URL, error) { } } +// CanonicalProxyURL returns a canonical form of the given proxy URL. +// It trims whitespace, parses and validates the URL via ParseProxyURL, and +// normalizes the result so that semantically equivalent URLs +// (e.g. differing only in host case or default port) produce identical strings. +// If the URL is empty or cannot be parsed, the trimmed input is returned unchanged. +func CanonicalProxyURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + parsed, err := ParseProxyURL(trimmed) + if err != nil { + // Unparseable URLs are returned as-is. This is safe because + // ParseProxyURL would also reject them when building the HTTP client, + // so a malformed URL can never actually be used as a proxy. + return trimmed + } + parsed.Host = canonicalHost(parsed.Host, parsed.Scheme) + return parsed.String() +} + +// canonicalHost lowercases the hostname and strips the port when it matches +// the default port for the given scheme. +func canonicalHost(host, scheme string) string { + hostname, port, err := net.SplitHostPort(host) + if err != nil { + return canonicalHostWithoutPort(host) + } + hostname = canonicalHostname(hostname) + defaultPort := defaultPortForScheme(scheme) + if port == defaultPort { + if strings.Contains(hostname, ":") { + return "[" + hostname + "]" + } + return hostname + } + return net.JoinHostPort(hostname, port) +} + +// canonicalHostWithoutPort handles bare hosts (no port) from net.SplitHostPort +// failure. For bracketed IPv6 literals like "[::1]", it strips the brackets, +// lowercases the address, and re-wraps so the result stays valid. +func canonicalHostWithoutPort(host string) string { + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return "[" + canonicalHostname(host[1:len(host)-1]) + "]" + } + return canonicalHostname(host) +} + +// canonicalHostname lowercases the address or DNS name while preserving the +// case of any IPv6 zone identifier, which is interface-name text. +func canonicalHostname(hostname string) string { + if zoneSep := strings.Index(hostname, "%"); zoneSep >= 0 { + return strings.ToLower(hostname[:zoneSep]) + hostname[zoneSep:] + } + return strings.ToLower(hostname) +} + +// defaultPortForScheme returns the conventional default port for a proxy scheme. +// SOCKS proxies default to 1080; HTTP/HTTPS fall back to their standard ports. +func defaultPortForScheme(scheme string) string { + switch scheme { + case "http": + return "80" + case "https": + return "443" + case "socks5", "socks5h": + return "1080" + default: + return "" + } +} + // ValidateProxyURL reports whether a configured proxy URL is supported. func ValidateProxyURL(raw string) error { _, err := ParseProxyURL(raw) diff --git a/internal/config/proxy_url_test.go b/internal/config/proxy_url_test.go new file mode 100644 index 0000000..04702e4 --- /dev/null +++ b/internal/config/proxy_url_test.go @@ -0,0 +1,102 @@ +package config + +import ( + "testing" +) + +func TestCanonicalProxyURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want string + }{ + {name: "empty", input: "", want: ""}, + {name: "whitespace only", input: " ", want: ""}, + {name: "basic http", input: "http://proxy.example:8080", want: "http://proxy.example:8080"}, + {name: "trims whitespace", input: " http://proxy.example:8080 ", want: "http://proxy.example:8080"}, + {name: "lowercases host", input: "http://PROXY.EXAMPLE:8080", want: "http://proxy.example:8080"}, + {name: "mixed case host", input: "http://Proxy.Example:8080", want: "http://proxy.example:8080"}, + {name: "removes default http port", input: "http://proxy.example:80", want: "http://proxy.example"}, + {name: "removes default https port", input: "https://proxy.example:443", want: "https://proxy.example"}, + {name: "removes default socks5 port", input: "socks5://proxy.example:1080", want: "socks5://proxy.example"}, + {name: "preserves non-default port", input: "http://proxy.example:8080", want: "http://proxy.example:8080"}, + {name: "scheme already lowercase", input: "http://proxy.example:8080", want: "http://proxy.example:8080"}, + {name: "socks5h", input: "socks5h://proxy.example:1080", want: "socks5h://proxy.example"}, + {name: "socks5h non-default port", input: "socks5h://proxy.example:9050", want: "socks5h://proxy.example:9050"}, + {name: "lowercases scheme", input: "HTTP://proxy.example:8080", want: "http://proxy.example:8080"}, + {name: "lowercases scheme and host together", input: "HTTPS://PROXY.EXAMPLE:443", want: "https://proxy.example"}, + {name: "invalid scheme returns trimmed input", input: "ftp://proxy.example:21", want: "ftp://proxy.example:21"}, + {name: "equivalent URLs produce identical canonical form", input: " HTTP://PROXY.EXAMPLE:80 ", want: "http://proxy.example"}, + {name: "ip address host", input: "http://127.0.0.1:7890", want: "http://127.0.0.1:7890"}, + {name: "ipv6 with non-default port", input: "http://[::1]:8080", want: "http://[::1]:8080"}, + {name: "ipv6 strips default http port", input: "http://[::1]:80", want: "http://[::1]"}, + {name: "ipv6 strips default https port", input: "https://[::1]:443", want: "https://[::1]"}, + {name: "ipv6 without port", input: "http://[::1]", want: "http://[::1]"}, + {name: "ipv6 full address with default port", input: "http://[2001:db8::1]:80", want: "http://[2001:db8::1]"}, + {name: "ipv6 zone with non-default port preserves zone case", input: "http://[fe80::1%25En0]:8080", want: "http://[fe80::1%25En0]:8080"}, + {name: "ipv6 zone without port preserves zone case", input: "http://[fe80::1%25En0]", want: "http://[fe80::1%25En0]"}, + {name: "ipv6 zone strips default port without lowercasing zone", input: "http://[FE80::ABCD%25En0]:80", want: "http://[fe80::abcd%25En0]"}, + {name: "preserves userinfo case", input: "http://User:Pass@PROXY.EXAMPLE:8080", want: "http://User:Pass@proxy.example:8080"}, + {name: "userinfo with default port stripped", input: "http://user:pass@proxy.example:80", want: "http://user:pass@proxy.example"}, + {name: "preserves path", input: "http://proxy.example:8080/path", want: "http://proxy.example:8080/path"}, + {name: "preserves trailing slash", input: "http://proxy.example:8080/", want: "http://proxy.example:8080/"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := CanonicalProxyURL(tt.input) + if got != tt.want { + t.Errorf("CanonicalProxyURL(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestCanonicalProxyURL_Equivalence(t *testing.T) { + t.Parallel() + + pairs := []struct{ a, b string }{ + {"http://PROXY.example:8080", "http://proxy.example:8080"}, + {"http://proxy.example:80", "http://proxy.example"}, + {" http://proxy.example:8080 ", "http://proxy.example:8080"}, + {"HTTP://PROXY.EXAMPLE:80", "http://proxy.example"}, + {"http://user:pass@PROXY.EXAMPLE:80", "http://user:pass@proxy.example"}, + } + + for _, p := range pairs { + ca := CanonicalProxyURL(p.a) + cb := CanonicalProxyURL(p.b) + if ca != cb { + t.Errorf("CanonicalProxyURL(%q) = %q, CanonicalProxyURL(%q) = %q; want equal", p.a, ca, p.b, cb) + } + } +} + +func TestProvider_CanonicalProxyURL(t *testing.T) { + t.Parallel() + + p := Provider{ProxyURL: " HTTP://PROXY.EXAMPLE:80 "} + if got := p.CanonicalProxyURL(); got != "http://proxy.example" { + t.Errorf("Provider.CanonicalProxyURL() = %q, want %q", got, "http://proxy.example") + } + // Normalized accessor preserves original text + if got := p.NormalizedProxyURL(); got != "HTTP://PROXY.EXAMPLE:80" { + t.Errorf("Provider.NormalizedProxyURL() = %q, want %q", got, "HTTP://PROXY.EXAMPLE:80") + } +} + +func TestGlobalConfig_CanonicalUpstreamProxyURL(t *testing.T) { + t.Parallel() + + g := GlobalConfig{UpstreamProxyURL: " HTTP://PROXY.EXAMPLE:80 "} + if got := g.CanonicalUpstreamProxyURL(); got != "http://proxy.example" { + t.Errorf("GlobalConfig.CanonicalUpstreamProxyURL() = %q, want %q", got, "http://proxy.example") + } + // Normalized accessor preserves original text + if got := g.NormalizedUpstreamProxyURL(); got != "HTTP://PROXY.EXAMPLE:80" { + t.Errorf("GlobalConfig.NormalizedUpstreamProxyURL() = %q, want %q", got, "HTTP://PROXY.EXAMPLE:80") + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 4c9fa06..abd9c8b 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -220,19 +220,19 @@ func NewRouter(cfg *config.Config) *Router { // Initialize client proxies claudeProviders := config.GetEnabledProviders(cfg.Claude) if len(claudeProviders) > 0 { - r.proxies[ClientClaude] = newClientProxyWithGlobalProxy(ClientClaude, cfg.Claude.Mode, cfg.Claude.PinnedProvider, claudeProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, cfg.Global.NormalizedUpstreamProxyMode(), cfg.Global.NormalizedUpstreamProxyURL(), telemetryStore) + r.proxies[ClientClaude] = newClientProxyWithGlobalProxy(ClientClaude, cfg.Claude.Mode, cfg.Claude.PinnedProvider, claudeProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, cfg.Global.NormalizedUpstreamProxyMode(), cfg.Global.CanonicalUpstreamProxyURL(), telemetryStore) r.proxies[ClientClaude].applyRoutingRuntimeSettings(routingCfg) } codexProviders := config.GetEnabledProviders(cfg.OpenAI) if len(codexProviders) > 0 { - r.proxies[ClientOpenAI] = newClientProxyWithGlobalProxy(ClientOpenAI, cfg.OpenAI.Mode, cfg.OpenAI.PinnedProvider, codexProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, cfg.Global.NormalizedUpstreamProxyMode(), cfg.Global.NormalizedUpstreamProxyURL(), telemetryStore) + r.proxies[ClientOpenAI] = newClientProxyWithGlobalProxy(ClientOpenAI, cfg.OpenAI.Mode, cfg.OpenAI.PinnedProvider, codexProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, cfg.Global.NormalizedUpstreamProxyMode(), cfg.Global.CanonicalUpstreamProxyURL(), telemetryStore) r.proxies[ClientOpenAI].applyRoutingRuntimeSettings(routingCfg) } geminiProviders := config.GetEnabledProviders(cfg.Gemini) if len(geminiProviders) > 0 { - r.proxies[ClientGemini] = newClientProxyWithGlobalProxy(ClientGemini, cfg.Gemini.Mode, cfg.Gemini.PinnedProvider, geminiProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, cfg.Global.NormalizedUpstreamProxyMode(), cfg.Global.NormalizedUpstreamProxyURL(), telemetryStore) + r.proxies[ClientGemini] = newClientProxyWithGlobalProxy(ClientGemini, cfg.Gemini.Mode, cfg.Gemini.PinnedProvider, geminiProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, cfg.Global.NormalizedUpstreamProxyMode(), cfg.Global.CanonicalUpstreamProxyURL(), telemetryStore) r.proxies[ClientGemini].applyRoutingRuntimeSettings(routingCfg) } @@ -367,13 +367,13 @@ func effectiveProviderProxyPolicy(provider config.Provider, globalMode config.Gl case config.ProviderProxyModeDirect: return upstreamProxyPolicyKey{mode: upstreamProxyPolicyDirect} case config.ProviderProxyModeCustom: - return upstreamProxyPolicyKey{mode: upstreamProxyPolicyCustom, url: provider.NormalizedProxyURL()} + return upstreamProxyPolicyKey{mode: upstreamProxyPolicyCustom, url: provider.CanonicalProxyURL()} default: switch globalMode { case config.GlobalUpstreamProxyModeDirect: return upstreamProxyPolicyKey{mode: upstreamProxyPolicyDirect} case config.GlobalUpstreamProxyModeCustom: - return upstreamProxyPolicyKey{mode: upstreamProxyPolicyCustom, url: strings.TrimSpace(globalURL)} + return upstreamProxyPolicyKey{mode: upstreamProxyPolicyCustom, url: globalURL} default: return upstreamProxyPolicyKey{mode: upstreamProxyPolicyEnvironment} } @@ -708,16 +708,18 @@ func (r *Router) reloadProviderConfigsLocked() error { logger.Warn("invalid runtime durations; defaulting to reactivate_after=1h upstream_idle_timeout=3m response_header_timeout=2m: %v", err) } cbCfg := normalizeCircuitBreakerConfig(newCfg.Global.CircuitBreaker) + globalProxyMode := newCfg.Global.NormalizedUpstreamProxyMode() + globalProxyURL := newCfg.Global.CanonicalUpstreamProxyURL() newProxies := make(map[ClientType]*ClientProxy) if ps := config.GetEnabledProviders(newCfg.Claude); len(ps) > 0 { - newProxies[ClientClaude] = newReloadedClientProxy(ClientClaude, newCfg.Claude.Mode, newCfg.Claude.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), newCfg.Global.NormalizedUpstreamProxyMode(), newCfg.Global.NormalizedUpstreamProxyURL(), oldProxies[ClientClaude], r.telemetry) + newProxies[ClientClaude] = newReloadedClientProxy(ClientClaude, newCfg.Claude.Mode, newCfg.Claude.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), globalProxyMode, globalProxyURL, oldProxies[ClientClaude], r.telemetry) } if ps := config.GetEnabledProviders(newCfg.OpenAI); len(ps) > 0 { - newProxies[ClientOpenAI] = newReloadedClientProxy(ClientOpenAI, newCfg.OpenAI.Mode, newCfg.OpenAI.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), newCfg.Global.NormalizedUpstreamProxyMode(), newCfg.Global.NormalizedUpstreamProxyURL(), oldProxies[ClientOpenAI], r.telemetry) + newProxies[ClientOpenAI] = newReloadedClientProxy(ClientOpenAI, newCfg.OpenAI.Mode, newCfg.OpenAI.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), globalProxyMode, globalProxyURL, oldProxies[ClientOpenAI], r.telemetry) } if ps := config.GetEnabledProviders(newCfg.Gemini); len(ps) > 0 { - newProxies[ClientGemini] = newReloadedClientProxy(ClientGemini, newCfg.Gemini.Mode, newCfg.Gemini.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), newCfg.Global.NormalizedUpstreamProxyMode(), newCfg.Global.NormalizedUpstreamProxyURL(), oldProxies[ClientGemini], r.telemetry) + newProxies[ClientGemini] = newReloadedClientProxy(ClientGemini, newCfg.Gemini.Mode, newCfg.Gemini.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), globalProxyMode, globalProxyURL, oldProxies[ClientGemini], r.telemetry) } r.reconcileTelemetryUsage(oldCfg, newCfg) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 2bf016b..e5c3667 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -297,8 +297,8 @@ func TestNewClientProxy_UsesCustomSocksProxy(t *testing.T) { if err != nil { t.Fatalf("custom proxy: %v", err) } - if customProxy == nil || customProxy.String() != "socks5://custom-proxy:1080" { - t.Fatalf("custom proxy = %v, want socks5://custom-proxy:1080", customProxy) + if customProxy == nil || customProxy.String() != "socks5://custom-proxy" { + t.Fatalf("custom proxy = %v, want socks5://custom-proxy", customProxy) } } diff --git a/internal/proxy/reload_state_test.go b/internal/proxy/reload_state_test.go index 44dfa2c..48a3b8b 100644 --- a/internal/proxy/reload_state_test.go +++ b/internal/proxy/reload_state_test.go @@ -342,6 +342,74 @@ func TestReloadProviderConfigsLocked_PreservesRuntimeStateAcrossHarmlessReload(t } } +func TestReloadProviderConfigsLocked_PreservesRuntimeStateAcrossEquivalentGlobalCustomProxyForms(t *testing.T) { + dir := t.TempDir() + global := config.DefaultGlobalConfig() + global.ListenAddr = "127.0.0.1" + global.Port = 3333 + global.UpstreamProxyMode = config.GlobalUpstreamProxyModeCustom + global.UpstreamProxyURL = "socks5://proxy.example" + + codex := config.ClientConfig{ + Mode: config.ClientModeAuto, + Providers: []config.Provider{ + {Name: "p1", BaseURL: "https://p1.example", APIKey: "k1", Priority: 1}, + }, + } + writeProxyReloadFixture(t, dir, global, codex) + + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("config.Load: %v", err) + } + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + + router := NewRouter(cfg) + oldProxy := router.proxies[ClientOpenAI] + now := time.Now() + + oldProxy.deactivated[0] = providerDeactivation{ + at: now.Add(-time.Second), + until: now.Add(30 * time.Second), + reason: "rate_limit", + status: http.StatusTooManyRequests, + message: "slow down", + } + oldProxy.keyDeactivated[0][0] = providerDeactivation{ + at: now.Add(-time.Second), + until: now.Add(20 * time.Second), + reason: "rate_limit", + status: http.StatusTooManyRequests, + message: "key cooldown", + } + oldProxy.breakers[0].state = circuitOpen + oldProxy.breakers[0].openedAt = now.Add(-5 * time.Second) + + global.UpstreamProxyURL = "SOCKS5://PROXY.EXAMPLE:1080" + writeProxyReloadFixture(t, dir, global, codex) + + if err := router.reloadProviderConfigsLocked(); err != nil { + t.Fatalf("reloadProviderConfigsLocked: %v", err) + } + + newProxy := router.proxies[ClientOpenAI] + wantPolicy := upstreamProxyPolicyKey{mode: upstreamProxyPolicyCustom, url: "socks5://proxy.example"} + if newProxy.providerProxyPolicies[0] != wantPolicy { + t.Fatalf("proxy policy = %#v, want %#v", newProxy.providerProxyPolicies[0], wantPolicy) + } + if newProxy.deactivated[0].reason != "rate_limit" || newProxy.deactivated[0].message != "slow down" { + t.Fatalf("deactivation = %#v", newProxy.deactivated[0]) + } + if newProxy.keyDeactivated[0][0].message != "key cooldown" { + t.Fatalf("key deactivation = %#v", newProxy.keyDeactivated[0][0]) + } + if newProxy.breakers[0].state != circuitOpen { + t.Fatalf("breaker state = %s, want open", newProxy.breakers[0].state) + } +} + func TestReloadProviderConfigsLocked_DoesNotPreserveSuppressionStateWhenBaseURLChanges(t *testing.T) { router, dir := newReloadTestRouter(t) oldProxy := router.proxies[ClientOpenAI]