Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 115 additions & 14 deletions internal/auth/gemini/gemini_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package gemini

import (
"bufio"
"context"
"encoding/json"
"errors"
Expand All @@ -13,6 +14,8 @@ import (
"net"
"net/http"
"net/url"
"os"
"strings"
"time"

"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
Expand Down Expand Up @@ -200,6 +203,7 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
// It starts a local HTTP server to listen for the callback from Google's auth server,
// opens the user's browser to the authorization URL, and exchanges the received
// authorization code for an access token.
// If the automatic callback fails (e.g., in Docker), it allows manual input of the callback URL.
//
// Parameters:
// - ctx: The context for the HTTP client
Expand All @@ -211,14 +215,17 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
// - error: An error if the token acquisition fails, nil otherwise
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) {
// Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string)
errChan := make(chan error)
codeChan := make(chan string, 1)
errChan := make(chan error, 1)
manualInputChan := make(chan string, 1)
contextDone := ctx.Done()

// Create a new HTTP server with its own multiplexer.
mux := http.NewServeMux()
server := &http.Server{Addr: ":8085", Handler: mux}
config.RedirectURL = "http://localhost:8085/oauth2callback"

// HTTP callback handler
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
if err := r.URL.Query().Get("error"); err != "" {
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
Expand All @@ -238,51 +245,92 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Start the server in a goroutine.
go func() {
if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("ListenAndServe(): %v", err)
log.Debugf("ListenAndServe(): %v", err)
}
}()

// Open the authorization URL in the user's browser.
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))

browserAvailable := true
if len(noBrowser) == 1 && !noBrowser[0] {
fmt.Println("Opening browser for authentication...")

// Check if browser is available
if !browser.IsAvailable() {
log.Warn("No browser available on this system")
util.PrintSSHTunnelInstructions(8085)
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
browserAvailable = false
} else {
if err := browser.OpenURL(authURL); err != nil {
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
log.Warn(codex.GetUserFriendlyMessage(authErr))
util.PrintSSHTunnelInstructions(8085)
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
browserAvailable = false

// Log platform info for debugging
platformInfo := browser.GetPlatformInfo()
log.Debugf("Browser platform info: %+v", platformInfo)
} else {
log.Debug("Browser opened successfully")
}
}
} else {
util.PrintSSHTunnelInstructions(8085)
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
browserAvailable = false
}

fmt.Println("Waiting for authentication callback...")
// Show instructions
if !browserAvailable {
util.PrintSSHTunnelInstructions(8085)
}
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)

// Show manual input instructions
fmt.Println("\n" + strings.Repeat("=", 80))
fmt.Println("MANUAL CALLBACK INPUT (for Docker or remote environments):")
fmt.Println(strings.Repeat("=", 80))
fmt.Println("After completing authentication in the browser, you will be redirected to:")
fmt.Println(" http://localhost:8085/oauth2callback?code=... ")
fmt.Println("\nYou can either:")
fmt.Println(" 1. Let the callback reach this server automatically, OR")
fmt.Println(" 2. Manually paste the full URL or authorization code below:")
fmt.Println("(Waiting for input, or automatic callback...)\n")

// Start goroutine to listen for manual input
go func() {
reader := bufio.NewReader(os.Stdin)
input, err := reader.ReadString('\n')
if err != nil && err != io.EOF {
log.Debugf("Failed to read manual input: %v", err)
return
}
input = strings.TrimSpace(input)
if input != "" {
manualInputChan <- input
}
}()

// Wait for the authorization code or an error.
var authCode string
select {
case code := <-codeChan:
// Automatic callback succeeded
fmt.Println("✓ Automatic callback received")
authCode = code
case input := <-manualInputChan:
// Manual input provided
fmt.Println("✓ Manual input received, processing...")
code := extractCodeFromInput(input)
if code == "" {
_ = server.Close()
return nil, fmt.Errorf("could not extract authorization code from input")
}
authCode = code
case err := <-errChan:
_ = server.Close()
return nil, err
case <-time.After(5 * time.Minute): // Timeout
return nil, fmt.Errorf("oauth flow timed out")
case <-time.After(5 * time.Minute):
_ = server.Close()
return nil, fmt.Errorf("oauth flow timed out after 5 minutes")
case <-contextDone:
_ = server.Close()
return nil, fmt.Errorf("context cancelled")
}

// Shutdown the server.
Expand All @@ -299,3 +347,56 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
fmt.Println("Authentication successful.")
return token, nil
}

// extractCodeFromInput extracts the authorization code from various input formats
func extractCodeFromInput(input string) string {
input = strings.TrimSpace(input)

// If input looks like a full URL
if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") {
parsedURL, err := url.Parse(input)
if err != nil {
log.Warnf("Failed to parse URL: %v", err)
return ""
}

// Get code from query parameters
code := parsedURL.Query().Get("code")
if code != "" {
return code
}

// Check for error parameter
if errParam := parsedURL.Query().Get("error"); errParam != "" {
log.Errorf("Authentication error from callback: %s", errParam)
return ""
}
}

// If input looks like a direct code (no spaces or special URL characters)
if !strings.Contains(input, " ") && !strings.Contains(input, "?") && !strings.Contains(input, "/") {
return input
}

// Try to parse as query string (code=... format)
if strings.HasPrefix(input, "code=") {
parts := strings.Split(input, "&")
for _, part := range parts {
if strings.HasPrefix(part, "code=") {
return strings.TrimPrefix(part, "code=")
}
}
}

// Try to extract code from query string even if not prefixed with "code="
if strings.Contains(input, "code=") {
startIdx := strings.Index(input, "code=")
endIdx := strings.IndexAny(input[startIdx+5:], "&")
if endIdx == -1 {
return strings.TrimSpace(input[startIdx+5:])
}
return strings.TrimSpace(input[startIdx+5 : startIdx+5+endIdx])
}

return ""
}