diff --git a/docs/en/config-reference.md b/docs/en/config-reference.md index bacce62..9ac33c4 100644 --- a/docs/en/config-reference.md +++ b/docs/en/config-reference.md @@ -60,6 +60,8 @@ providers: | `reactivate_after` | duration | `1h` | Auto-reactivation delay for temporarily deactivated providers; set `0` to disable temporary deactivation for auth, billing, and quota failures | | `upstream_idle_timeout` | duration | `3m` | Abort the current upstream attempt if no response body bytes arrive for too long | | `response_header_timeout` | duration | `2m` | Timeout while waiting for upstream response headers | +| `upstream_proxy_mode` | string | `environment` | Default upstream proxy mode for providers that use `proxy_mode: default`; `environment` / `direct` / `custom` | +| `upstream_proxy_url` | string | empty | Required when `upstream_proxy_mode: custom`; supports `http://`, `https://`, `socks5://`, and `socks5h://` proxy URLs | | `max_request_body_bytes` | int | `33554432` | Request body size limit, default 32 MiB | | `log_dir` | string | `/logs` | Log directory | | `log_retention_days` | int | `7` | Log retention days; `0` keeps logs forever; default is 7 days | @@ -179,6 +181,8 @@ providers: | `base_url` | string | yes | Upstream API base URL | | `api_key` | string | one of two | Single API key | | `api_keys` | array | one of two | Multiple API keys, used in order | +| `proxy_mode` | string | no | Upstream proxy mode for this provider; `default` follows the global default | +| `proxy_url` | string | no | Required when `proxy_mode: custom`; supports `http://`, `https://`, `socks5://`, and `socks5h://` proxy URLs | | `priority` | int | no | Lower number = higher priority; omitted or `0` is treated as `1` | | `enabled` | bool | no | Defaults to `true` | | `model` | string | no | Force this provider to use a specific upstream model name for supported OpenAI and Claude requests | @@ -189,6 +193,8 @@ providers: - Use `api_key` when you only have one key - Use `api_keys` when you want retries across multiple keys within the same provider +- Use global `upstream_proxy_mode` / `upstream_proxy_url` to define the default proxy for providers that use `proxy_mode: default` +- Use provider `proxy_mode: direct` to bypass both the global default proxy and environment proxy settings - Use `model` when different upstream providers expose the same family under different model IDs - Use `reasoning_effort` and `thinking_budget_tokens` only when you want Clipal to override the client-sent defaults for that provider - For long-running background setups, this is a good default: diff --git a/docs/zh/config-reference.md b/docs/zh/config-reference.md index dff57d7..f626dfb 100644 --- a/docs/zh/config-reference.md +++ b/docs/zh/config-reference.md @@ -60,6 +60,8 @@ providers: | `reactivate_after` | duration | `1h` | provider 临时禁用后的自动恢复时间;设为 `0` 可禁用基于鉴权、计费、额度错误的临时禁用 | | `upstream_idle_timeout` | duration | `3m` | 上游响应 body 长时间无字节时中断当前尝试 | | `response_header_timeout` | duration | `2m` | 等待上游响应头的超时 | +| `upstream_proxy_mode` | string | `environment` | 作为默认值应用到 `proxy_mode: default` 的 provider;可选 `environment` / `direct` / `custom` | +| `upstream_proxy_url` | string | 空 | 当 `upstream_proxy_mode: custom` 时必填;支持 `http://`、`https://`、`socks5://` 和 `socks5h://` 代理 URL | | `max_request_body_bytes` | int | `33554432` | 请求体大小上限,默认 32 MiB | | `log_dir` | string | `/logs` | 日志目录 | | `log_retention_days` | int | `7` | 日志保留天数;`0` 表示永久保留;默认保留 7 天 | @@ -179,6 +181,8 @@ providers: | `base_url` | string | 是 | 上游 API Base URL | | `api_key` | string | 二选一 | 单个 API Key | | `api_keys` | array | 二选一 | 多个 API Key,按顺序使用 | +| `proxy_mode` | string | 否 | 该 provider 的上游代理模式;`default` 表示使用全局默认代理 | +| `proxy_url` | string | 否 | 当 `proxy_mode: custom` 时必填;支持 `http://`、`https://`、`socks5://` 和 `socks5h://` 代理 URL | | `priority` | int | 否 | 数字越小优先级越高;省略或 `0` 时按 `1` 处理 | | `enabled` | bool | 否 | 是否启用,默认 `true` | | `model` | string | 否 | 对支持的 OpenAI / Claude 请求强制改写为这个上游模型名 | @@ -189,6 +193,8 @@ providers: - 只有一个 key 时用 `api_key` - 需要同 provider 多 key 轮转时用 `api_keys` +- 需要统一默认代理时,优先配置全局 `upstream_proxy_mode` / `upstream_proxy_url`,并让 provider 使用 `proxy_mode: default` +- 需要让某个 provider 绕过全局默认代理和环境代理时,用 `proxy_mode: direct` - 不同上游对同一模型族使用不同模型 ID 时,可为该 provider 配置 `model` - 只有在你希望 Clipal 按 provider 覆盖客户端默认思考参数时,才配置 `reasoning_effort` 或 `thinking_budget_tokens` - 常驻后台运行时,建议: diff --git a/internal/config/canon_edge_test.go b/internal/config/canon_edge_test.go new file mode 100644 index 0000000..89bd554 --- /dev/null +++ b/internal/config/canon_edge_test.go @@ -0,0 +1,17 @@ +package config + +import "testing" + +func TestXXXCanonTrailingSlash(t *testing.T) { + cases := []struct{ in, want string }{ + {"http://proxy:8080/", "http://proxy:8080/"}, + {"http://proxy:8080", "http://proxy:8080"}, + {"http://user:pass@proxy:8080", "http://user:pass@proxy:8080"}, + {"http://USER:PASS@PROXY:8080", "http://USER:PASS@proxy:8080"}, + {"http://proxy:8080/path", "http://proxy:8080/path"}, + } + for _, c := range cases { + got := CanonicalProxyURL(c.in) + t.Logf("CanonicalProxyURL(%q) = %q (want %q, match=%v)", c.in, got, c.want, got == c.want) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 34f21d9..44f5448 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -92,14 +92,16 @@ type GlobalConfig struct { UpstreamIdleTimeout string `yaml:"upstream_idle_timeout"` // ResponseHeaderTimeout controls how long we wait for the upstream to return // response headers after the request is fully written. Set to "0" to disable. - ResponseHeaderTimeout string `yaml:"response_header_timeout"` - MaxRequestBody int64 `yaml:"max_request_body_bytes"` - LogDir string `yaml:"log_dir"` - LogRetentionDays int `yaml:"log_retention_days"` - LogStdout *bool `yaml:"log_stdout"` - Notifications NotificationsConfig `yaml:"notifications"` - CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"` - Routing RoutingConfig `yaml:"routing"` + ResponseHeaderTimeout string `yaml:"response_header_timeout"` + UpstreamProxyMode GlobalUpstreamProxyMode `yaml:"upstream_proxy_mode,omitempty"` + UpstreamProxyURL string `yaml:"upstream_proxy_url,omitempty"` + MaxRequestBody int64 `yaml:"max_request_body_bytes"` + LogDir string `yaml:"log_dir"` + LogRetentionDays int `yaml:"log_retention_days"` + LogStdout *bool `yaml:"log_stdout"` + Notifications NotificationsConfig `yaml:"notifications"` + CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"` + Routing RoutingConfig `yaml:"routing"` // Deprecated: retained only so older config.yaml files still load under // strict KnownFields decoding. Runtime no longer reads this field. IgnoreCountTokensFailover bool `yaml:"ignore_count_tokens_failover"` @@ -133,11 +135,39 @@ type ClaudeOverrides struct { ThinkingBudgetTokens *int `yaml:"thinking_budget_tokens,omitempty"` } +type GlobalUpstreamProxyMode string + +const ( + GlobalUpstreamProxyModeEnvironment GlobalUpstreamProxyMode = "environment" + GlobalUpstreamProxyModeDirect GlobalUpstreamProxyMode = "direct" + GlobalUpstreamProxyModeCustom GlobalUpstreamProxyMode = "custom" +) + +type ProviderProxyMode string + +const ( + ProviderProxyModeDefault ProviderProxyMode = "default" + ProviderProxyModeDirect ProviderProxyMode = "direct" + ProviderProxyModeCustom ProviderProxyMode = "custom" +) + +type UpstreamProxySettingsPatch struct { + Mode *string + URL *string +} + +type ProviderProxySettingsPatch struct { + Mode *string + URL *string +} + type providerYAML struct { Name string `yaml:"name"` BaseURL string `yaml:"base_url"` APIKey string `yaml:"api_key,omitempty"` APIKeys []string `yaml:"api_keys,omitempty"` + ProxyMode ProviderProxyMode `yaml:"proxy_mode,omitempty"` + ProxyURL string `yaml:"proxy_url,omitempty"` Priority int `yaml:"priority"` Enabled *bool `yaml:"enabled,omitempty"` Overrides *ProviderOverrides `yaml:"overrides,omitempty"` @@ -152,6 +182,8 @@ type Provider struct { BaseURL string `yaml:"base_url"` APIKey string `yaml:"api_key,omitempty"` APIKeys []string `yaml:"api_keys,omitempty"` + ProxyMode ProviderProxyMode `yaml:"proxy_mode,omitempty"` + ProxyURL string `yaml:"proxy_url,omitempty"` Priority int `yaml:"priority"` Enabled *bool `yaml:"enabled,omitempty"` Overrides *ProviderOverrides `yaml:"-"` @@ -190,19 +222,32 @@ func (p *Provider) UnmarshalYAML(value *yaml.Node) error { BaseURL: raw.BaseURL, APIKey: raw.APIKey, APIKeys: append([]string(nil), raw.APIKeys...), + ProxyMode: raw.ProxyMode, + ProxyURL: raw.ProxyURL, Priority: raw.Priority, Enabled: raw.Enabled, Overrides: NormalizeProviderOverrides(overrides), } + NormalizeProviderProxySettings(p) return nil } func (p Provider) MarshalYAML() (any, error) { + proxyMode := p.NormalizedProxyMode() + proxyURL := p.NormalizedProxyURL() + if proxyMode == ProviderProxyModeDefault { + proxyMode = "" + } + if proxyMode != ProviderProxyModeCustom { + proxyURL = "" + } return providerYAML{ Name: p.Name, BaseURL: p.BaseURL, APIKey: p.APIKey, APIKeys: append([]string(nil), p.APIKeys...), + ProxyMode: proxyMode, + ProxyURL: proxyURL, Priority: p.Priority, Enabled: p.Enabled, Overrides: NormalizeProviderOverrides(p.Overrides), @@ -263,6 +308,215 @@ func NormalizeProviderOverrides(overrides *ProviderOverrides) *ProviderOverrides return &normalized } +func NormalizeProviderProxySettings(provider *Provider) { + if provider == nil { + return + } + provider.ProxyMode = provider.NormalizedProxyMode() + provider.ProxyURL = provider.NormalizedProxyURL() +} + +func NormalizeUpstreamProxySettings(global *GlobalConfig) { + if global == nil { + return + } + global.UpstreamProxyMode = global.NormalizedUpstreamProxyMode() + global.UpstreamProxyURL = global.NormalizedUpstreamProxyURL() +} + +func NormalizeUpstreamProxySettingsPatch(patch *UpstreamProxySettingsPatch) { + if patch == nil { + return + } + patch.Mode = normalizeLowerTrimStringPtr(patch.Mode) + patch.URL = normalizeTrimStringPtr(patch.URL) +} + +func NormalizeProviderProxySettingsPatch(patch *ProviderProxySettingsPatch) { + if patch == nil { + return + } + patch.Mode = normalizeLowerTrimStringPtr(patch.Mode) + patch.URL = normalizeTrimStringPtr(patch.URL) +} + +func (g GlobalConfig) NormalizedUpstreamProxyMode() GlobalUpstreamProxyMode { + mode := strings.ToLower(strings.TrimSpace(string(g.UpstreamProxyMode))) + if mode == "" { + return GlobalUpstreamProxyModeEnvironment + } + return GlobalUpstreamProxyMode(mode) +} + +func (g GlobalConfig) NormalizedUpstreamProxyURL() string { + return strings.TrimSpace(g.UpstreamProxyURL) +} + +// CanonicalUpstreamProxyURL returns a canonicalized upstream proxy URL suitable +// for policy key construction and identity comparison. +func (g GlobalConfig) CanonicalUpstreamProxyURL() string { + return CanonicalProxyURL(g.UpstreamProxyURL) +} + +func (p Provider) NormalizedProxyMode() ProviderProxyMode { + mode := strings.ToLower(strings.TrimSpace(string(p.ProxyMode))) + if mode == "" { + return ProviderProxyModeDefault + } + return ProviderProxyMode(mode) +} + +func (p Provider) NormalizedProxyURL() string { + return strings.TrimSpace(p.ProxyURL) +} + +// CanonicalProxyURL returns a canonicalized provider proxy URL suitable +// for policy key construction and identity comparison. +func (p Provider) CanonicalProxyURL() string { + return CanonicalProxyURL(p.ProxyURL) +} + +func ApplyUpstreamProxySettings(global *GlobalConfig, patch UpstreamProxySettingsPatch) error { + if global == nil { + return nil + } + NormalizeUpstreamProxySettingsPatch(&patch) + if patch.Mode == nil { + if patch.URL != nil { + return fmt.Errorf("upstream_proxy_url requires upstream_proxy_mode") + } + return nil + } + + mode := GlobalUpstreamProxyMode(*patch.Mode) + proxyURL := global.NormalizedUpstreamProxyURL() + if patch.URL != nil { + proxyURL = *patch.URL + } + + switch mode { + case GlobalUpstreamProxyModeEnvironment, GlobalUpstreamProxyModeDirect: + if patch.URL != nil && proxyURL != "" { + return fmt.Errorf("upstream_proxy_url requires upstream_proxy_mode custom") + } + global.UpstreamProxyMode = mode + global.UpstreamProxyURL = "" + return nil + case GlobalUpstreamProxyModeCustom: + if proxyURL == "" { + return fmt.Errorf("upstream_proxy_url is required when upstream_proxy_mode=custom") + } + if err := ValidateProxyURL(proxyURL); err != nil { + return err + } + global.UpstreamProxyMode = mode + global.UpstreamProxyURL = proxyURL + return nil + default: + return fmt.Errorf("upstream_proxy_mode must be one of environment, direct, custom") + } +} + +func ApplyProviderProxySettings(provider *Provider, patch ProviderProxySettingsPatch, isUpdate bool) error { + if provider == nil { + return nil + } + NormalizeProviderProxySettingsPatch(&patch) + if !isUpdate && patch.Mode == nil && patch.URL == nil { + provider.ProxyMode = ProviderProxyModeDefault + provider.ProxyURL = "" + return nil + } + if patch.Mode == nil { + if patch.URL != nil { + return fmt.Errorf("proxy_url requires proxy_mode") + } + return nil + } + + mode := ProviderProxyMode(*patch.Mode) + proxyURL := provider.NormalizedProxyURL() + if patch.URL != nil { + proxyURL = *patch.URL + } + + switch mode { + case ProviderProxyModeDefault, ProviderProxyModeDirect: + if patch.URL != nil && proxyURL != "" { + return fmt.Errorf("proxy_url requires proxy_mode custom") + } + provider.ProxyMode = mode + provider.ProxyURL = "" + return nil + case ProviderProxyModeCustom: + if proxyURL == "" { + return fmt.Errorf("proxy_url is required when proxy_mode=custom") + } + if err := ValidateProxyURL(proxyURL); err != nil { + return err + } + provider.ProxyMode = mode + provider.ProxyURL = proxyURL + return nil + default: + return fmt.Errorf("proxy_mode must be one of default, direct, custom") + } +} + +func validateGlobalProxySettings(scope string, mode GlobalUpstreamProxyMode, rawURL string) error { + switch mode { + case GlobalUpstreamProxyModeEnvironment, GlobalUpstreamProxyModeDirect: + if rawURL != "" { + return fmt.Errorf("%s: upstream_proxy_url requires upstream_proxy_mode custom", scope) + } + case GlobalUpstreamProxyModeCustom: + if rawURL == "" { + return fmt.Errorf("%s: upstream_proxy_url is required when upstream_proxy_mode=custom", scope) + } + if err := ValidateProxyURL(rawURL); err != nil { + return fmt.Errorf("%s: %w", scope, err) + } + default: + return fmt.Errorf("%s: invalid upstream_proxy_mode %q", scope, mode) + } + return nil +} + +func validateProviderProxySettings(scope string, mode ProviderProxyMode, rawURL string) error { + switch mode { + case ProviderProxyModeDefault, ProviderProxyModeDirect: + if rawURL != "" { + return fmt.Errorf("%s: proxy_url requires proxy_mode custom", scope) + } + case ProviderProxyModeCustom: + if rawURL == "" { + return fmt.Errorf("%s: proxy_url is required when proxy_mode=custom", scope) + } + if err := ValidateProxyURL(rawURL); err != nil { + return fmt.Errorf("%s: %w", scope, err) + } + default: + return fmt.Errorf("%s: invalid proxy_mode %q", scope, mode) + } + return nil +} + +func normalizeLowerTrimStringPtr(v *string) *string { + if v == nil { + return nil + } + trimmed := strings.ToLower(strings.TrimSpace(*v)) + return &trimmed +} + +func normalizeTrimStringPtr(v *string) *string { + if v == nil { + return nil + } + trimmed := strings.TrimSpace(*v) + return &trimmed +} + // IsEnabled returns whether the provider is enabled (default true) func (p *Provider) IsEnabled() bool { if p.Enabled == nil { @@ -363,6 +617,8 @@ func DefaultGlobalConfig() GlobalConfig { ReactivateAfter: "1h", UpstreamIdleTimeout: "3m", ResponseHeaderTimeout: "2m", + UpstreamProxyMode: GlobalUpstreamProxyModeEnvironment, + UpstreamProxyURL: "", // Default body limit: 32 MiB. clipal buffers request bodies to support retries, // so a hard cap prevents unbounded memory usage. MaxRequestBody: 32 * 1024 * 1024, @@ -422,6 +678,7 @@ func Load(configDir string) (*Config, error) { if err := loadYAML(globalPath, &cfg.Global); err != nil && !os.IsNotExist(err) { return nil, fmt.Errorf("failed to load global config: %w", err) } + NormalizeUpstreamProxySettings(&cfg.Global) if err := migrateLegacyClientConfigFiles(configDir); err != nil { return nil, err @@ -602,6 +859,7 @@ func applyClientDefaults(cc *ClientConfig) { } cc.Providers[i].APIKey = strings.TrimSpace(cc.Providers[i].APIKey) cc.Providers[i].APIKeys = cc.Providers[i].NormalizedAPIKeys() + NormalizeProviderProxySettings(&cc.Providers[i]) cc.Providers[i].Overrides = NormalizeProviderOverrides(cc.Providers[i].Overrides) if len(cc.Providers[i].APIKeys) == 1 { cc.Providers[i].APIKey = cc.Providers[i].APIKeys[0] @@ -700,6 +958,9 @@ func (c *Config) Validate() error { if c.Global.LogRetentionDays < 0 { return fmt.Errorf("invalid log_retention_days: %d", c.Global.LogRetentionDays) } + if err := validateGlobalProxySettings("global upstream proxy", c.Global.NormalizedUpstreamProxyMode(), c.Global.NormalizedUpstreamProxyURL()); err != nil { + return err + } // Circuit breaker: // - failure_threshold == 0 disables the circuit breaker entirely. @@ -809,6 +1070,9 @@ func validateProviders(clientName string, providers []Provider) error { if p.Priority < 1 { return fmt.Errorf("%s provider %s: priority must be >= 1", clientName, p.Name) } + if err := validateProviderProxySettings(fmt.Sprintf("%s provider %s", clientName, p.Name), p.NormalizedProxyMode(), p.NormalizedProxyURL()); err != nil { + return err + } if !providerOverridesSupportedForClient(clientName, p.Overrides) { return fmt.Errorf("%s provider %s: unsupported overrides for client", clientName, p.Name) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index d40d25e..1c01c2e 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -705,3 +705,259 @@ func TestGetConfigDir_RespectsEnvironmentOverride(t *testing.T) { t.Fatalf("GetConfigDir = %q, want %q", got, want) } } + +func TestLoad_ProviderProxyModeDefaultsToDefault(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + writeClientConfigFile(t, dir, "openai.yaml", ` +providers: + - name: p1 + base_url: https://example.com + api_key: key + priority: 1 +`) + + cfg, err := Load(dir) + if err != nil { + t.Fatalf("Load: %v", err) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != ProviderProxyModeDefault { + t.Fatalf("proxy mode = %q, want %q", got, ProviderProxyModeDefault) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyURL(); got != "" { + t.Fatalf("proxy url = %q, want empty", got) + } +} + +func TestLoad_GlobalUpstreamProxyModeDefaultsToEnvironment(t *testing.T) { + t.Parallel() + + cfg, err := Load(t.TempDir()) + if err != nil { + t.Fatalf("Load: %v", err) + } + if got := cfg.Global.NormalizedUpstreamProxyMode(); got != GlobalUpstreamProxyModeEnvironment { + t.Fatalf("upstream proxy mode = %q, want %q", got, GlobalUpstreamProxyModeEnvironment) + } +} + +func TestValidate_ProviderProxySettings(t *testing.T) { + t.Parallel() + + base := &Config{ + Global: DefaultGlobalConfig(), + Claude: ClientConfig{Mode: ClientModeAuto}, + OpenAI: ClientConfig{Mode: ClientModeAuto}, + Gemini: ClientConfig{Mode: ClientModeAuto}, + } + + makeProvider := func(mode ProviderProxyMode, proxyURL string) Provider { + return Provider{ + Name: "p1", + BaseURL: "https://example.com", + APIKey: "key", + ProxyMode: mode, + ProxyURL: proxyURL, + Priority: 1, + } + } + + t.Run("accepts direct", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeDirect, "")} + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + }) + + t.Run("accepts custom http proxy", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeCustom, "http://127.0.0.1:7890")} + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + }) + + t.Run("accepts custom socks5 proxy", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeCustom, "socks5://127.0.0.1:1080")} + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + }) + + t.Run("accepts custom socks5h proxy", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeCustom, "socks5h://127.0.0.1:1080")} + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + }) + + t.Run("rejects custom proxy without url", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeCustom, "")} + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), "proxy_url is required") { + t.Fatalf("Validate err = %v", err) + } + }) + + t.Run("rejects unsupported proxy scheme", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeCustom, "ftp://127.0.0.1:21")} + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), "proxy_url scheme must be http, https, socks5, or socks5h") { + t.Fatalf("Validate err = %v", err) + } + }) + + t.Run("rejects proxy url without custom mode", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyModeDirect, "http://127.0.0.1:7890")} + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), "proxy_url requires proxy_mode custom") { + t.Fatalf("Validate err = %v", err) + } + }) + + t.Run("rejects removed inherit mode", func(t *testing.T) { + cfg := *base + cfg.OpenAI.Providers = []Provider{makeProvider(ProviderProxyMode("inherit"), "")} + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), `invalid proxy_mode "inherit"`) { + t.Fatalf("Validate err = %v", err) + } + }) +} + +func TestValidate_GlobalUpstreamProxySettings(t *testing.T) { + t.Parallel() + + cfg := &Config{ + Global: DefaultGlobalConfig(), + Claude: ClientConfig{Mode: ClientModeAuto}, + OpenAI: ClientConfig{Mode: ClientModeAuto}, + Gemini: ClientConfig{Mode: ClientModeAuto}, + } + + cfg.Global.UpstreamProxyMode = GlobalUpstreamProxyModeCustom + cfg.Global.UpstreamProxyURL = "http://127.0.0.1:7890" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + + cfg.Global.UpstreamProxyURL = "socks5://127.0.0.1:1080" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + + cfg.Global.UpstreamProxyURL = "socks5h://127.0.0.1:1080" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + + cfg.Global.UpstreamProxyURL = "ftp://127.0.0.1:21" + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), "proxy_url scheme must be http, https, socks5, or socks5h") { + t.Fatalf("Validate err = %v", err) + } + + cfg.Global.UpstreamProxyMode = GlobalUpstreamProxyMode("inherit") + cfg.Global.UpstreamProxyURL = "" + if err := cfg.Validate(); err == nil || !strings.Contains(err.Error(), `invalid upstream_proxy_mode "inherit"`) { + t.Fatalf("Validate err = %v", err) + } +} + +func TestApplyProviderProxySettings(t *testing.T) { + t.Parallel() + + t.Run("create defaults to default mode", func(t *testing.T) { + var provider Provider + if err := ApplyProviderProxySettings(&provider, ProviderProxySettingsPatch{}, false); err != nil { + t.Fatalf("ApplyProviderProxySettings: %v", err) + } + if provider.ProxyMode != ProviderProxyModeDefault || provider.ProxyURL != "" { + t.Fatalf("provider = %#v", provider) + } + }) + + t.Run("update retains existing custom url", func(t *testing.T) { + provider := Provider{ProxyMode: ProviderProxyModeCustom, ProxyURL: "http://127.0.0.1:7890"} + mode := "custom" + if err := ApplyProviderProxySettings(&provider, ProviderProxySettingsPatch{Mode: &mode}, true); err != nil { + t.Fatalf("ApplyProviderProxySettings: %v", err) + } + if provider.ProxyMode != ProviderProxyModeCustom || provider.ProxyURL != "http://127.0.0.1:7890" { + t.Fatalf("provider = %#v", provider) + } + }) + + t.Run("update clears url on direct", func(t *testing.T) { + provider := Provider{ProxyMode: ProviderProxyModeCustom, ProxyURL: "http://127.0.0.1:7890"} + mode := "direct" + if err := ApplyProviderProxySettings(&provider, ProviderProxySettingsPatch{Mode: &mode}, true); err != nil { + t.Fatalf("ApplyProviderProxySettings: %v", err) + } + if provider.ProxyMode != ProviderProxyModeDirect || provider.ProxyURL != "" { + t.Fatalf("provider = %#v", provider) + } + }) + + t.Run("rejects default with url", func(t *testing.T) { + mode := "default" + rawURL := "http://127.0.0.1:7890" + if err := ApplyProviderProxySettings(&Provider{}, ProviderProxySettingsPatch{Mode: &mode, URL: &rawURL}, false); err == nil || !strings.Contains(err.Error(), "proxy_url requires proxy_mode custom") { + t.Fatalf("ApplyProviderProxySettings err = %v", err) + } + }) + + t.Run("rejects removed inherit mode", func(t *testing.T) { + mode := "inherit" + if err := ApplyProviderProxySettings(&Provider{}, ProviderProxySettingsPatch{Mode: &mode}, false); err == nil || !strings.Contains(err.Error(), "proxy_mode must be one of default, direct, custom") { + t.Fatalf("ApplyProviderProxySettings err = %v", err) + } + }) +} + +func TestApplyUpstreamProxySettings(t *testing.T) { + t.Parallel() + + t.Run("omitted patch preserves existing values", func(t *testing.T) { + global := GlobalConfig{ + UpstreamProxyMode: GlobalUpstreamProxyModeCustom, + UpstreamProxyURL: "http://127.0.0.1:7890", + } + if err := ApplyUpstreamProxySettings(&global, UpstreamProxySettingsPatch{}); err != nil { + t.Fatalf("ApplyUpstreamProxySettings: %v", err) + } + if global.UpstreamProxyMode != GlobalUpstreamProxyModeCustom || global.UpstreamProxyURL != "http://127.0.0.1:7890" { + t.Fatalf("global = %#v", global) + } + }) + + t.Run("switches to environment and clears url", func(t *testing.T) { + global := GlobalConfig{ + UpstreamProxyMode: GlobalUpstreamProxyModeCustom, + UpstreamProxyURL: "http://127.0.0.1:7890", + } + mode := "environment" + if err := ApplyUpstreamProxySettings(&global, UpstreamProxySettingsPatch{Mode: &mode}); err != nil { + t.Fatalf("ApplyUpstreamProxySettings: %v", err) + } + if global.UpstreamProxyMode != GlobalUpstreamProxyModeEnvironment || global.UpstreamProxyURL != "" { + t.Fatalf("global = %#v", global) + } + }) + + t.Run("rejects url without mode", func(t *testing.T) { + rawURL := "http://127.0.0.1:7890" + if err := ApplyUpstreamProxySettings(&GlobalConfig{}, UpstreamProxySettingsPatch{URL: &rawURL}); err == nil || !strings.Contains(err.Error(), "upstream_proxy_url requires upstream_proxy_mode") { + t.Fatalf("ApplyUpstreamProxySettings err = %v", err) + } + }) + + t.Run("rejects removed inherit mode", func(t *testing.T) { + mode := "inherit" + if err := ApplyUpstreamProxySettings(&GlobalConfig{}, UpstreamProxySettingsPatch{Mode: &mode}); err == nil || !strings.Contains(err.Error(), "upstream_proxy_mode must be one of environment, direct, custom") { + t.Fatalf("ApplyUpstreamProxySettings err = %v", err) + } + }) +} diff --git a/internal/config/proxy_url.go b/internal/config/proxy_url.go new file mode 100644 index 0000000..435ee5e --- /dev/null +++ b/internal/config/proxy_url.go @@ -0,0 +1,108 @@ +package config + +import ( + "fmt" + "net" + "net/url" + "strings" +) + +const ( + supportedProxyURLSchemeList = "http, https, socks5, or socks5h" + supportedProxyURLPrefixList = "http://, https://, socks5://, or socks5h://" +) + +// ParseProxyURL validates a configured proxy URL and normalizes its scheme. +func ParseProxyURL(raw string) (*url.URL, error) { + parsed, err := url.Parse(strings.TrimSpace(raw)) + if err != nil || !parsed.IsAbs() || parsed.Host == "" { + return nil, fmt.Errorf("proxy_url must be an absolute %s URL", supportedProxyURLPrefixList) + } + + parsed.Scheme = strings.ToLower(parsed.Scheme) + switch parsed.Scheme { + case "http", "https", "socks5", "socks5h": + return parsed, nil + default: + return nil, fmt.Errorf("proxy_url scheme must be %s", supportedProxyURLSchemeList) + } +} + +// CanonicalProxyURL returns a canonical form of the given proxy URL. +// It trims whitespace, parses and validates the URL via ParseProxyURL, and +// normalizes the result so that semantically equivalent URLs +// (e.g. differing only in host case or default port) produce identical strings. +// If the URL is empty or cannot be parsed, the trimmed input is returned unchanged. +func CanonicalProxyURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + parsed, err := ParseProxyURL(trimmed) + if err != nil { + // Unparseable URLs are returned as-is. This is safe because + // ParseProxyURL would also reject them when building the HTTP client, + // so a malformed URL can never actually be used as a proxy. + return trimmed + } + parsed.Host = canonicalHost(parsed.Host, parsed.Scheme) + return parsed.String() +} + +// canonicalHost lowercases the hostname and strips the port when it matches +// the default port for the given scheme. +func canonicalHost(host, scheme string) string { + hostname, port, err := net.SplitHostPort(host) + if err != nil { + return canonicalHostWithoutPort(host) + } + hostname = canonicalHostname(hostname) + defaultPort := defaultPortForScheme(scheme) + if port == defaultPort { + if strings.Contains(hostname, ":") { + return "[" + hostname + "]" + } + return hostname + } + return net.JoinHostPort(hostname, port) +} + +// canonicalHostWithoutPort handles bare hosts (no port) from net.SplitHostPort +// failure. For bracketed IPv6 literals like "[::1]", it strips the brackets, +// lowercases the address, and re-wraps so the result stays valid. +func canonicalHostWithoutPort(host string) string { + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return "[" + canonicalHostname(host[1:len(host)-1]) + "]" + } + return canonicalHostname(host) +} + +// canonicalHostname lowercases the address or DNS name while preserving the +// case of any IPv6 zone identifier, which is interface-name text. +func canonicalHostname(hostname string) string { + if zoneSep := strings.Index(hostname, "%"); zoneSep >= 0 { + return strings.ToLower(hostname[:zoneSep]) + hostname[zoneSep:] + } + return strings.ToLower(hostname) +} + +// defaultPortForScheme returns the conventional default port for a proxy scheme. +// SOCKS proxies default to 1080; HTTP/HTTPS fall back to their standard ports. +func defaultPortForScheme(scheme string) string { + switch scheme { + case "http": + return "80" + case "https": + return "443" + case "socks5", "socks5h": + return "1080" + default: + return "" + } +} + +// ValidateProxyURL reports whether a configured proxy URL is supported. +func ValidateProxyURL(raw string) error { + _, err := ParseProxyURL(raw) + return err +} diff --git a/internal/config/proxy_url_test.go b/internal/config/proxy_url_test.go new file mode 100644 index 0000000..04702e4 --- /dev/null +++ b/internal/config/proxy_url_test.go @@ -0,0 +1,102 @@ +package config + +import ( + "testing" +) + +func TestCanonicalProxyURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want string + }{ + {name: "empty", input: "", want: ""}, + {name: "whitespace only", input: " ", want: ""}, + {name: "basic http", input: "http://proxy.example:8080", want: "http://proxy.example:8080"}, + {name: "trims whitespace", input: " http://proxy.example:8080 ", want: "http://proxy.example:8080"}, + {name: "lowercases host", input: "http://PROXY.EXAMPLE:8080", want: "http://proxy.example:8080"}, + {name: "mixed case host", input: "http://Proxy.Example:8080", want: "http://proxy.example:8080"}, + {name: "removes default http port", input: "http://proxy.example:80", want: "http://proxy.example"}, + {name: "removes default https port", input: "https://proxy.example:443", want: "https://proxy.example"}, + {name: "removes default socks5 port", input: "socks5://proxy.example:1080", want: "socks5://proxy.example"}, + {name: "preserves non-default port", input: "http://proxy.example:8080", want: "http://proxy.example:8080"}, + {name: "scheme already lowercase", input: "http://proxy.example:8080", want: "http://proxy.example:8080"}, + {name: "socks5h", input: "socks5h://proxy.example:1080", want: "socks5h://proxy.example"}, + {name: "socks5h non-default port", input: "socks5h://proxy.example:9050", want: "socks5h://proxy.example:9050"}, + {name: "lowercases scheme", input: "HTTP://proxy.example:8080", want: "http://proxy.example:8080"}, + {name: "lowercases scheme and host together", input: "HTTPS://PROXY.EXAMPLE:443", want: "https://proxy.example"}, + {name: "invalid scheme returns trimmed input", input: "ftp://proxy.example:21", want: "ftp://proxy.example:21"}, + {name: "equivalent URLs produce identical canonical form", input: " HTTP://PROXY.EXAMPLE:80 ", want: "http://proxy.example"}, + {name: "ip address host", input: "http://127.0.0.1:7890", want: "http://127.0.0.1:7890"}, + {name: "ipv6 with non-default port", input: "http://[::1]:8080", want: "http://[::1]:8080"}, + {name: "ipv6 strips default http port", input: "http://[::1]:80", want: "http://[::1]"}, + {name: "ipv6 strips default https port", input: "https://[::1]:443", want: "https://[::1]"}, + {name: "ipv6 without port", input: "http://[::1]", want: "http://[::1]"}, + {name: "ipv6 full address with default port", input: "http://[2001:db8::1]:80", want: "http://[2001:db8::1]"}, + {name: "ipv6 zone with non-default port preserves zone case", input: "http://[fe80::1%25En0]:8080", want: "http://[fe80::1%25En0]:8080"}, + {name: "ipv6 zone without port preserves zone case", input: "http://[fe80::1%25En0]", want: "http://[fe80::1%25En0]"}, + {name: "ipv6 zone strips default port without lowercasing zone", input: "http://[FE80::ABCD%25En0]:80", want: "http://[fe80::abcd%25En0]"}, + {name: "preserves userinfo case", input: "http://User:Pass@PROXY.EXAMPLE:8080", want: "http://User:Pass@proxy.example:8080"}, + {name: "userinfo with default port stripped", input: "http://user:pass@proxy.example:80", want: "http://user:pass@proxy.example"}, + {name: "preserves path", input: "http://proxy.example:8080/path", want: "http://proxy.example:8080/path"}, + {name: "preserves trailing slash", input: "http://proxy.example:8080/", want: "http://proxy.example:8080/"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := CanonicalProxyURL(tt.input) + if got != tt.want { + t.Errorf("CanonicalProxyURL(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestCanonicalProxyURL_Equivalence(t *testing.T) { + t.Parallel() + + pairs := []struct{ a, b string }{ + {"http://PROXY.example:8080", "http://proxy.example:8080"}, + {"http://proxy.example:80", "http://proxy.example"}, + {" http://proxy.example:8080 ", "http://proxy.example:8080"}, + {"HTTP://PROXY.EXAMPLE:80", "http://proxy.example"}, + {"http://user:pass@PROXY.EXAMPLE:80", "http://user:pass@proxy.example"}, + } + + for _, p := range pairs { + ca := CanonicalProxyURL(p.a) + cb := CanonicalProxyURL(p.b) + if ca != cb { + t.Errorf("CanonicalProxyURL(%q) = %q, CanonicalProxyURL(%q) = %q; want equal", p.a, ca, p.b, cb) + } + } +} + +func TestProvider_CanonicalProxyURL(t *testing.T) { + t.Parallel() + + p := Provider{ProxyURL: " HTTP://PROXY.EXAMPLE:80 "} + if got := p.CanonicalProxyURL(); got != "http://proxy.example" { + t.Errorf("Provider.CanonicalProxyURL() = %q, want %q", got, "http://proxy.example") + } + // Normalized accessor preserves original text + if got := p.NormalizedProxyURL(); got != "HTTP://PROXY.EXAMPLE:80" { + t.Errorf("Provider.NormalizedProxyURL() = %q, want %q", got, "HTTP://PROXY.EXAMPLE:80") + } +} + +func TestGlobalConfig_CanonicalUpstreamProxyURL(t *testing.T) { + t.Parallel() + + g := GlobalConfig{UpstreamProxyURL: " HTTP://PROXY.EXAMPLE:80 "} + if got := g.CanonicalUpstreamProxyURL(); got != "http://proxy.example" { + t.Errorf("GlobalConfig.CanonicalUpstreamProxyURL() = %q, want %q", got, "http://proxy.example") + } + // Normalized accessor preserves original text + if got := g.NormalizedUpstreamProxyURL(); got != "HTTP://PROXY.EXAMPLE:80" { + t.Errorf("GlobalConfig.NormalizedUpstreamProxyURL() = %q, want %q", got, "HTTP://PROXY.EXAMPLE:80") + } +} diff --git a/internal/proxy/failover_forward.go b/internal/proxy/failover_forward.go index a4c39cc..ced20d1 100644 --- a/internal/proxy/failover_forward.go +++ b/internal/proxy/failover_forward.go @@ -236,7 +236,7 @@ func (cp *ClientProxy) forwardWithFailover(w http.ResponseWriter, req *http.Requ } //nolint:gosec // Clipal is a user-configured reverse proxy and intentionally forwards to configured upstream base URLs. - resp, err := cp.httpClient.Do(proxyReq) + resp, err := cp.upstreamHTTPClient(index).Do(proxyReq) if err != nil { if req.Context().Err() != nil { if busyProbeHeld { @@ -565,7 +565,7 @@ func (cp *ClientProxy) forwardCountTokensSingleShot(w http.ResponseWriter, req * } //nolint:gosec // Clipal is a user-configured reverse proxy and intentionally forwards to configured upstream base URLs. - resp, err := cp.httpClient.Do(proxyReq) + resp, err := cp.upstreamHTTPClient(index).Do(proxyReq) if err != nil { if req.Context().Err() != nil { return diff --git a/internal/proxy/failover_manual.go b/internal/proxy/failover_manual.go index eb4a605..73e41ae 100644 --- a/internal/proxy/failover_manual.go +++ b/internal/proxy/failover_manual.go @@ -74,7 +74,7 @@ func (cp *ClientProxy) forwardManual(w http.ResponseWriter, req *http.Request, p } //nolint:gosec // Clipal is a user-configured reverse proxy and intentionally forwards to the pinned upstream provider. - resp, err := cp.httpClient.Do(proxyReq) + resp, err := cp.upstreamHTTPClient(index).Do(proxyReq) if err != nil { if req.Context().Err() != nil { cancelAttempt(nil) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 528145e..abd9c8b 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -94,6 +94,19 @@ type routingRuntimeSettings struct { maxInlineWait time.Duration } +type upstreamProxyPolicyMode string + +const ( + upstreamProxyPolicyEnvironment upstreamProxyPolicyMode = "environment" + upstreamProxyPolicyDirect upstreamProxyPolicyMode = "direct" + upstreamProxyPolicyCustom upstreamProxyPolicyMode = "custom" +) + +type upstreamProxyPolicyKey struct { + mode upstreamProxyPolicyMode + url string +} + // Router manages multiple client proxies type Router struct { cfg *config.Config @@ -129,27 +142,29 @@ func (r *Router) TelemetryStore() *telemetry.Store { // ClientProxy handles requests for a specific client type type ClientProxy struct { - clientType ClientType - mode config.ClientMode - pinnedProvider string - pinnedIndex int - providers []config.Provider - providerKeys [][]string - currentIndex int - countTokensIndex int - responsesIndex int - geminiStreamIndex int - currentKeyIndex []int - countTokensKeyIndex []int - responsesKeyIndex []int - geminiStreamKeyIndex []int - mu sync.RWMutex - httpClient *http.Client - deactivated []providerDeactivation - keyDeactivated [][]providerDeactivation - providerBusy []providerBusyState - reactivateAfter time.Duration - upstreamIdle time.Duration + clientType ClientType + mode config.ClientMode + pinnedProvider string + pinnedIndex int + providers []config.Provider + providerKeys [][]string + currentIndex int + countTokensIndex int + responsesIndex int + geminiStreamIndex int + currentKeyIndex []int + countTokensKeyIndex []int + responsesKeyIndex []int + geminiStreamKeyIndex []int + mu sync.RWMutex + httpClient *http.Client + providerHTTPClients []*http.Client + providerProxyPolicies []upstreamProxyPolicyKey + deactivated []providerDeactivation + keyDeactivated [][]providerDeactivation + providerBusy []providerBusyState + reactivateAfter time.Duration + upstreamIdle time.Duration stickyBindings map[string]stickyBinding responseLookup map[string]stickyLookupEntry @@ -163,8 +178,20 @@ type ClientProxy struct { // Close releases resources held by the ClientProxy. func (cp *ClientProxy) Close() { - if cp.httpClient != nil { - cp.httpClient.CloseIdleConnections() + seen := make(map[*http.Client]struct{}, len(cp.providerHTTPClients)+1) + closeClient := func(client *http.Client) { + if client == nil { + return + } + if _, ok := seen[client]; ok { + return + } + seen[client] = struct{}{} + client.CloseIdleConnections() + } + closeClient(cp.httpClient) + for _, client := range cp.providerHTTPClients { + closeClient(client) } } @@ -193,19 +220,19 @@ func NewRouter(cfg *config.Config) *Router { // Initialize client proxies claudeProviders := config.GetEnabledProviders(cfg.Claude) if len(claudeProviders) > 0 { - r.proxies[ClientClaude] = newClientProxy(ClientClaude, cfg.Claude.Mode, cfg.Claude.PinnedProvider, claudeProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, telemetryStore) + r.proxies[ClientClaude] = newClientProxyWithGlobalProxy(ClientClaude, cfg.Claude.Mode, cfg.Claude.PinnedProvider, claudeProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, cfg.Global.NormalizedUpstreamProxyMode(), cfg.Global.CanonicalUpstreamProxyURL(), telemetryStore) r.proxies[ClientClaude].applyRoutingRuntimeSettings(routingCfg) } codexProviders := config.GetEnabledProviders(cfg.OpenAI) if len(codexProviders) > 0 { - r.proxies[ClientOpenAI] = newClientProxy(ClientOpenAI, cfg.OpenAI.Mode, cfg.OpenAI.PinnedProvider, codexProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, telemetryStore) + r.proxies[ClientOpenAI] = newClientProxyWithGlobalProxy(ClientOpenAI, cfg.OpenAI.Mode, cfg.OpenAI.PinnedProvider, codexProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, cfg.Global.NormalizedUpstreamProxyMode(), cfg.Global.CanonicalUpstreamProxyURL(), telemetryStore) r.proxies[ClientOpenAI].applyRoutingRuntimeSettings(routingCfg) } geminiProviders := config.GetEnabledProviders(cfg.Gemini) if len(geminiProviders) > 0 { - r.proxies[ClientGemini] = newClientProxy(ClientGemini, cfg.Gemini.Mode, cfg.Gemini.PinnedProvider, geminiProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, telemetryStore) + r.proxies[ClientGemini] = newClientProxyWithGlobalProxy(ClientGemini, cfg.Gemini.Mode, cfg.Gemini.PinnedProvider, geminiProviders, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, cfg.Global.NormalizedUpstreamProxyMode(), cfg.Global.CanonicalUpstreamProxyURL(), telemetryStore) r.proxies[ClientGemini].applyRoutingRuntimeSettings(routingCfg) } @@ -213,6 +240,10 @@ func NewRouter(cfg *config.Config) *Router { } func newClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvider string, providers []config.Provider, reactivateAfter time.Duration, upstreamIdle time.Duration, responseHeaderTimeout time.Duration, cbCfg circuitBreakerConfig, telemetryStore ...*telemetry.Store) *ClientProxy { + return newClientProxyWithGlobalProxy(clientType, mode, pinnedProvider, providers, reactivateAfter, upstreamIdle, responseHeaderTimeout, cbCfg, config.GlobalUpstreamProxyModeEnvironment, "", telemetryStore...) +} + +func newClientProxyWithGlobalProxy(clientType ClientType, mode config.ClientMode, pinnedProvider string, providers []config.Provider, reactivateAfter time.Duration, upstreamIdle time.Duration, responseHeaderTimeout time.Duration, cbCfg circuitBreakerConfig, globalProxyMode config.GlobalUpstreamProxyMode, globalProxyURL string, telemetryStore ...*telemetry.Store) *ClientProxy { var store *telemetry.Store if len(telemetryStore) > 0 { store = telemetryStore[0] @@ -221,6 +252,7 @@ func newClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvide Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, } + sharedClient := newUpstreamHTTPClient(dialer, responseHeaderTimeout, http.ProxyFromEnvironment) pinnedIndex := -1 pinnedProvider = strings.TrimSpace(pinnedProvider) if mode == config.ClientModeManual && pinnedProvider != "" { @@ -238,10 +270,23 @@ func newClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvide countTokensKeyIndex := make([]int, len(providers)) responsesKeyIndex := make([]int, len(providers)) geminiStreamKeyIndex := make([]int, len(providers)) + providerHTTPClients := make([]*http.Client, len(providers)) + providerProxyPolicies := make([]upstreamProxyPolicyKey, len(providers)) + policyClients := map[upstreamProxyPolicyKey]*http.Client{ + {mode: upstreamProxyPolicyEnvironment}: sharedClient, + } keyDeactivated := make([][]providerDeactivation, len(providers)) for i := range providers { breakers[i] = newCircuitBreaker(cbCfg) providerKeys[i] = providers[i].NormalizedAPIKeys() + providerProxyPolicies[i] = effectiveProviderProxyPolicy(providers[i], globalProxyMode, globalProxyURL) + if client, ok := policyClients[providerProxyPolicies[i]]; ok { + providerHTTPClients[i] = client + } else { + client = newProviderHTTPClient(providerProxyPolicies[i], providers[i].Name, sharedClient, dialer, responseHeaderTimeout) + policyClients[providerProxyPolicies[i]] = client + providerHTTPClients[i] = client + } if len(providerKeys[i]) == 0 { providerKeys[i] = []string{""} } @@ -283,6 +328,8 @@ func newClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvide responsesKeyIndex: responsesKeyIndex, geminiStreamKeyIndex: geminiStreamKeyIndex, telemetry: store, + providerHTTPClients: providerHTTPClients, + providerProxyPolicies: providerProxyPolicies, deactivated: make([]providerDeactivation, len(providers)), keyDeactivated: keyDeactivated, providerBusy: make([]providerBusyState, len(providers)), @@ -293,24 +340,72 @@ func newClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvide dynamicFeatureBindings: make(map[string]stickyLookupEntry), routing: defaultRoutingRuntimeSettings(), breakers: breakers, - httpClient: &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: dialer.DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: responseHeaderTimeout, - ExpectContinueTimeout: 1 * time.Second, - // Keep response bytes unchanged unless the client explicitly asks for compression. - DisableCompression: true, - }, + httpClient: sharedClient, + } +} + +func newUpstreamHTTPClient(dialer *net.Dialer, responseHeaderTimeout time.Duration, proxy func(*http.Request) (*url.URL, error)) *http.Client { + return &http.Client{ + Transport: &http.Transport{ + Proxy: proxy, + DialContext: dialer.DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: responseHeaderTimeout, + ExpectContinueTimeout: 1 * time.Second, + // Keep response bytes unchanged unless the client explicitly asks for compression. + DisableCompression: true, }, } } +func effectiveProviderProxyPolicy(provider config.Provider, globalMode config.GlobalUpstreamProxyMode, globalURL string) upstreamProxyPolicyKey { + switch provider.NormalizedProxyMode() { + case config.ProviderProxyModeDirect: + return upstreamProxyPolicyKey{mode: upstreamProxyPolicyDirect} + case config.ProviderProxyModeCustom: + return upstreamProxyPolicyKey{mode: upstreamProxyPolicyCustom, url: provider.CanonicalProxyURL()} + default: + switch globalMode { + case config.GlobalUpstreamProxyModeDirect: + return upstreamProxyPolicyKey{mode: upstreamProxyPolicyDirect} + case config.GlobalUpstreamProxyModeCustom: + return upstreamProxyPolicyKey{mode: upstreamProxyPolicyCustom, url: globalURL} + default: + return upstreamProxyPolicyKey{mode: upstreamProxyPolicyEnvironment} + } + } +} + +func newProviderHTTPClient(policy upstreamProxyPolicyKey, providerName string, sharedClient *http.Client, dialer *net.Dialer, responseHeaderTimeout time.Duration) *http.Client { + switch policy.mode { + case upstreamProxyPolicyDirect: + return newUpstreamHTTPClient(dialer, responseHeaderTimeout, nil) + case upstreamProxyPolicyCustom: + proxyURL, err := config.ParseProxyURL(policy.url) + if err != nil { + logger.Warn("invalid custom proxy for provider %s; falling back to environment proxy settings", providerName) + return sharedClient + } + return newUpstreamHTTPClient(dialer, responseHeaderTimeout, http.ProxyURL(proxyURL)) + default: + return sharedClient + } +} + +func (cp *ClientProxy) upstreamHTTPClient(providerIndex int) *http.Client { + if cp == nil { + return nil + } + if providerIndex >= 0 && providerIndex < len(cp.providerHTTPClients) && cp.providerHTTPClients[providerIndex] != nil { + return cp.providerHTTPClients[providerIndex] + } + return cp.httpClient +} + func (cp *ClientProxy) applyRoutingRuntimeSettings(settings routingRuntimeSettings) { if cp == nil { return @@ -613,16 +708,18 @@ func (r *Router) reloadProviderConfigsLocked() error { logger.Warn("invalid runtime durations; defaulting to reactivate_after=1h upstream_idle_timeout=3m response_header_timeout=2m: %v", err) } cbCfg := normalizeCircuitBreakerConfig(newCfg.Global.CircuitBreaker) + globalProxyMode := newCfg.Global.NormalizedUpstreamProxyMode() + globalProxyURL := newCfg.Global.CanonicalUpstreamProxyURL() newProxies := make(map[ClientType]*ClientProxy) if ps := config.GetEnabledProviders(newCfg.Claude); len(ps) > 0 { - newProxies[ClientClaude] = newReloadedClientProxy(ClientClaude, newCfg.Claude.Mode, newCfg.Claude.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), oldProxies[ClientClaude], r.telemetry) + newProxies[ClientClaude] = newReloadedClientProxy(ClientClaude, newCfg.Claude.Mode, newCfg.Claude.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), globalProxyMode, globalProxyURL, oldProxies[ClientClaude], r.telemetry) } if ps := config.GetEnabledProviders(newCfg.OpenAI); len(ps) > 0 { - newProxies[ClientOpenAI] = newReloadedClientProxy(ClientOpenAI, newCfg.OpenAI.Mode, newCfg.OpenAI.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), oldProxies[ClientOpenAI], r.telemetry) + newProxies[ClientOpenAI] = newReloadedClientProxy(ClientOpenAI, newCfg.OpenAI.Mode, newCfg.OpenAI.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), globalProxyMode, globalProxyURL, oldProxies[ClientOpenAI], r.telemetry) } if ps := config.GetEnabledProviders(newCfg.Gemini); len(ps) > 0 { - newProxies[ClientGemini] = newReloadedClientProxy(ClientGemini, newCfg.Gemini.Mode, newCfg.Gemini.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), oldProxies[ClientGemini], r.telemetry) + newProxies[ClientGemini] = newReloadedClientProxy(ClientGemini, newCfg.Gemini.Mode, newCfg.Gemini.PinnedProvider, ps, durations.ReactivateAfter, durations.UpstreamIdleTimeout, durations.ResponseHeaderTimeout, cbCfg, routingRuntimeSettingsFromConfig(newCfg.Global.Routing), globalProxyMode, globalProxyURL, oldProxies[ClientGemini], r.telemetry) } r.reconcileTelemetryUsage(oldCfg, newCfg) @@ -644,8 +741,8 @@ func (r *Router) reloadProviderConfigsLocked() error { return nil } -func newReloadedClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvider string, providers []config.Provider, reactivateAfter time.Duration, upstreamIdle time.Duration, responseHeaderTimeout time.Duration, cbCfg circuitBreakerConfig, routing routingRuntimeSettings, old *ClientProxy, telemetryStore *telemetry.Store) *ClientProxy { - cp := newClientProxy(clientType, mode, pinnedProvider, providers, reactivateAfter, upstreamIdle, responseHeaderTimeout, cbCfg, telemetryStore) +func newReloadedClientProxy(clientType ClientType, mode config.ClientMode, pinnedProvider string, providers []config.Provider, reactivateAfter time.Duration, upstreamIdle time.Duration, responseHeaderTimeout time.Duration, cbCfg circuitBreakerConfig, routing routingRuntimeSettings, globalProxyMode config.GlobalUpstreamProxyMode, globalProxyURL string, old *ClientProxy, telemetryStore *telemetry.Store) *ClientProxy { + cp := newClientProxyWithGlobalProxy(clientType, mode, pinnedProvider, providers, reactivateAfter, upstreamIdle, responseHeaderTimeout, cbCfg, globalProxyMode, globalProxyURL, telemetryStore) cp.applyRoutingRuntimeSettings(routing) if old != nil { cp.inheritRuntimeState(old) @@ -676,7 +773,7 @@ func (cp *ClientProxy) inheritRuntimeState(old *ClientProxy) { if !ok { continue } - if !sameProviderRuntimeIdentity(cp.providers[newIdx], old.providers[oldIdx]) { + if !sameProviderRuntimeIdentity(cp.providers[newIdx], cp.providerProxyPolicies[newIdx], old.providers[oldIdx], old.providerProxyPolicies[oldIdx]) { continue } cp.deactivated[newIdx] = old.deactivated[oldIdx] @@ -691,7 +788,7 @@ func (cp *ClientProxy) inheritRuntimeState(old *ClientProxy) { if !ok { continue } - if !sameProviderRuntimeIdentity(cp.providers[newIdx], old.providers[oldIdx]) { + if !sameProviderRuntimeIdentity(cp.providers[newIdx], cp.providerProxyPolicies[newIdx], old.providers[oldIdx], old.providerProxyPolicies[oldIdx]) { continue } newByOldIndex[oldIdx] = newIdx @@ -830,8 +927,10 @@ func inheritStickyRuntimeState(dst *ClientProxy, src *ClientProxy, indexMap map[ } } -func sameProviderRuntimeIdentity(a, b config.Provider) bool { - return a.Name == b.Name && strings.TrimSpace(a.BaseURL) == strings.TrimSpace(b.BaseURL) +func sameProviderRuntimeIdentity(a config.Provider, aPolicy upstreamProxyPolicyKey, b config.Provider, bPolicy upstreamProxyPolicyKey) bool { + return a.Name == b.Name && + strings.TrimSpace(a.BaseURL) == strings.TrimSpace(b.BaseURL) && + aPolicy == bPolicy } func providerIndexByName(providers []config.Provider, name string) int { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index d6042a0..e5c3667 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -150,6 +150,158 @@ func TestBuildTargetURL(t *testing.T) { } } +func TestNewClientProxy_UsesProviderSpecificProxyModes(t *testing.T) { + t.Setenv("HTTP_PROXY", "http://env-proxy:8080") + + cp := newClientProxy(ClientOpenAI, config.ClientModeAuto, "", []config.Provider{ + {Name: "default", BaseURL: "http://default.example", APIKey: "k1", Priority: 1}, + {Name: "direct", BaseURL: "http://direct.example", APIKey: "k2", ProxyMode: config.ProviderProxyModeDirect, Priority: 2}, + {Name: "custom", BaseURL: "http://custom.example", APIKey: "k3", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "http://custom-proxy:9090", Priority: 3}, + }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}) + + req, err := http.NewRequest(http.MethodGet, "http://upstream.example/v1/test", nil) + if err != nil { + t.Fatalf("http.NewRequest: %v", err) + } + + defaultTransport, ok := cp.upstreamHTTPClient(0).Transport.(*http.Transport) + if !ok { + t.Fatalf("default transport type = %T", cp.upstreamHTTPClient(0).Transport) + } + defaultProxy, err := defaultTransport.Proxy(req) + if err != nil { + t.Fatalf("default proxy: %v", err) + } + if defaultProxy == nil || defaultProxy.String() != "http://env-proxy:8080" { + t.Fatalf("default proxy = %v, want http://env-proxy:8080", defaultProxy) + } + + directTransport, ok := cp.upstreamHTTPClient(1).Transport.(*http.Transport) + if !ok { + t.Fatalf("direct transport type = %T", cp.upstreamHTTPClient(1).Transport) + } + if directTransport.Proxy != nil { + directProxy, err := directTransport.Proxy(req) + if err != nil { + t.Fatalf("direct proxy: %v", err) + } + if directProxy != nil { + t.Fatalf("direct proxy = %v, want nil", directProxy) + } + } + + customTransport, ok := cp.upstreamHTTPClient(2).Transport.(*http.Transport) + if !ok { + t.Fatalf("custom transport type = %T", cp.upstreamHTTPClient(2).Transport) + } + customProxy, err := customTransport.Proxy(req) + if err != nil { + t.Fatalf("custom proxy: %v", err) + } + if customProxy == nil || customProxy.String() != "http://custom-proxy:9090" { + t.Fatalf("custom proxy = %v, want http://custom-proxy:9090", customProxy) + } +} + +func TestNewClientProxyWithGlobalProxy_UsesGlobalDefaultForDefaultProviders(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodGet, "http://upstream.example/v1/test", nil) + if err != nil { + t.Fatalf("http.NewRequest: %v", err) + } + + directCP := newClientProxyWithGlobalProxy(ClientOpenAI, config.ClientModeAuto, "", []config.Provider{ + {Name: "default", BaseURL: "http://default.example", APIKey: "k1", Priority: 1}, + }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}, config.GlobalUpstreamProxyModeDirect, "") + + directTransport, ok := directCP.upstreamHTTPClient(0).Transport.(*http.Transport) + if !ok { + t.Fatalf("direct transport type = %T", directCP.upstreamHTTPClient(0).Transport) + } + if directTransport.Proxy != nil { + directProxy, err := directTransport.Proxy(req) + if err != nil { + t.Fatalf("direct proxy: %v", err) + } + if directProxy != nil { + t.Fatalf("direct proxy = %v, want nil", directProxy) + } + } + + customCP := newClientProxyWithGlobalProxy(ClientOpenAI, config.ClientModeAuto, "", []config.Provider{ + {Name: "default", BaseURL: "http://default.example", APIKey: "k1", Priority: 1}, + }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}, config.GlobalUpstreamProxyModeCustom, "http://global-proxy:8081") + + customTransport, ok := customCP.upstreamHTTPClient(0).Transport.(*http.Transport) + if !ok { + t.Fatalf("custom transport type = %T", customCP.upstreamHTTPClient(0).Transport) + } + customProxy, err := customTransport.Proxy(req) + if err != nil { + t.Fatalf("custom proxy: %v", err) + } + if customProxy == nil || customProxy.String() != "http://global-proxy:8081" { + t.Fatalf("custom proxy = %v, want http://global-proxy:8081", customProxy) + } +} + +func TestNewClientProxyWithGlobalProxy_SharesHTTPClientsByEffectivePolicy(t *testing.T) { + t.Parallel() + + cp := newClientProxyWithGlobalProxy(ClientOpenAI, config.ClientModeAuto, "", []config.Provider{ + {Name: "default-a", BaseURL: "http://default-a.example", APIKey: "k1", Priority: 1}, + {Name: "default-b", BaseURL: "http://default-b.example", APIKey: "k2", Priority: 2}, + {Name: "direct-a", BaseURL: "http://direct-a.example", APIKey: "k3", ProxyMode: config.ProviderProxyModeDirect, Priority: 3}, + {Name: "direct-b", BaseURL: "http://direct-b.example", APIKey: "k4", ProxyMode: config.ProviderProxyModeDirect, Priority: 4}, + {Name: "custom-a", BaseURL: "http://custom-a.example", APIKey: "k5", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "http://custom-proxy:9090", Priority: 5}, + {Name: "custom-b", BaseURL: "http://custom-b.example", APIKey: "k6", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "http://custom-proxy:9090", Priority: 6}, + {Name: "custom-c", BaseURL: "http://custom-c.example", APIKey: "k7", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "http://other-proxy:9090", Priority: 7}, + }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}, config.GlobalUpstreamProxyModeEnvironment, "") + + if cp.upstreamHTTPClient(0) != cp.upstreamHTTPClient(1) { + t.Fatalf("default providers should share the environment client") + } + if cp.upstreamHTTPClient(2) != cp.upstreamHTTPClient(3) { + t.Fatalf("direct providers should share the direct client") + } + if cp.upstreamHTTPClient(4) != cp.upstreamHTTPClient(5) { + t.Fatalf("matching custom proxy URLs should share the same client") + } + if cp.upstreamHTTPClient(4) == cp.upstreamHTTPClient(6) { + t.Fatalf("different custom proxy URLs should not share the same client") + } + if cp.upstreamHTTPClient(0) == cp.upstreamHTTPClient(2) { + t.Fatalf("environment and direct policies should not share the same client") + } +} + +func TestNewClientProxy_UsesCustomSocksProxy(t *testing.T) { + t.Parallel() + + cp := newClientProxy(ClientOpenAI, config.ClientModeAuto, "", []config.Provider{ + {Name: "custom-socks", BaseURL: "http://custom.example", APIKey: "k1", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "socks5://custom-proxy:1080", Priority: 1}, + }, time.Hour, 0, testResponseHeaderTimeout, circuitBreakerConfig{}) + + req, err := http.NewRequest(http.MethodGet, "https://upstream.example/v1/test", nil) + if err != nil { + t.Fatalf("http.NewRequest: %v", err) + } + + transport, ok := cp.upstreamHTTPClient(0).Transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T", cp.upstreamHTTPClient(0).Transport) + } + + customProxy, err := transport.Proxy(req) + if err != nil { + t.Fatalf("custom proxy: %v", err) + } + if customProxy == nil || customProxy.String() != "socks5://custom-proxy" { + t.Fatalf("custom proxy = %v, want socks5://custom-proxy", customProxy) + } +} + func TestAddForwardedHeaders(t *testing.T) { t.Parallel() diff --git a/internal/proxy/reload_state_test.go b/internal/proxy/reload_state_test.go index b9773ab..48a3b8b 100644 --- a/internal/proxy/reload_state_test.go +++ b/internal/proxy/reload_state_test.go @@ -342,6 +342,74 @@ func TestReloadProviderConfigsLocked_PreservesRuntimeStateAcrossHarmlessReload(t } } +func TestReloadProviderConfigsLocked_PreservesRuntimeStateAcrossEquivalentGlobalCustomProxyForms(t *testing.T) { + dir := t.TempDir() + global := config.DefaultGlobalConfig() + global.ListenAddr = "127.0.0.1" + global.Port = 3333 + global.UpstreamProxyMode = config.GlobalUpstreamProxyModeCustom + global.UpstreamProxyURL = "socks5://proxy.example" + + codex := config.ClientConfig{ + Mode: config.ClientModeAuto, + Providers: []config.Provider{ + {Name: "p1", BaseURL: "https://p1.example", APIKey: "k1", Priority: 1}, + }, + } + writeProxyReloadFixture(t, dir, global, codex) + + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("config.Load: %v", err) + } + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate: %v", err) + } + + router := NewRouter(cfg) + oldProxy := router.proxies[ClientOpenAI] + now := time.Now() + + oldProxy.deactivated[0] = providerDeactivation{ + at: now.Add(-time.Second), + until: now.Add(30 * time.Second), + reason: "rate_limit", + status: http.StatusTooManyRequests, + message: "slow down", + } + oldProxy.keyDeactivated[0][0] = providerDeactivation{ + at: now.Add(-time.Second), + until: now.Add(20 * time.Second), + reason: "rate_limit", + status: http.StatusTooManyRequests, + message: "key cooldown", + } + oldProxy.breakers[0].state = circuitOpen + oldProxy.breakers[0].openedAt = now.Add(-5 * time.Second) + + global.UpstreamProxyURL = "SOCKS5://PROXY.EXAMPLE:1080" + writeProxyReloadFixture(t, dir, global, codex) + + if err := router.reloadProviderConfigsLocked(); err != nil { + t.Fatalf("reloadProviderConfigsLocked: %v", err) + } + + newProxy := router.proxies[ClientOpenAI] + wantPolicy := upstreamProxyPolicyKey{mode: upstreamProxyPolicyCustom, url: "socks5://proxy.example"} + if newProxy.providerProxyPolicies[0] != wantPolicy { + t.Fatalf("proxy policy = %#v, want %#v", newProxy.providerProxyPolicies[0], wantPolicy) + } + if newProxy.deactivated[0].reason != "rate_limit" || newProxy.deactivated[0].message != "slow down" { + t.Fatalf("deactivation = %#v", newProxy.deactivated[0]) + } + if newProxy.keyDeactivated[0][0].message != "key cooldown" { + t.Fatalf("key deactivation = %#v", newProxy.keyDeactivated[0][0]) + } + if newProxy.breakers[0].state != circuitOpen { + t.Fatalf("breaker state = %s, want open", newProxy.breakers[0].state) + } +} + func TestReloadProviderConfigsLocked_DoesNotPreserveSuppressionStateWhenBaseURLChanges(t *testing.T) { router, dir := newReloadTestRouter(t) oldProxy := router.proxies[ClientOpenAI] @@ -475,6 +543,109 @@ func TestReloadProviderConfigsLocked_ReconcilesTelemetryFromYAMLChanges(t *testi } } +func TestReloadProviderConfigsLocked_DoesNotPreserveSuppressionStateWhenProxyChanges(t *testing.T) { + router, dir := newReloadTestRouter(t) + oldProxy := router.proxies[ClientOpenAI] + now := time.Now() + + oldProxy.deactivated[0] = providerDeactivation{ + at: now.Add(-time.Second), + until: now.Add(30 * time.Second), + reason: "rate_limit", + status: http.StatusTooManyRequests, + message: "slow down", + } + oldProxy.keyDeactivated[0][0] = providerDeactivation{ + at: now.Add(-time.Second), + until: now.Add(20 * time.Second), + reason: "rate_limit", + status: http.StatusTooManyRequests, + message: "key cooldown", + } + oldProxy.breakers[0].state = circuitOpen + oldProxy.breakers[0].openedAt = now.Add(-5 * time.Second) + + global := config.DefaultGlobalConfig() + global.ListenAddr = "127.0.0.1" + global.Port = 3333 + writeProxyReloadFixture(t, dir, global, config.ClientConfig{ + Mode: config.ClientModeAuto, + Providers: []config.Provider{ + { + Name: "p1", + BaseURL: "https://p1.example", + APIKey: "k1", + ProxyMode: config.ProviderProxyModeDirect, + Priority: 1, + }, + }, + }) + + if err := router.reloadProviderConfigsLocked(); err != nil { + t.Fatalf("reloadProviderConfigsLocked: %v", err) + } + + newProxy := router.proxies[ClientOpenAI] + if !newProxy.deactivated[0].until.IsZero() || newProxy.deactivated[0].reason != "" { + t.Fatalf("provider cooldown should not carry across proxy change: %#v", newProxy.deactivated[0]) + } + if !newProxy.keyDeactivated[0][0].until.IsZero() || newProxy.keyDeactivated[0][0].reason != "" { + t.Fatalf("key cooldown should not carry across proxy change: %#v", newProxy.keyDeactivated[0][0]) + } + if newProxy.breakers[0].state != circuitClosed { + t.Fatalf("breaker state = %s, want closed", newProxy.breakers[0].state) + } +} + +func TestReloadProviderConfigsLocked_DoesNotPreserveSuppressionStateWhenGlobalProxyChangesForDefaultProvider(t *testing.T) { + router, dir := newReloadTestRouter(t) + oldProxy := router.proxies[ClientOpenAI] + now := time.Now() + + oldProxy.deactivated[0] = providerDeactivation{ + at: now.Add(-time.Second), + until: now.Add(30 * time.Second), + reason: "rate_limit", + status: http.StatusTooManyRequests, + message: "slow down", + } + oldProxy.keyDeactivated[0][0] = providerDeactivation{ + at: now.Add(-time.Second), + until: now.Add(20 * time.Second), + reason: "rate_limit", + status: http.StatusTooManyRequests, + message: "key cooldown", + } + oldProxy.breakers[0].state = circuitOpen + oldProxy.breakers[0].openedAt = now.Add(-5 * time.Second) + + global := config.DefaultGlobalConfig() + global.ListenAddr = "127.0.0.1" + global.Port = 3333 + global.UpstreamProxyMode = config.GlobalUpstreamProxyModeDirect + writeProxyReloadFixture(t, dir, global, config.ClientConfig{ + Mode: config.ClientModeAuto, + Providers: []config.Provider{ + {Name: "p1", BaseURL: "https://p1.example", APIKey: "k1", Priority: 1}, + }, + }) + + if err := router.reloadProviderConfigsLocked(); err != nil { + t.Fatalf("reloadProviderConfigsLocked: %v", err) + } + + newProxy := router.proxies[ClientOpenAI] + if !newProxy.deactivated[0].until.IsZero() || newProxy.deactivated[0].reason != "" { + t.Fatalf("provider cooldown should not carry across global proxy change: %#v", newProxy.deactivated[0]) + } + if !newProxy.keyDeactivated[0][0].until.IsZero() || newProxy.keyDeactivated[0][0].reason != "" { + t.Fatalf("key cooldown should not carry across global proxy change: %#v", newProxy.keyDeactivated[0][0]) + } + if newProxy.breakers[0].state != circuitClosed { + t.Fatalf("breaker state = %s, want closed", newProxy.breakers[0].state) + } +} + func TestTimeUntilNextAvailable_PicksEarliestBlockedSource(t *testing.T) { cbCfg := circuitBreakerConfig{ enabled: true, diff --git a/internal/web/api.go b/internal/web/api.go index 6f00c3d..eb338d3 100644 --- a/internal/web/api.go +++ b/internal/web/api.go @@ -111,6 +111,13 @@ func (a *API) HandleUpdateGlobalConfig(w http.ResponseWriter, r *http.Request) { if strings.TrimSpace(req.ResponseHeaderTimeout) != "" { cfg.Global.ResponseHeaderTimeout = req.ResponseHeaderTimeout } + if err := config.ApplyUpstreamProxySettings(&cfg.Global, config.UpstreamProxySettingsPatch{ + Mode: req.UpstreamProxyMode, + URL: req.UpstreamProxyURL, + }); err != nil { + writeError(w, err.Error(), http.StatusBadRequest) + return + } cfg.Global.MaxRequestBody = req.MaxRequestBodyBytes cfg.Global.LogDir = req.LogDir cfg.Global.LogRetentionDays = req.LogRetentionDays @@ -332,6 +339,13 @@ func (a *API) HandleAddProvider(w http.ResponseWriter, r *http.Request) { Enabled: req.Enabled, } applyProviderOverrides(&provider, req) + if err := config.ApplyProviderProxySettings(&provider, config.ProviderProxySettingsPatch{ + Mode: req.ProxyMode, + URL: req.ProxyURL, + }, false); err != nil { + writeError(w, err.Error(), http.StatusBadRequest) + return + } assignProviderKeys(&provider, keys) cc.Providers = append(cc.Providers, provider) @@ -416,7 +430,11 @@ func (a *API) HandleUpdateProvider(w http.ResponseWriter, r *http.Request) { } } - updated := updateProviderInList(cc.Providers, providerName, req, keys) + updated, err := updateProviderInList(cc.Providers, providerName, req, keys) + if err != nil { + writeError(w, err.Error(), http.StatusBadRequest) + return + } if updated { if !a.saveClientConfigOrWriteError(w, clientType, cfg) { return @@ -1107,7 +1125,7 @@ func extractClientAndProvider(path string) (string, string) { return "", "" } -func updateProviderInList(providers []config.Provider, name string, req ProviderRequest, keys []string) bool { +func updateProviderInList(providers []config.Provider, name string, req ProviderRequest, keys []string) (bool, error) { for i := range providers { if providers[i].Name == name { if req.Name != "" { @@ -1126,10 +1144,16 @@ func updateProviderInList(providers []config.Provider, name string, req Provider providers[i].Enabled = req.Enabled } applyProviderOverrides(&providers[i], req) - return true + if err := config.ApplyProviderProxySettings(&providers[i], config.ProviderProxySettingsPatch{ + Mode: req.ProxyMode, + URL: req.ProxyURL, + }, true); err != nil { + return false, err + } + return true, nil } } - return false + return false, nil } func deleteProviderFromList(providers []config.Provider, name string) ([]config.Provider, bool) { diff --git a/internal/web/api_test.go b/internal/web/api_test.go index 9448b0b..af3e6cc 100644 --- a/internal/web/api_test.go +++ b/internal/web/api_test.go @@ -77,6 +77,8 @@ providers: - name: p1 base_url: https://example.com api_key: secret + proxy_mode: custom + proxy_url: http://proxy.internal:7890 model: gpt-5.4 reasoning_effort: high priority: 2 @@ -117,6 +119,9 @@ providers: if _, ok := got[0]["base_url"]; !ok { t.Fatalf("expected base_url in provider listing, got keys=%v", keys(got[0])) } + if _, ok := got[0]["proxy_url"]; ok { + t.Fatalf("did not expect proxy_url in provider listing") + } if got[0]["key_count"] != float64(1) { t.Fatalf("expected key_count=1, got %v", got[0]["key_count"]) } @@ -130,6 +135,12 @@ providers: if first == nil { t.Fatalf("expected provider p1 in listing, got %#v", got) } + if first["proxy_mode"] != "custom" { + t.Fatalf("expected proxy_mode=custom, got %v", first["proxy_mode"]) + } + if first["proxy_url_hint"] != "http://proxy.internal:7890" { + t.Fatalf("expected proxy_url_hint redaction, got %v", first["proxy_url_hint"]) + } overrides, ok := first["overrides"].(map[string]any) if !ok { t.Fatalf("expected overrides object in listing, got %T", first["overrides"]) @@ -324,15 +335,20 @@ providers: } func TestHandleUpdateGlobalConfig_AcceptsSnakeCaseNotifications(t *testing.T) { - dir := t.TempDir() - api := NewAPI(dir, "test", nil) + for _, proxyURL := range []string{"http://127.0.0.1:7890", "socks5://127.0.0.1:1080"} { + t.Run(proxyURL, func(t *testing.T) { + dir := t.TempDir() + api := NewAPI(dir, "test", nil) - body := []byte(`{ + body := []byte(`{ "listen_addr": "127.0.0.1", "port": 3333, "log_level": "info", "reactivate_after": "10m", "upstream_idle_timeout": "1m", + "response_header_timeout": "30s", + "upstream_proxy_mode": "custom", + "upstream_proxy_url": "` + proxyURL + `", "max_request_body_bytes": 12345, "log_dir": "", "log_retention_days": 7, @@ -361,41 +377,49 @@ func TestHandleUpdateGlobalConfig_AcceptsSnakeCaseNotifications(t *testing.T) { } }`) - req := httptest.NewRequest(http.MethodPut, "/api/config/global/update", bytes.NewReader(body)) - w := httptest.NewRecorder() - api.HandleUpdateGlobalConfig(w, req) - res := w.Result() - if res.StatusCode != http.StatusOK { - t.Fatalf("status=%d body=%s", res.StatusCode, w.Body.String()) - } + req := httptest.NewRequest(http.MethodPut, "/api/config/global/update", bytes.NewReader(body)) + w := httptest.NewRecorder() + api.HandleUpdateGlobalConfig(w, req) + res := w.Result() + if res.StatusCode != http.StatusOK { + t.Fatalf("status=%d body=%s", res.StatusCode, w.Body.String()) + } - cfg, err := config.Load(dir) - if err != nil { - t.Fatalf("reload config: %v", err) - } - if !cfg.Global.Notifications.Enabled { - t.Fatalf("expected notifications.enabled=true") - } - if cfg.Global.Notifications.MinLevel != config.LogLevelWarn { - t.Fatalf("expected notifications.min_level=warn, got %q", cfg.Global.Notifications.MinLevel) - } - if cfg.Global.Notifications.ProviderSwitch == nil || *cfg.Global.Notifications.ProviderSwitch { - t.Fatalf("expected notifications.provider_switch=false, got %v", cfg.Global.Notifications.ProviderSwitch) - } - if !cfg.Global.Routing.StickySessions.Enabled { - t.Fatalf("expected routing.sticky_sessions.enabled=true") - } - if cfg.Global.Routing.StickySessions.ExplicitTTL != "45m" { - t.Fatalf("expected routing.sticky_sessions.explicit_ttl=45m, got %q", cfg.Global.Routing.StickySessions.ExplicitTTL) - } - if !cfg.Global.Routing.BusyBackpressure.Enabled { - t.Fatalf("expected routing.busy_backpressure.enabled=true") - } - if cfg.Global.Routing.BusyBackpressure.ShortRetryAfterMax != "5s" { - t.Fatalf("expected routing.busy_backpressure.short_retry_after_max=5s, got %q", cfg.Global.Routing.BusyBackpressure.ShortRetryAfterMax) - } - if cfg.Global.Routing.BusyBackpressure.MaxInlineWait != "12s" { - t.Fatalf("expected routing.busy_backpressure.max_inline_wait=12s, got %q", cfg.Global.Routing.BusyBackpressure.MaxInlineWait) + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("reload config: %v", err) + } + if !cfg.Global.Notifications.Enabled { + t.Fatalf("expected notifications.enabled=true") + } + if cfg.Global.Notifications.MinLevel != config.LogLevelWarn { + t.Fatalf("expected notifications.min_level=warn, got %q", cfg.Global.Notifications.MinLevel) + } + if cfg.Global.Notifications.ProviderSwitch == nil || *cfg.Global.Notifications.ProviderSwitch { + t.Fatalf("expected notifications.provider_switch=false, got %v", cfg.Global.Notifications.ProviderSwitch) + } + if !cfg.Global.Routing.StickySessions.Enabled { + t.Fatalf("expected routing.sticky_sessions.enabled=true") + } + if cfg.Global.Routing.StickySessions.ExplicitTTL != "45m" { + t.Fatalf("expected routing.sticky_sessions.explicit_ttl=45m, got %q", cfg.Global.Routing.StickySessions.ExplicitTTL) + } + if !cfg.Global.Routing.BusyBackpressure.Enabled { + t.Fatalf("expected routing.busy_backpressure.enabled=true") + } + if cfg.Global.Routing.BusyBackpressure.ShortRetryAfterMax != "5s" { + t.Fatalf("expected routing.busy_backpressure.short_retry_after_max=5s, got %q", cfg.Global.Routing.BusyBackpressure.ShortRetryAfterMax) + } + if cfg.Global.Routing.BusyBackpressure.MaxInlineWait != "12s" { + t.Fatalf("expected routing.busy_backpressure.max_inline_wait=12s, got %q", cfg.Global.Routing.BusyBackpressure.MaxInlineWait) + } + if cfg.Global.NormalizedUpstreamProxyMode() != config.GlobalUpstreamProxyModeCustom { + t.Fatalf("expected upstream_proxy_mode=custom, got %q", cfg.Global.NormalizedUpstreamProxyMode()) + } + if cfg.Global.NormalizedUpstreamProxyURL() != proxyURL { + t.Fatalf("expected upstream_proxy_url to be saved, got %q", cfg.Global.NormalizedUpstreamProxyURL()) + } + }) } } @@ -418,6 +442,8 @@ func TestHandleUpdateGlobalConfig_AllowsClearingRoutingStrings(t *testing.T) { "reactivate_after": "10m", "upstream_idle_timeout": "1m", "response_header_timeout": "30s", + "upstream_proxy_mode": "direct", + "upstream_proxy_url": "", "max_request_body_bytes": 12345, "log_dir": "", "log_retention_days": 7, @@ -467,6 +493,12 @@ func TestHandleUpdateGlobalConfig_AllowsClearingRoutingStrings(t *testing.T) { if cfg.Global.Routing.BusyBackpressure.MaxInlineWait != "" { t.Fatalf("expected routing.busy_backpressure.max_inline_wait to be cleared, got %q", cfg.Global.Routing.BusyBackpressure.MaxInlineWait) } + if cfg.Global.NormalizedUpstreamProxyMode() != config.GlobalUpstreamProxyModeDirect { + t.Fatalf("expected upstream_proxy_mode to be direct, got %q", cfg.Global.NormalizedUpstreamProxyMode()) + } + if cfg.Global.NormalizedUpstreamProxyURL() != "" { + t.Fatalf("expected upstream_proxy_url to be cleared, got %q", cfg.Global.NormalizedUpstreamProxyURL()) + } } func TestHandleAddProvider_AcceptsAPIKeys(t *testing.T) { @@ -504,6 +536,9 @@ func TestHandleAddProvider_AcceptsAPIKeys(t *testing.T) { if cfg.OpenAI.Providers[0].APIKey != "" { t.Fatalf("expected multi-key provider to be persisted via api_keys") } + if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != config.ProviderProxyModeDefault { + t.Fatalf("proxy_mode = %q, want %q", got, config.ProviderProxyModeDefault) + } if got := cfg.OpenAI.Providers[0].ModelOverride(); got != "gpt-5.4" { t.Fatalf("model = %q", got) } @@ -584,6 +619,197 @@ providers: } } +func TestHandleAddProvider_AcceptsCustomProxy(t *testing.T) { + for _, proxyURL := range []string{"http://127.0.0.1:7890", "socks5://127.0.0.1:1080"} { + t.Run(proxyURL, func(t *testing.T) { + dir := t.TempDir() + api := NewAPI(dir, "test", nil) + + body := []byte(`{ + "name": "p1", + "base_url": "https://example.com", + "api_key": "key1", + "proxy_mode": "custom", + "proxy_url": "` + proxyURL + `", + "priority": 1, + "enabled": true +}`) + + req := httptest.NewRequest(http.MethodPost, "/api/providers/codex", bytes.NewReader(body)) + w := httptest.NewRecorder() + api.HandleAddProvider(w, req) + if w.Result().StatusCode != http.StatusOK { + t.Fatalf("status=%d body=%s", w.Result().StatusCode, w.Body.String()) + } + + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("load config: %v", err) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != config.ProviderProxyModeCustom { + t.Fatalf("proxy_mode = %q", got) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyURL(); got != proxyURL { + t.Fatalf("proxy_url = %q", got) + } + }) + } +} + +func TestHandleAddProvider_RejectsProxyURLWithoutCustomMode(t *testing.T) { + for _, mode := range []string{"direct", "default"} { + t.Run(mode, func(t *testing.T) { + dir := t.TempDir() + api := NewAPI(dir, "test", nil) + + body := []byte(`{ + "name": "p1", + "base_url": "https://example.com", + "api_key": "key1", + "proxy_mode": "` + mode + `", + "proxy_url": "http://127.0.0.1:7890", + "priority": 1, + "enabled": true +}`) + + req := httptest.NewRequest(http.MethodPost, "/api/providers/codex", bytes.NewReader(body)) + w := httptest.NewRecorder() + api.HandleAddProvider(w, req) + if w.Result().StatusCode != http.StatusBadRequest { + t.Fatalf("status=%d body=%s", w.Result().StatusCode, w.Body.String()) + } + + got := testutil.DecodeJSONMap(t, w.Body.Bytes()) + if got["error"] != "proxy_url requires proxy_mode custom" { + t.Fatalf("error = %#v", got["error"]) + } + + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("load config: %v", err) + } + if len(cfg.OpenAI.Providers) != 0 { + t.Fatalf("providers len = %d, want 0", len(cfg.OpenAI.Providers)) + } + }) + } +} + +func TestHandleUpdateProvider_ProxySettings(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "openai.yaml"), []byte(` +providers: + - name: p1 + base_url: https://example.com + api_key: key1 + proxy_mode: custom + proxy_url: http://127.0.0.1:7890 + priority: 1 +`), 0o600); err != nil { + t.Fatal(err) + } + + api := NewAPI(dir, "test", nil) + + t.Run("retain existing custom proxy url", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/api/providers/codex/p1", bytes.NewReader([]byte(`{ + "proxy_mode": "custom" +}`))) + w := httptest.NewRecorder() + api.HandleUpdateProvider(w, req) + if w.Result().StatusCode != http.StatusOK { + t.Fatalf("status=%d body=%s", w.Result().StatusCode, w.Body.String()) + } + + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("load config: %v", err) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyURL(); got != "http://127.0.0.1:7890" { + t.Fatalf("proxy_url = %q", got) + } + }) + + t.Run("switch to direct clears proxy url", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/api/providers/codex/p1", bytes.NewReader([]byte(`{ + "proxy_mode": "direct" +}`))) + w := httptest.NewRecorder() + api.HandleUpdateProvider(w, req) + if w.Result().StatusCode != http.StatusOK { + t.Fatalf("status=%d body=%s", w.Result().StatusCode, w.Body.String()) + } + + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("load config: %v", err) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != config.ProviderProxyModeDirect { + t.Fatalf("proxy_mode = %q", got) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyURL(); got != "" { + t.Fatalf("proxy_url = %q, want empty", got) + } + }) + + t.Run("reject proxy url without mode", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/api/providers/codex/p1", bytes.NewReader([]byte(`{ + "proxy_url": "http://127.0.0.1:8899" +}`))) + w := httptest.NewRecorder() + api.HandleUpdateProvider(w, req) + if w.Result().StatusCode != http.StatusBadRequest { + t.Fatalf("status=%d body=%s", w.Result().StatusCode, w.Body.String()) + } + }) + + t.Run("reject proxy url without custom mode", func(t *testing.T) { + for _, mode := range []string{"direct", "default"} { + t.Run(mode, func(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "openai.yaml"), []byte(` +providers: + - name: p1 + base_url: https://example.com + api_key: key1 + proxy_mode: custom + proxy_url: http://127.0.0.1:7890 + priority: 1 +`), 0o600); err != nil { + t.Fatal(err) + } + + api := NewAPI(dir, "test", nil) + req := httptest.NewRequest(http.MethodPut, "/api/providers/codex/p1", bytes.NewReader([]byte(`{ + "proxy_mode": "`+mode+`", + "proxy_url": "http://127.0.0.1:8899" +}`))) + w := httptest.NewRecorder() + api.HandleUpdateProvider(w, req) + if w.Result().StatusCode != http.StatusBadRequest { + t.Fatalf("status=%d body=%s", w.Result().StatusCode, w.Body.String()) + } + + got := testutil.DecodeJSONMap(t, w.Body.Bytes()) + if got["error"] != "proxy_url requires proxy_mode custom" { + t.Fatalf("error = %#v", got["error"]) + } + + cfg, err := config.Load(dir) + if err != nil { + t.Fatalf("load config: %v", err) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyMode(); got != config.ProviderProxyModeCustom { + t.Fatalf("proxy_mode = %q, want %q", got, config.ProviderProxyModeCustom) + } + if got := cfg.OpenAI.Providers[0].NormalizedProxyURL(); got != "http://127.0.0.1:7890" { + t.Fatalf("proxy_url = %q", got) + } + }) + } + }) +} + func TestHandleGetClientConfig_ReturnsConfiguredModeAndPin(t *testing.T) { dir := t.TempDir() if err := os.WriteFile(filepath.Join(dir, "openai.yaml"), []byte(` diff --git a/internal/web/static/app.js b/internal/web/static/app.js index f179ed9..f9c9ac7 100644 --- a/internal/web/static/app.js +++ b/internal/web/static/app.js @@ -66,6 +66,7 @@ function app() { empty: 'No providers configured for {client}', pinBadge: 'Pinned', baseUrl: 'Base URL', + proxy: 'Proxy', apiKeys: 'API Keys', usageTotal: 'Usage', usageInOut: 'Input / Output', @@ -104,7 +105,9 @@ function app() { deleteConfirm: 'Are you sure you want to delete provider "{name}"?', deletedTitle: 'Deleted provider {name}', deletedMessage: 'It has been removed from {client}\'s provider list.', - clientTypeLabel: 'Client Type' + clientTypeLabel: 'Client Type', + proxyDirect: 'Direct', + proxyCustom: 'Custom' }, modal: { provider: { @@ -114,6 +117,14 @@ function app() { name: 'Name *', nameHint: 'Letters, numbers, dot (.), underscore (_), and hyphen (-).', baseUrl: 'Base URL *', + proxyMode: 'Proxy Mode', + proxyModeDefault: 'Use Default', + proxyModeDirect: 'Direct', + proxyModeCustom: 'Custom Proxy', + proxyUrl: 'Proxy URL', + proxyUrlHint: 'http://127.0.0.1:7890', + proxyUrlHelp: 'Supports http://, https://, socks5://, and socks5h:// proxy URLs.', + keepExistingProxy: 'Leave empty to keep the current proxy ({proxy}).', model: 'Model', modelHint: 'model-id', reasoningEffort: 'Reasoning Effort', @@ -154,6 +165,14 @@ function app() { upstreamIdleTimeoutHint: 'Set to 0 to disable stalled-stream protection.', responseHeaderTimeout: 'Response Header Timeout', responseHeaderTimeoutHint: 'Set to 0 to wait indefinitely for headers.', + upstreamProxyMode: 'Default Upstream Proxy', + upstreamProxyUrl: 'Default Proxy URL', + upstreamProxyHint: 'Used by providers whose proxy mode is Default.', + upstreamProxyUrlHelp: 'Supports http://, https://, socks5://, and socks5h:// proxy URLs.', + proxyModeEnvironment: 'Use Environment', + proxyModeDirect: 'Direct', + proxyModeCustom: 'Custom Proxy', + upstreamProxyUrlHint: 'http://127.0.0.1:7890', failureThreshold: 'Failure Threshold', failureThresholdHint: '0 disables the circuit breaker.', successThreshold: 'Success Threshold', @@ -356,6 +375,7 @@ function app() { empty: '{client} 还没有配置任何 Provider', pinBadge: '已固定', baseUrl: 'Base URL', + proxy: '代理', apiKeys: 'API Keys', usageTotal: '用量', usageInOut: '输入 / 输出', @@ -394,7 +414,9 @@ function app() { deletedTitle: '已删除 Provider {name}', deletedMessage: '它已从 {client} 的 Provider 列表中移除。', clientTypeLabel: '客户端类型', - dragToReorder: '拖拽调整优先级' + dragToReorder: '拖拽调整优先级', + proxyDirect: '直连', + proxyCustom: '自定义' }, modal: { provider: { @@ -404,6 +426,14 @@ function app() { name: '名称 *', nameHint: '允许字母、数字、点号 (.)、下划线 (_) 和连字符 (-)。', baseUrl: 'Base URL *', + proxyMode: '代理模式', + proxyModeDefault: '使用默认值', + proxyModeDirect: '直连', + proxyModeCustom: '自定义代理', + proxyUrl: '代理 URL', + proxyUrlHint: 'http://127.0.0.1:7890', + proxyUrlHelp: '支持 http://、https://、socks5:// 和 socks5h:// 代理 URL。', + keepExistingProxy: '留空则保留当前代理({proxy})。', model: '模型', modelHint: 'model-id', reasoningEffort: '思考强度', @@ -443,6 +473,14 @@ function app() { upstreamIdleTimeoutHint: '设为 0 可关闭流式响应停滞保护。', responseHeaderTimeout: '响应头超时', responseHeaderTimeoutHint: '设为 0 表示无限等待响应头。', + upstreamProxyMode: '默认上游代理', + upstreamProxyUrl: '默认代理 URL', + upstreamProxyHint: '对代理模式为“使用默认值”的 Provider 生效。', + upstreamProxyUrlHelp: '支持 http://、https://、socks5:// 和 socks5h:// 代理 URL。', + proxyModeEnvironment: '使用环境变量', + proxyModeDirect: '直连', + proxyModeCustom: '自定义代理', + upstreamProxyUrlHint: 'http://127.0.0.1:7890', failureThreshold: '失败阈值', failureThresholdHint: '设为 0 可关闭熔断器。', successThreshold: '成功阈值', @@ -612,6 +650,8 @@ function app() { reactivate_after: '', upstream_idle_timeout: '', response_header_timeout: '', + upstream_proxy_mode: 'environment', + upstream_proxy_url: '', max_request_body_bytes: 0, log_dir: '', log_retention_days: 7, @@ -672,6 +712,9 @@ function app() { providerForm: { name: '', base_url: '', + proxy_mode: 'default', + proxy_url: '', + proxy_url_hint: '', model: '', reasoning_effort: '', thinking_budget_tokens: 0, @@ -1472,6 +1515,42 @@ function app() { return this.tf('modal.provider.keepExistingKeys', { count }); }, + normalizeGlobalProxyMode(mode) { + const value = String(mode || '').trim().toLowerCase(); + return value || 'environment'; + }, + + normalizeProviderProxyMode(mode) { + const value = String(mode || '').trim().toLowerCase(); + return value || 'default'; + }, + + providerProxySummary(provider) { + const mode = this.normalizeProviderProxyMode(provider && provider.proxy_mode); + if (mode === 'direct') { + return this.t('providers.proxyDirect'); + } + if (mode === 'custom') { + return String((provider && provider.proxy_url_hint) || '').trim() || this.t('providers.proxyCustom'); + } + if (mode !== 'default') { + return mode; + } + return ''; + }, + + providerFormUsesCustomProxy() { + return this.normalizeProviderProxyMode(this.providerForm.proxy_mode) === 'custom'; + }, + + providerEditProxyHint() { + const proxy = String(this.providerForm.proxy_url_hint || '').trim(); + if (!proxy) { + return ''; + } + return this.tf('modal.provider.keepExistingProxy', { proxy }); + }, + providerOverrideSupport() { const support = this.clientConfig && this.clientConfig.override_support; if (!support || typeof support !== 'object') { @@ -1877,9 +1956,16 @@ function app() { const payload = { name: this.providerForm.name, base_url: this.providerForm.base_url, + proxy_mode: this.normalizeProviderProxyMode(this.providerForm.proxy_mode), priority: this.providerForm.priority, enabled: this.providerForm.enabled }; + if (payload.proxy_mode === 'custom') { + const proxyURL = String(this.providerForm.proxy_url || '').trim(); + if (proxyURL) { + payload.proxy_url = proxyURL; + } + } const overrides = {}; if (this.providerSupportsModelOverride()) { overrides.model = String(this.providerForm.model || ''); @@ -1948,6 +2034,9 @@ function app() { this.providerForm = { name: provider.name, base_url: provider.base_url, + proxy_mode: this.normalizeProviderProxyMode(provider.proxy_mode), + proxy_url: '', + proxy_url_hint: String(provider.proxy_url_hint || ''), model: String((provider.overrides && provider.overrides.model) || ''), reasoning_effort: String((provider.overrides && provider.overrides.openai && provider.overrides.openai.reasoning_effort) || ''), thinking_budget_tokens: Number((provider.overrides && provider.overrides.claude && provider.overrides.claude.thinking_budget_tokens) || 0), @@ -1990,6 +2079,9 @@ function app() { this.providerForm = { name: '', base_url: '', + proxy_mode: 'default', + proxy_url: '', + proxy_url_hint: '', model: '', reasoning_effort: '', thinking_budget_tokens: 0, @@ -2014,9 +2106,16 @@ function app() { async saveGlobalConfig() { try { + const payload = { + ...this.globalConfig, + upstream_proxy_mode: this.normalizeGlobalProxyMode(this.globalConfig.upstream_proxy_mode) + }; + if (payload.upstream_proxy_mode !== 'custom') { + payload.upstream_proxy_url = ''; + } await this.apiCall('/api/config/global/update', { method: 'PUT', - body: JSON.stringify(this.globalConfig) + body: JSON.stringify(payload) }); this.showAlert('success', this.t('settings.saveSuccess')); await this.refreshStatus(); @@ -2145,6 +2244,9 @@ function app() { this.providerForm = { name: '', base_url: '', + proxy_mode: 'default', + proxy_url: '', + proxy_url_hint: '', model: '', reasoning_effort: '', thinking_budget_tokens: 0, diff --git a/internal/web/static/app.test.js b/internal/web/static/app.test.js index 711602e..6bbad92 100644 --- a/internal/web/static/app.test.js +++ b/internal/web/static/app.test.js @@ -141,6 +141,7 @@ test('saveProvider includes OpenAI override fields in payload', async () => { assert.deepEqual(calls[0].options, { name: 'openai-primary', base_url: 'https://example.com', + proxy_mode: 'default', priority: 1, enabled: true, overrides: { @@ -192,6 +193,7 @@ test('saveProvider includes Claude thinking budget override in payload', async ( assert.deepEqual(calls[0].options, { name: 'claude-primary', base_url: 'https://example.com', + proxy_mode: 'default', priority: 2, enabled: false, overrides: { @@ -287,6 +289,7 @@ test('openAddProviderModal sets next priority for provider form', () => { assert.equal(state.showAddProviderModal, true); assert.equal(state.providerForm.priority, 4); + assert.equal(state.providerForm.proxy_mode, 'default'); }); test('editProvider hydrates override fields directly into the form', () => { @@ -295,6 +298,8 @@ test('editProvider hydrates override fields directly into the form', () => { state.editProvider({ name: 'openai-primary', base_url: 'https://example.com', + proxy_mode: 'custom', + proxy_url_hint: 'http://127.0.0.1:7890', overrides: { model: 'gpt-5.4', openai: { @@ -307,11 +312,54 @@ test('editProvider hydrates override fields directly into the form', () => { }); assert.equal(state.showEditProviderModal, true); + assert.equal(state.providerForm.proxy_mode, 'custom'); + assert.equal(state.providerForm.proxy_url, ''); + assert.equal(state.providerForm.proxy_url_hint, 'http://127.0.0.1:7890'); assert.equal(state.providerForm.model, 'gpt-5.4'); assert.equal(state.providerForm.reasoning_effort, 'high'); assert.equal(state.providerForm.thinking_budget_tokens, 0); }); +test('saveProvider includes custom proxy settings when configured', async () => { + const state = loadApp(); + const calls = []; + state.selectedClient = 'openai'; + state.providerForm = { + name: 'openai-proxy', + base_url: 'https://example.com', + proxy_mode: 'custom', + proxy_url: 'http://127.0.0.1:7890', + proxy_url_hint: '', + model: '', + reasoning_effort: '', + thinking_budget_tokens: 0, + api_keys_text: 'key-1', + priority: 1, + enabled: true + }; + state.apiCall = async (url, options) => { + calls.push({ url, options: JSON.parse(options.body) }); + return {}; + }; + state.showAlert = () => {}; + state.closeModals = () => {}; + state.loadProviders = async () => {}; + state.refreshStatus = async () => {}; + + await state.saveProvider(); + + assert.equal(calls.length, 1); + assert.deepEqual(calls[0].options, { + name: 'openai-proxy', + base_url: 'https://example.com', + proxy_mode: 'custom', + proxy_url: 'http://127.0.0.1:7890', + priority: 1, + enabled: true, + api_key: 'key-1' + }); +}); + test('saveProvider omits unsupported override fields for gemini', async () => { const state = loadApp(); const calls = []; @@ -319,6 +367,9 @@ test('saveProvider omits unsupported override fields for gemini', async () => { state.providerForm = { name: 'gemini-primary', base_url: 'https://example.com', + proxy_mode: 'default', + proxy_url: '', + proxy_url_hint: '', model: 'gemini-2.5-pro', reasoning_effort: 'high', thinking_budget_tokens: 2048, @@ -342,8 +393,32 @@ test('saveProvider omits unsupported override fields for gemini', async () => { assert.deepEqual(calls[0].options, { name: 'gemini-primary', base_url: 'https://example.com', + proxy_mode: 'default', priority: 1, enabled: true, api_key: 'key-1' }); }); + +test('saveGlobalConfig normalizes and clears non-custom upstream proxy settings', async () => { + const state = loadApp(); + const calls = []; + state.globalConfig = { + ...state.globalConfig, + upstream_proxy_mode: 'DIRECT', + upstream_proxy_url: 'http://127.0.0.1:7890' + }; + state.apiCall = async (url, options) => { + calls.push({ url, options: JSON.parse(options.body) }); + return {}; + }; + state.showAlert = () => {}; + state.refreshStatus = async () => {}; + + await state.saveGlobalConfig(); + + assert.equal(calls.length, 1); + assert.equal(calls[0].url, '/api/config/global/update'); + assert.equal(calls[0].options.upstream_proxy_mode, 'direct'); + assert.equal(calls[0].options.upstream_proxy_url, ''); +}); diff --git a/internal/web/static/index.html b/internal/web/static/index.html index 4e7d74d..5f1f255 100644 --- a/internal/web/static/index.html +++ b/internal/web/static/index.html @@ -214,6 +214,11 @@

+
+ + +
required>
+
+ + +
+
+
+ + +
+
placeholder="https://api.example.com">
+ +
+ +
+ + +
+ +
+
+
+
diff --git a/internal/web/types.go b/internal/web/types.go index 7adb100..a64ae19 100644 --- a/internal/web/types.go +++ b/internal/web/types.go @@ -5,6 +5,8 @@ package web // and lets us redact sensitive fields like API keys. import ( + "net/url" + "strings" "time" "github.com/lansespirit/Clipal/internal/config" @@ -20,6 +22,8 @@ type GlobalConfigRequest struct { ReactivateAfter string `json:"reactivate_after"` UpstreamIdleTimeout string `json:"upstream_idle_timeout"` ResponseHeaderTimeout string `json:"response_header_timeout"` + UpstreamProxyMode *string `json:"upstream_proxy_mode,omitempty"` + UpstreamProxyURL *string `json:"upstream_proxy_url,omitempty"` MaxRequestBodyBytes int64 `json:"max_request_body_bytes"` LogDir string `json:"log_dir"` LogRetentionDays int `json:"log_retention_days"` @@ -66,6 +70,8 @@ type GlobalConfigResponse struct { ReactivateAfter string `json:"reactivate_after"` UpstreamIdleTimeout string `json:"upstream_idle_timeout"` ResponseHeaderTimeout string `json:"response_header_timeout"` + UpstreamProxyMode string `json:"upstream_proxy_mode"` + UpstreamProxyURL string `json:"upstream_proxy_url"` MaxRequestBodyBytes int64 `json:"max_request_body_bytes"` LogDir string `json:"log_dir"` LogRetentionDays int `json:"log_retention_days"` @@ -163,6 +169,8 @@ type ProviderRequest struct { BaseURL string `json:"base_url"` APIKey string `json:"api_key,omitempty"` APIKeys []string `json:"api_keys,omitempty"` + ProxyMode *string `json:"proxy_mode,omitempty"` + ProxyURL *string `json:"proxy_url,omitempty"` Overrides *ProviderOverridesRequest `json:"overrides,omitempty"` // Priority is 1-based. Omit to keep existing value (on updates) or to // auto-assign the next priority (on create). @@ -172,13 +180,15 @@ type ProviderRequest struct { // ProviderResponse is returned for provider listings (never includes api_key). type ProviderResponse struct { - Name string `json:"name"` - BaseURL string `json:"base_url"` - Priority int `json:"priority"` - Enabled bool `json:"enabled"` - KeyCount int `json:"key_count"` - Usage *ProviderUsageResponse `json:"usage,omitempty"` - Overrides *ProviderOverridesResponse `json:"overrides,omitempty"` + Name string `json:"name"` + BaseURL string `json:"base_url"` + ProxyMode string `json:"proxy_mode"` + ProxyURLHint string `json:"proxy_url_hint,omitempty"` + Priority int `json:"priority"` + Enabled bool `json:"enabled"` + KeyCount int `json:"key_count"` + Usage *ProviderUsageResponse `json:"usage,omitempty"` + Overrides *ProviderOverridesResponse `json:"overrides,omitempty"` } type ProviderUsageResponse struct { @@ -216,6 +226,8 @@ type ProviderExport struct { BaseURL string `json:"base_url"` APIKey string `json:"api_key,omitempty"` APIKeys []string `json:"api_keys,omitempty"` + ProxyMode string `json:"proxy_mode,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` Priority int `json:"priority"` Enabled *bool `json:"enabled,omitempty"` Overrides *ProviderOverridesResponse `json:"overrides,omitempty"` @@ -360,6 +372,8 @@ func toGlobalConfigResponse(gc config.GlobalConfig) GlobalConfigResponse { ReactivateAfter: gc.ReactivateAfter, UpstreamIdleTimeout: gc.UpstreamIdleTimeout, ResponseHeaderTimeout: gc.ResponseHeaderTimeout, + UpstreamProxyMode: string(gc.NormalizedUpstreamProxyMode()), + UpstreamProxyURL: gc.NormalizedUpstreamProxyURL(), MaxRequestBodyBytes: gc.MaxRequestBody, LogDir: gc.LogDir, LogRetentionDays: gc.LogRetentionDays, @@ -418,13 +432,15 @@ func toProviderResponses(providers []config.Provider, usageByProvider map[string out := make([]ProviderResponse, 0, len(providers)) for _, p := range providers { out = append(out, ProviderResponse{ - Name: p.Name, - BaseURL: p.BaseURL, - Priority: p.Priority, - Enabled: p.IsEnabled(), - KeyCount: p.KeyCount(), - Usage: mapProviderUsageResponse(usageByProvider[p.Name]), - Overrides: mapProviderOverridesResponse(p), + Name: p.Name, + BaseURL: p.BaseURL, + ProxyMode: string(p.NormalizedProxyMode()), + ProxyURLHint: proxyURLHint(p.NormalizedProxyURL()), + Priority: p.Priority, + Enabled: p.IsEnabled(), + KeyCount: p.KeyCount(), + Usage: mapProviderUsageResponse(usageByProvider[p.Name]), + Overrides: mapProviderOverridesResponse(p), }) } return out @@ -461,6 +477,8 @@ func toClientConfigExport(cc config.ClientConfig) ClientConfigExport { BaseURL: p.BaseURL, APIKey: p.APIKey, APIKeys: append([]string(nil), p.APIKeys...), + ProxyMode: string(p.NormalizedProxyMode()), + ProxyURL: p.NormalizedProxyURL(), Priority: p.Priority, Enabled: p.Enabled, Overrides: mapProviderOverridesResponse(p), @@ -473,6 +491,18 @@ func toClientConfigExport(cc config.ClientConfig) ClientConfigExport { } } +func proxyURLHint(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + parsed, err := url.Parse(trimmed) + if err != nil || parsed.Scheme == "" || parsed.Host == "" { + return "" + } + return parsed.Scheme + "://" + parsed.Host +} + func toProviderOverrideSupport(s providerOverrideSupport) ProviderOverrideSupport { return ProviderOverrideSupport{ Model: s.Model, diff --git a/internal/web/yaml_format.go b/internal/web/yaml_format.go index 5dd9a2d..f24f8b0 100644 --- a/internal/web/yaml_format.go +++ b/internal/web/yaml_format.go @@ -62,6 +62,12 @@ func formatClientConfigYAML(clientType string, cc config.ClientConfig) []byte { writeBufferString(&b, fmt.Sprintf(" - name: %s\n", yamlDoubleQuote(p.Name))) writeBufferString(&b, fmt.Sprintf(" base_url: %s\n", yamlDoubleQuote(p.BaseURL))) + if p.NormalizedProxyMode() != config.ProviderProxyModeDefault { + writeBufferString(&b, fmt.Sprintf(" proxy_mode: %s\n", yamlDoubleQuote(string(p.NormalizedProxyMode())))) + } + if p.NormalizedProxyMode() == config.ProviderProxyModeCustom && p.NormalizedProxyURL() != "" { + writeBufferString(&b, fmt.Sprintf(" proxy_url: %s\n", yamlDoubleQuote(p.NormalizedProxyURL()))) + } keys := p.NormalizedAPIKeys() if len(keys) <= 1 { writeBufferString(&b, fmt.Sprintf(" api_key: %s\n", yamlDoubleQuote(p.PrimaryAPIKey()))) @@ -111,6 +117,10 @@ func formatGlobalConfigYAML(gc config.GlobalConfig) []byte { writeBufferString(&b, "# How long to wait for the upstream to return response headers.\n") writeBufferString(&b, "# Set to 0 to disable.\n") writeBufferString(&b, fmt.Sprintf("response_header_timeout: %s\n", yamlDoubleQuote(strings.TrimSpace(gc.ResponseHeaderTimeout)))) + writeBufferString(&b, "# Default upstream proxy for providers whose proxy_mode is default.\n") + writeBufferString(&b, fmt.Sprintf("upstream_proxy_mode: %s # environment | direct | custom\n", yamlDoubleQuote(string(gc.NormalizedUpstreamProxyMode())))) + writeBufferString(&b, "# Supported proxy URLs: http://, https://, socks5://, socks5h://\n") + writeBufferString(&b, fmt.Sprintf("upstream_proxy_url: %s\n", yamlDoubleQuote(gc.NormalizedUpstreamProxyURL()))) writeBufferString(&b, "# Max request body size in bytes (clipal buffers request bodies for retries).\n") writeBufferString(&b, fmt.Sprintf("max_request_body_bytes: %d\n\n", gc.MaxRequestBody)) diff --git a/internal/web/yaml_format_test.go b/internal/web/yaml_format_test.go index 843d391..06f9130 100644 --- a/internal/web/yaml_format_test.go +++ b/internal/web/yaml_format_test.go @@ -150,6 +150,27 @@ func TestFormatClientConfigYAML_SingleNormalizedKeyUsesAPIKeyField(t *testing.T) } } +func TestFormatClientConfigYAML_WritesProxySettingsOnlyWhenNeeded(t *testing.T) { + cc := config.ClientConfig{ + Providers: []config.Provider{ + {Name: "default", BaseURL: "https://a.example", APIKey: "k1", Priority: 1, Enabled: boolPtr(true)}, + {Name: "direct", BaseURL: "https://b.example", APIKey: "k2", ProxyMode: config.ProviderProxyModeDirect, Priority: 2, Enabled: boolPtr(true)}, + {Name: "custom", BaseURL: "https://c.example", APIKey: "k3", ProxyMode: config.ProviderProxyModeCustom, ProxyURL: "http://127.0.0.1:7890", Priority: 3, Enabled: boolPtr(true)}, + }, + } + + got := string(formatClientConfigYAML("codex", cc)) + if strings.Count(got, "proxy_mode:") != 2 { + t.Fatalf("expected exactly two proxy_mode entries, got:\n%s", got) + } + if !strings.Contains(got, `name: "direct"`) || !strings.Contains(got, `proxy_mode: "direct"`) { + t.Fatalf("expected direct proxy_mode, got:\n%s", got) + } + if !strings.Contains(got, `name: "custom"`) || !strings.Contains(got, `proxy_mode: "custom"`) || !strings.Contains(got, `proxy_url: "http://127.0.0.1:7890"`) { + t.Fatalf("expected custom proxy settings, got:\n%s", got) + } +} + func TestFormatClientConfigYAML_RoundTripAndEscapesSpecialCharacters(t *testing.T) { cc := config.ClientConfig{ Providers: []config.Provider{ @@ -331,12 +352,16 @@ func TestFormatGlobalConfigYAML_RoundTripAndEscapesSpecialCharacters(t *testing. gc.LogDir = "logs\r\nfolder\tcontrol\x01" gc.LogRetentionDays = 7 gc.LogStdout = boolPtr(false) + gc.UpstreamProxyMode = config.GlobalUpstreamProxyModeCustom + gc.UpstreamProxyURL = "http://127.0.0.1:7890" gc.Notifications.ProviderSwitch = boolPtr(false) got := string(formatGlobalConfigYAML(gc)) for _, want := range []string{ `listen_addr: "host\"quoted\"\\path"`, `log_dir: "logs\r\nfolder\tcontrol\x01"`, + `upstream_proxy_mode: "custom" # environment | direct | custom`, + `upstream_proxy_url: "http://127.0.0.1:7890"`, `log_retention_days: 7 # default 7 days`, `log_stdout: false`, `provider_switch: false`, @@ -360,6 +385,12 @@ func TestFormatGlobalConfigYAML_RoundTripAndEscapesSpecialCharacters(t *testing. if loaded.Global.LogDir != gc.LogDir { t.Fatalf("log_dir = %q, want %q", loaded.Global.LogDir, gc.LogDir) } + if loaded.Global.NormalizedUpstreamProxyMode() != config.GlobalUpstreamProxyModeCustom { + t.Fatalf("upstream_proxy_mode = %q, want custom", loaded.Global.NormalizedUpstreamProxyMode()) + } + if loaded.Global.NormalizedUpstreamProxyURL() != "http://127.0.0.1:7890" { + t.Fatalf("upstream_proxy_url = %q", loaded.Global.NormalizedUpstreamProxyURL()) + } if loaded.Global.LogRetentionDays != 7 { t.Fatalf("log_retention_days = %d, want 7", loaded.Global.LogRetentionDays) }