diff --git a/config.example.yaml b/config.example.yaml index 9dfca5bc..dd175d5f 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -87,6 +87,26 @@ ws-auth: false # - "*-mini" # wildcard matching suffix (e.g. gpt-5-codex-mini) # - "*codex*" # wildcard matching substring (e.g. gpt-5-codex-low) +# GitHub Copilot account configuration +# Note: Copilot uses OAuth device code authentication, NOT API keys or tokens. +# Do NOT paste your GitHub access token or Copilot bearer token here. +# Tokens are stored only in auth-dir JSON files, never in config.yaml. +# +# To authenticate: +# - CLI: run with -copilot-login flag +# - Web: use the /copilot-auth-url management endpoint +# +# After OAuth login, tokens are managed automatically and stored in auth-dir. +# The entries below only configure account type and optional proxy settings. +#copilot-api-key: +# - account-type: "individual" # Options: individual, business, enterprise +# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy for Copilot requests + +# # When set to true, this flag forces subsequent requests in a session (sharing the same prompt_cache_key) +# # to send the header "X-Initiator: agent" instead of "vscode". This mirrors VS Code's behavior for +# # long-running agent interactions and helps prevent hitting standard rate limits. +# agent-initiator-persist: true + # Claude API keys #claude-api-key: # - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url diff --git a/internal/config/config.go b/internal/config/config.go index 97b5a0c2..f546c643 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,6 +12,7 @@ import ( "strings" "syscall" + copilotshared "github.com/router-for-me/CLIProxyAPI/v6/internal/copilot" "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" "golang.org/x/crypto/bcrypt" "gopkg.in/yaml.v3" @@ -66,6 +67,9 @@ type Config struct { // RequestRetry defines the retry times when the request failed. RequestRetry int `yaml:"request-retry" json:"request-retry"` + // ScannerBufferSize defines the buffer size for reading response streams (in bytes). + // If 0, a default of 20MB is used. + ScannerBufferSize int `yaml:"scanner-buffer-size" json:"scanner-buffer-size"` // MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential. MaxRetryInterval int `yaml:"max-retry-interval" json:"max-retry-interval"` @@ -75,6 +79,9 @@ type Config struct { // Codex defines a list of Codex API key configurations as specified in the YAML configuration file. CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"` + // CopilotKey defines GitHub Copilot API configurations. + CopilotKey []CopilotKey `yaml:"copilot-api-key" json:"copilot-api-key"` + // OpenAICompatibility defines OpenAI API compatibility configurations for external providers. OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"` @@ -194,6 +201,21 @@ type CodexKey struct { ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` } +// CopilotKey represents the configuration for GitHub Copilot API access. +// Authentication is handled via device code OAuth flow, not API keys. +type CopilotKey struct { + // AccountType is the Copilot subscription type (individual, business, enterprise). + // Defaults to "individual" if not specified. + AccountType string `yaml:"account-type" json:"account-type"` + + // ProxyURL overrides the global proxy setting for Copilot requests if provided. + ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + + // AgentInitiatorPersist, when true, forces subsequent Copilot requests sharing the + // same prompt_cache_key to send X-Initiator=agent after the first call. Default false. + AgentInitiatorPersist bool `yaml:"agent-initiator-persist" json:"agent-initiator-persist"` +} + // GeminiKey represents the configuration for a Gemini API key, // including optional overrides for upstream base URL, proxy routing, and headers. type GeminiKey struct { @@ -328,6 +350,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Sanitize Codex keys: drop entries without base-url cfg.SanitizeCodexKeys() + // Sanitize Copilot keys: normalize account type + cfg.SanitizeCopilotKeys() + // Sanitize Claude key headers cfg.SanitizeClaudeKeys() @@ -383,6 +408,25 @@ func (cfg *Config) SanitizeCodexKeys() { cfg.CodexKey = out } +// SanitizeCopilotKeys normalizes Copilot configurations. +// It sets default account type and trims whitespace. +func (cfg *Config) SanitizeCopilotKeys() { + if cfg == nil || len(cfg.CopilotKey) == 0 { + return + } + for i := range cfg.CopilotKey { + entry := &cfg.CopilotKey[i] + entry.AccountType = strings.TrimSpace(strings.ToLower(entry.AccountType)) + entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + validation := copilotshared.ValidateAccountType(entry.AccountType) + if validation.Valid { + entry.AccountType = string(validation.AccountType) + } else { + entry.AccountType = string(copilotshared.DefaultAccountType) + } + } +} + // SanitizeClaudeKeys normalizes headers for Claude credentials. func (cfg *Config) SanitizeClaudeKeys() { if cfg == nil || len(cfg.ClaudeKey) == 0 { diff --git a/internal/copilot/types.go b/internal/copilot/types.go new file mode 100644 index 00000000..97f1958a --- /dev/null +++ b/internal/copilot/types.go @@ -0,0 +1,60 @@ +package copilot + +import ( + "fmt" + "strings" +) + +// AccountType is the Copilot subscription type. +type AccountType string + +const ( + AccountTypeIndividual AccountType = "individual" + AccountTypeBusiness AccountType = "business" + AccountTypeEnterprise AccountType = "enterprise" +) + +// ValidAccountTypes is the canonical list of valid Copilot account types. +var ValidAccountTypes = []string{string(AccountTypeIndividual), string(AccountTypeBusiness), string(AccountTypeEnterprise)} + +const DefaultAccountType = AccountTypeIndividual + +// AccountTypeValidationResult contains the result of account type validation. +type AccountTypeValidationResult struct { + AccountType AccountType + Valid bool + ValidValues []string + DefaultValue string + ErrorMessage string +} + +// ParseAccountType parses a string into an AccountType. +// Returns the parsed type and whether the input was a valid account type. +// Empty or invalid strings return (AccountTypeIndividual, false). +func ParseAccountType(s string) (AccountType, bool) { + switch strings.ToLower(strings.TrimSpace(s)) { + case "individual": + return AccountTypeIndividual, true + case "business": + return AccountTypeBusiness, true + case "enterprise": + return AccountTypeEnterprise, true + default: + return AccountTypeIndividual, false + } +} + +// ValidateAccountType validates an account type string and returns details suitable for API responses. +func ValidateAccountType(s string) AccountTypeValidationResult { + accountType, valid := ParseAccountType(s) + result := AccountTypeValidationResult{ + AccountType: accountType, + Valid: valid, + ValidValues: ValidAccountTypes, + DefaultValue: string(DefaultAccountType), + } + if !valid && s != "" { + result.ErrorMessage = fmt.Sprintf("invalid account_type '%s', valid values are: %s", s, strings.Join(ValidAccountTypes, ", ")) + } + return result +} diff --git a/internal/copilot/types_test.go b/internal/copilot/types_test.go new file mode 100644 index 00000000..9627bace --- /dev/null +++ b/internal/copilot/types_test.go @@ -0,0 +1,169 @@ + package copilot + + import ( + "testing" + ) + + func TestParseAccountType(t *testing.T) { + tests := []struct { + name string + input string + wantType AccountType + wantValid bool + }{ + { + name: "individual lowercase", + input: "individual", + wantType: AccountTypeIndividual, + wantValid: true, + }, + { + name: "individual uppercase", + input: "INDIVIDUAL", + wantType: AccountTypeIndividual, + wantValid: true, + }, + { + name: "individual mixed case", + input: "Individual", + wantType: AccountTypeIndividual, + wantValid: true, + }, + { + name: "business", + input: "business", + wantType: AccountTypeBusiness, + wantValid: true, + }, + { + name: "enterprise", + input: "enterprise", + wantType: AccountTypeEnterprise, + wantValid: true, + }, + { + name: "empty string", + input: "", + wantType: AccountTypeIndividual, + wantValid: false, + }, + { + name: "invalid value", + input: "invalid", + wantType: AccountTypeIndividual, + wantValid: false, + }, + { + name: "whitespace", + input: " individual ", + wantType: AccountTypeIndividual, + wantValid: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotType, gotValid := ParseAccountType(tt.input) + if gotType != tt.wantType { + t.Errorf("ParseAccountType(%q) type = %v, want %v", tt.input, gotType, tt.wantType) + } + if gotValid != tt.wantValid { + t.Errorf("ParseAccountType(%q) valid = %v, want %v", tt.input, gotValid, tt.wantValid) + } + }) + } + } + + func TestValidateAccountType(t *testing.T) { + tests := []struct { + name string + input string + wantValid bool + wantHasError bool + wantType AccountType + }{ + { + name: "valid individual", + input: "individual", + wantValid: true, + wantHasError: false, + wantType: AccountTypeIndividual, + }, + { + name: "valid business", + input: "business", + wantValid: true, + wantHasError: false, + wantType: AccountTypeBusiness, + }, + { + name: "valid enterprise", + input: "enterprise", + wantValid: true, + wantHasError: false, + wantType: AccountTypeEnterprise, + }, + { + name: "invalid value", + input: "invalid", + wantValid: false, + wantHasError: true, + wantType: AccountTypeIndividual, + }, + { + name: "empty string", + input: "", + wantValid: false, + wantHasError: false, // empty string doesn't generate error message + wantType: AccountTypeIndividual, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ValidateAccountType(tt.input) + if result.Valid != tt.wantValid { + t.Errorf("ValidateAccountType(%q).Valid = %v, want %v", tt.input, result.Valid, tt.wantValid) + } + if (result.ErrorMessage != "") != tt.wantHasError { + t.Errorf("ValidateAccountType(%q).ErrorMessage = %q, wantHasError %v", tt.input, result.ErrorMessage, tt.wantHasError) + } + if result.AccountType != tt.wantType { + t.Errorf("ValidateAccountType(%q).AccountType = %v, want %v", tt.input, result.AccountType, tt.wantType) + } + if result.DefaultValue != string(DefaultAccountType) { + t.Errorf("ValidateAccountType(%q).DefaultValue = %q, want %q", tt.input, result.DefaultValue, DefaultAccountType) + } + if len(result.ValidValues) != 3 { + t.Errorf("ValidateAccountType(%q).ValidValues has %d items, want 3", tt.input, len(result.ValidValues)) + } + }) + } + } + + func TestAccountTypeConstants(t *testing.T) { + if AccountTypeIndividual != "individual" { + t.Errorf("AccountTypeIndividual = %q, want %q", AccountTypeIndividual, "individual") + } + if AccountTypeBusiness != "business" { + t.Errorf("AccountTypeBusiness = %q, want %q", AccountTypeBusiness, "business") + } + if AccountTypeEnterprise != "enterprise" { + t.Errorf("AccountTypeEnterprise = %q, want %q", AccountTypeEnterprise, "enterprise") + } + if DefaultAccountType != AccountTypeIndividual { + t.Errorf("DefaultAccountType = %q, want %q", DefaultAccountType, AccountTypeIndividual) + } + } + + func TestValidAccountTypes(t *testing.T) { + expected := []string{"individual", "business", "enterprise"} + if len(ValidAccountTypes) != len(expected) { + t.Fatalf("ValidAccountTypes has %d items, want %d", len(ValidAccountTypes), len(expected)) + } + for i, v := range expected { + if ValidAccountTypes[i] != v { + t.Errorf("ValidAccountTypes[%d] = %q, want %q", i, ValidAccountTypes[i], v) + } + } + } diff --git a/internal/util/util.go b/internal/util/util.go index 17536ac1..aa947a6b 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -85,6 +85,35 @@ func CountAuthFiles(authDir string) int { return count } +// EnsureAuthDir ensures the auth directory exists and is a directory. +// If it doesn't exist, it creates it with permissions 0755 (main-branch default). +// Returns the resolved path and any error encountered. +func EnsureAuthDir(authDir string) (string, error) { + dir, err := ResolveAuthDir(authDir) + if err != nil { + return "", fmt.Errorf("failed to resolve auth directory: %w", err) + } + if dir == "" { + return "", fmt.Errorf("auth directory not configured") + } + + info, err := os.Stat(dir) + if err != nil { + if os.IsNotExist(err) { + if mkErr := os.MkdirAll(dir, 0o755); mkErr != nil { + return "", fmt.Errorf("failed to create auth directory %s: %w", dir, mkErr) + } + log.Infof("created auth directory: %s", dir) + return dir, nil + } + return "", fmt.Errorf("failed to access auth directory %s: %w", dir, err) + } + if !info.IsDir() { + return "", fmt.Errorf("auth path exists but is not a directory: %s", dir) + } + return dir, nil +} + // WritablePath returns the cleaned WRITABLE_PATH environment variable when it is set. // It accepts both uppercase and lowercase variants for compatibility with existing conventions. func WritablePath() string { diff --git a/internal/util/util_test.go b/internal/util/util_test.go new file mode 100644 index 00000000..940e2206 --- /dev/null +++ b/internal/util/util_test.go @@ -0,0 +1,181 @@ + package util + + import ( + "os" + "path/filepath" + "testing" + ) + + func TestResolveAuthDir(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + { + name: "empty string", + input: "", + wantErr: false, + }, + { + name: "absolute path", + input: "/tmp/auth", + wantErr: false, + }, + { + name: "relative path", + input: "auth", + wantErr: false, + }, + { + name: "tilde path", + input: "~/auth", + wantErr: false, + }, + { + name: "tilde only", + input: "~", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ResolveAuthDir(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ResolveAuthDir(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + return + } + if tt.input == "" && result != "" { + t.Errorf("ResolveAuthDir(%q) = %q, want empty", tt.input, result) + } + if tt.input != "" && result == "" && !tt.wantErr { + t.Errorf("ResolveAuthDir(%q) returned empty string unexpectedly", tt.input) + } + }) + } + } + + func TestResolveAuthDir_TildeExpansion(t *testing.T) { + home, err := os.UserHomeDir() + if err != nil { + t.Skip("cannot get user home directory") + } + + result, err := ResolveAuthDir("~/test/auth") + if err != nil { + t.Fatalf("ResolveAuthDir(~/test/auth) error = %v", err) + } + + expected := filepath.Join(home, "test", "auth") + if result != expected { + t.Errorf("ResolveAuthDir(~/test/auth) = %q, want %q", result, expected) + } + } + + func TestEnsureAuthDir(t *testing.T) { + // Test with temp directory + tmpDir := t.TempDir() + testDir := filepath.Join(tmpDir, "test-auth") + + result, err := EnsureAuthDir(testDir) + if err != nil { + t.Fatalf("EnsureAuthDir(%q) error = %v", testDir, err) + } + if result != testDir { + t.Errorf("EnsureAuthDir(%q) = %q, want %q", testDir, result, testDir) + } + + // Verify directory was created + info, err := os.Stat(testDir) + if err != nil { + t.Fatalf("os.Stat(%q) error = %v", testDir, err) + } + if !info.IsDir() { + t.Errorf("EnsureAuthDir(%q) did not create a directory", testDir) + } + } + + func TestEnsureAuthDir_ExistingDir(t *testing.T) { + tmpDir := t.TempDir() + + result, err := EnsureAuthDir(tmpDir) + if err != nil { + t.Fatalf("EnsureAuthDir(%q) error = %v", tmpDir, err) + } + if result != tmpDir { + t.Errorf("EnsureAuthDir(%q) = %q, want %q", tmpDir, result, tmpDir) + } + } + + func TestEnsureAuthDir_EmptyString(t *testing.T) { + _, err := EnsureAuthDir("") + if err == nil { + t.Error("EnsureAuthDir(\"\") should return an error") + } + } + + func TestCountAuthFiles(t *testing.T) { + tmpDir := t.TempDir() + + // Create some test files + files := []string{"auth1.json", "auth2.json", "config.yaml", "readme.txt"} + for _, f := range files { + path := filepath.Join(tmpDir, f) + if err := os.WriteFile(path, []byte("{}"), 0644); err != nil { + t.Fatalf("failed to create test file %s: %v", path, err) + } + } + + count := CountAuthFiles(tmpDir) + if count != 2 { + t.Errorf("CountAuthFiles(%q) = %d, want 2", tmpDir, count) + } + } + + func TestCountAuthFiles_EmptyDir(t *testing.T) { + tmpDir := t.TempDir() + count := CountAuthFiles(tmpDir) + if count != 0 { + t.Errorf("CountAuthFiles(%q) = %d, want 0", tmpDir, count) + } + } + + func TestCountAuthFiles_NonExistent(t *testing.T) { + count := CountAuthFiles("/nonexistent/path/that/does/not/exist") + if count != 0 { + t.Errorf("CountAuthFiles(nonexistent) = %d, want 0", count) + } + } + + func TestWritablePath(t *testing.T) { + // Save and restore environment + origUpper := os.Getenv("WRITABLE_PATH") + origLower := os.Getenv("writable_path") + defer func() { + os.Setenv("WRITABLE_PATH", origUpper) + os.Setenv("writable_path", origLower) + }() + + // Clear both + os.Unsetenv("WRITABLE_PATH") + os.Unsetenv("writable_path") + + // Test empty + if result := WritablePath(); result != "" { + t.Errorf("WritablePath() = %q, want empty", result) + } + + // Test uppercase + os.Setenv("WRITABLE_PATH", "/tmp/test") + if result := WritablePath(); result != "/tmp/test" { + t.Errorf("WritablePath() = %q, want /tmp/test", result) + } + os.Unsetenv("WRITABLE_PATH") + + // Test lowercase + os.Setenv("writable_path", "/tmp/lower") + if result := WritablePath(); result != "/tmp/lower" { + t.Errorf("WritablePath() = %q, want /tmp/lower", result) + } + }