diff --git a/internal/googleapi/client.go b/internal/googleapi/client.go index 40a77d73..f8d4e4f0 100644 --- a/internal/googleapi/client.go +++ b/internal/googleapi/client.go @@ -7,6 +7,8 @@ import ( "fmt" "log/slog" "net/http" + "strings" + "sync" "time" "github.com/99designs/keyring" @@ -27,6 +29,58 @@ var ( openSecretsStore = secrets.OpenDefault ) +type persistingTokenSource struct { + base oauth2.TokenSource + store secrets.Store + client string + email string + + mu sync.Mutex + tok secrets.Token +} + +func newPersistingTokenSource(base oauth2.TokenSource, store secrets.Store, client string, email string, tok secrets.Token) oauth2.TokenSource { + return &persistingTokenSource{ + base: base, + store: store, + client: client, + email: email, + tok: tok, + } +} + +func (p *persistingTokenSource) Token() (*oauth2.Token, error) { + t, err := p.base.Token() + if err != nil { + return nil, err + } + + refreshToken := strings.TrimSpace(t.RefreshToken) + if refreshToken == "" { + return t, nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + if refreshToken == p.tok.RefreshToken { + return t, nil + } + + updated := p.tok + updated.RefreshToken = refreshToken + + if err := p.store.SetToken(p.client, p.email, updated); err != nil { + slog.Warn("persist rotated refresh token failed", "email", p.email, "client", p.client, "err", err) + return t, nil + } + + p.tok = updated + slog.Debug("persisted rotated refresh token", "email", p.email, "client", p.client) + + return t, nil +} + func tokenSourceForAccount(ctx context.Context, service googleauth.Service, email string) (oauth2.TokenSource, error) { client, err := authclient.ResolveClient(ctx, email) if err != nil { @@ -80,7 +134,8 @@ func tokenSourceForAccountScopes(ctx context.Context, serviceLabel string, email // Ensure refresh-token exchanges don't hang forever. ctx = context.WithValue(ctx, oauth2.HTTPClient, &http.Client{Timeout: defaultHTTPTimeout}) - return cfg.TokenSource(ctx, &oauth2.Token{RefreshToken: tok.RefreshToken}), nil + baseSource := cfg.TokenSource(ctx, &oauth2.Token{RefreshToken: tok.RefreshToken}) + return newPersistingTokenSource(baseSource, store, client, email, tok), nil } func optionsForAccount(ctx context.Context, service googleauth.Service, email string) ([]option.ClientOption, error) { diff --git a/internal/googleapi/client_more_test.go b/internal/googleapi/client_more_test.go index d520b86d..f23b17f0 100644 --- a/internal/googleapi/client_more_test.go +++ b/internal/googleapi/client_more_test.go @@ -7,8 +7,10 @@ import ( "net/http" "os" "path/filepath" + "reflect" "strings" "testing" + "time" "github.com/99designs/keyring" "golang.org/x/oauth2" @@ -29,14 +31,32 @@ type stubStore struct { lastEmail string tok secrets.Token err error + + setClient string + setEmail string + lastSet secrets.Token + setCalls int + setErr error } -func (s *stubStore) Keys() ([]string, error) { return nil, nil } -func (s *stubStore) SetToken(string, string, secrets.Token) error { return nil } -func (s *stubStore) DeleteToken(string, string) error { return nil } -func (s *stubStore) ListTokens() ([]secrets.Token, error) { return nil, nil } -func (s *stubStore) GetDefaultAccount(string) (string, error) { return "", nil } -func (s *stubStore) SetDefaultAccount(string, string) error { return nil } +func (s *stubStore) Keys() ([]string, error) { return nil, nil } +func (s *stubStore) SetToken(client string, email string, tok secrets.Token) error { + s.setClient = client + s.setEmail = email + s.lastSet = tok + s.setCalls++ + + if s.setErr != nil { + return s.setErr + } + + s.tok = tok + return nil +} +func (s *stubStore) DeleteToken(string, string) error { return nil } +func (s *stubStore) ListTokens() ([]secrets.Token, error) { return nil, nil } +func (s *stubStore) GetDefaultAccount(string) (string, error) { return "", nil } +func (s *stubStore) SetDefaultAccount(string, string) error { return nil } func (s *stubStore) GetToken(client string, email string) (secrets.Token, error) { s.lastClient = client s.lastEmail = email @@ -124,6 +144,92 @@ func TestTokenSourceForAccountScopes_HappyPath(t *testing.T) { } } +func TestPersistingTokenSource_PersistsRotatedRefreshToken(t *testing.T) { + stored := secrets.Token{ + Client: config.DefaultClientName, + Email: "a@b.com", + RefreshToken: "old-refresh-token", + Services: []string{"gmail"}, + Scopes: []string{"s1"}, + CreatedAt: time.Unix(1735689600, 0).UTC(), + } + + store := &stubStore{tok: stored} + base := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "access", RefreshToken: "new-refresh-token"}) + ts := newPersistingTokenSource(base, store, config.DefaultClientName, "A@B.COM", stored) + + if _, err := ts.Token(); err != nil { + t.Fatalf("Token: %v", err) + } + + if store.setCalls != 1 { + t.Fatalf("expected 1 SetToken call, got %d", store.setCalls) + } + + if store.setClient != config.DefaultClientName { + t.Fatalf("unexpected client: %q", store.setClient) + } + + if store.setEmail != "A@B.COM" { + t.Fatalf("unexpected email: %q", store.setEmail) + } + + if store.lastSet.RefreshToken != "new-refresh-token" { + t.Fatalf("expected rotated refresh token to persist, got %q", store.lastSet.RefreshToken) + } + + if !reflect.DeepEqual(store.lastSet.Services, stored.Services) { + t.Fatalf("services changed unexpectedly: %#v", store.lastSet.Services) + } + + if !reflect.DeepEqual(store.lastSet.Scopes, stored.Scopes) { + t.Fatalf("scopes changed unexpectedly: %#v", store.lastSet.Scopes) + } + + if !store.lastSet.CreatedAt.Equal(stored.CreatedAt) { + t.Fatalf("createdAt changed unexpectedly: %v", store.lastSet.CreatedAt) + } +} + +func TestPersistingTokenSource_NoRotationDoesNotPersist(t *testing.T) { + stored := secrets.Token{Email: "a@b.com", RefreshToken: "same-token"} + store := &stubStore{tok: stored} + base := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "access", RefreshToken: "same-token"}) + ts := newPersistingTokenSource(base, store, config.DefaultClientName, "a@b.com", stored) + + if _, err := ts.Token(); err != nil { + t.Fatalf("Token: %v", err) + } + + if store.setCalls != 0 { + t.Fatalf("expected no SetToken calls, got %d", store.setCalls) + } +} + +func TestPersistingTokenSource_PersistFailureIsNonFatal(t *testing.T) { + stored := secrets.Token{Email: "a@b.com", RefreshToken: "old-token"} + store := &stubStore{tok: stored, setErr: errBoom} + base := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "access", RefreshToken: "new-token"}) + ts := newPersistingTokenSource(base, store, config.DefaultClientName, "a@b.com", stored) + + tok, err := ts.Token() + if err != nil { + t.Fatalf("Token: %v", err) + } + + if tok.AccessToken != "access" { + t.Fatalf("unexpected access token: %q", tok.AccessToken) + } + + if store.setCalls != 1 { + t.Fatalf("expected 1 SetToken attempt, got %d", store.setCalls) + } + + if store.tok.RefreshToken != "old-token" { + t.Fatalf("store should keep old token on persist error, got %q", store.tok.RefreshToken) + } +} + func TestTokenSourceForAccount_ReadCredsError(t *testing.T) { origRead := readClientCredentials