From 06a8e1060a4e8a3bdb4121f191a6f99abbc1c9ce Mon Sep 17 00:00:00 2001 From: Jeff Nash <9919536+jeffnash@users.noreply.github.com> Date: Sun, 30 Nov 2025 12:35:35 -0800 Subject: [PATCH 1/5] feat(util): add shared infrastructure helpers - Add auth directory management helper (GetAuthDir) - Add random hex string generator for request IDs - Add helper for generating unique machine identifiers --- internal/util/util.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/internal/util/util.go b/internal/util/util.go index 17536ac1..aa947a6b 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -85,6 +85,35 @@ func CountAuthFiles(authDir string) int { return count } +// EnsureAuthDir ensures the auth directory exists and is a directory. +// If it doesn't exist, it creates it with permissions 0755 (main-branch default). +// Returns the resolved path and any error encountered. +func EnsureAuthDir(authDir string) (string, error) { + dir, err := ResolveAuthDir(authDir) + if err != nil { + return "", fmt.Errorf("failed to resolve auth directory: %w", err) + } + if dir == "" { + return "", fmt.Errorf("auth directory not configured") + } + + info, err := os.Stat(dir) + if err != nil { + if os.IsNotExist(err) { + if mkErr := os.MkdirAll(dir, 0o755); mkErr != nil { + return "", fmt.Errorf("failed to create auth directory %s: %w", dir, mkErr) + } + log.Infof("created auth directory: %s", dir) + return dir, nil + } + return "", fmt.Errorf("failed to access auth directory %s: %w", dir, err) + } + if !info.IsDir() { + return "", fmt.Errorf("auth path exists but is not a directory: %s", dir) + } + return dir, nil +} + // WritablePath returns the cleaned WRITABLE_PATH environment variable when it is set. // It accepts both uppercase and lowercase variants for compatibility with existing conventions. func WritablePath() string { From a895864f4d57170e505e95aa1acda512617598d5 Mon Sep 17 00:00:00 2001 From: Jeff Nash <9919536+jeffnash@users.noreply.github.com> Date: Sun, 30 Nov 2025 13:24:18 -0800 Subject: [PATCH 2/5] feat(config): add Copilot configuration options - Add Copilot section with agent-initiator-persist flag - Add scanner buffer size configuration - Add account type configuration option - Add copilot types with account type validation - Document configuration options in example config - Add tests for util and copilot types --- config.example.yaml | 20 ++++ internal/config/config.go | 44 ++++++++ internal/copilot/types.go | 60 +++++++++++ internal/copilot/types_test.go | 169 ++++++++++++++++++++++++++++++ internal/util/util_test.go | 181 +++++++++++++++++++++++++++++++++ 5 files changed, 474 insertions(+) create mode 100644 internal/copilot/types.go create mode 100644 internal/copilot/types_test.go create mode 100644 internal/util/util_test.go diff --git a/config.example.yaml b/config.example.yaml index 9dfca5bc..dd175d5f 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -87,6 +87,26 @@ ws-auth: false # - "*-mini" # wildcard matching suffix (e.g. gpt-5-codex-mini) # - "*codex*" # wildcard matching substring (e.g. gpt-5-codex-low) +# GitHub Copilot account configuration +# Note: Copilot uses OAuth device code authentication, NOT API keys or tokens. +# Do NOT paste your GitHub access token or Copilot bearer token here. +# Tokens are stored only in auth-dir JSON files, never in config.yaml. +# +# To authenticate: +# - CLI: run with -copilot-login flag +# - Web: use the /copilot-auth-url management endpoint +# +# After OAuth login, tokens are managed automatically and stored in auth-dir. +# The entries below only configure account type and optional proxy settings. +#copilot-api-key: +# - account-type: "individual" # Options: individual, business, enterprise +# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy for Copilot requests + +# # When set to true, this flag forces subsequent requests in a session (sharing the same prompt_cache_key) +# # to send the header "X-Initiator: agent" instead of "vscode". This mirrors VS Code's behavior for +# # long-running agent interactions and helps prevent hitting standard rate limits. +# agent-initiator-persist: true + # Claude API keys #claude-api-key: # - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url diff --git a/internal/config/config.go b/internal/config/config.go index 97b5a0c2..f546c643 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,6 +12,7 @@ import ( "strings" "syscall" + copilotshared "github.com/router-for-me/CLIProxyAPI/v6/internal/copilot" "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" "golang.org/x/crypto/bcrypt" "gopkg.in/yaml.v3" @@ -66,6 +67,9 @@ type Config struct { // RequestRetry defines the retry times when the request failed. RequestRetry int `yaml:"request-retry" json:"request-retry"` + // ScannerBufferSize defines the buffer size for reading response streams (in bytes). + // If 0, a default of 20MB is used. + ScannerBufferSize int `yaml:"scanner-buffer-size" json:"scanner-buffer-size"` // MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential. MaxRetryInterval int `yaml:"max-retry-interval" json:"max-retry-interval"` @@ -75,6 +79,9 @@ type Config struct { // Codex defines a list of Codex API key configurations as specified in the YAML configuration file. CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"` + // CopilotKey defines GitHub Copilot API configurations. + CopilotKey []CopilotKey `yaml:"copilot-api-key" json:"copilot-api-key"` + // OpenAICompatibility defines OpenAI API compatibility configurations for external providers. OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"` @@ -194,6 +201,21 @@ type CodexKey struct { ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` } +// CopilotKey represents the configuration for GitHub Copilot API access. +// Authentication is handled via device code OAuth flow, not API keys. +type CopilotKey struct { + // AccountType is the Copilot subscription type (individual, business, enterprise). + // Defaults to "individual" if not specified. + AccountType string `yaml:"account-type" json:"account-type"` + + // ProxyURL overrides the global proxy setting for Copilot requests if provided. + ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + + // AgentInitiatorPersist, when true, forces subsequent Copilot requests sharing the + // same prompt_cache_key to send X-Initiator=agent after the first call. Default false. + AgentInitiatorPersist bool `yaml:"agent-initiator-persist" json:"agent-initiator-persist"` +} + // GeminiKey represents the configuration for a Gemini API key, // including optional overrides for upstream base URL, proxy routing, and headers. type GeminiKey struct { @@ -328,6 +350,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Sanitize Codex keys: drop entries without base-url cfg.SanitizeCodexKeys() + // Sanitize Copilot keys: normalize account type + cfg.SanitizeCopilotKeys() + // Sanitize Claude key headers cfg.SanitizeClaudeKeys() @@ -383,6 +408,25 @@ func (cfg *Config) SanitizeCodexKeys() { cfg.CodexKey = out } +// SanitizeCopilotKeys normalizes Copilot configurations. +// It sets default account type and trims whitespace. +func (cfg *Config) SanitizeCopilotKeys() { + if cfg == nil || len(cfg.CopilotKey) == 0 { + return + } + for i := range cfg.CopilotKey { + entry := &cfg.CopilotKey[i] + entry.AccountType = strings.TrimSpace(strings.ToLower(entry.AccountType)) + entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + validation := copilotshared.ValidateAccountType(entry.AccountType) + if validation.Valid { + entry.AccountType = string(validation.AccountType) + } else { + entry.AccountType = string(copilotshared.DefaultAccountType) + } + } +} + // SanitizeClaudeKeys normalizes headers for Claude credentials. func (cfg *Config) SanitizeClaudeKeys() { if cfg == nil || len(cfg.ClaudeKey) == 0 { diff --git a/internal/copilot/types.go b/internal/copilot/types.go new file mode 100644 index 00000000..97f1958a --- /dev/null +++ b/internal/copilot/types.go @@ -0,0 +1,60 @@ +package copilot + +import ( + "fmt" + "strings" +) + +// AccountType is the Copilot subscription type. +type AccountType string + +const ( + AccountTypeIndividual AccountType = "individual" + AccountTypeBusiness AccountType = "business" + AccountTypeEnterprise AccountType = "enterprise" +) + +// ValidAccountTypes is the canonical list of valid Copilot account types. +var ValidAccountTypes = []string{string(AccountTypeIndividual), string(AccountTypeBusiness), string(AccountTypeEnterprise)} + +const DefaultAccountType = AccountTypeIndividual + +// AccountTypeValidationResult contains the result of account type validation. +type AccountTypeValidationResult struct { + AccountType AccountType + Valid bool + ValidValues []string + DefaultValue string + ErrorMessage string +} + +// ParseAccountType parses a string into an AccountType. +// Returns the parsed type and whether the input was a valid account type. +// Empty or invalid strings return (AccountTypeIndividual, false). +func ParseAccountType(s string) (AccountType, bool) { + switch strings.ToLower(strings.TrimSpace(s)) { + case "individual": + return AccountTypeIndividual, true + case "business": + return AccountTypeBusiness, true + case "enterprise": + return AccountTypeEnterprise, true + default: + return AccountTypeIndividual, false + } +} + +// ValidateAccountType validates an account type string and returns details suitable for API responses. +func ValidateAccountType(s string) AccountTypeValidationResult { + accountType, valid := ParseAccountType(s) + result := AccountTypeValidationResult{ + AccountType: accountType, + Valid: valid, + ValidValues: ValidAccountTypes, + DefaultValue: string(DefaultAccountType), + } + if !valid && s != "" { + result.ErrorMessage = fmt.Sprintf("invalid account_type '%s', valid values are: %s", s, strings.Join(ValidAccountTypes, ", ")) + } + return result +} diff --git a/internal/copilot/types_test.go b/internal/copilot/types_test.go new file mode 100644 index 00000000..9627bace --- /dev/null +++ b/internal/copilot/types_test.go @@ -0,0 +1,169 @@ + package copilot + + import ( + "testing" + ) + + func TestParseAccountType(t *testing.T) { + tests := []struct { + name string + input string + wantType AccountType + wantValid bool + }{ + { + name: "individual lowercase", + input: "individual", + wantType: AccountTypeIndividual, + wantValid: true, + }, + { + name: "individual uppercase", + input: "INDIVIDUAL", + wantType: AccountTypeIndividual, + wantValid: true, + }, + { + name: "individual mixed case", + input: "Individual", + wantType: AccountTypeIndividual, + wantValid: true, + }, + { + name: "business", + input: "business", + wantType: AccountTypeBusiness, + wantValid: true, + }, + { + name: "enterprise", + input: "enterprise", + wantType: AccountTypeEnterprise, + wantValid: true, + }, + { + name: "empty string", + input: "", + wantType: AccountTypeIndividual, + wantValid: false, + }, + { + name: "invalid value", + input: "invalid", + wantType: AccountTypeIndividual, + wantValid: false, + }, + { + name: "whitespace", + input: " individual ", + wantType: AccountTypeIndividual, + wantValid: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotType, gotValid := ParseAccountType(tt.input) + if gotType != tt.wantType { + t.Errorf("ParseAccountType(%q) type = %v, want %v", tt.input, gotType, tt.wantType) + } + if gotValid != tt.wantValid { + t.Errorf("ParseAccountType(%q) valid = %v, want %v", tt.input, gotValid, tt.wantValid) + } + }) + } + } + + func TestValidateAccountType(t *testing.T) { + tests := []struct { + name string + input string + wantValid bool + wantHasError bool + wantType AccountType + }{ + { + name: "valid individual", + input: "individual", + wantValid: true, + wantHasError: false, + wantType: AccountTypeIndividual, + }, + { + name: "valid business", + input: "business", + wantValid: true, + wantHasError: false, + wantType: AccountTypeBusiness, + }, + { + name: "valid enterprise", + input: "enterprise", + wantValid: true, + wantHasError: false, + wantType: AccountTypeEnterprise, + }, + { + name: "invalid value", + input: "invalid", + wantValid: false, + wantHasError: true, + wantType: AccountTypeIndividual, + }, + { + name: "empty string", + input: "", + wantValid: false, + wantHasError: false, // empty string doesn't generate error message + wantType: AccountTypeIndividual, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ValidateAccountType(tt.input) + if result.Valid != tt.wantValid { + t.Errorf("ValidateAccountType(%q).Valid = %v, want %v", tt.input, result.Valid, tt.wantValid) + } + if (result.ErrorMessage != "") != tt.wantHasError { + t.Errorf("ValidateAccountType(%q).ErrorMessage = %q, wantHasError %v", tt.input, result.ErrorMessage, tt.wantHasError) + } + if result.AccountType != tt.wantType { + t.Errorf("ValidateAccountType(%q).AccountType = %v, want %v", tt.input, result.AccountType, tt.wantType) + } + if result.DefaultValue != string(DefaultAccountType) { + t.Errorf("ValidateAccountType(%q).DefaultValue = %q, want %q", tt.input, result.DefaultValue, DefaultAccountType) + } + if len(result.ValidValues) != 3 { + t.Errorf("ValidateAccountType(%q).ValidValues has %d items, want 3", tt.input, len(result.ValidValues)) + } + }) + } + } + + func TestAccountTypeConstants(t *testing.T) { + if AccountTypeIndividual != "individual" { + t.Errorf("AccountTypeIndividual = %q, want %q", AccountTypeIndividual, "individual") + } + if AccountTypeBusiness != "business" { + t.Errorf("AccountTypeBusiness = %q, want %q", AccountTypeBusiness, "business") + } + if AccountTypeEnterprise != "enterprise" { + t.Errorf("AccountTypeEnterprise = %q, want %q", AccountTypeEnterprise, "enterprise") + } + if DefaultAccountType != AccountTypeIndividual { + t.Errorf("DefaultAccountType = %q, want %q", DefaultAccountType, AccountTypeIndividual) + } + } + + func TestValidAccountTypes(t *testing.T) { + expected := []string{"individual", "business", "enterprise"} + if len(ValidAccountTypes) != len(expected) { + t.Fatalf("ValidAccountTypes has %d items, want %d", len(ValidAccountTypes), len(expected)) + } + for i, v := range expected { + if ValidAccountTypes[i] != v { + t.Errorf("ValidAccountTypes[%d] = %q, want %q", i, ValidAccountTypes[i], v) + } + } + } diff --git a/internal/util/util_test.go b/internal/util/util_test.go new file mode 100644 index 00000000..940e2206 --- /dev/null +++ b/internal/util/util_test.go @@ -0,0 +1,181 @@ + package util + + import ( + "os" + "path/filepath" + "testing" + ) + + func TestResolveAuthDir(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + { + name: "empty string", + input: "", + wantErr: false, + }, + { + name: "absolute path", + input: "/tmp/auth", + wantErr: false, + }, + { + name: "relative path", + input: "auth", + wantErr: false, + }, + { + name: "tilde path", + input: "~/auth", + wantErr: false, + }, + { + name: "tilde only", + input: "~", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ResolveAuthDir(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ResolveAuthDir(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + return + } + if tt.input == "" && result != "" { + t.Errorf("ResolveAuthDir(%q) = %q, want empty", tt.input, result) + } + if tt.input != "" && result == "" && !tt.wantErr { + t.Errorf("ResolveAuthDir(%q) returned empty string unexpectedly", tt.input) + } + }) + } + } + + func TestResolveAuthDir_TildeExpansion(t *testing.T) { + home, err := os.UserHomeDir() + if err != nil { + t.Skip("cannot get user home directory") + } + + result, err := ResolveAuthDir("~/test/auth") + if err != nil { + t.Fatalf("ResolveAuthDir(~/test/auth) error = %v", err) + } + + expected := filepath.Join(home, "test", "auth") + if result != expected { + t.Errorf("ResolveAuthDir(~/test/auth) = %q, want %q", result, expected) + } + } + + func TestEnsureAuthDir(t *testing.T) { + // Test with temp directory + tmpDir := t.TempDir() + testDir := filepath.Join(tmpDir, "test-auth") + + result, err := EnsureAuthDir(testDir) + if err != nil { + t.Fatalf("EnsureAuthDir(%q) error = %v", testDir, err) + } + if result != testDir { + t.Errorf("EnsureAuthDir(%q) = %q, want %q", testDir, result, testDir) + } + + // Verify directory was created + info, err := os.Stat(testDir) + if err != nil { + t.Fatalf("os.Stat(%q) error = %v", testDir, err) + } + if !info.IsDir() { + t.Errorf("EnsureAuthDir(%q) did not create a directory", testDir) + } + } + + func TestEnsureAuthDir_ExistingDir(t *testing.T) { + tmpDir := t.TempDir() + + result, err := EnsureAuthDir(tmpDir) + if err != nil { + t.Fatalf("EnsureAuthDir(%q) error = %v", tmpDir, err) + } + if result != tmpDir { + t.Errorf("EnsureAuthDir(%q) = %q, want %q", tmpDir, result, tmpDir) + } + } + + func TestEnsureAuthDir_EmptyString(t *testing.T) { + _, err := EnsureAuthDir("") + if err == nil { + t.Error("EnsureAuthDir(\"\") should return an error") + } + } + + func TestCountAuthFiles(t *testing.T) { + tmpDir := t.TempDir() + + // Create some test files + files := []string{"auth1.json", "auth2.json", "config.yaml", "readme.txt"} + for _, f := range files { + path := filepath.Join(tmpDir, f) + if err := os.WriteFile(path, []byte("{}"), 0644); err != nil { + t.Fatalf("failed to create test file %s: %v", path, err) + } + } + + count := CountAuthFiles(tmpDir) + if count != 2 { + t.Errorf("CountAuthFiles(%q) = %d, want 2", tmpDir, count) + } + } + + func TestCountAuthFiles_EmptyDir(t *testing.T) { + tmpDir := t.TempDir() + count := CountAuthFiles(tmpDir) + if count != 0 { + t.Errorf("CountAuthFiles(%q) = %d, want 0", tmpDir, count) + } + } + + func TestCountAuthFiles_NonExistent(t *testing.T) { + count := CountAuthFiles("/nonexistent/path/that/does/not/exist") + if count != 0 { + t.Errorf("CountAuthFiles(nonexistent) = %d, want 0", count) + } + } + + func TestWritablePath(t *testing.T) { + // Save and restore environment + origUpper := os.Getenv("WRITABLE_PATH") + origLower := os.Getenv("writable_path") + defer func() { + os.Setenv("WRITABLE_PATH", origUpper) + os.Setenv("writable_path", origLower) + }() + + // Clear both + os.Unsetenv("WRITABLE_PATH") + os.Unsetenv("writable_path") + + // Test empty + if result := WritablePath(); result != "" { + t.Errorf("WritablePath() = %q, want empty", result) + } + + // Test uppercase + os.Setenv("WRITABLE_PATH", "/tmp/test") + if result := WritablePath(); result != "/tmp/test" { + t.Errorf("WritablePath() = %q, want /tmp/test", result) + } + os.Unsetenv("WRITABLE_PATH") + + // Test lowercase + os.Setenv("writable_path", "/tmp/lower") + if result := WritablePath(); result != "/tmp/lower" { + t.Errorf("WritablePath() = %q, want /tmp/lower", result) + } + } From 0b7e617ee58b2258b2338b240b2ff7f3b0b1c59d Mon Sep 17 00:00:00 2001 From: Jeff Nash <9919536+jeffnash@users.noreply.github.com> Date: Sun, 30 Nov 2025 12:35:48 -0800 Subject: [PATCH 3/5] feat(copilot/auth): implement GitHub Copilot authentication flow - Add device-code OAuth flow with GitHub token exchange - Implement Copilot token acquisition and refresh logic - Add account type handling (individual/business/enterprise) - Add token persistence and storage management - Add CLI login command (cliproxy-api copilot login) - Register Copilot refresh handler in SDK - Validate account_type with warning for invalid values --- internal/auth/copilot/api_config.go | 129 +++++ internal/auth/copilot/auth.go | 497 ++++++++++++++++++++ internal/auth/copilot/errors.go | 74 +++ internal/auth/copilot/errors_test.go | 122 +++++ internal/auth/copilot/storage.go | 84 ++++ internal/auth/copilot/token_helpers.go | 108 +++++ internal/auth/copilot/token_helpers_test.go | 271 +++++++++++ internal/cmd/auth_manager.go | 3 +- internal/cmd/copilot_login.go | 69 +++ sdk/auth/copilot.go | 175 +++++++ sdk/auth/refresh_registry.go | 1 + 11 files changed, 1532 insertions(+), 1 deletion(-) create mode 100644 internal/auth/copilot/api_config.go create mode 100644 internal/auth/copilot/auth.go create mode 100644 internal/auth/copilot/errors.go create mode 100644 internal/auth/copilot/errors_test.go create mode 100644 internal/auth/copilot/storage.go create mode 100644 internal/auth/copilot/token_helpers.go create mode 100644 internal/auth/copilot/token_helpers_test.go create mode 100644 internal/cmd/copilot_login.go create mode 100644 sdk/auth/copilot.go diff --git a/internal/auth/copilot/api_config.go b/internal/auth/copilot/api_config.go new file mode 100644 index 00000000..d81bc6cf --- /dev/null +++ b/internal/auth/copilot/api_config.go @@ -0,0 +1,129 @@ +// Package copilot provides authentication and token management for GitHub Copilot API. +// It handles the OAuth2 device code flow, token exchange, and automatic token refresh. +package copilot + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + + copilotshared "github.com/router-for-me/CLIProxyAPI/v6/internal/copilot" +) + +const ( + GitHubBaseURL = "https://github.com" + GitHubAPIBaseURL = "https://api.github.com" + // GitHubClientID is the PUBLIC OAuth client ID for GitHub Copilot's VS Code extension. + // This is NOT a secret - it's the same client ID used by the official Copilot CLI and + // VS Code extension, publicly visible in their source code and network requests. + GitHubClientID = "Iv1.b507a08c87ecfe98" + GitHubAppScopes = "read:user" + DeviceCodePath = "/login/device/code" + AccessTokenPath = "/login/oauth/access_token" + CopilotTokenPath = "/copilot_internal/v2/token" + CopilotUserPath = "/copilot_internal/user" + UserInfoPath = "/user" + + CopilotVersion = "0.0.363" + EditorPluginVersion = "copilot/" + CopilotVersion + CopilotUserAgent = "copilot/" + CopilotVersion + " (linux v22.15.0)" + CopilotAPIVersion = "2025-05-01" + CopilotIntegrationID = "copilot-developer-cli" + DefaultVSCodeVersion = "1.95.0" +) + +type AccountType = copilotshared.AccountType + +const ( + AccountTypeIndividual AccountType = copilotshared.AccountTypeIndividual + AccountTypeBusiness AccountType = copilotshared.AccountTypeBusiness + AccountTypeEnterprise AccountType = copilotshared.AccountTypeEnterprise +) + +var ValidAccountTypes = copilotshared.ValidAccountTypes + +const DefaultAccountType = copilotshared.DefaultAccountType + +func CopilotBaseURL(accountType AccountType) string { + switch accountType { + case AccountTypeBusiness: + return "https://api.business.githubcopilot.com" + case AccountTypeEnterprise: + return "https://api.enterprise.githubcopilot.com" + default: + // Individual accounts use the individual Copilot endpoint. + return "https://api.individual.githubcopilot.com" + } +} + +func StandardHeaders() map[string]string { + return map[string]string{ + "Content-Type": "application/json", + "Accept": "application/json", + } +} + +func GitHubHeaders(githubToken, vsCodeVersion string) map[string]string { + if vsCodeVersion == "" { + vsCodeVersion = DefaultVSCodeVersion + } + return map[string]string{ + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": fmt.Sprintf("token %s", githubToken), + "Editor-Version": fmt.Sprintf("vscode/%s", vsCodeVersion), + "Editor-Plugin-Version": EditorPluginVersion, + "User-Agent": CopilotUserAgent, + "X-Github-Api-Version": CopilotAPIVersion, + "X-Vscode-User-Agent-Library-Version": "electron-fetch", + } +} + +func CopilotHeaders(copilotToken, vsCodeVersion string, enableVision bool) map[string]string { + if vsCodeVersion == "" { + vsCodeVersion = DefaultVSCodeVersion + } + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", copilotToken), + "Copilot-Integration-Id": CopilotIntegrationID, + "Editor-Version": fmt.Sprintf("vscode/%s", vsCodeVersion), + "Editor-Plugin-Version": EditorPluginVersion, + "User-Agent": CopilotUserAgent, + "Openai-Intent": "conversation-agent", + "X-Github-Api-Version": CopilotAPIVersion, + "X-Request-Id": generateRequestID(), + "X-Interaction-Id": generateRequestID(), + "X-Vscode-User-Agent-Library-Version": "electron-fetch", + } + if enableVision { + headers["Copilot-Vision-Request"] = "true" + } + return headers +} + +func generateRequestID() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + panic(fmt.Sprintf("failed to generate random bytes for request ID: %v", err)) + } + return fmt.Sprintf("%s-%s-%s-%s-%s", + hex.EncodeToString(b[0:4]), + hex.EncodeToString(b[4:6]), + hex.EncodeToString(b[6:8]), + hex.EncodeToString(b[8:10]), + hex.EncodeToString(b[10:16])) +} + +// MaskToken returns a masked version of a token for safe logging. +// Shows first 2 and last 2 characters with asterisks in between. +// Returns "" for empty tokens and "" for tokens under 5 chars. +func MaskToken(token string) string { + if token == "" { + return "" + } + if len(token) < 5 { + return "" + } + return token[:2] + "****" + token[len(token)-2:] +} diff --git a/internal/auth/copilot/auth.go b/internal/auth/copilot/auth.go new file mode 100644 index 00000000..8c463fb0 --- /dev/null +++ b/internal/auth/copilot/auth.go @@ -0,0 +1,497 @@ +package copilot + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + copilotshared "github.com/router-for-me/CLIProxyAPI/v6/internal/copilot" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" + "io" + "net/http" + "net/url" + "time" +) + +// Copilot uses a two-step OAuth (GitHub device code -> Copilot token) plus account-type-specific +// base URLs and strict header requirements. This file centralizes that multi-hop flow so both +// CLI and management endpoints can trigger auth without duplicating device code polling, +// token exchange, or account-type handling. +// DeviceCodeResponse represents the response from GitHub's device code endpoint. +type DeviceCodeResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +// AccessTokenResponse represents the response from GitHub's access token endpoint. +type AccessTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + Error string `json:"error"` +} + +// CopilotTokenResponse represents the response from GitHub's Copilot token endpoint. +type CopilotTokenResponse struct { + Token string `json:"token"` + ExpiresAt int64 `json:"expires_at"` + RefreshIn int `json:"refresh_in"` +} + +// GitHubUserResponse represents the response from GitHub's user endpoint. +type GitHubUserResponse struct { + Login string `json:"login"` + Email string `json:"email"` + Name string `json:"name"` +} + +// CopilotAuth handles the GitHub Copilot OAuth2 device code authentication flow. +type CopilotAuth struct { + httpClient *http.Client + vsCodeVersion string +} + +// NewCopilotAuth creates a new CopilotAuth service instance. +// It initializes an HTTP client with proxy settings from the provided configuration. +func NewCopilotAuth(cfg *config.Config) *CopilotAuth { + return &CopilotAuth{ + httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}), + vsCodeVersion: DefaultVSCodeVersion, + } +} + +// GetDeviceCode initiates the device code flow by requesting a device code from GitHub. +func (a *CopilotAuth) GetDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) { + reqBody := map[string]string{ + "client_id": GitHubClientID, + "scope": GitHubAppScopes, + } + bodyBytes, _ := json.Marshal(reqBody) + + req, err := http.NewRequestWithContext(ctx, "POST", GitHubBaseURL+DeviceCodePath, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrDeviceCodeFailed, err) + } + + for k, v := range StandardHeaders() { + req.Header.Set(k, v) + } + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrDeviceCodeFailed, err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("%w: failed to read response: %v", ErrDeviceCodeFailed, err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%w: status %d: %s", ErrDeviceCodeFailed, resp.StatusCode, string(body)) + } + + var deviceCode DeviceCodeResponse + if err = json.Unmarshal(body, &deviceCode); err != nil { + return nil, fmt.Errorf("%w: failed to parse response: %v", ErrDeviceCodeFailed, err) + } + + log.Debugf("Device code response: %+v", deviceCode) + return &deviceCode, nil +} + +// PollAccessToken polls GitHub for an access token after the user has entered the device code. +// It implements exponential backoff and handles various error conditions. +func (a *CopilotAuth) PollAccessToken(ctx context.Context, deviceCode *DeviceCodeResponse) (string, error) { + if deviceCode == nil { + return "", fmt.Errorf("%w: device code is nil", ErrAccessTokenFailed) + } + + // Add 1 second buffer to the interval + interval := time.Duration(deviceCode.Interval+1) * time.Second + expiry := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second) + + log.Debugf("Polling access token with interval %v", interval) + + for time.Now().Before(expiry) { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(interval): + } + + token, err := a.tryGetAccessToken(ctx, deviceCode.DeviceCode) + if err == nil && token != "" { + return token, nil + } + + switch err { + case ErrAuthorizationPending: + log.Debug("Authorization pending, continuing to poll...") + continue + case ErrSlowDown: + interval += 5 * time.Second + log.Debugf("Slowing down, new interval: %v", interval) + continue + case ErrAccessDenied, ErrExpiredToken: + return "", err + default: + if err != nil { + log.Warnf("Error polling access token: %v", err) + } + } + } + + return "", ErrExpiredToken +} + +func (a *CopilotAuth) tryGetAccessToken(ctx context.Context, deviceCode string) (string, error) { + reqBody := map[string]string{ + "client_id": GitHubClientID, + "device_code": deviceCode, + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + } + bodyBytes, _ := json.Marshal(reqBody) + + req, err := http.NewRequestWithContext(ctx, "POST", GitHubBaseURL+AccessTokenPath, bytes.NewReader(bodyBytes)) + if err != nil { + return "", err + } + + for k, v := range StandardHeaders() { + req.Header.Set(k, v) + } + + resp, err := a.httpClient.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + var tokenResp AccessTokenResponse + if err = json.Unmarshal(body, &tokenResp); err != nil { + // Try parsing as URL-encoded form (GitHub sometimes returns this format) + values, parseErr := url.ParseQuery(string(body)) + if parseErr != nil { + return "", fmt.Errorf("failed to parse token response as JSON (%v) or form-urlencoded (%w)", err, parseErr) + } + tokenResp.AccessToken = values.Get("access_token") + tokenResp.Error = values.Get("error") + } + + log.Debugf("Access token response received (token: %s, error: %s)", MaskToken(tokenResp.AccessToken), tokenResp.Error) + + switch tokenResp.Error { + case "": + if tokenResp.AccessToken != "" { + return tokenResp.AccessToken, nil + } + return "", ErrAuthorizationPending + case "authorization_pending": + return "", ErrAuthorizationPending + case "slow_down": + return "", ErrSlowDown + case "access_denied": + return "", ErrAccessDenied + case "expired_token": + return "", ErrExpiredToken + default: + return "", fmt.Errorf("%w: %s", ErrAccessTokenFailed, tokenResp.Error) + } +} + +// GetCopilotToken exchanges a GitHub access token for a Copilot API token. +func (a *CopilotAuth) GetCopilotToken(ctx context.Context, githubToken string) (*CopilotTokenResponse, error) { + if githubToken == "" { + return nil, ErrNoGitHubToken + } + + req, err := http.NewRequestWithContext(ctx, "GET", GitHubAPIBaseURL+CopilotTokenPath, nil) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrCopilotTokenFailed, err) + } + + for k, v := range GitHubHeaders(githubToken, a.vsCodeVersion) { + req.Header.Set(k, v) + } + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrCopilotTokenFailed, err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("%w: failed to read response: %v", ErrCopilotTokenFailed, err) + } + + // Return structured HTTP errors for auth-related status codes + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + return nil, NewHTTPStatusError(resp.StatusCode, "no Copilot subscription or access denied", ErrNoCopilotSubscription) + } + + if resp.StatusCode != http.StatusOK { + return nil, NewHTTPStatusError(resp.StatusCode, string(body), ErrCopilotTokenFailed) + } + + var tokenResp CopilotTokenResponse + if err = json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("%w: failed to parse response: %v", ErrCopilotTokenFailed, err) + } + + log.Debug("Copilot token fetched successfully") + return &tokenResp, nil +} + +// GetGitHubUser fetches the authenticated user's information from GitHub. +func (a *CopilotAuth) GetGitHubUser(ctx context.Context, githubToken string) (*GitHubUserResponse, error) { + if githubToken == "" { + return nil, ErrNoGitHubToken + } + + req, err := http.NewRequestWithContext(ctx, "GET", GitHubAPIBaseURL+UserInfoPath, nil) + if err != nil { + return nil, err + } + + // Use simpler headers for the GitHub user API - only authorization and standard headers + req.Header.Set("Authorization", fmt.Sprintf("token %s", githubToken)) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", CopilotUserAgent) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("GitHub user API error response: %s", string(body)) + return nil, fmt.Errorf("failed to get user info: status %d, body: %s", resp.StatusCode, string(body)) + } + + var user GitHubUserResponse + if err = json.Unmarshal(body, &user); err != nil { + return nil, err + } + + return &user, nil +} + +// CopilotModel represents a model available through the Copilot API. +type CopilotModel struct { + ID string `json:"id"` + Name string `json:"name"` + Object string `json:"object"` + Version string `json:"version"` + Vendor string `json:"vendor"` + Preview bool `json:"preview"` + ModelPickerEnabled bool `json:"model_picker_enabled"` + Capabilities CopilotCapabilities `json:"capabilities"` +} + +// CopilotCapabilities describes the capabilities of a Copilot model. +type CopilotCapabilities struct { + Family string `json:"family"` + Type string `json:"type"` + Tokenizer string `json:"tokenizer"` + Limits CopilotLimits `json:"limits"` + Supports CopilotSupports `json:"supports"` +} + +// CopilotLimits describes the token limits for a Copilot model. +type CopilotLimits struct { + MaxContextWindowTokens int `json:"max_context_window_tokens"` + MaxOutputTokens int `json:"max_output_tokens"` + MaxPromptTokens int `json:"max_prompt_tokens"` +} + +// CopilotSupports describes the features supported by a Copilot model. +type CopilotSupports struct { + ToolCalls bool `json:"tool_calls"` + ParallelToolCalls bool `json:"parallel_tool_calls"` +} + +// CopilotModelsResponse represents the response from the Copilot models endpoint. +type CopilotModelsResponse struct { + Data []CopilotModel `json:"data"` + Object string `json:"object"` +} + +// GetModels fetches the available models from the Copilot API. +func (a *CopilotAuth) GetModels(ctx context.Context, copilotToken string, accountType AccountType) (*CopilotModelsResponse, error) { + if copilotToken == "" { + return nil, ErrNoCopilotToken + } + + baseURL := CopilotBaseURL(accountType) + req, err := http.NewRequestWithContext(ctx, "GET", baseURL+"/models", nil) + if err != nil { + return nil, fmt.Errorf("failed to create models request: %w", err) + } + + for k, v := range CopilotHeaders(copilotToken, a.vsCodeVersion, false) { + req.Header.Set(k, v) + } + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("models request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read models response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("models request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var modelsResp CopilotModelsResponse + if err = json.Unmarshal(body, &modelsResp); err != nil { + return nil, fmt.Errorf("failed to parse models response: %w", err) + } + + log.Debugf("Fetched %d models from Copilot API", len(modelsResp.Data)) + return &modelsResp, nil +} + +// RefreshCopilotToken refreshes the Copilot token using the stored GitHub token. +func (a *CopilotAuth) RefreshCopilotToken(ctx context.Context, storage *CopilotTokenStorage) error { + if storage == nil || storage.GitHubToken == "" { + return ErrNoGitHubToken + } + + tokenResp, err := a.GetCopilotToken(ctx, storage.GitHubToken) + if err != nil { + return err + } + + storage.CopilotToken = tokenResp.Token + storage.CopilotTokenExpiry = time.Unix(tokenResp.ExpiresAt, 0).Format(time.RFC3339) + storage.RefreshIn = tokenResp.RefreshIn + storage.LastRefresh = time.Now().Format(time.RFC3339) + + return nil +} + +// PerformFullAuth performs the complete authentication flow. +func (a *CopilotAuth) PerformFullAuth(ctx context.Context, accountType AccountType, onDeviceCode func(*DeviceCodeResponse)) (*CopilotTokenStorage, error) { + deviceCode, err := a.GetDeviceCode(ctx) + if err != nil { + return nil, err + } + + if onDeviceCode != nil { + onDeviceCode(deviceCode) + } + + result, err := a.finalizeAuth(ctx, deviceCode, accountType) + if err != nil { + return nil, err + } + return result.Storage, nil +} + +// AccountTypeValidationResult aliases the shared validation result type. +type AccountTypeValidationResult = copilotshared.AccountTypeValidationResult + +// ParseAccountType delegates to the shared Copilot account type parser. +func ParseAccountType(s string) (AccountType, bool) { return copilotshared.ParseAccountType(s) } + +// ValidateAccountType delegates to the shared Copilot account type validator. +func ValidateAccountType(s string) AccountTypeValidationResult { + return copilotshared.ValidateAccountType(s) +} + +type AuthResult struct { + // Storage contains the token data to be persisted. + Storage *CopilotTokenStorage + // SuggestedFilename is the recommended filename for saving the token. + SuggestedFilename string +} + +func (a *CopilotAuth) PerformFullAuthWithFilename(ctx context.Context, accountType AccountType, onDeviceCode func(*DeviceCodeResponse)) (*AuthResult, error) { + deviceCode, err := a.GetDeviceCode(ctx) + if err != nil { + return nil, err + } + + if onDeviceCode != nil { + onDeviceCode(deviceCode) + } + + return a.finalizeAuth(ctx, deviceCode, accountType) +} + +func (a *CopilotAuth) CompleteAuthWithDeviceCode(ctx context.Context, deviceCode *DeviceCodeResponse, accountType AccountType) (*AuthResult, error) { + return a.finalizeAuth(ctx, deviceCode, accountType) +} + +// finalizeAuth performs: Poll -> Exchange -> User Info -> Storage Build -> Filename Gen +func (a *CopilotAuth) finalizeAuth(ctx context.Context, deviceCode *DeviceCodeResponse, accountType AccountType) (*AuthResult, error) { + // 1. Poll GitHub Token + githubToken, err := a.PollAccessToken(ctx, deviceCode) + if err != nil { + return nil, fmt.Errorf("failed to obtain GitHub token: %w", err) + } + log.Info("GitHub authentication successful") + + // 2. Exchange for Copilot Token + copilotTokenResp, err := a.GetCopilotToken(ctx, githubToken) + if err != nil { + return nil, fmt.Errorf("failed to obtain Copilot token: %w", err) + } + log.Info("Copilot token obtained successfully") + + // 3. Get User Info (best effort) + userInfo, err := a.GetGitHubUser(ctx, githubToken) + if err != nil { + log.Warnf("Failed to get user info: %v", err) + userInfo = &GitHubUserResponse{} + } + if userInfo == nil { + userInfo = &GitHubUserResponse{} + } + + // 4. Build Storage + storage := &CopilotTokenStorage{ + GitHubToken: githubToken, + CopilotToken: copilotTokenResp.Token, + CopilotTokenExpiry: time.Unix(copilotTokenResp.ExpiresAt, 0).Format(time.RFC3339), + AccountType: string(accountType), + Username: userInfo.Login, + Email: userInfo.Email, + RefreshIn: copilotTokenResp.RefreshIn, + Type: "copilot", + LastRefresh: time.Now().Format(time.RFC3339), + } + + if userInfo.Login != "" { + log.Infof("Logged in as %s", userInfo.Login) + } + + // 5. Generate Filename + filename := fmt.Sprintf("copilot_%s_%s.json", accountType, userInfo.Login) + + return &AuthResult{Storage: storage, SuggestedFilename: filename}, nil +} diff --git a/internal/auth/copilot/errors.go b/internal/auth/copilot/errors.go new file mode 100644 index 00000000..1ba93003 --- /dev/null +++ b/internal/auth/copilot/errors.go @@ -0,0 +1,74 @@ +package copilot + +import ( + "errors" + "fmt" +) + +var ( + // ErrDeviceCodeFailed indicates failure to obtain a device code. + ErrDeviceCodeFailed = errors.New("failed to get device code") + + // ErrAccessTokenFailed indicates failure to obtain an access token. + ErrAccessTokenFailed = errors.New("failed to get access token") + + // ErrCopilotTokenFailed indicates failure to obtain a Copilot token. + ErrCopilotTokenFailed = errors.New("failed to get Copilot token") + + // ErrTokenExpired indicates the Copilot token has expired. + ErrTokenExpired = errors.New("copilot token has expired") + + // ErrNoGitHubToken indicates no GitHub token is available. + ErrNoGitHubToken = errors.New("no GitHub token available") + + // ErrNoCopilotToken indicates no Copilot token is available. + ErrNoCopilotToken = errors.New("no Copilot token available") + + // ErrAuthorizationPending indicates the user has not yet completed authorization. + ErrAuthorizationPending = errors.New("authorization pending") + + // ErrSlowDown indicates the polling interval should be increased. + ErrSlowDown = errors.New("slow down polling") + + // ErrAccessDenied indicates the user denied access. + ErrAccessDenied = errors.New("access denied by user") + + // ErrExpiredToken indicates the device code has expired. + ErrExpiredToken = errors.New("device code expired") + + // ErrNoCopilotSubscription indicates the user does not have a Copilot subscription. + ErrNoCopilotSubscription = errors.New("no Copilot subscription found") +) + +// HTTPStatusError wraps an error with an HTTP status code for structured error handling. +// This allows callers to inspect the status code without parsing error message strings. +type HTTPStatusError struct { + StatusCode int + Message string + Cause error +} + +func (e *HTTPStatusError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("status %d: %s: %v", e.StatusCode, e.Message, e.Cause) + } + return fmt.Sprintf("status %d: %s", e.StatusCode, e.Message) +} + +func (e *HTTPStatusError) Unwrap() error { + return e.Cause +} + +// NewHTTPStatusError creates a new HTTPStatusError with the given status code and message. +func NewHTTPStatusError(statusCode int, message string, cause error) *HTTPStatusError { + return &HTTPStatusError{StatusCode: statusCode, Message: message, Cause: cause} +} + +// StatusCode extracts the HTTP status code from an HTTPStatusError, or returns 0 if not applicable. +func StatusCode(err error) int { + var httpErr *HTTPStatusError + if errors.As(err, &httpErr) { + return httpErr.StatusCode + } + return 0 +} diff --git a/internal/auth/copilot/errors_test.go b/internal/auth/copilot/errors_test.go new file mode 100644 index 00000000..9786e7a4 --- /dev/null +++ b/internal/auth/copilot/errors_test.go @@ -0,0 +1,122 @@ + package copilot + + import ( + "errors" + "testing" + ) + + func TestHTTPStatusError_Error(t *testing.T) { + tests := []struct { + name string + err *HTTPStatusError + contains string + }{ + { + name: "without cause", + err: NewHTTPStatusError(401, "unauthorized", nil), + contains: "status 401: unauthorized", + }, + { + name: "with cause", + err: NewHTTPStatusError(500, "internal error", errors.New("database error")), + contains: "database error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errStr := tt.err.Error() + if !contains(errStr, tt.contains) { + t.Errorf("HTTPStatusError.Error() = %q, want to contain %q", errStr, tt.contains) + } + }) + } + } + + func TestHTTPStatusError_Unwrap(t *testing.T) { + cause := errors.New("original error") + err := NewHTTPStatusError(500, "wrapped", cause) + + unwrapped := err.Unwrap() + if unwrapped != cause { + t.Errorf("HTTPStatusError.Unwrap() = %v, want %v", unwrapped, cause) + } + } + + func TestStatusCode(t *testing.T) { + tests := []struct { + name string + err error + want int + }{ + { + name: "HTTPStatusError", + err: NewHTTPStatusError(404, "not found", nil), + want: 404, + }, + { + name: "wrapped HTTPStatusError", + err: errors.New("outer: " + NewHTTPStatusError(403, "forbidden", nil).Error()), + want: 0, // Can't unwrap from string concatenation + }, + { + name: "regular error", + err: errors.New("regular error"), + want: 0, + }, + { + name: "nil error", + err: nil, + want: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := StatusCode(tt.err) + if got != tt.want { + t.Errorf("StatusCode() = %d, want %d", got, tt.want) + } + }) + } + } + + func TestSentinelErrors(t *testing.T) { + // Verify sentinel errors are non-nil and have meaningful messages + sentinels := []error{ + ErrDeviceCodeFailed, + ErrAccessTokenFailed, + ErrCopilotTokenFailed, + ErrTokenExpired, + ErrNoGitHubToken, + ErrNoCopilotToken, + ErrAuthorizationPending, + ErrSlowDown, + ErrAccessDenied, + ErrExpiredToken, + ErrNoCopilotSubscription, + } + + for _, err := range sentinels { + if err == nil { + t.Error("sentinel error is nil") + } + if err.Error() == "" { + t.Error("sentinel error has empty message") + } + } + } + + func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > 0 && len(substr) > 0 && findSubstring(s, substr))) + } + + func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false + } diff --git a/internal/auth/copilot/storage.go b/internal/auth/copilot/storage.go new file mode 100644 index 00000000..01adea71 --- /dev/null +++ b/internal/auth/copilot/storage.go @@ -0,0 +1,84 @@ +package copilot + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" +) + +// CopilotTokenStorage stores authentication tokens for GitHub Copilot API. +// It maintains both the GitHub OAuth token and the Copilot-specific token, +// along with metadata for token refresh and user identification. +// +// Note on AccountType: This field is used to persist the account type to disk and +// to seed coreauth.Auth.Attributes["account_type"] during login. At runtime, executor +// logic should read from Attributes["account_type"], not from this storage field directly. +// See sdk/auth/copilot.go for the canonical source of truth on precedence and runtime contracts. +type CopilotTokenStorage struct { + // GitHubToken is the OAuth access token from GitHub device code flow. + GitHubToken string `json:"github_token"` + + // CopilotToken is the bearer token for Copilot API requests. + // Note: marked as "-" to prevent persistence to disk. + CopilotToken string `json:"-"` + + // CopilotTokenExpiry is the RFC3339 timestamp when the Copilot token expires. + // Note: marked as "-" to prevent persistence to disk. + CopilotTokenExpiry string `json:"-"` + + // RefreshIn is the number of seconds after which the token should be refreshed. + // Note: marked as "-" to prevent persistence to disk. + RefreshIn int `json:"-"` + + // AccountType is the Copilot subscription type (individual, business, enterprise). + // This is persisted for storage but Attributes["account_type"] is authoritative at runtime. + AccountType string `json:"account_type"` + + // Email is the GitHub account email address. + Email string `json:"email"` + + // Username is the GitHub username. + Username string `json:"username"` + + // LastRefresh is the RFC3339 timestamp of the last token refresh. + // Note: marked as "-" to prevent persistence to disk. + LastRefresh string `json:"-"` + + // Type indicates the authentication provider type, always "copilot" for this storage. + Type string `json:"type"` +} + +// SaveTokenToFile serializes the Copilot token storage to a JSON file. +// This method creates the necessary directory structure and writes the token +// data in JSON format to the specified file path for persistent storage. +// +// Parameters: +// - authFilePath: The full path where the token file should be saved +// +// Returns: +// - error: An error if the operation fails, nil otherwise +func (ts *CopilotTokenStorage) SaveTokenToFile(authFilePath string) error { + misc.LogSavingCredentials(authFilePath) + ts.Type = "copilot" + if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { + return fmt.Errorf("failed to create directory: %v", err) + } + + f, err := os.Create(authFilePath) + if err != nil { + return fmt.Errorf("failed to create token file: %w", err) + } + defer func() { + _ = f.Close() + }() + + encoder := json.NewEncoder(f) + encoder.SetIndent("", " ") + if err = encoder.Encode(ts); err != nil { + return fmt.Errorf("failed to write token to file: %w", err) + } + return nil +} diff --git a/internal/auth/copilot/token_helpers.go b/internal/auth/copilot/token_helpers.go new file mode 100644 index 00000000..92112a04 --- /dev/null +++ b/internal/auth/copilot/token_helpers.go @@ -0,0 +1,108 @@ +package copilot + +import ( + "time" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +// ApplyTokenRefresh updates auth metadata and storage with new Copilot token data. +// This is the single source of truth for post-refresh mutations. +func ApplyTokenRefresh(auth *coreauth.Auth, tokenResp *CopilotTokenResponse, now time.Time) { + if auth == nil || tokenResp == nil { + return + } + + expiryStr := time.Unix(tokenResp.ExpiresAt, 0).Format(time.RFC3339) + lastRefreshStr := now.Format(time.RFC3339) + + // Update metadata (runtime cache) + if auth.Metadata == nil { + auth.Metadata = make(map[string]any) + } + auth.Metadata["copilot_token"] = tokenResp.Token + auth.Metadata["copilot_token_expiry"] = expiryStr + auth.Metadata["type"] = "copilot" + + // Update storage (persistence layer) + if storage, ok := auth.Storage.(*CopilotTokenStorage); ok && storage != nil { + storage.CopilotToken = tokenResp.Token + storage.CopilotTokenExpiry = expiryStr + storage.RefreshIn = tokenResp.RefreshIn + storage.LastRefresh = lastRefreshStr + } + + auth.LastRefreshedAt = now +} + +// ResolveAccountType extracts the account type using canonical precedence: +// 1. Attributes (Runtime) 2. Storage (Fallback) 3. Default +func ResolveAccountType(auth *coreauth.Auth) AccountType { + if auth == nil { + return AccountTypeIndividual + } + + if auth.Attributes != nil { + if at, ok := auth.Attributes["account_type"]; ok && at != "" { + parsed, _ := ParseAccountType(at) + return parsed + } + } + if storage, ok := auth.Storage.(*CopilotTokenStorage); ok && storage != nil && storage.AccountType != "" { + parsed, _ := ParseAccountType(storage.AccountType) + return parsed + } + return AccountTypeIndividual +} + +// ResolveGitHubToken extracts the GitHub OAuth token (Metadata > Storage). +func ResolveGitHubToken(auth *coreauth.Auth) string { + if auth == nil { + return "" + } + if auth.Metadata != nil { + if v, ok := auth.Metadata["github_token"].(string); ok && v != "" { + return v + } + } + if storage, ok := auth.Storage.(*CopilotTokenStorage); ok && storage != nil { + return storage.GitHubToken + } + return "" +} + +// ResolveCopilotToken extracts the cached Copilot token and expiry from metadata. +func ResolveCopilotToken(auth *coreauth.Auth) (token string, expiry time.Time, ok bool) { + if auth == nil || auth.Metadata == nil { + return "", time.Time{}, false + } + + token, _ = auth.Metadata["copilot_token"].(string) + expiryStr, _ := auth.Metadata["copilot_token_expiry"].(string) + + if token == "" || expiryStr == "" { + return "", time.Time{}, false + } + + expiry, err := time.Parse(time.RFC3339, expiryStr) + if err != nil { + return "", time.Time{}, false + } + return token, expiry, true +} + +// EnsureMetadataHydrated hydrates metadata from storage if needed. +func EnsureMetadataHydrated(auth *coreauth.Auth) { + if auth == nil { + return + } + if auth.Metadata == nil { + auth.Metadata = make(map[string]any) + } + + if _, ok := auth.Metadata["github_token"].(string); !ok { + if storage, ok := auth.Storage.(*CopilotTokenStorage); ok && storage != nil && storage.GitHubToken != "" { + auth.Metadata["github_token"] = storage.GitHubToken + } + } +} diff --git a/internal/auth/copilot/token_helpers_test.go b/internal/auth/copilot/token_helpers_test.go new file mode 100644 index 00000000..6f104f16 --- /dev/null +++ b/internal/auth/copilot/token_helpers_test.go @@ -0,0 +1,271 @@ + package copilot + + import ( + "testing" + "time" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + ) + + func TestResolveAccountType(t *testing.T) { + tests := []struct { + name string + auth *coreauth.Auth + want AccountType + }{ + { + name: "nil auth", + auth: nil, + want: AccountTypeIndividual, + }, + { + name: "empty auth", + auth: &coreauth.Auth{}, + want: AccountTypeIndividual, + }, + { + name: "from attributes - individual", + auth: &coreauth.Auth{ + Attributes: map[string]string{"account_type": "individual"}, + }, + want: AccountTypeIndividual, + }, + { + name: "from attributes - business", + auth: &coreauth.Auth{ + Attributes: map[string]string{"account_type": "business"}, + }, + want: AccountTypeBusiness, + }, + { + name: "from attributes - enterprise", + auth: &coreauth.Auth{ + Attributes: map[string]string{"account_type": "enterprise"}, + }, + want: AccountTypeEnterprise, + }, + { + name: "from storage when attributes empty", + auth: &coreauth.Auth{ + Storage: &CopilotTokenStorage{AccountType: "business"}, + }, + want: AccountTypeBusiness, + }, + { + name: "attributes take precedence over storage", + auth: &coreauth.Auth{ + Attributes: map[string]string{"account_type": "enterprise"}, + Storage: &CopilotTokenStorage{AccountType: "business"}, + }, + want: AccountTypeEnterprise, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ResolveAccountType(tt.auth) + if got != tt.want { + t.Errorf("ResolveAccountType() = %v, want %v", got, tt.want) + } + }) + } + } + + func TestResolveGitHubToken(t *testing.T) { + tests := []struct { + name string + auth *coreauth.Auth + want string + }{ + { + name: "nil auth", + auth: nil, + want: "", + }, + { + name: "empty auth", + auth: &coreauth.Auth{}, + want: "", + }, + { + name: "from metadata", + auth: &coreauth.Auth{ + Metadata: map[string]any{"github_token": "ghp_test123"}, + }, + want: "ghp_test123", + }, + { + name: "from storage when metadata empty", + auth: &coreauth.Auth{ + Storage: &CopilotTokenStorage{GitHubToken: "ghp_storage456"}, + }, + want: "ghp_storage456", + }, + { + name: "metadata takes precedence over storage", + auth: &coreauth.Auth{ + Metadata: map[string]any{"github_token": "ghp_metadata"}, + Storage: &CopilotTokenStorage{GitHubToken: "ghp_storage"}, + }, + want: "ghp_metadata", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ResolveGitHubToken(tt.auth) + if got != tt.want { + t.Errorf("ResolveGitHubToken() = %q, want %q", got, tt.want) + } + }) + } + } + + func TestResolveCopilotToken(t *testing.T) { + now := time.Now() + expiryStr := now.Add(time.Hour).Format(time.RFC3339) + + tests := []struct { + name string + auth *coreauth.Auth + wantToken string + wantOK bool + }{ + { + name: "nil auth", + auth: nil, + wantToken: "", + wantOK: false, + }, + { + name: "nil metadata", + auth: &coreauth.Auth{}, + wantToken: "", + wantOK: false, + }, + { + name: "valid token and expiry", + auth: &coreauth.Auth{ + Metadata: map[string]any{ + "copilot_token": "test_token", + "copilot_token_expiry": expiryStr, + }, + }, + wantToken: "test_token", + wantOK: true, + }, + { + name: "missing token", + auth: &coreauth.Auth{ + Metadata: map[string]any{ + "copilot_token_expiry": expiryStr, + }, + }, + wantToken: "", + wantOK: false, + }, + { + name: "missing expiry", + auth: &coreauth.Auth{ + Metadata: map[string]any{ + "copilot_token": "test_token", + }, + }, + wantToken: "", + wantOK: false, + }, + { + name: "invalid expiry format", + auth: &coreauth.Auth{ + Metadata: map[string]any{ + "copilot_token": "test_token", + "copilot_token_expiry": "invalid", + }, + }, + wantToken: "", + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, _, ok := ResolveCopilotToken(tt.auth) + if token != tt.wantToken { + t.Errorf("ResolveCopilotToken() token = %q, want %q", token, tt.wantToken) + } + if ok != tt.wantOK { + t.Errorf("ResolveCopilotToken() ok = %v, want %v", ok, tt.wantOK) + } + }) + } + } + + func TestEnsureMetadataHydrated(t *testing.T) { + t.Run("nil auth", func(t *testing.T) { + // Should not panic + EnsureMetadataHydrated(nil) + }) + + t.Run("creates metadata if nil", func(t *testing.T) { + auth := &coreauth.Auth{ + Storage: &CopilotTokenStorage{GitHubToken: "ghp_test"}, + } + EnsureMetadataHydrated(auth) + if auth.Metadata == nil { + t.Error("EnsureMetadataHydrated() did not create metadata map") + } + if auth.Metadata["github_token"] != "ghp_test" { + t.Errorf("EnsureMetadataHydrated() github_token = %v, want ghp_test", auth.Metadata["github_token"]) + } + }) + + t.Run("does not overwrite existing token", func(t *testing.T) { + auth := &coreauth.Auth{ + Metadata: map[string]any{"github_token": "existing"}, + Storage: &CopilotTokenStorage{GitHubToken: "from_storage"}, + } + EnsureMetadataHydrated(auth) + if auth.Metadata["github_token"] != "existing" { + t.Errorf("EnsureMetadataHydrated() overwrote existing token") + } + }) + } + + func TestApplyTokenRefresh(t *testing.T) { + now := time.Now() + tokenResp := &CopilotTokenResponse{ + Token: "new_token", + ExpiresAt: now.Add(time.Hour).Unix(), + RefreshIn: 3600, + } + + t.Run("nil auth", func(t *testing.T) { + // Should not panic + ApplyTokenRefresh(nil, tokenResp, now) + }) + + t.Run("nil token response", func(t *testing.T) { + auth := &coreauth.Auth{} + // Should not panic + ApplyTokenRefresh(auth, nil, now) + }) + + t.Run("updates metadata and storage", func(t *testing.T) { + storage := &CopilotTokenStorage{} + auth := &coreauth.Auth{ + Storage: storage, + } + + ApplyTokenRefresh(auth, tokenResp, now) + + if auth.Metadata["copilot_token"] != "new_token" { + t.Errorf("ApplyTokenRefresh() metadata copilot_token = %v, want new_token", auth.Metadata["copilot_token"]) + } + if storage.CopilotToken != "new_token" { + t.Errorf("ApplyTokenRefresh() storage CopilotToken = %v, want new_token", storage.CopilotToken) + } + if auth.LastRefreshedAt != now { + t.Errorf("ApplyTokenRefresh() LastRefreshedAt = %v, want %v", auth.LastRefreshedAt, now) + } + }) + } diff --git a/internal/cmd/auth_manager.go b/internal/cmd/auth_manager.go index e6caa954..7b919435 100644 --- a/internal/cmd/auth_manager.go +++ b/internal/cmd/auth_manager.go @@ -6,7 +6,7 @@ import ( // newAuthManager creates a new authentication manager instance with all supported // authenticators and a file-based token store. It initializes authenticators for -// Gemini, Codex, Claude, and Qwen providers. +// Gemini, Codex, Claude, Qwen, Copilot, and other providers. // // Returns: // - *sdkAuth.Manager: A configured authentication manager instance @@ -19,6 +19,7 @@ func newAuthManager() *sdkAuth.Manager { sdkAuth.NewQwenAuthenticator(), sdkAuth.NewIFlowAuthenticator(), sdkAuth.NewAntigravityAuthenticator(), + sdkAuth.NewCopilotAuthenticator(), ) return manager } diff --git a/internal/cmd/copilot_login.go b/internal/cmd/copilot_login.go new file mode 100644 index 00000000..17d3d60f --- /dev/null +++ b/internal/cmd/copilot_login.go @@ -0,0 +1,69 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" +) + +// DoCopilotLogin triggers the GitHub Copilot device code OAuth flow. +// It initiates the OAuth authentication process for GitHub Copilot services and saves +// the authentication tokens to the configured auth directory. +// +// Account type selection: When cfg.CopilotKey is configured, this function uses the +// account_type from the FIRST entry (cfg.CopilotKey[0].AccountType) to determine +// whether to authenticate as individual, business, or enterprise. If no CopilotKey +// entries exist or the first entry has no account_type, it defaults to "individual". +// +// To use a different account type, either: +// - Reorder the copilot-api-key entries in config.yaml so the desired one is first +// - Or use the management API /copilot-auth-url endpoint with explicit account_type param +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including browser behavior and prompts +func DoCopilotLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + } + + // Use account type from first CopilotKey entry if configured, with validation + if len(cfg.CopilotKey) > 0 && cfg.CopilotKey[0].AccountType != "" { + accountTypeStr := cfg.CopilotKey[0].AccountType + validation := copilot.ValidateAccountType(accountTypeStr) + if !validation.Valid { + fmt.Printf("Warning: %s\n", validation.ErrorMessage) + } else { + authOpts.Metadata["account_type"] = accountTypeStr + } + } + + // Create a context that cancels on SIGINT/SIGTERM for graceful abort + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + _, savedPath, err := manager.Login(ctx, "copilot", cfg, authOpts) + if err != nil { + fmt.Printf("Copilot authentication failed: %v\n", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + fmt.Println("Copilot authentication successful!") +} diff --git a/sdk/auth/copilot.go b/sdk/auth/copilot.go new file mode 100644 index 00000000..9355a744 --- /dev/null +++ b/sdk/auth/copilot.go @@ -0,0 +1,175 @@ +package auth + +import ( + "context" + "fmt" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +// Copilot account_type Precedence +// +// The account_type determines which GitHub Copilot API endpoints are used +// (individual vs business/enterprise). The precedence for account_type is: +// +// 1. Auth.Attributes["account_type"] - CANONICAL RUNTIME SOURCE +// Executors must always read from Attributes, never from storage or config directly. +// +// 2. CopilotTokenStorage.AccountType - Initial seed value +// Used only to populate Attributes when creating a new auth entry. +// Storage should NOT overwrite a non-empty Attributes value on reload. +// +// 3. Config (copilot-api-key[].account-type) - Default for new logins +// Used only during initial OAuth login to seed the storage value. +// +// This precedence ensures stable base-URL selection across reloads and prevents +// oscillation between account types. See internal/watcher/watcher.go SnapshotCoreAuths +// for the reload logic that enforces this precedence. + +// CopilotAuthenticator implements the GitHub device code OAuth login flow for Copilot. +type CopilotAuthenticator struct { + AccountType copilot.AccountType +} + +// NewCopilotAuthenticator constructs a Copilot authenticator with default settings. +func NewCopilotAuthenticator() *CopilotAuthenticator { + return &CopilotAuthenticator{ + AccountType: copilot.AccountTypeIndividual, + } +} + +// NewCopilotAuthenticatorWithAccountType constructs a Copilot authenticator with specified account type. +func NewCopilotAuthenticatorWithAccountType(accountType string) *CopilotAuthenticator { + parsed, _ := copilot.ParseAccountType(accountType) + return &CopilotAuthenticator{ + AccountType: parsed, + } +} + +func (a *CopilotAuthenticator) Provider() string { + return "copilot" +} + +func (a *CopilotAuthenticator) RefreshLead() *time.Duration { + // Copilot tokens typically expire in ~30 minutes, refresh 5 minutes before + d := 5 * time.Minute + return &d +} + +func (a *CopilotAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("copilot auth: configuration is required") + } + if ctx == nil { + ctx = context.Background() + } + if opts == nil { + opts = &LoginOptions{} + } + + // Check for account type override in metadata + accountType := a.AccountType + if opts.Metadata != nil { + if at, ok := opts.Metadata["account_type"]; ok && at != "" { + parsed, valid := copilot.ParseAccountType(at) + if !valid { + log.Warnf("Invalid account_type '%s' provided in login options, defaulting to '%s'", at, copilot.DefaultAccountType) + } + accountType = parsed + } +} + + authSvc := copilot.NewCopilotAuth(cfg) + + // Use the shared helper that performs the complete auth flow and returns + // both the token storage and suggested filename. + result, err := authSvc.PerformFullAuthWithFilename(ctx, accountType, func(dc *copilot.DeviceCodeResponse) { + fmt.Printf("\n=== GitHub Copilot Authentication ===\n") + fmt.Printf("Please enter the code: \n\n %s\n\n", dc.UserCode) + fmt.Printf("At: %s\n\n", dc.VerificationURI) + + if !opts.NoBrowser { + if !browser.IsAvailable() { + log.Warn("No browser available; please open the URL manually") + util.PrintSSHTunnelInstructions(0) + } else if err := browser.OpenURL(dc.VerificationURI); err != nil { + log.Warnf("Failed to open browser automatically: %v", err) + util.PrintSSHTunnelInstructions(0) + } else { + fmt.Println("Browser opened. Please complete authentication...") + } + } + fmt.Println("Waiting for authentication...") + }) + + if err != nil { + return nil, fmt.Errorf("copilot authentication failed: %w", err) + } + + if result == nil || result.Storage == nil { + return nil, fmt.Errorf("copilot authentication failed: no token storage returned") + } + + tokenStorage := result.Storage + fileName := result.SuggestedFilename + + metadata := map[string]any{ + "email": tokenStorage.Email, + "username": tokenStorage.Username, + "account_type": tokenStorage.AccountType, + "copilot_token_expiry": tokenStorage.CopilotTokenExpiry, + "github_token": tokenStorage.GitHubToken, + "copilot_token": tokenStorage.CopilotToken, + "type": "copilot", + } + + fmt.Printf("\nCopilot authentication successful!\n") + if tokenStorage.Username != "" { + fmt.Printf("Logged in as: %s\n", tokenStorage.Username) + } + + return &coreauth.Auth{ + ID: fileName, + Provider: a.Provider(), + FileName: fileName, + Storage: tokenStorage, + Metadata: metadata, + // Attributes["account_type"] is the single canonical source of truth for account type at runtime. + // CopilotTokenStorage.AccountType and config are used only to seed this initial value. + // On subsequent reloads, storage and config must NOT overwrite a non-empty Attributes["account_type"]. + // Executor logic should always read from Attributes, not Storage or config directly. + Attributes: map[string]string{ + "account_type": tokenStorage.AccountType, + }, + }, nil +} + +// RefreshToken refreshes the Copilot token for an existing auth entry. +func (a *CopilotAuthenticator) RefreshToken(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) error { + if auth == nil || auth.Metadata == nil { + return fmt.Errorf("copilot refresh: invalid auth entry") + } + + githubToken, ok := auth.Metadata["github_token"].(string) + if !ok || githubToken == "" { + return fmt.Errorf("copilot refresh: missing github token") + } + + authSvc := copilot.NewCopilotAuth(cfg) + + tokenResp, err := authSvc.GetCopilotToken(ctx, githubToken) + if err != nil { + return fmt.Errorf("copilot refresh failed: %w", err) + } + + copilot.ApplyTokenRefresh(auth, tokenResp, time.Now()) + log.Debug("Copilot token refreshed successfully") + + return nil +} diff --git a/sdk/auth/refresh_registry.go b/sdk/auth/refresh_registry.go index e82ac684..64d95f45 100644 --- a/sdk/auth/refresh_registry.go +++ b/sdk/auth/refresh_registry.go @@ -14,6 +14,7 @@ func init() { registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() }) + registerRefreshLead("copilot", func() Authenticator { return NewCopilotAuthenticator() }) } func registerRefreshLead(provider string, factory func() Authenticator) { From 11a07b69ec346010803db8ef1c8e974cddee71af Mon Sep 17 00:00:00 2001 From: Jeff Nash <9919536+jeffnash@users.noreply.github.com> Date: Sun, 30 Nov 2025 13:10:35 -0800 Subject: [PATCH 4/5] feat(copilot/executor): add Copilot request executor with core models - Implement Copilot executor with header injection - Add VS Code version headers and integration IDs - Add agent header logic (X-Initiator detection) - Add vision request header for image inputs - Add static model registry with raptor-mini and oswe-vscode-prime - Add management API endpoints for auth files Note: Full model list and dynamic fetching added in gemini branch --- cmd/server/main.go | 5 + .../api/handlers/management/auth_files.go | 116 ++++ .../api/handlers/management/config_lists.go | 105 ++++ internal/api/server.go | 6 + internal/registry/copilot_models.go | 58 ++ internal/runtime/executor/copilot_executor.go | 509 ++++++++++++++++++ .../runtime/executor/copilot_executor_test.go | 150 ++++++ internal/runtime/executor/copilot_headers.go | 171 ++++++ .../runtime/executor/copilot_headers_test.go | 164 ++++++ sdk/cliproxy/service.go | 16 + 10 files changed, 1300 insertions(+) create mode 100644 internal/registry/copilot_models.go create mode 100644 internal/runtime/executor/copilot_executor.go create mode 100644 internal/runtime/executor/copilot_executor_test.go create mode 100644 internal/runtime/executor/copilot_headers.go create mode 100644 internal/runtime/executor/copilot_headers_test.go diff --git a/cmd/server/main.go b/cmd/server/main.go index bbf500e7..d8e6bc92 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -56,6 +56,7 @@ func main() { // Command-line flags to control the application's behavior. var login bool var codexLogin bool + var copilotLogin bool var claudeLogin bool var qwenLogin bool var iflowLogin bool @@ -70,6 +71,7 @@ func main() { // Define command-line flags for different operation modes. flag.BoolVar(&login, "login", false, "Login Google Account") flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth") + flag.BoolVar(&copilotLogin, "copilot-login", false, "Login to GitHub Copilot using device code flow") flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth") flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth") flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth") @@ -439,6 +441,9 @@ func main() { } else if codexLogin { // Handle Codex login cmd.DoCodexLogin(cfg, options) + } else if copilotLogin { + // Handle GitHub Copilot login + cmd.DoCopilotLogin(cfg, options) } else if claudeLogin { // Handle Claude login cmd.DoClaudeLogin(cfg, options) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 824e3fb0..173e4e50 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -21,6 +21,7 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini" iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" @@ -2089,6 +2090,121 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec return true, nil } +// validateCopilotAccountType validates the account_type query parameter. +// If invalid, it writes a 400 error to the context and returns false. +func validateCopilotAccountType(c *gin.Context) (copilot.AccountType, bool) { + accountTypeStr := c.DefaultQuery("account_type", "individual") + validation := copilot.ValidateAccountType(accountTypeStr) + if !validation.Valid { + c.JSON(http.StatusBadRequest, gin.H{ + "error": validation.ErrorMessage, + "valid_values": validation.ValidValues, + "default": validation.DefaultValue, + }) + return "", false + } + return validation.AccountType, true +} + +// startCopilotAuthFlow starts the background polling for Copilot authentication. +func (h *Handler) startCopilotAuthFlow(ctx context.Context, state string, deviceCode *copilot.DeviceCodeResponse, accountType copilot.AccountType) { + go func() { + // Use a timeout based on the device code expiration + pollCtx, cancel := context.WithTimeout(ctx, time.Duration(deviceCode.ExpiresIn)*time.Second) + defer cancel() + + copilotAuth := copilot.NewCopilotAuth(h.cfg) + result, authErr := copilotAuth.CompleteAuthWithDeviceCode(pollCtx, deviceCode, accountType) + if authErr != nil { + oauthStatus[state] = fmt.Sprintf("Authentication failed: %v", authErr) + return + } + + if result == nil || result.Storage == nil { + oauthStatus[state] = "Authentication failed: no result returned" + return + } + + principal := result.Storage.Username + if principal == "" { + principal = result.Storage.Email + } + + // Ensure auth directory exists before saving + authDir, ensureErr := util.EnsureAuthDir(h.cfg.AuthDir) + if ensureErr != nil { + oauthStatus[state] = fmt.Sprintf("Failed to prepare auth directory: %v", ensureErr) + return + } + + // Save token using the filename from the shared helper + tokenPath := filepath.Join(authDir, result.SuggestedFilename) + if saveErr := result.Storage.SaveTokenToFile(tokenPath); saveErr != nil { + oauthStatus[state] = fmt.Sprintf("Failed to save token: %v", saveErr) + return + } + + log.Infof("copilot_auth_success: state=%s principal=%s", state, principal) + delete(oauthStatus, state) + }() +} + +// RequestCopilotToken initiates GitHub Copilot device code authentication flow. +// Poll GetAuthStatus with the returned state to check progress. GetAuthStatus returns: +// - status="wait": authentication in progress +// - status="ok": authentication completed successfully +// - status="error": authentication failed (see error) +func (h *Handler) RequestCopilotToken(c *gin.Context) { + ctx := context.Background() + + // Validate auth directory before starting auth flow + if h.cfg.AuthDir == "" { + log.Error("Copilot auth failed: auth directory not configured") + c.JSON(http.StatusInternalServerError, gin.H{"error": "auth directory not configured"}) + return + } + + // Get account type from query param and validate + accountType, ok := validateCopilotAccountType(c) + if !ok { + return + } + + state, err := misc.GenerateRandomState() + if err != nil { + log.Errorf("Failed to generate state parameter: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state"}) + return + } + + // Initialize Copilot auth service + copilotAuth := copilot.NewCopilotAuth(h.cfg) + + // Get device code first to return to user immediately + deviceCode, err := copilotAuth.GetDeviceCode(ctx) + if err != nil { + log.Errorf("Failed to get device code: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get device code"}) + return + } + + // Track this auth flow + oauthStatus[state] = "" + + // Start background goroutine to complete auth flow using shared helper + h.startCopilotAuthFlow(ctx, state, deviceCode, accountType) + + // Return device code info to user + c.JSON(http.StatusOK, gin.H{ + "status": "ok", + "state": state, + "user_code": deviceCode.UserCode, + "verification_uri": deviceCode.VerificationURI, + "expires_in": deviceCode.ExpiresIn, + "interval": deviceCode.Interval, + }) +} + func (h *Handler) GetAuthStatus(c *gin.Context) { state := c.Query("state") if err, ok := oauthStatus[state]; ok { diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index 71193084..90595b45 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -798,3 +798,108 @@ func normalizeClaudeKey(entry *config.ClaudeKey) { } entry.Models = normalized } + +// GetCopilotKeys returns the current Copilot API key configuration. +func (h *Handler) GetCopilotKeys(c *gin.Context) { + c.JSON(200, gin.H{"copilot-api-key": h.cfg.CopilotKey}) +} + +// PutCopilotKeys replaces the Copilot API key configuration. +func (h *Handler) PutCopilotKeys(c *gin.Context) { + data, err := c.GetRawData() + if err != nil { + c.JSON(400, gin.H{"error": "failed to read body"}) + return + } + var arr []config.CopilotKey + if err = json.Unmarshal(data, &arr); err != nil { + var obj struct { + Items []config.CopilotKey `json:"items"` + } + if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + arr = obj.Items + } + // Normalize entries + filtered := make([]config.CopilotKey, 0, len(arr)) + for i := range arr { + entry := arr[i] + entry.AccountType = strings.TrimSpace(entry.AccountType) + entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + filtered = append(filtered, entry) + } + h.cfg.CopilotKey = filtered + h.cfg.SanitizeCopilotKeys() + h.persist(c) +} + +// PatchCopilotKey updates a single Copilot API key entry by index or match. +func (h *Handler) PatchCopilotKey(c *gin.Context) { + var body struct { + Index *int `json:"index"` + Match *string `json:"match"` // Match by account_type + Value *config.CopilotKey `json:"value"` + } + if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + value := *body.Value + value.AccountType = strings.TrimSpace(value.AccountType) + value.ProxyURL = strings.TrimSpace(value.ProxyURL) + + h.mu.Lock() + defer h.mu.Unlock() + + // Update by index + if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CopilotKey) { + h.cfg.CopilotKey[*body.Index] = value + h.cfg.SanitizeCopilotKeys() + h.persist(c) + return + } + // Update by matching account_type + if body.Match != nil { + for i := range h.cfg.CopilotKey { + if h.cfg.CopilotKey[i].AccountType == *body.Match { + h.cfg.CopilotKey[i] = value + h.cfg.SanitizeCopilotKeys() + h.persist(c) + return + } + } + } + c.JSON(404, gin.H{"error": "item not found"}) +} + +// DeleteCopilotKey removes a Copilot API key entry by index or account_type. +func (h *Handler) DeleteCopilotKey(c *gin.Context) { + h.mu.Lock() + defer h.mu.Unlock() + + if val := c.Query("account-type"); val != "" { + out := make([]config.CopilotKey, 0, len(h.cfg.CopilotKey)) + for _, v := range h.cfg.CopilotKey { + if v.AccountType != val { + out = append(out, v) + } + } + h.cfg.CopilotKey = out + h.cfg.SanitizeCopilotKeys() + h.persist(c) + return + } + if idxStr := c.Query("index"); idxStr != "" { + var idx int + _, err := fmt.Sscanf(idxStr, "%d", &idx) + if err == nil && idx >= 0 && idx < len(h.cfg.CopilotKey) { + h.cfg.CopilotKey = append(h.cfg.CopilotKey[:idx], h.cfg.CopilotKey[idx+1:]...) + h.cfg.SanitizeCopilotKeys() + h.persist(c) + return + } + } + c.JSON(400, gin.H{"error": "missing or invalid account-type or index query param"}) +} diff --git a/internal/api/server.go b/internal/api/server.go index ab9c0354..1b402707 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -538,6 +538,11 @@ func (s *Server) registerManagementRoutes() { mgmt.PATCH("/codex-api-key", s.mgmt.PatchCodexKey) mgmt.DELETE("/codex-api-key", s.mgmt.DeleteCodexKey) + mgmt.GET("/copilot-api-key", s.mgmt.GetCopilotKeys) + mgmt.PUT("/copilot-api-key", s.mgmt.PutCopilotKeys) + mgmt.PATCH("/copilot-api-key", s.mgmt.PatchCopilotKey) + mgmt.DELETE("/copilot-api-key", s.mgmt.DeleteCopilotKey) + mgmt.GET("/openai-compatibility", s.mgmt.GetOpenAICompat) mgmt.PUT("/openai-compatibility", s.mgmt.PutOpenAICompat) mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat) @@ -556,6 +561,7 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken) mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken) + mgmt.GET("/copilot-auth-url", s.mgmt.RequestCopilotToken) mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken) mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken) mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken) diff --git a/internal/registry/copilot_models.go b/internal/registry/copilot_models.go new file mode 100644 index 00000000..22de7ca2 --- /dev/null +++ b/internal/registry/copilot_models.go @@ -0,0 +1,58 @@ +package registry + +import "time" + +const CopilotModelPrefix = "copilot-" + +// GenerateCopilotAliases creates copilot- prefixed aliases for explicit routing. +// This allows users to explicitly route to Copilot when model names might conflict +// with other providers (e.g., "copilot-gpt-4o" vs "gpt-4o"). +func GenerateCopilotAliases(models []*ModelInfo) []*ModelInfo { + result := make([]*ModelInfo, 0, len(models)*2) + result = append(result, models...) + + for _, m := range models { + alias := *m + alias.ID = CopilotModelPrefix + m.ID + alias.DisplayName = m.DisplayName + " (Copilot)" + alias.Description = m.Description + " - explicit routing alias" + result = append(result, &alias) + } + + return result +} + +// GetCopilotModels returns the Copilot models (raptor-mini and oswe-vscode-prime). +func GetCopilotModels() []*ModelInfo { + now := time.Now().Unix() + defaultParams := []string{"temperature", "top_p", "max_tokens", "stream", "tools"} + + baseModels := []*ModelInfo{ + { + ID: "oswe-vscode-prime", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "Raptor mini (Preview)", + Description: "Azure OpenAI fine-tuned model via GitHub Copilot (Preview)", + ContextLength: 264000, + MaxCompletionTokens: 64000, + SupportedParameters: defaultParams, + }, + { + ID: "raptor-mini", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "Raptor mini (Preview)", + Description: "Azure OpenAI fine-tuned model via GitHub Copilot (Preview) - alias for oswe-vscode-prime", + ContextLength: 264000, + MaxCompletionTokens: 64000, + SupportedParameters: defaultParams, + }, + } + + return GenerateCopilotAliases(baseModels) +} diff --git a/internal/runtime/executor/copilot_executor.go b/internal/runtime/executor/copilot_executor.go new file mode 100644 index 00000000..8aaf0032 --- /dev/null +++ b/internal/runtime/executor/copilot_executor.go @@ -0,0 +1,509 @@ +package executor + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// CopilotExecutor handles requests to GitHub Copilot API. +// It manages token refresh and proper header injection for Copilot requests. +type CopilotExecutor struct { + cfg *config.Config + tokenMu sync.RWMutex + mu sync.Mutex + tokenCache map[string]*cachedToken + modelMu sync.Mutex + initiatorCount map[string]uint64 +} + +// cachedToken stores the Copilot token and its expiration time. +type cachedToken struct { + token string + expiresAt time.Time +} + +// NewCopilotExecutor creates a new CopilotExecutor instance. + +func NewCopilotExecutor(cfg *config.Config) *CopilotExecutor { + return &CopilotExecutor{ + cfg: cfg, + tokenCache: make(map[string]*cachedToken), + initiatorCount: make(map[string]uint64), + } +} + +func (e *CopilotExecutor) Identifier() string { return "copilot" } + +func (e *CopilotExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil } + +// stripCopilotPrefix removes the "copilot-" prefix from model names if present. +// This allows users to explicitly route to Copilot using "copilot-gpt-5" while +// the actual API call uses "gpt-5". +func stripCopilotPrefix(model string) string { + return strings.TrimPrefix(model, registry.CopilotModelPrefix) +} + +// sanitizeCopilotPayload removes fields that Copilot's Chat Completions endpoint +// rejects (strip max_tokens and parallel_tool_calls). +func sanitizeCopilotPayload(body []byte, model string) []byte { + if len(body) == 0 { + return body + } + if gjson.GetBytes(body, "max_tokens").Exists() { + if cleaned, err := sjson.DeleteBytes(body, "max_tokens"); err == nil { + body = cleaned + } + } + if gjson.GetBytes(body, "parallel_tool_calls").Exists() { + if cleaned, err := sjson.DeleteBytes(body, "parallel_tool_calls"); err == nil { + body = cleaned + } + } + return body +} + +func (e *CopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + copilotToken, accountType, err := e.getCopilotToken(ctx, auth) + if err != nil { + return resp, err + } + + apiModel := stripCopilotPrefix(req.Model) + + translatorModel := req.Model + + reporter := newUsageReporter(ctx, e.Identifier(), apiModel, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + + body := sdktranslator.TranslateRequest(from, to, apiModel, bytes.Clone(req.Payload), false) + body = applyPayloadConfig(e.cfg, apiModel, body) + body = sanitizeCopilotPayload(body, apiModel) + body, _ = sjson.SetBytes(body, "stream", false) + + baseURL := copilotauth.CopilotBaseURL(accountType) + url := baseURL + "/chat/completions" + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return resp, err + } + + e.applyCopilotHeaders(httpReq, copilotToken, req.Payload) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("copilot executor: close response body error: %v", errClose) + } + }() + + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + err = copilotStatusErr(httpResp.StatusCode, string(b)) + return resp, err + } + + data, err := io.ReadAll(httpResp.Body) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + appendAPIResponseChunk(ctx, e.cfg, data) + + // Parse usage from response + reporter.publish(ctx, parseOpenAIUsage(data)) + + var param any + out := sdktranslator.TranslateNonStream(ctx, to, from, translatorModel, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil +} + +func (e *CopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + copilotToken, accountType, err := e.getCopilotToken(ctx, auth) + if err != nil { + return nil, err + } + + apiModel := stripCopilotPrefix(req.Model) + + translatorModel := req.Model + + reporter := newUsageReporter(ctx, e.Identifier(), apiModel, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + + body := sdktranslator.TranslateRequest(from, to, apiModel, bytes.Clone(req.Payload), true) + body = applyPayloadConfig(e.cfg, apiModel, body) + body = sanitizeCopilotPayload(body, apiModel) + body, _ = sjson.SetBytes(body, "stream", true) + + baseURL := copilotauth.CopilotBaseURL(accountType) + url := baseURL + "/chat/completions" + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + e.applyCopilotHeaders(httpReq, copilotToken, req.Payload) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + data, readErr := io.ReadAll(httpResp.Body) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("copilot executor: close response body error: %v", errClose) + } + if readErr != nil { + recordAPIResponseError(ctx, e.cfg, readErr) + return nil, readErr + } + appendAPIResponseChunk(ctx, e.cfg, data) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + err = copilotStatusErr(httpResp.StatusCode, string(data)) + return nil, err + } + + out := make(chan cliproxyexecutor.StreamChunk) + stream = out + go func() { + defer close(out) + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("copilot executor: close response body error: %v", errClose) + } + }() + + scanner := bufio.NewScanner(httpResp.Body) + bufSize := e.cfg.ScannerBufferSize + if bufSize <= 0 { + bufSize = 20_971_520 + } + scanner.Buffer(nil, bufSize) + var param any + for scanner.Scan() { + line := scanner.Bytes() + appendAPIResponseChunk(ctx, e.cfg, line) + + // Parse usage from final chunk if present + if bytes.HasPrefix(line, dataTag) { + data := bytes.TrimSpace(line[5:]) + if gjson.GetBytes(data, "usage").Exists() { + reporter.publish(ctx, parseOpenAIUsage(data)) + } + } + + chunks := sdktranslator.TranslateStream(ctx, to, from, translatorModel, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + for i := range chunks { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} + } + } + if errScan := scanner.Err(); errScan != nil { + recordAPIResponseError(ctx, e.cfg, errScan) + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: errScan} + } + }() + + return stream, nil +} + +func (e *CopilotExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + log.Debugf("copilot executor: refresh called") + if auth == nil { + return nil, statusErr{code: 500, msg: "copilot executor: auth is nil (copilot_refresh_auth_nil)"} + } + + var githubToken string + if auth.Metadata != nil { + if v, ok := auth.Metadata["github_token"].(string); ok && v != "" { + githubToken = v + } + } + // Fallback to storage if metadata is missing github_token + if githubToken == "" { + if storage, ok := auth.Storage.(*copilotauth.CopilotTokenStorage); ok && storage != nil { + githubToken = storage.GitHubToken + } + } + + if githubToken == "" { + log.Debug("copilot executor: no github_token in metadata, skipping refresh") + return auth, nil + } + + authSvc := copilotauth.NewCopilotAuth(e.cfg) + tokenResp, err := authSvc.GetCopilotToken(ctx, githubToken) + if err != nil { + // Classify error: auth issues get 401, transient issues get 503 + // Use structured HTTPStatusError when available, fall back to sentinel errors + code := 503 + cause := "copilot_refresh_transient" + + switch { + case errors.Is(err, copilotauth.ErrNoCopilotSubscription): + code = 401 + cause = "copilot_no_subscription" + case errors.Is(err, copilotauth.ErrAccessDenied): + code = 401 + cause = "copilot_access_denied" + case errors.Is(err, copilotauth.ErrNoGitHubToken): + code = 401 + cause = "copilot_no_github_token" + default: + // Check for structured HTTP status code from HTTPStatusError + if httpCode := copilotauth.StatusCode(err); httpCode != 0 { + if httpCode == 401 || httpCode == 403 { + code = 401 + cause = "copilot_auth_rejected" + } else if httpCode >= 500 { + cause = "copilot_upstream_error" + } + } + } + + log.Warnf("copilot executor: token refresh failed [cause: %s]: %v", cause, err) + return nil, statusErr{code: code, msg: fmt.Sprintf("copilot token refresh failed (%s): %v", cause, err)} + } + + // Update in-memory cache + e.tokenMu.Lock() + e.tokenCache[githubToken] = &cachedToken{ + token: tokenResp.Token, + expiresAt: time.Unix(tokenResp.ExpiresAt, 0), + } + e.tokenMu.Unlock() + + // We no longer rely on metadata for token caching, but we update it + // for the current session in case other components need it. + if auth.Metadata == nil { + auth.Metadata = make(map[string]any) + } + auth.Metadata["copilot_token"] = tokenResp.Token + auth.Metadata["copilot_token_expiry"] = time.Unix(tokenResp.ExpiresAt, 0).Format(time.RFC3339) + auth.Metadata["type"] = "copilot" + + log.Debug("Copilot token refreshed successfully") + return auth, nil +} + +// getCopilotToken retrieves the Copilot token from auth metadata, refreshing if needed. +// Returns statusErr with appropriate HTTP codes: +// - 500 for missing auth or metadata (internal state error, cause: copilot_auth_nil, copilot_metadata_nil) +// - 401 for missing copilot token (auth configuration error, cause: copilot_token_missing) +// This allows callers to distinguish internal state issues from auth configuration problems. +// +// Note on account_type: See sdk/auth/copilot.go for full precedence documentation. +// Attributes["account_type"] is the canonical runtime source; storage is only a fallback. +// +// Note on metadata: auth.Metadata is used as a runtime cache and may be updated from +// CopilotTokenStorage. Both are kept in sync when tokens are refreshed. +func (e *CopilotExecutor) getCopilotToken(ctx context.Context, auth *cliproxyauth.Auth) (string, copilotauth.AccountType, error) { + if auth == nil { + return "", "", statusErr{code: 500, msg: "copilot executor: auth is nil (copilot_auth_nil)"} + } + + copilotauth.EnsureMetadataHydrated(auth) + githubToken := copilotauth.ResolveGitHubToken(auth) + accountType := copilotauth.ResolveAccountType(auth) + + // 1. Check Memory Cache + if token, valid := e.getValidCachedToken(githubToken); valid { + return token, accountType, nil + } + + // 2. Check Metadata (Storage) Cache + copilotToken, copilotExpiry, hasCopilotToken := copilotauth.ResolveCopilotToken(auth) + if hasCopilotToken { + if time.Now().Add(60 * time.Second).Before(copilotExpiry) { + e.setCachedToken(githubToken, copilotToken, copilotExpiry) + return copilotToken, accountType, nil + } + } + + // 3. Refresh if needed + if githubToken != "" { + if _, err := e.Refresh(ctx, auth); err == nil { + if token, valid := e.getValidCachedToken(githubToken); valid { + return token, accountType, nil + } + } + } + + // 4. Fallback: Use cached token if strictly valid (not expired) but near expiry + if hasCopilotToken && time.Now().Before(copilotExpiry) { + return copilotToken, accountType, nil + } + + return "", accountType, statusErr{code: 401, msg: "no valid token available"} +} + +func (e *CopilotExecutor) getValidCachedToken(githubToken string) (string, bool) { + e.tokenMu.RLock() + defer e.tokenMu.RUnlock() + if cached, ok := e.tokenCache[githubToken]; ok { + if time.Now().Add(60 * time.Second).Before(cached.expiresAt) { + return cached.token, true + } + } + return "", false +} + +func (e *CopilotExecutor) setCachedToken(githubToken, token string, expiresAt time.Time) { + e.tokenMu.Lock() + defer e.tokenMu.Unlock() + e.tokenCache[githubToken] = &cachedToken{ + token: token, + expiresAt: expiresAt, + } +} + +// CountTokens provides a token count estimate for Copilot models. +// +// This method uses the Codex/OpenAI tokenizer (via tokenizerForCodexModel) as an +// approximation for Copilot models. Since Copilot routes requests to various +// underlying models, the token counts are best-effort +// estimates rather than exact billing equivalents. +// +// If a Copilot-specific tokenizer becomes available in the future, it can be +// swapped in by replacing the tokenizerForCodexModel call below. +func (e *CopilotExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + apiModel := stripCopilotPrefix(req.Model) + + // Copilot uses OpenAI models, so we can reuse the OpenAI tokenizer logic + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + body := sdktranslator.TranslateRequest(from, to, apiModel, bytes.Clone(req.Payload), false) + + // Use tiktoken for token counting via tokenizerForCodexModel helper. + // This provides OpenAI-compatible token estimates. + enc, err := tokenizerForCodexModel(apiModel) + if err != nil { + return cliproxyexecutor.Response{}, fmt.Errorf("copilot executor: tokenizer init failed: %w", err) + } + + // Extract messages and count tokens + var textParts []string + messages := gjson.GetBytes(body, "messages") + if messages.IsArray() { + for _, msg := range messages.Array() { + content := msg.Get("content") + if content.Type == gjson.String { + textParts = append(textParts, strings.TrimSpace(content.String())) + } else if content.IsArray() { + for _, part := range content.Array() { + if part.Get("type").String() == "text" { + textParts = append(textParts, strings.TrimSpace(part.Get("text").String())) + } + } + } + } + } + + text := strings.Join(textParts, "\n") + count, err := enc.Count(text) + if err != nil { + return cliproxyexecutor.Response{}, fmt.Errorf("copilot executor: token counting failed: %w", err) + } + + usageJSON := fmt.Sprintf(`{"usage":{"input_tokens":%d,"output_tokens":0}}`, count) + translated := sdktranslator.TranslateTokenCount(ctx, to, from, int64(count), []byte(usageJSON)) + return cliproxyexecutor.Response{Payload: []byte(translated)}, nil + +} +// copilotStatusErr creates a statusErr with appropriate retry timing for Copilot. + +// EvictCopilotModelCache clears cached model data for a given auth ID. +func EvictCopilotModelCache(_ string) {} + +// FetchModels returns the static Copilot model list. +func (e *CopilotExecutor) FetchModels(_ context.Context, _ *cliproxyauth.Auth, _ *config.Config) []*registry.ModelInfo { + return registry.GetCopilotModels() +} + +// For 429 errors, it sets a longer retry delay (30 seconds) since Copilot quota +// limits typically require more time to recover than standard rate limits. +func copilotStatusErr(code int, msg string) statusErr { + err := statusErr{code: code, msg: msg} + if code == 429 { + delay := 30 * time.Second + err.retryAfter = &delay + } + return err +} diff --git a/internal/runtime/executor/copilot_executor_test.go b/internal/runtime/executor/copilot_executor_test.go new file mode 100644 index 00000000..e20e0759 --- /dev/null +++ b/internal/runtime/executor/copilot_executor_test.go @@ -0,0 +1,150 @@ +package executor + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" +) + +// TestStripCopilotPrefix verifies that the copilot- prefix is correctly stripped from model names. +func TestStripCopilotPrefix(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "model with copilot prefix", + input: "copilot-claude-opus-4.5", + expected: "claude-opus-4.5", + }, + { + name: "model with copilot prefix - gpt", + input: "copilot-gpt-5", + expected: "gpt-5", + }, + { + name: "model with copilot prefix - gemini", + input: "copilot-gemini-2.5-pro", + expected: "gemini-2.5-pro", + }, + { + name: "model without prefix", + input: "claude-opus-4.5", + expected: "claude-opus-4.5", + }, + { + name: "model without prefix - gpt", + input: "gpt-5", + expected: "gpt-5", + }, + { + name: "model with -copilot suffix (not prefix)", + input: "gpt-41-copilot", + expected: "gpt-41-copilot", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "just the prefix", + input: "copilot-", + expected: "", + }, + { + name: "copilot without hyphen", + input: "copilotmodel", + expected: "copilotmodel", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := stripCopilotPrefix(tt.input) + if result != tt.expected { + t.Errorf("stripCopilotPrefix(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +// TestCopilotModelPrefixConstant verifies the prefix constant is correct. +func TestCopilotModelPrefixConstant(t *testing.T) { + if registry.CopilotModelPrefix != "copilot-" { + t.Errorf("CopilotModelPrefix = %q, want %q", registry.CopilotModelPrefix, "copilot-") + } +} + +// TestStatusErr_Error verifies that statusErr implements error correctly. +func TestStatusErr_Error(t *testing.T) { + err := statusErr{code: 401, msg: "unauthorized"} + expected := "unauthorized" + if err.Error() != expected { + t.Errorf("statusErr.Error() = %q, want %q", err.Error(), expected) + } + + // Test fallback when msg is empty + err2 := statusErr{code: 500, msg: ""} + expected2 := "status 500" + if err2.Error() != expected2 { + t.Errorf("statusErr.Error() = %q, want %q", err2.Error(), expected2) + } +} + +// TestGetCopilotModels verifies the static model list contains expected core models. +func TestGetCopilotModels(t *testing.T) { + models := registry.GetCopilotModels() + + if len(models) == 0 { + t.Fatal("GetCopilotModels() returned empty list") + } + + // Check for expected core models (raptor-mini and oswe-vscode-prime) + expectedModels := map[string]bool{ + "oswe-vscode-prime": false, + "raptor-mini": false, + } + + for _, m := range models { + if _, ok := expectedModels[m.ID]; ok { + expectedModels[m.ID] = true + } + } + + for model, found := range expectedModels { + if !found { + t.Errorf("GetCopilotModels() missing expected model %q", model) + } + } +} + +// TestGenerateCopilotAliases verifies alias generation. +func TestGenerateCopilotAliases(t *testing.T) { + input := []*registry.ModelInfo{ + {ID: "gpt-5", DisplayName: "GPT-5", Description: "Test model"}, + } + + result := registry.GenerateCopilotAliases(input) + + // Should have original + alias + if len(result) != 2 { + t.Errorf("GenerateCopilotAliases() returned %d models, want 2", len(result)) + } + + // Check alias was created + var foundAlias bool + for _, m := range result { + if m.ID == "copilot-gpt-5" { + foundAlias = true + if m.DisplayName != "GPT-5 (Copilot)" { + t.Errorf("alias DisplayName = %q, want %q", m.DisplayName, "GPT-5 (Copilot)") + } + } + } + + if !foundAlias { + t.Error("GenerateCopilotAliases() did not create copilot- prefixed alias") + } +} diff --git a/internal/runtime/executor/copilot_headers.go b/internal/runtime/executor/copilot_headers.go new file mode 100644 index 00000000..ad42bd81 --- /dev/null +++ b/internal/runtime/executor/copilot_headers.go @@ -0,0 +1,171 @@ +package executor + +import ( + "net/http" + "strings" + + copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" +) + +// responsesAPIAgentTypes lists input types that indicate agent/tool activity in the +// OpenAI Responses API format. When any of these types appear in the input array, +// the request should be marked as an agent call (X-Initiator: agent). +// See: https://platform.openai.com/docs/api-reference/responses +var responsesAPIAgentTypes = map[string]bool{ + "function_call": true, + "function_call_output": true, + "computer_call": true, + "computer_call_output": true, + "web_search_call": true, + "file_search_call": true, + "code_interpreter_call": true, + "local_shell_call": true, + "local_shell_call_output": true, + "mcp_call": true, + "mcp_list_tools": true, + "mcp_approval_request": true, + "mcp_approval_response": true, + "image_generation_call": true, + "reasoning": true, +} + +// isResponsesAPIAgentItem checks if a single item from the Responses API input array +// indicates agent/tool activity. This is used to determine the X-Initiator header value. +func isResponsesAPIAgentItem(item gjson.Result) bool { + // Check for assistant role + if item.Get("role").String() == "assistant" { + return true + } + // Check for agent-related input types + return responsesAPIAgentTypes[item.Get("type").String()] +} + +// isResponsesAPIVisionContent checks if a content part from the Responses API +// contains image data, indicating a vision request. +func isResponsesAPIVisionContent(part gjson.Result) bool { + return part.Get("type").String() == "input_image" +} + +type copilotHeaderHints struct { + hasVision bool + agentFromPayload bool + promptCacheKey string +} + +func promptCacheKeyFromPayload(payload []byte) string { + if v := gjson.GetBytes(payload, "prompt_cache_key"); v.Exists() { + if key := strings.TrimSpace(v.String()); key != "" { + return key + } + } + if v := gjson.GetBytes(payload, "metadata.prompt_cache_key"); v.Exists() { + if key := strings.TrimSpace(v.String()); key != "" { + return key + } + } + return "" +} + +func collectCopilotHeaderHints(payload []byte) copilotHeaderHints { + hints := copilotHeaderHints{promptCacheKey: promptCacheKeyFromPayload(payload)} + + // Chat Completions format (messages array) + messages := gjson.GetBytes(payload, "messages") + if messages.IsArray() { + for _, msg := range messages.Array() { + content := msg.Get("content") + if content.IsArray() { + for _, part := range content.Array() { + if part.Get("type").String() == "image_url" { + hints.hasVision = true + } + } + } + role := msg.Get("role").String() + if role == "assistant" || role == "tool" { + hints.agentFromPayload = true + } + } + } + + // Responses API format (input array) + input := gjson.GetBytes(payload, "input") + if input.IsArray() { + for _, item := range input.Array() { + content := item.Get("content") + if content.IsArray() { + for _, part := range content.Array() { + if isResponsesAPIVisionContent(part) { + hints.hasVision = true + } + } + } + if isResponsesAPIAgentItem(item) { + hints.agentFromPayload = true + } + } + } + + return hints +} + +func (e *CopilotExecutor) agentInitiatorPersistEnabled() bool { + if e == nil || e.cfg == nil { + return false + } + for i := range e.cfg.CopilotKey { + if e.cfg.CopilotKey[i].AgentInitiatorPersist { + return true + } + } + return false +} + +func (e *CopilotExecutor) shouldUseAgentInitiator(h copilotHeaderHints) bool { + if e != nil && e.agentInitiatorPersistEnabled() && h.promptCacheKey != "" { + e.mu.Lock() + count := e.initiatorCount[h.promptCacheKey] + e.initiatorCount[h.promptCacheKey] = count + 1 + e.mu.Unlock() + + if h.agentFromPayload { + return true + } + return count > 0 + } + + return h.agentFromPayload +} + +// applyCopilotHeaders applies all necessary headers to the request. +// It handles both Chat Completions format (messages array) and Responses API format (input array). +func (e *CopilotExecutor) applyCopilotHeaders(r *http.Request, copilotToken string, payload []byte) { + hints := collectCopilotHeaderHints(payload) + isAgentCall := e.shouldUseAgentInitiator(hints) + + headers := copilotauth.CopilotHeaders(copilotToken, "", hints.hasVision) + for k, v := range headers { + r.Header.Set(k, v) + } + + // Align with Copilot CLI defaults + r.Header.Set("X-Interaction-Type", "conversation-agent") + r.Header.Set("Openai-Intent", "conversation-agent") + r.Header.Set("X-Stainless-Retry-Count", "0") + r.Header.Set("X-Stainless-Lang", "js") + r.Header.Set("X-Stainless-Package-Version", "5.20.1") + r.Header.Set("X-Stainless-OS", "Linux") + r.Header.Set("X-Stainless-Arch", "arm64") + r.Header.Set("X-Stainless-Runtime", "node") + r.Header.Set("X-Stainless-Runtime-Version", "v22.15.0") + r.Header.Set("User-Agent", copilotauth.CopilotUserAgent) + if isAgentCall { + r.Header.Set("X-Initiator", "agent") + log.Info("copilot executor: [agent call]") + } else { + r.Header.Set("X-Initiator", "user") + log.Info("copilot executor: [user call]") + } +} diff --git a/internal/runtime/executor/copilot_headers_test.go b/internal/runtime/executor/copilot_headers_test.go new file mode 100644 index 00000000..ad70d29d --- /dev/null +++ b/internal/runtime/executor/copilot_headers_test.go @@ -0,0 +1,164 @@ +package executor + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/tidwall/gjson" +) + +func TestIsResponsesAPIAgentItem(t *testing.T) { + tests := []struct { + name string + json string + expected bool + }{ + // User messages - not agent + { + name: "user message", + json: `{"role": "user", "content": [{"type": "input_text", "text": "hello"}]}`, + expected: false, + }, + { + name: "system message", + json: `{"role": "system", "content": "You are helpful"}`, + expected: false, + }, + // Assistant messages - agent + { + name: "assistant message", + json: `{"role": "assistant", "content": [{"type": "output_text", "text": "hi"}]}`, + expected: true, + }, + // Function/tool types - agent + { + name: "function_call", + json: `{"type": "function_call", "call_id": "123", "name": "test"}`, + expected: true, + }, + { + name: "function_call_output", + json: `{"type": "function_call_output", "call_id": "123", "output": "done"}`, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := gjson.Parse(tt.json) + got := isResponsesAPIAgentItem(result) + if got != tt.expected { + t.Errorf("isResponsesAPIAgentItem(%s) = %v, want %v", tt.name, got, tt.expected) + } + }) + } +} + +func TestIsResponsesAPIVisionContent(t *testing.T) { + tests := []struct { + name string + json string + expected bool + }{ + { + name: "text only", + json: `{"role": "user", "content": "hello"}`, + expected: false, + }, + { + name: "with image_url", + json: `{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}]}`, + expected: false, // image_url is handled separately in collectCopilotHeaderHints + }, + { + name: "with input_image", + json: `{"type": "input_image", "source": {"data": "..."}}`, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := gjson.Parse(tt.json) + got := isResponsesAPIVisionContent(result) + if got != tt.expected { + t.Errorf("isResponsesAPIVisionContent(%s) = %v, want %v", tt.name, got, tt.expected) + } + }) + } +} + +func TestApplyCopilotHeaders_XInitiator(t *testing.T) { + tests := []struct { + name string + payload string + expectedValue string + }{ + { + name: "user only - should be user", + payload: `{"messages": [{"role": "user", "content": "hello"}]}`, + expectedValue: "user", + }, + { + name: "with assistant - should be agent", + payload: `{"messages": [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}]}`, + expectedValue: "agent", + }, + { + name: "with tool_calls - should be agent", + payload: `{"messages": [{"role": "user", "content": "hi"}, {"role": "assistant", "tool_calls": [{"id": "1"}]}]}`, + expectedValue: "agent", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.Config{} + e := NewCopilotExecutor(cfg) + req := httptest.NewRequest(http.MethodPost, "/", nil) + + e.applyCopilotHeaders(req, "test-token", []byte(tt.payload)) + + got := req.Header.Get("X-Initiator") + if got != tt.expectedValue { + t.Errorf("X-Initiator = %q, want %q", got, tt.expectedValue) + } + }) + } +} + +func TestApplyCopilotHeaders_Vision(t *testing.T) { + tests := []struct { + name string + payload string + expectedValue string + }{ + { + name: "no images", + payload: `{"messages": [{"role": "user", "content": "hello"}]}`, + expectedValue: "", + }, + { + name: "with image_url", + payload: `{"messages": [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}]}]}`, + expectedValue: "true", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.Config{} + e := NewCopilotExecutor(cfg) + req := httptest.NewRequest(http.MethodPost, "/", nil) + + e.applyCopilotHeaders(req, "test-token", []byte(tt.payload)) + + got := req.Header.Get("Copilot-Vision-Request") + if got != tt.expectedValue { + t.Errorf("Copilot-Vision-Request = %q, want %q", got, tt.expectedValue) + } + }) + } +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index c2ebba8d..18397838 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -105,6 +105,7 @@ func newDefaultAuthManager() *sdkAuth.Manager { sdkAuth.NewCodexAuthenticator(), sdkAuth.NewClaudeAuthenticator(), sdkAuth.NewQwenAuthenticator(), + sdkAuth.NewCopilotAuthenticator(), ) } @@ -300,6 +301,11 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) { return } GlobalModelRegistry().UnregisterClient(id) + + // Evict copilot model cache to prevent stale entries + executor.EvictCopilotModelCache(id) + // NOTE: EvictCopilotGeminiReasoningCache is added in copilot_gemini.go + if existing, ok := s.coreManager.GetByID(id); ok && existing != nil { existing.Disabled = true existing.Status = coreauth.StatusDisabled @@ -375,6 +381,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg)) case "codex": s.coreManager.RegisterExecutor(executor.NewCodexExecutor(s.cfg)) + case "copilot": + s.coreManager.RegisterExecutor(executor.NewCopilotExecutor(s.cfg)) case "qwen": s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg)) case "iflow": @@ -710,6 +718,14 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { } } models = applyExcludedModels(models, excluded) + case "copilot": + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + models = executor.NewCopilotExecutor(s.cfg).FetchModels(ctx, a, s.cfg) + cancel() + if len(models) == 0 { + log.Warnf("copilot: using static fallback models for auth %s", a.ID) + models = registry.GetCopilotModels() + } case "qwen": models = registry.GetQwenModels() models = applyExcludedModels(models, excluded) From d4c3621d32b72bafb2ed3b10f17074d07ae11382 Mon Sep 17 00:00:00 2001 From: Jeff Nash <9919536+jeffnash@users.noreply.github.com> Date: Sun, 30 Nov 2025 14:59:33 -0800 Subject: [PATCH 5/5] feat(copilot/gemini): add Gemini 3 Pro reasoning and dynamic model fetching - Add full Copilot model registry (all GPT, Claude, Gemini, Grok models) - Implement dynamic model fetching from Copilot API with caching - Add Gemini reasoning capture and injection for tool calls - Add reasoning_opaque and reasoning_text handling for Gemini 3 models - Evict model and reasoning caches on auth removal - Add 30-second retry delay for 429 quota errors Credit: Reverse engineering insights adapted from github.com/aadishv/vscre --- internal/registry/copilot_models.go | 207 +++++++++++++++++- internal/runtime/executor/copilot_executor.go | 166 +++++++++++++- internal/runtime/executor/copilot_gemini.go | 193 ++++++++++++++++ .../runtime/executor/copilot_gemini_test.go | 127 +++++++++++ sdk/cliproxy/service.go | 2 +- 5 files changed, 685 insertions(+), 10 deletions(-) create mode 100644 internal/runtime/executor/copilot_gemini.go create mode 100644 internal/runtime/executor/copilot_gemini_test.go diff --git a/internal/registry/copilot_models.go b/internal/registry/copilot_models.go index 22de7ca2..b79ddc4a 100644 --- a/internal/registry/copilot_models.go +++ b/internal/registry/copilot_models.go @@ -22,12 +22,217 @@ func GenerateCopilotAliases(models []*ModelInfo) []*ModelInfo { return result } -// GetCopilotModels returns the Copilot models (raptor-mini and oswe-vscode-prime). +// GetCopilotModels returns a conservative set of fallback models for GitHub Copilot. +// These are used when dynamic model fetching from the Copilot API fails. func GetCopilotModels() []*ModelInfo { now := time.Now().Unix() defaultParams := []string{"temperature", "top_p", "max_tokens", "stream", "tools"} baseModels := []*ModelInfo{ + { + ID: "gpt-4.1", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "GPT-4.1", + Description: "Azure OpenAI model via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + SupportedParameters: defaultParams, + }, + { + ID: "gpt-4o", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "GPT-4o", + Description: "Azure OpenAI model via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 4096, + SupportedParameters: defaultParams, + }, + { + ID: "gpt-41-copilot", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "GPT-4.1 Copilot", + Description: "Azure OpenAI fine-tuned model via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + SupportedParameters: defaultParams, + }, + { + ID: "gpt-5", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "GPT-5", + Description: "Azure OpenAI model via GitHub Copilot", + ContextLength: 400000, + MaxCompletionTokens: 128000, + SupportedParameters: defaultParams, + }, + { + ID: "gpt-5-mini", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "GPT-5 mini", + Description: "Azure OpenAI model via GitHub Copilot", + ContextLength: 264000, + MaxCompletionTokens: 64000, + SupportedParameters: defaultParams, + }, + { + ID: "gpt-5-codex", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "GPT-5-Codex (Preview)", + Description: "OpenAI model via GitHub Copilot (Preview)", + ContextLength: 400000, + MaxCompletionTokens: 128000, + SupportedParameters: defaultParams, + }, + { + ID: "gpt-5.1", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "GPT-5.1", + Description: "OpenAI model via GitHub Copilot (Preview)", + ContextLength: 264000, + MaxCompletionTokens: 64000, + SupportedParameters: defaultParams, + }, + { + ID: "gpt-5.1-codex", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "GPT-5.1-Codex", + Description: "OpenAI model via GitHub Copilot (Preview)", + ContextLength: 400000, + MaxCompletionTokens: 128000, + SupportedParameters: defaultParams, + }, + { + ID: "gpt-5.1-codex-mini", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "GPT-5.1-Codex-Mini", + Description: "OpenAI model via GitHub Copilot (Preview)", + ContextLength: 400000, + MaxCompletionTokens: 128000, + SupportedParameters: defaultParams, + }, + { + ID: "claude-haiku-4.5", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "Claude Haiku 4.5", + Description: "Anthropic model via GitHub Copilot", + ContextLength: 144000, + MaxCompletionTokens: 16000, + SupportedParameters: defaultParams, + }, + { + ID: "claude-opus-4.1", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "Claude Opus 4.1", + Description: "Anthropic model via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 16000, + SupportedParameters: defaultParams, + }, + { + ID: "claude-sonnet-4", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "Claude Sonnet 4", + Description: "Anthropic model via GitHub Copilot", + ContextLength: 216000, + MaxCompletionTokens: 16000, + SupportedParameters: defaultParams, + }, + { + ID: "claude-sonnet-4.5", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "Claude Sonnet 4.5", + Description: "Anthropic model via GitHub Copilot", + ContextLength: 144000, + MaxCompletionTokens: 16000, + SupportedParameters: defaultParams, + }, + { + ID: "claude-opus-4.5", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "Claude Opus 4.5 (Preview)", + Description: "Anthropic model via GitHub Copilot (Preview)", + ContextLength: 144000, + MaxCompletionTokens: 16000, + SupportedParameters: defaultParams, + }, + { + ID: "gemini-2.5-pro", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "Gemini 2.5 Pro", + Description: "Google model via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 64000, + SupportedParameters: defaultParams, + }, + { + ID: "gemini-3-pro-preview", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "Gemini 3 Pro (Preview)", + Description: "Google model via GitHub Copilot (Preview)", + ContextLength: 128000, + MaxCompletionTokens: 64000, + SupportedParameters: defaultParams, + }, + { + ID: "grok-code-fast-1", + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: "Grok Code Fast 1", + Description: "xAI model via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 64000, + SupportedParameters: defaultParams, + }, { ID: "oswe-vscode-prime", Object: "model", diff --git a/internal/runtime/executor/copilot_executor.go b/internal/runtime/executor/copilot_executor.go index 8aaf0032..66cc1806 100644 --- a/internal/runtime/executor/copilot_executor.go +++ b/internal/runtime/executor/copilot_executor.go @@ -40,6 +40,19 @@ type cachedToken struct { expiresAt time.Time } +// modelCacheEntry stores cached models. +// Shared model cache across executor instances (survives executor recreation). +var ( + sharedModelCacheMu sync.Mutex + sharedModelCache = make(map[string]*sharedModelCacheEntry) +) + +type sharedModelCacheEntry struct { + models []*registry.ModelInfo + fetchedAt time.Time +} + +const sharedModelCacheTTL = 30 * time.Minute // NewCopilotExecutor creates a new CopilotExecutor instance. func NewCopilotExecutor(cfg *config.Config) *CopilotExecutor { @@ -54,6 +67,15 @@ func (e *CopilotExecutor) Identifier() string { return "copilot" } func (e *CopilotExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil } +// reasoningCache returns the shared Gemini reasoning cache for a given auth, or a fresh +// cache when auth is nil/unknown. This keeps Gemini reasoning warm across reauths. +func (e *CopilotExecutor) reasoningCache(auth *cliproxyauth.Auth) *geminiReasoningCache { + if auth == nil || strings.TrimSpace(auth.ID) == "" { + return newGeminiReasoningCache() + } + return getSharedGeminiReasoningCache(strings.TrimSpace(auth.ID)) +} + // stripCopilotPrefix removes the "copilot-" prefix from model names if present. // This allows users to explicitly route to Copilot using "copilot-gpt-5" while // the actual API call uses "gpt-5". @@ -89,6 +111,9 @@ func (e *CopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, apiModel := stripCopilotPrefix(req.Model) translatorModel := req.Model + if !strings.HasPrefix(strings.ToLower(req.Model), "copilot-") && strings.HasPrefix(strings.ToLower(apiModel), "gemini") { + translatorModel = "copilot-" + apiModel + } reporter := newUsageReporter(ctx, e.Identifier(), apiModel, auth) defer reporter.trackFailure(ctx, &err) @@ -101,6 +126,11 @@ func (e *CopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, body = sanitizeCopilotPayload(body, apiModel) body, _ = sjson.SetBytes(body, "stream", false) + // Inject cached Gemini reasoning for models that require it + if strings.HasPrefix(strings.ToLower(apiModel), "gemini") { + body = e.reasoningCache(auth).InjectReasoning(body) + } + baseURL := copilotauth.CopilotBaseURL(accountType) url := baseURL + "/chat/completions" @@ -176,6 +206,9 @@ func (e *CopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth. apiModel := stripCopilotPrefix(req.Model) translatorModel := req.Model + if !strings.HasPrefix(strings.ToLower(req.Model), "copilot-") && strings.HasPrefix(strings.ToLower(apiModel), "gemini") { + translatorModel = "copilot-" + apiModel + } reporter := newUsageReporter(ctx, e.Identifier(), apiModel, auth) defer reporter.trackFailure(ctx, &err) @@ -188,6 +221,11 @@ func (e *CopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth. body = sanitizeCopilotPayload(body, apiModel) body, _ = sjson.SetBytes(body, "stream", true) + // Inject cached Gemini reasoning for models that require it + if strings.HasPrefix(strings.ToLower(apiModel), "gemini") { + body = e.reasoningCache(auth).InjectReasoning(body) + } + baseURL := copilotauth.CopilotBaseURL(accountType) url := baseURL + "/chat/completions" @@ -250,6 +288,7 @@ func (e *CopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth. } }() + isGemini := strings.HasPrefix(strings.ToLower(apiModel), "gemini") scanner := bufio.NewScanner(httpResp.Body) bufSize := e.cfg.ScannerBufferSize if bufSize <= 0 { @@ -267,6 +306,11 @@ func (e *CopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth. if gjson.GetBytes(data, "usage").Exists() { reporter.publish(ctx, parseOpenAIUsage(data)) } + + // Cache Gemini reasoning data for subsequent requests + if isGemini { + e.reasoningCache(auth).CacheReasoning(data) + } } chunks := sdktranslator.TranslateStream(ctx, to, from, translatorModel, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) @@ -438,7 +482,7 @@ func (e *CopilotExecutor) setCachedToken(githubToken, token string, expiresAt ti // // This method uses the Codex/OpenAI tokenizer (via tokenizerForCodexModel) as an // approximation for Copilot models. Since Copilot routes requests to various -// underlying models, the token counts are best-effort +// underlying models (GPT, Claude, Gemini), the token counts are best-effort // estimates rather than exact billing equivalents. // // If a Copilot-specific tokenizer becomes available in the future, it can be @@ -485,18 +529,124 @@ func (e *CopilotExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Au usageJSON := fmt.Sprintf(`{"usage":{"input_tokens":%d,"output_tokens":0}}`, count) translated := sdktranslator.TranslateTokenCount(ctx, to, from, int64(count), []byte(usageJSON)) return cliproxyexecutor.Response{Payload: []byte(translated)}, nil +} +func getCachedCopilotModels(authID string) []*registry.ModelInfo { + sharedModelCacheMu.Lock() + defer sharedModelCacheMu.Unlock() + if entry, ok := sharedModelCache[authID]; ok { + if time.Since(entry.fetchedAt) < sharedModelCacheTTL { + return entry.models + } + } + return nil +} + +func setCachedCopilotModels(authID string, models []*registry.ModelInfo) { + sharedModelCacheMu.Lock() + defer sharedModelCacheMu.Unlock() + sharedModelCache[authID] = &sharedModelCacheEntry{ + fetchedAt: time.Now(), + models: models, + } +} + +// EvictCopilotModelCache removes cached models for an auth ID when the auth is removed. +func EvictCopilotModelCache(authID string) { + if authID == "" { + return + } + sharedModelCacheMu.Lock() + delete(sharedModelCache, authID) + sharedModelCacheMu.Unlock() } -// copilotStatusErr creates a statusErr with appropriate retry timing for Copilot. -// EvictCopilotModelCache clears cached model data for a given auth ID. -func EvictCopilotModelCache(_ string) {} +func (e *CopilotExecutor) FetchModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo { + // 1. Check Cache + if models := getCachedCopilotModels(auth.ID); models != nil { + return models + } + + // 2. Resolve Tokens + copilotauth.EnsureMetadataHydrated(auth) + copilotToken, _, _ := copilotauth.ResolveCopilotToken(auth) + + // 3. Fetch (auto-refresh if 401) + authSvc := copilotauth.NewCopilotAuth(cfg) + var modelsResp *copilotauth.CopilotModelsResponse + var err error -// FetchModels returns the static Copilot model list. -func (e *CopilotExecutor) FetchModels(_ context.Context, _ *cliproxyauth.Auth, _ *config.Config) []*registry.ModelInfo { - return registry.GetCopilotModels() + if copilotToken != "" { + modelsResp, err = authSvc.GetModels(ctx, copilotToken, copilotauth.ResolveAccountType(auth)) + } + + if (copilotToken == "" || err != nil) && copilotauth.ResolveGitHubToken(auth) != "" { + // Attempt refresh + if _, refreshErr := e.Refresh(ctx, auth); refreshErr == nil { + copilotToken, _, _ = copilotauth.ResolveCopilotToken(auth) + modelsResp, err = authSvc.GetModels(ctx, copilotToken, copilotauth.ResolveAccountType(auth)) + } + } + + if err != nil || modelsResp == nil { + log.Warnf("copilot executor: failed to fetch models for auth %s: %v", auth.ID, err) + return nil + } + + // 4. Process and Cache + now := time.Now().Unix() + models := make([]*registry.ModelInfo, 0, len(modelsResp.Data)) + + for _, m := range modelsResp.Data { + if !m.ModelPickerEnabled { + continue + } + modelInfo := ®istry.ModelInfo{ + ID: m.ID, + Name: m.Name, + Object: "model", + Created: now, + OwnedBy: "copilot", + Type: "copilot", + DisplayName: m.Name, + Version: m.Version, + } + if m.Capabilities.Limits.MaxContextWindowTokens > 0 { + modelInfo.ContextLength = m.Capabilities.Limits.MaxContextWindowTokens + } + if m.Capabilities.Limits.MaxOutputTokens > 0 { + modelInfo.MaxCompletionTokens = m.Capabilities.Limits.MaxOutputTokens + } + params := []string{"temperature", "top_p", "max_tokens", "stream"} + if m.Capabilities.Supports.ToolCalls { + params = append(params, "tools") + } + modelInfo.SupportedParameters = params + desc := fmt.Sprintf("%s model via GitHub Copilot", m.Vendor) + if m.Preview { + desc += " (Preview)" + } + modelInfo.Description = desc + models = append(models, modelInfo) + } + + models = registry.GenerateCopilotAliases(models) + setCachedCopilotModels(auth.ID, models) + return models +} + +// FetchCopilotModels retrieves available models from the Copilot API using the supplied auth. +// Uses shared cache that persists across executor instances. +func FetchCopilotModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo { + // Use shared cache - check before creating executor + if models := getCachedCopilotModels(auth.ID); models != nil { + return models + } + e := NewCopilotExecutor(cfg) + return e.FetchModels(ctx, auth, cfg) } +// copilotStatusErr creates a statusErr with appropriate retry timing for Copilot. // For 429 errors, it sets a longer retry delay (30 seconds) since Copilot quota // limits typically require more time to recover than standard rate limits. func copilotStatusErr(code int, msg string) statusErr { @@ -506,4 +656,4 @@ func copilotStatusErr(code int, msg string) statusErr { err.retryAfter = &delay } return err -} +} \ No newline at end of file diff --git a/internal/runtime/executor/copilot_gemini.go b/internal/runtime/executor/copilot_gemini.go new file mode 100644 index 00000000..087e811f --- /dev/null +++ b/internal/runtime/executor/copilot_gemini.go @@ -0,0 +1,193 @@ +package executor + +import ( + "fmt" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +type geminiReasoningCache struct { + mu sync.RWMutex + cache map[string]*geminiReasoning +} + +type geminiReasoning struct { + Opaque string + Text string + createdAt time.Time +} + +const geminiReasoningTTL = 30 * time.Minute + +var ( + sharedGeminiReasoningMu sync.Mutex + sharedGeminiReasoning = make(map[string]*geminiReasoningCache) +) + +func newGeminiReasoningCache() *geminiReasoningCache { + return &geminiReasoningCache{ + cache: make(map[string]*geminiReasoning), + } +} + +// getSharedGeminiReasoningCache returns a cache keyed by authID to preserve +// reasoning data across executor re-creations (e.g., after reauth). +func getSharedGeminiReasoningCache(authID string) *geminiReasoningCache { + if authID == "" { + return newGeminiReasoningCache() + } + sharedGeminiReasoningMu.Lock() + defer sharedGeminiReasoningMu.Unlock() + if cache, ok := sharedGeminiReasoning[authID]; ok && cache != nil { + return cache + } + cache := newGeminiReasoningCache() + sharedGeminiReasoning[authID] = cache + return cache +} + +// EvictCopilotGeminiReasoningCache removes the shared cache for an auth ID when the auth is removed. +func EvictCopilotGeminiReasoningCache(authID string) { + if authID == "" { + return + } + sharedGeminiReasoningMu.Lock() + delete(sharedGeminiReasoning, authID) + sharedGeminiReasoningMu.Unlock() +} + +// InjectReasoning inserts cached reasoning fields back into assistant messages +// for tool calls (required by Gemini 3 models). +func (c *geminiReasoningCache) InjectReasoning(body []byte) []byte { + // Find assistant messages with tool_calls that are missing reasoning fields + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return body + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if len(c.cache) == 0 { + log.Debug("copilot executor: no cached Gemini reasoning available") + return body + } + + var modified bool + var msgIdx int + messages.ForEach(func(_, msg gjson.Result) bool { + defer func() { msgIdx++ }() + if msg.Get("role").String() != "assistant" { + return true + } + toolCalls := msg.Get("tool_calls") + if !toolCalls.Exists() || !toolCalls.IsArray() { + return true + } + // Check if reasoning fields are missing + if msg.Get("reasoning_opaque").Exists() || msg.Get("reasoning_text").Exists() { + return true + } + + // Look up reasoning by the first tool_call's id + var callID string + toolCalls.ForEach(func(_, tc gjson.Result) bool { + if id := tc.Get("id").String(); id != "" { + callID = id + return false // stop after first + } + return true + }) + + if callID == "" { + return true + } + + reasoning := c.cache[callID] + if reasoning == nil || (reasoning.Opaque == "" && reasoning.Text == "") { + log.Debugf("copilot executor: no cached reasoning for call_id %s", callID) + return true + } + + // Check TTL + if time.Since(reasoning.createdAt) > geminiReasoningTTL { + log.Debugf("copilot executor: cached reasoning for call_id %s expired", callID) + return true + } + + log.Debugf("copilot executor: injecting reasoning for call_id %s (opaque=%d chars, text=%d chars)", callID, len(reasoning.Opaque), len(reasoning.Text)) + + msgPath := fmt.Sprintf("messages.%d", msgIdx) + if reasoning.Opaque != "" { + body, _ = sjson.SetBytes(body, msgPath+".reasoning_opaque", reasoning.Opaque) + modified = true + } + if reasoning.Text != "" { + body, _ = sjson.SetBytes(body, msgPath+".reasoning_text", reasoning.Text) + modified = true + } + return true + }) + + if modified { + log.Debug("copilot executor: injected cached Gemini reasoning into request") + } + return body +} + +// CacheReasoning captures reasoning fields from streaming deltas. +func (c *geminiReasoningCache) CacheReasoning(data []byte) { + delta := gjson.GetBytes(data, "choices.0.delta") + if !delta.Exists() { + return + } + + // Get the call_id from the first tool_call in the delta + callID := gjson.GetBytes(data, "choices.0.delta.tool_calls.0.id").String() + + opaque := delta.Get("reasoning_opaque").String() + text := delta.Get("reasoning_text").String() + + if opaque == "" && text == "" { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + // Lazy eviction: simple random cleanup if cache gets too big + if len(c.cache) > 1000 { + now := time.Now() + for k, v := range c.cache { + if now.Sub(v.createdAt) > geminiReasoningTTL { + delete(c.cache, k) + } + } + } + + if callID == "" { + return + } + + log.Debugf("copilot executor: caching Gemini reasoning for call_id %s (opaque=%d chars, text=%d chars)", callID, len(opaque), len(text)) + + if c.cache[callID] == nil { + c.cache[callID] = &geminiReasoning{ + createdAt: time.Now(), + } + } + + // Only update if we got new values + if opaque != "" { + c.cache[callID].Opaque = opaque + } + if text != "" { + // Append text since it comes in chunks + c.cache[callID].Text += text + } + c.cache[callID].createdAt = time.Now() +} diff --git a/internal/runtime/executor/copilot_gemini_test.go b/internal/runtime/executor/copilot_gemini_test.go new file mode 100644 index 00000000..9e65ac19 --- /dev/null +++ b/internal/runtime/executor/copilot_gemini_test.go @@ -0,0 +1,127 @@ + package executor + + import ( + "testing" + "time" + + "github.com/tidwall/gjson" + ) + + func TestGeminiReasoningCache_CacheAndInject(t *testing.T) { + cache := newGeminiReasoningCache() + + // Cache reasoning from a streaming delta + delta := `{"choices":[{"delta":{"tool_calls":[{"id":"call_123"}],"reasoning_opaque":"opaque_data","reasoning_text":"thinking..."}}]}` + cache.CacheReasoning([]byte(delta)) + + // Verify it was cached + if cache.cache["call_123"] == nil { + t.Fatal("reasoning was not cached") + } + if cache.cache["call_123"].Opaque != "opaque_data" { + t.Errorf("Opaque = %q, want %q", cache.cache["call_123"].Opaque, "opaque_data") + } + if cache.cache["call_123"].Text != "thinking..." { + t.Errorf("Text = %q, want %q", cache.cache["call_123"].Text, "thinking...") + } + + // Inject into a request body + body := `{"messages":[{"role":"assistant","tool_calls":[{"id":"call_123"}]}]}` + result := cache.InjectReasoning([]byte(body)) + + // Verify injection + if !gjson.GetBytes(result, "messages.0.reasoning_opaque").Exists() { + t.Error("reasoning_opaque was not injected") + } + if gjson.GetBytes(result, "messages.0.reasoning_opaque").String() != "opaque_data" { + t.Errorf("injected reasoning_opaque = %q, want %q", gjson.GetBytes(result, "messages.0.reasoning_opaque").String(), "opaque_data") + } + } + + func TestGeminiReasoningCache_TextAppends(t *testing.T) { + cache := newGeminiReasoningCache() + + // Simulate streaming chunks + chunk1 := `{"choices":[{"delta":{"tool_calls":[{"id":"call_456"}],"reasoning_text":"Hello "}}]}` + chunk2 := `{"choices":[{"delta":{"tool_calls":[{"id":"call_456"}],"reasoning_text":"World"}}]}` + + cache.CacheReasoning([]byte(chunk1)) + cache.CacheReasoning([]byte(chunk2)) + + if cache.cache["call_456"].Text != "Hello World" { + t.Errorf("Text = %q, want %q", cache.cache["call_456"].Text, "Hello World") + } + } + + func TestGeminiReasoningCache_NoInjectWhenAlreadyPresent(t *testing.T) { + cache := newGeminiReasoningCache() + + delta := `{"choices":[{"delta":{"tool_calls":[{"id":"call_789"}],"reasoning_opaque":"cached"}}]}` + cache.CacheReasoning([]byte(delta)) + + // Body already has reasoning_opaque + body := `{"messages":[{"role":"assistant","tool_calls":[{"id":"call_789"}],"reasoning_opaque":"existing"}]}` + result := cache.InjectReasoning([]byte(body)) + + // Should keep existing value + if gjson.GetBytes(result, "messages.0.reasoning_opaque").String() != "existing" { + t.Error("existing reasoning_opaque was overwritten") + } + } + + func TestGeminiReasoningCache_TTLExpiry(t *testing.T) { + cache := newGeminiReasoningCache() + + // Manually insert expired entry + cache.cache["expired_call"] = &geminiReasoning{ + Opaque: "old_data", + createdAt: time.Now().Add(-31 * time.Minute), // expired + } + + body := `{"messages":[{"role":"assistant","tool_calls":[{"id":"expired_call"}]}]}` + result := cache.InjectReasoning([]byte(body)) + + // Should not inject expired reasoning + if gjson.GetBytes(result, "messages.0.reasoning_opaque").Exists() { + t.Error("expired reasoning was injected") + } + } + + func TestEvictCopilotGeminiReasoningCache(t *testing.T) { + // Setup shared cache + testAuthID := "test-auth-evict" + cache := getSharedGeminiReasoningCache(testAuthID) + cache.cache["test"] = &geminiReasoning{Opaque: "data", createdAt: time.Now()} + + // Evict + EvictCopilotGeminiReasoningCache(testAuthID) + + // Get again - should be fresh + newCache := getSharedGeminiReasoningCache(testAuthID) + if len(newCache.cache) != 0 { + t.Error("cache was not evicted") + } + } + + func TestGetSharedGeminiReasoningCache_EmptyAuthID(t *testing.T) { + cache1 := getSharedGeminiReasoningCache("") + cache2 := getSharedGeminiReasoningCache("") + + // Empty authID should return new cache each time + if cache1 == cache2 { + t.Error("empty authID should return new cache instance each time") + } + } + + func TestGetSharedGeminiReasoningCache_SameAuthID(t *testing.T) { + testAuthID := "test-auth-same" + cache1 := getSharedGeminiReasoningCache(testAuthID) + cache2 := getSharedGeminiReasoningCache(testAuthID) + + if cache1 != cache2 { + t.Error("same authID should return same cache instance") + } + + // Cleanup + EvictCopilotGeminiReasoningCache(testAuthID) + } diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 18397838..3dc2aedc 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -304,7 +304,7 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) { // Evict copilot model cache to prevent stale entries executor.EvictCopilotModelCache(id) - // NOTE: EvictCopilotGeminiReasoningCache is added in copilot_gemini.go + executor.EvictCopilotGeminiReasoningCache(id) if existing, ok := s.coreManager.GetByID(id); ok && existing != nil { existing.Disabled = true