diff --git a/cmd/root/login.go b/cmd/root/login.go new file mode 100644 index 000000000..0e0470744 --- /dev/null +++ b/cmd/root/login.go @@ -0,0 +1,74 @@ +package root + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/docker/cagent/pkg/chatgpt" +) + +func newLoginCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "login ", + Short: "Authenticate with a model provider", + Long: "Authenticate with a model provider using OAuth. Currently supports 'chatgpt' for ChatGPT Plus/Pro subscriptions.", + GroupID: "core", + Example: ` cagent login chatgpt`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + provider := args[0] + switch provider { + case "chatgpt": + return loginChatGPT(cmd) + default: + return fmt.Errorf("unsupported provider %q (supported: chatgpt)", provider) + } + }, + } + + return cmd +} + +func loginChatGPT(cmd *cobra.Command) error { + fmt.Fprintln(cmd.OutOrStdout(), "Opening browser to authenticate with ChatGPT...") + + token, err := chatgpt.Login(cmd.Context()) + if err != nil { + return fmt.Errorf("ChatGPT login failed: %w", err) + } + + if err := chatgpt.SaveToken(token); err != nil { + return fmt.Errorf("failed to save token: %w", err) + } + + fmt.Fprintln(cmd.OutOrStdout(), "Successfully authenticated with ChatGPT!") + fmt.Fprintln(cmd.OutOrStdout(), "You can now use 'chatgpt' as a provider, e.g.: chatgpt/o3") + return nil +} + +func newLogoutCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "logout ", + Short: "Remove stored authentication for a provider", + Long: "Remove stored authentication tokens for a model provider.", + GroupID: "core", + Example: ` cagent logout chatgpt`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + provider := args[0] + switch provider { + case "chatgpt": + if err := chatgpt.RemoveToken(); err != nil { + return fmt.Errorf("failed to remove ChatGPT token: %w", err) + } + fmt.Fprintln(cmd.OutOrStdout(), "Successfully logged out from ChatGPT.") + return nil + default: + return fmt.Errorf("unsupported provider %q (supported: chatgpt)", provider) + } + }, + } + + return cmd +} diff --git a/cmd/root/root.go b/cmd/root/root.go index f1284ac18..df332eb72 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -123,6 +123,8 @@ func NewRootCmd() *cobra.Command { cmd.AddCommand(newDebugCmd()) cmd.AddCommand(newAliasCmd()) cmd.AddCommand(newServeCmd()) + cmd.AddCommand(newLoginCmd()) + cmd.AddCommand(newLogoutCmd()) // Define groups cmd.AddGroup(&cobra.Group{ID: "core", Title: "Core Commands:"}) diff --git a/pkg/chatgpt/auth.go b/pkg/chatgpt/auth.go new file mode 100644 index 000000000..ccab13724 --- /dev/null +++ b/pkg/chatgpt/auth.go @@ -0,0 +1,327 @@ +// Package chatgpt implements OAuth authentication for ChatGPT Plus/Pro subscriptions. +// It uses the OAuth2 PKCE flow against auth.openai.com to obtain access tokens +// that can be exchanged for an OpenAI API key. +package chatgpt + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/url" + "time" + + "golang.org/x/oauth2" + + "github.com/docker/cagent/pkg/browser" +) + +const ( + // OAuth endpoints for ChatGPT authentication + authorizationEndpoint = "https://auth.openai.com/oauth/authorize" + + // OAuth client configuration (same as Codex CLI) + clientID = "app_EMoamEEZ73f0CkXaXp7hrann" + + // OAuth scopes + defaultScopes = "openid profile email offline_access" + + // defaultPort is the preferred local port for the OAuth callback server. + defaultPort = 1455 +) + +// tokenEndpointURL is the OAuth token endpoint. It is a variable so tests can override it. +var tokenEndpointURL = "https://auth.openai.com/oauth/token" + +// Token represents the persisted authentication state from the ChatGPT OAuth flow. +type Token struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + IDToken string `json:"id_token,omitempty"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + ExpiresAt time.Time `json:"expires_at"` +} + +// IsExpired checks if the token is expired. +// Returns true if the token will expire within 60 seconds. +func (t *Token) IsExpired() bool { + if t.ExpiresAt.IsZero() { + return false + } + return time.Now().Add(60 * time.Second).After(t.ExpiresAt) +} + +// Login performs the OAuth PKCE flow to authenticate with ChatGPT. +// It opens the user's browser to the OpenAI login page, starts a local +// callback server, exchanges the authorization code for tokens, and then +// exchanges the id_token for a standard OpenAI API key. +func Login(ctx context.Context) (*Token, error) { + // Generate PKCE code verifier and challenge + verifier := oauth2.GenerateVerifier() + challenge := oauth2.S256ChallengeFromVerifier(verifier) + + // Generate random state for CSRF protection + state, err := generateState() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + + // Start local callback server, preferring the default port + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", defaultPort)) + if err != nil { + // Fall back to a random port + listener, err = net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("failed to start callback server: %w", err) + } + } + port := listener.Addr().(*net.TCPAddr).Port + redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", port) + + // Build authorization URL + authURL := buildAuthURL(redirectURI, state, challenge) + + slog.Debug("Starting ChatGPT OAuth login", "redirect_uri", redirectURI) + + // Channel to receive the authorization code + type callbackResult struct { + code string + err error + } + resultCh := make(chan callbackResult, 1) + + // Set up callback handler + mux := http.NewServeMux() + mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + + // Verify state + if r.URL.Query().Get("state") != state { + resultCh <- callbackResult{err: fmt.Errorf("state mismatch")} + http.Error(w, "State mismatch", http.StatusBadRequest) + return + } + + // Check for errors + if errParam := r.URL.Query().Get("error"); errParam != "" { + desc := r.URL.Query().Get("error_description") + resultCh <- callbackResult{err: fmt.Errorf("OAuth error: %s: %s", errParam, desc)} + fmt.Fprintf(w, "

Authentication failed

%s

You can close this window.

", desc) + return + } + + code := r.URL.Query().Get("code") + if code == "" { + resultCh <- callbackResult{err: fmt.Errorf("no authorization code received")} + http.Error(w, "No code received", http.StatusBadRequest) + return + } + + resultCh <- callbackResult{code: code} + fmt.Fprint(w, "

Authentication successful!

You can close this window and return to the terminal.

") + }) + + server := &http.Server{ + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + } + + // Start server in background + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + slog.Error("Callback server error", "error", err) + } + }() + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cancel() + _ = server.Shutdown(shutdownCtx) + }() + + // Open browser + if err := browser.Open(ctx, authURL); err != nil { + return nil, fmt.Errorf("failed to open browser (visit this URL manually):\n%s\n\nerror: %w", authURL, err) + } + + // Wait for callback or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-resultCh: + if result.err != nil { + return nil, fmt.Errorf("authentication failed: %w", result.err) + } + + // Exchange code for OAuth tokens + tokens, err := exchangeCode(ctx, result.code, verifier, redirectURI) + if err != nil { + return nil, err + } + + // Exchange id_token for an OpenAI API key + apiKey, err := exchangeForAPIKey(ctx, tokens.IDToken) + if err != nil { + return nil, fmt.Errorf("failed to obtain API key: %w", err) + } + + tokens.AccessToken = apiKey + return tokens, nil + } +} + +// RefreshAccessToken refreshes an expired access token using the refresh token. +// It obtains new OAuth tokens and then exchanges the new id_token for an API key. +func RefreshAccessToken(ctx context.Context, refreshToken string) (*Token, error) { + payload, err := json.Marshal(map[string]string{ + "grant_type": "refresh_token", + "client_id": clientID, + "refresh_token": refreshToken, + "scope": "openid profile email", + }) + if err != nil { + return nil, fmt.Errorf("failed to marshal refresh request: %w", err) + } + + var refreshResp struct { + IDToken string `json:"id_token"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + } + if err := postJSON(ctx, tokenEndpointURL, "application/json", bytes.NewReader(payload), &refreshResp); err != nil { + return nil, fmt.Errorf("token refresh failed: %w", err) + } + + // Use new refresh token if provided, otherwise keep the old one + newRefreshToken := refreshResp.RefreshToken + if newRefreshToken == "" { + newRefreshToken = refreshToken + } + + // Exchange the new id_token for an API key + if refreshResp.IDToken == "" { + return nil, fmt.Errorf("refresh response did not include an id_token") + } + + apiKey, err := exchangeForAPIKey(ctx, refreshResp.IDToken) + if err != nil { + return nil, fmt.Errorf("failed to obtain API key after refresh: %w", err) + } + + slog.Debug("ChatGPT token refreshed successfully") + return &Token{ + AccessToken: apiKey, + RefreshToken: newRefreshToken, + IDToken: refreshResp.IDToken, + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil +} + +// buildAuthURL constructs the OAuth authorization URL with PKCE parameters. +func buildAuthURL(redirectURI, state, codeChallenge string) string { + params := url.Values{} + params.Set("response_type", "code") + params.Set("client_id", clientID) + params.Set("redirect_uri", redirectURI) + params.Set("scope", defaultScopes) + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + params.Set("state", state) + return authorizationEndpoint + "?" + params.Encode() +} + +// exchangeCode exchanges an authorization code for OAuth tokens (id_token, access_token, refresh_token). +func exchangeCode(ctx context.Context, code, verifier, redirectURI string) (*Token, error) { + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("code", code) + data.Set("redirect_uri", redirectURI) + data.Set("client_id", clientID) + data.Set("code_verifier", verifier) + + var tokenResp struct { + IDToken string `json:"id_token"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + } + if err := postForm(ctx, tokenEndpointURL, data, &tokenResp); err != nil { + return nil, fmt.Errorf("code exchange failed: %w", err) + } + + slog.Debug("ChatGPT OAuth code exchange successful") + return &Token{ + IDToken: tokenResp.IDToken, + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil +} + +// exchangeForAPIKey exchanges an id_token for a standard OpenAI API key +// using the token exchange grant type. +func exchangeForAPIKey(ctx context.Context, idToken string) (string, error) { + data := url.Values{} + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") + data.Set("client_id", clientID) + data.Set("requested_token", "openai-api-key") + data.Set("subject_token", idToken) + data.Set("subject_token_type", "urn:ietf:params:oauth:token-type:id_token") + + var exchangeResp struct { + AccessToken string `json:"access_token"` + } + if err := postForm(ctx, tokenEndpointURL, data, &exchangeResp); err != nil { + return "", fmt.Errorf("API key exchange failed: %w", err) + } + + return exchangeResp.AccessToken, nil +} + +// postForm sends a POST request with form-encoded data and decodes the JSON response. +func postForm(ctx context.Context, endpoint string, data url.Values, result any) error { + return postJSON(ctx, endpoint, "application/x-www-form-urlencoded", bytes.NewBufferString(data.Encode()), result) +} + +// postJSON sends a POST request with the given content type and body, then +// decodes the JSON response into result. +func postJSON(ctx context.Context, endpoint, contentType string, body io.Reader, result any) error { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", contentType) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("status %d: %s", resp.StatusCode, string(respBody)) + } + + if err := json.NewDecoder(resp.Body).Decode(result); err != nil { + return fmt.Errorf("decode response: %w", err) + } + + return nil +} + +// generateState generates a random state string for CSRF protection. +func generateState() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} diff --git a/pkg/chatgpt/chatgpt_test.go b/pkg/chatgpt/chatgpt_test.go new file mode 100644 index 000000000..081b7cf99 --- /dev/null +++ b/pkg/chatgpt/chatgpt_test.go @@ -0,0 +1,446 @@ +package chatgpt + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// overrideTokenPath redirects the token store to a temp directory for the +// duration of the test. Because it mutates a package-level variable, tests +// that use this helper must NOT be marked parallel. +func overrideTokenPath(t *testing.T) { + t.Helper() + + dir := t.TempDir() + orig := tokenFilePathFunc + tokenFilePathFunc = func() string { return filepath.Join(dir, "chatgpt_token.json") } + t.Cleanup(func() { tokenFilePathFunc = orig }) +} + +// overrideTokenEndpoint redirects token HTTP calls to the given test server +// for the duration of the test. +func overrideTokenEndpoint(t *testing.T, url string) { + t.Helper() + + orig := tokenEndpointURL + tokenEndpointURL = url + t.Cleanup(func() { tokenEndpointURL = orig }) +} + +func TestToken_IsExpired(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + token Token + expected bool + }{ + { + name: "zero expiry is never expired", + token: Token{AccessToken: "test"}, + expected: false, + }, + { + name: "future expiry is not expired", + token: Token{ + AccessToken: "test", + ExpiresAt: time.Now().Add(10 * time.Minute), + }, + expected: false, + }, + { + name: "past expiry is expired", + token: Token{ + AccessToken: "test", + ExpiresAt: time.Now().Add(-10 * time.Minute), + }, + expected: true, + }, + { + name: "expiry within 60 seconds is considered expired", + token: Token{ + AccessToken: "test", + ExpiresAt: time.Now().Add(30 * time.Second), + }, + expected: true, + }, + { + name: "expiry beyond 60 seconds is not expired", + token: Token{ + AccessToken: "test", + ExpiresAt: time.Now().Add(90 * time.Second), + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.expected, tt.token.IsExpired()) + }) + } +} + +// TestTokenStore_SaveLoadRemove tests the token store using a helper that +// overrides the package-level tokenFilePathFunc. Because mutating a package +// global is not parallel-safe, this test is NOT marked parallel. +func TestTokenStore_SaveLoadRemove(t *testing.T) { + overrideTokenPath(t) + + // Initially no token + token, err := LoadToken() + require.NoError(t, err) + assert.Nil(t, token) + + // Save a token + testToken := &Token{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour).Truncate(time.Second), + } + require.NoError(t, SaveToken(testToken)) + + // Load it back + loaded, err := LoadToken() + require.NoError(t, err) + require.NotNil(t, loaded) + assert.Equal(t, testToken.AccessToken, loaded.AccessToken) + assert.Equal(t, testToken.RefreshToken, loaded.RefreshToken) + assert.Equal(t, testToken.TokenType, loaded.TokenType) + + // Remove it + require.NoError(t, RemoveToken()) + + // Gone + token, err = LoadToken() + require.NoError(t, err) + assert.Nil(t, token) +} + +func TestTokenStore_RemoveNoFile(t *testing.T) { + overrideTokenPath(t) + + require.NoError(t, RemoveToken()) +} + +func TestTokenStore_InvalidJSON(t *testing.T) { + overrideTokenPath(t) + + require.NoError(t, os.WriteFile(tokenFilePath(), []byte("not json"), 0o600)) + + _, err := LoadToken() + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse token file") +} + +func TestRefreshAccessToken(t *testing.T) { + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + + if callCount == 1 { + // First call: token refresh (JSON body) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + _ = json.NewEncoder(w).Encode(map[string]any{ + "id_token": "new-id-token", + "access_token": "new-access-token", + "refresh_token": "", // no new refresh token + }) + } else { + // Second call: API key exchange (form body) + assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "new-api-key", + }) + } + })) + defer server.Close() + + overrideTokenEndpoint(t, server.URL) + + token, err := RefreshAccessToken(t.Context(), "test-refresh") + require.NoError(t, err) + assert.Equal(t, "new-api-key", token.AccessToken) + assert.Equal(t, "test-refresh", token.RefreshToken) // preserved when empty in response + assert.Equal(t, "new-id-token", token.IDToken) + assert.Equal(t, "Bearer", token.TokenType) + assert.False(t, token.ExpiresAt.IsZero()) +} + +func TestRefreshAccessToken_NewRefreshToken(t *testing.T) { + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + + if callCount == 1 { + _ = json.NewEncoder(w).Encode(map[string]any{ + "id_token": "new-id-token", + "access_token": "new-access-token", + "refresh_token": "new-refresh-token", + }) + } else { + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "new-api-key", + }) + } + })) + defer server.Close() + + overrideTokenEndpoint(t, server.URL) + + token, err := RefreshAccessToken(t.Context(), "old-refresh") + require.NoError(t, err) + assert.Equal(t, "new-refresh-token", token.RefreshToken) +} + +func TestRefreshAccessToken_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "bad request", http.StatusBadRequest) + })) + defer server.Close() + + overrideTokenEndpoint(t, server.URL) + + _, err := RefreshAccessToken(t.Context(), "test-refresh") + require.Error(t, err) + assert.Contains(t, err.Error(), "token refresh failed") +} + +func TestProvider_Get_NonMatchingVar(t *testing.T) { + t.Parallel() + + p := NewProvider() + val, ok := p.Get(t.Context(), "OTHER_VAR") + assert.Empty(t, val) + assert.False(t, ok) +} + +func TestProvider_Get_NoToken(t *testing.T) { + overrideTokenPath(t) + + p := NewProvider() + val, ok := p.Get(t.Context(), TokenEnvVar) + assert.Empty(t, val) + assert.False(t, ok) +} + +func TestProvider_Get_ValidToken(t *testing.T) { + overrideTokenPath(t) + + testToken := &Token{ + AccessToken: "valid-token", + RefreshToken: "refresh", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + require.NoError(t, SaveToken(testToken)) + + p := NewProvider() + val, ok := p.Get(t.Context(), TokenEnvVar) + assert.True(t, ok) + assert.Equal(t, "valid-token", val) +} + +func TestProvider_Get_ExpiredTokenRefreshes(t *testing.T) { + overrideTokenPath(t) + + // Set up a mock token endpoint that handles refresh + API key exchange + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + + if callCount == 1 { + _ = json.NewEncoder(w).Encode(map[string]any{ + "id_token": "refreshed-id", + "access_token": "refreshed-access", + }) + } else { + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "refreshed-api-key", + }) + } + })) + defer server.Close() + + overrideTokenEndpoint(t, server.URL) + + // Save an expired token with a refresh token + expiredToken := &Token{ + AccessToken: "expired-token", + RefreshToken: "refresh-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(-10 * time.Minute), + } + require.NoError(t, SaveToken(expiredToken)) + + p := NewProvider() + val, ok := p.Get(t.Context(), TokenEnvVar) + assert.True(t, ok) + assert.Equal(t, "refreshed-api-key", val) +} + +func TestProvider_Get_ExpiredTokenNoRefresh(t *testing.T) { + overrideTokenPath(t) + + // Save an expired token without a refresh token + expiredToken := &Token{ + AccessToken: "expired-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(-10 * time.Minute), + } + require.NoError(t, SaveToken(expiredToken)) + + p := NewProvider() + val, ok := p.Get(t.Context(), TokenEnvVar) + assert.False(t, ok) + assert.Empty(t, val) +} + +func TestProvider_GetAccessToken_NotLoggedIn(t *testing.T) { + overrideTokenPath(t) + + p := NewProvider() + _, err := p.GetAccessToken(t.Context()) + require.Error(t, err) + assert.Contains(t, err.Error(), "not logged in") +} + +func TestProvider_GetAccessToken_ValidToken(t *testing.T) { + overrideTokenPath(t) + + testToken := &Token{ + AccessToken: "my-access-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + require.NoError(t, SaveToken(testToken)) + + p := NewProvider() + val, err := p.GetAccessToken(t.Context()) + require.NoError(t, err) + assert.Equal(t, "my-access-token", val) +} + +func TestBuildAuthURL(t *testing.T) { + t.Parallel() + + authURL := buildAuthURL("http://localhost:1455/auth/callback", "test-state", "test-challenge") + assert.Contains(t, authURL, authorizationEndpoint) + assert.Contains(t, authURL, "response_type=code") + assert.Contains(t, authURL, "client_id="+clientID) + assert.Contains(t, authURL, "redirect_uri=http") + assert.Contains(t, authURL, "state=test-state") + assert.Contains(t, authURL, "code_challenge=test-challenge") + assert.Contains(t, authURL, "code_challenge_method=S256") +} + +func TestGenerateState(t *testing.T) { + t.Parallel() + + s1, err := generateState() + require.NoError(t, err) + assert.Len(t, s1, 32) // 16 bytes hex-encoded + + s2, err := generateState() + require.NoError(t, err) + assert.NotEqual(t, s1, s2) // should be random +} + +func TestExchangeCode(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + if err := r.ParseForm(); err != nil { + http.Error(w, "bad form", http.StatusBadRequest) + return + } + assert.Equal(t, "authorization_code", r.Form.Get("grant_type")) + assert.Equal(t, "test-code", r.Form.Get("code")) + assert.Equal(t, "test-verifier", r.Form.Get("code_verifier")) + assert.Equal(t, clientID, r.Form.Get("client_id")) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id_token": "exchanged-id-token", + "access_token": "exchanged-access", + "refresh_token": "exchanged-refresh", + }) + })) + defer server.Close() + + overrideTokenEndpoint(t, server.URL) + + token, err := exchangeCode(t.Context(), "test-code", "test-verifier", "http://localhost:1455/auth/callback") + require.NoError(t, err) + assert.Equal(t, "exchanged-access", token.AccessToken) + assert.Equal(t, "exchanged-id-token", token.IDToken) + assert.Equal(t, "exchanged-refresh", token.RefreshToken) + assert.False(t, token.ExpiresAt.IsZero()) +} + +func TestExchangeForAPIKey(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + if err := r.ParseForm(); err != nil { + http.Error(w, "bad form", http.StatusBadRequest) + return + } + assert.Equal(t, "urn:ietf:params:oauth:grant-type:token-exchange", r.Form.Get("grant_type")) + assert.Equal(t, clientID, r.Form.Get("client_id")) + assert.Equal(t, "openai-api-key", r.Form.Get("requested_token")) + assert.Equal(t, "my-id-token", r.Form.Get("subject_token")) + assert.Equal(t, "urn:ietf:params:oauth:token-type:id_token", r.Form.Get("subject_token_type")) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "sk-api-key-12345", + }) + })) + defer server.Close() + + overrideTokenEndpoint(t, server.URL) + + apiKey, err := exchangeForAPIKey(t.Context(), "my-id-token") + require.NoError(t, err) + assert.Equal(t, "sk-api-key-12345", apiKey) +} + +func TestExchangeForAPIKey_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "forbidden", http.StatusForbidden) + })) + defer server.Close() + + overrideTokenEndpoint(t, server.URL) + + _, err := exchangeForAPIKey(t.Context(), "bad-token") + require.Error(t, err) + assert.Contains(t, err.Error(), "API key exchange failed") +} + +func TestProvider_ContextCancellation(t *testing.T) { + t.Parallel() + + p := NewProvider() + + // Non-matching key should return quickly regardless + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + val, ok := p.Get(ctx, "OTHER_VAR") + assert.Empty(t, val) + assert.False(t, ok) +} diff --git a/pkg/chatgpt/provider.go b/pkg/chatgpt/provider.go new file mode 100644 index 000000000..0238d4562 --- /dev/null +++ b/pkg/chatgpt/provider.go @@ -0,0 +1,87 @@ +package chatgpt + +import ( + "context" + "fmt" + "log/slog" + "sync" +) + +const ( + // TokenEnvVar is the virtual environment variable name used for ChatGPT auth. + // This is used by the provider system to check/resolve the ChatGPT token. + TokenEnvVar = "CHATGPT_ACCESS_TOKEN" +) + +// Provider implements environment.Provider for ChatGPT tokens. +// It loads the stored token from disk, refreshes it if expired, +// and returns the access token when TokenEnvVar is queried. +type Provider struct { + mu sync.Mutex + token *Token +} + +// NewProvider creates a new ChatGPT environment provider. +func NewProvider() *Provider { + return &Provider{} +} + +// Get retrieves the ChatGPT access token when the requested variable +// matches TokenEnvVar. For all other variables, it returns ("", false). +func (p *Provider) Get(ctx context.Context, name string) (string, bool) { + if name != TokenEnvVar { + return "", false + } + + token, err := p.resolveToken(ctx) + if err != nil { + slog.Debug("ChatGPT token not available", "error", err) + return "", false + } + + return token, true +} + +// GetAccessToken returns the current access token, refreshing if needed. +// Unlike Get, this returns an error on failure for direct use by the provider. +func (p *Provider) GetAccessToken(ctx context.Context) (string, error) { + return p.resolveToken(ctx) +} + +// resolveToken loads the token from disk (if not cached), refreshes it if +// expired, persists the refreshed token, and returns the access token. +func (p *Provider) resolveToken(ctx context.Context) (string, error) { + p.mu.Lock() + defer p.mu.Unlock() + + // Load token from disk if not cached + if p.token == nil { + token, err := LoadToken() + if err != nil { + return "", fmt.Errorf("failed to load ChatGPT token: %w", err) + } + if token == nil { + return "", fmt.Errorf("not logged in to ChatGPT - run 'cagent login chatgpt' first") + } + p.token = token + } + + // Refresh if expired + if p.token.IsExpired() { + if p.token.RefreshToken == "" { + return "", fmt.Errorf("ChatGPT token expired - run 'cagent login chatgpt' to re-authenticate") + } + + newToken, err := RefreshAccessToken(ctx, p.token.RefreshToken) + if err != nil { + return "", fmt.Errorf("failed to refresh ChatGPT token: %w", err) + } + + p.token = newToken + if err := SaveToken(newToken); err != nil { + slog.Warn("Failed to save refreshed ChatGPT token", "error", err) + } + } + + return p.token.AccessToken, nil +} diff --git a/pkg/chatgpt/token_store.go b/pkg/chatgpt/token_store.go new file mode 100644 index 000000000..ff9cef673 --- /dev/null +++ b/pkg/chatgpt/token_store.go @@ -0,0 +1,70 @@ +package chatgpt + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/docker/cagent/pkg/paths" +) + +// tokenFilePathFunc is overridable for testing. +var tokenFilePathFunc = defaultTokenFilePath + +// defaultTokenFilePath returns the path to the stored ChatGPT token file. +func defaultTokenFilePath() string { + return filepath.Join(paths.GetConfigDir(), "chatgpt_token.json") +} + +// tokenFilePath returns the path to the stored ChatGPT token file. +func tokenFilePath() string { + return tokenFilePathFunc() +} + +// LoadToken loads a stored ChatGPT token from disk. +// Returns nil if no token is stored or the file is invalid. +func LoadToken() (*Token, error) { + data, err := os.ReadFile(tokenFilePath()) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to read token file: %w", err) + } + + var token Token + if err := json.Unmarshal(data, &token); err != nil { + return nil, fmt.Errorf("failed to parse token file: %w", err) + } + + return &token, nil +} + +// SaveToken saves a ChatGPT token to disk. +func SaveToken(token *Token) error { + path := tokenFilePath() + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + data, err := json.MarshalIndent(token, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal token: %w", err) + } + + if err := os.WriteFile(path, data, 0o600); err != nil { + return fmt.Errorf("failed to write token file: %w", err) + } + + return nil +} + +// RemoveToken removes the stored ChatGPT token from disk. +func RemoveToken() error { + if err := os.Remove(tokenFilePath()); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove token file: %w", err) + } + return nil +} diff --git a/pkg/environment/default.go b/pkg/environment/default.go index 093cfa10f..e65e2456c 100644 --- a/pkg/environment/default.go +++ b/pkg/environment/default.go @@ -1,6 +1,7 @@ package environment import ( + "github.com/docker/cagent/pkg/chatgpt" "github.com/docker/cagent/pkg/paths" "github.com/docker/cagent/pkg/userconfig" ) @@ -35,8 +36,9 @@ func NewDefaultProvider() Provider { providers = append(providers, NewCredentialHelperProvider(cfg.CredentialHelper.Command, cfg.CredentialHelper.Args...)) } - // Docker Desktop provider comes after credential helper - providers = append(providers, NewDockerDesktopProvider()) + // Docker Desktop provider comes after credential helper. + // ChatGPT OAuth token provider for chatgpt/* models. + providers = append(providers, NewDockerDesktopProvider(), chatgpt.NewProvider()) // Append pass provider at the end if available if passProvider, err := NewPassProvider(); err == nil { diff --git a/pkg/model/provider/provider.go b/pkg/model/provider/provider.go index d85c593ca..44b47815b 100644 --- a/pkg/model/provider/provider.go +++ b/pkg/model/provider/provider.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/docker/cagent/pkg/chat" + "github.com/docker/cagent/pkg/chatgpt" "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/environment" "github.com/docker/cagent/pkg/model/provider/anthropic" @@ -128,6 +129,10 @@ var Aliases = map[string]Alias{ BaseURL: "https://api.minimax.io/v1", TokenEnvVar: "MINIMAX_API_KEY", }, + "chatgpt": { + APIType: "openai_responses", + TokenEnvVar: chatgpt.TokenEnvVar, + }, } // Provider defines the interface for model providers