Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion internal/googleapi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
"time"

"github.com/99designs/keyring"
Expand All @@ -27,6 +29,58 @@ var (
openSecretsStore = secrets.OpenDefault
)

type persistingTokenSource struct {
base oauth2.TokenSource
store secrets.Store
client string
email string

mu sync.Mutex
tok secrets.Token
}

func newPersistingTokenSource(base oauth2.TokenSource, store secrets.Store, client string, email string, tok secrets.Token) oauth2.TokenSource {
return &persistingTokenSource{
base: base,
store: store,
client: client,
email: email,
tok: tok,
}
}

func (p *persistingTokenSource) Token() (*oauth2.Token, error) {
t, err := p.base.Token()
if err != nil {
return nil, err
}

refreshToken := strings.TrimSpace(t.RefreshToken)
if refreshToken == "" {
return t, nil
}

p.mu.Lock()
defer p.mu.Unlock()

if refreshToken == p.tok.RefreshToken {
return t, nil
}

updated := p.tok
updated.RefreshToken = refreshToken

if err := p.store.SetToken(p.client, p.email, updated); err != nil {
slog.Warn("persist rotated refresh token failed", "email", p.email, "client", p.client, "err", err)
return t, nil
}

p.tok = updated
slog.Debug("persisted rotated refresh token", "email", p.email, "client", p.client)

return t, nil
}

func tokenSourceForAccount(ctx context.Context, service googleauth.Service, email string) (oauth2.TokenSource, error) {
client, err := authclient.ResolveClient(ctx, email)
if err != nil {
Expand Down Expand Up @@ -80,7 +134,8 @@ func tokenSourceForAccountScopes(ctx context.Context, serviceLabel string, email
// Ensure refresh-token exchanges don't hang forever.
ctx = context.WithValue(ctx, oauth2.HTTPClient, &http.Client{Timeout: defaultHTTPTimeout})

return cfg.TokenSource(ctx, &oauth2.Token{RefreshToken: tok.RefreshToken}), nil
baseSource := cfg.TokenSource(ctx, &oauth2.Token{RefreshToken: tok.RefreshToken})
return newPersistingTokenSource(baseSource, store, client, email, tok), nil
}

func optionsForAccount(ctx context.Context, service googleauth.Service, email string) ([]option.ClientOption, error) {
Expand Down
118 changes: 112 additions & 6 deletions internal/googleapi/client_more_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"net/http"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"time"

"github.com/99designs/keyring"
"golang.org/x/oauth2"
Expand All @@ -29,14 +31,32 @@ type stubStore struct {
lastEmail string
tok secrets.Token
err error

setClient string
setEmail string
lastSet secrets.Token
setCalls int
setErr error
}

func (s *stubStore) Keys() ([]string, error) { return nil, nil }
func (s *stubStore) SetToken(string, string, secrets.Token) error { return nil }
func (s *stubStore) DeleteToken(string, string) error { return nil }
func (s *stubStore) ListTokens() ([]secrets.Token, error) { return nil, nil }
func (s *stubStore) GetDefaultAccount(string) (string, error) { return "", nil }
func (s *stubStore) SetDefaultAccount(string, string) error { return nil }
func (s *stubStore) Keys() ([]string, error) { return nil, nil }
func (s *stubStore) SetToken(client string, email string, tok secrets.Token) error {
s.setClient = client
s.setEmail = email
s.lastSet = tok
s.setCalls++

if s.setErr != nil {
return s.setErr
}

s.tok = tok
return nil
}
func (s *stubStore) DeleteToken(string, string) error { return nil }
func (s *stubStore) ListTokens() ([]secrets.Token, error) { return nil, nil }
func (s *stubStore) GetDefaultAccount(string) (string, error) { return "", nil }
func (s *stubStore) SetDefaultAccount(string, string) error { return nil }
func (s *stubStore) GetToken(client string, email string) (secrets.Token, error) {
s.lastClient = client
s.lastEmail = email
Expand Down Expand Up @@ -124,6 +144,92 @@ func TestTokenSourceForAccountScopes_HappyPath(t *testing.T) {
}
}

func TestPersistingTokenSource_PersistsRotatedRefreshToken(t *testing.T) {
stored := secrets.Token{
Client: config.DefaultClientName,
Email: "a@b.com",
RefreshToken: "old-refresh-token",
Services: []string{"gmail"},
Scopes: []string{"s1"},
CreatedAt: time.Unix(1735689600, 0).UTC(),
}

store := &stubStore{tok: stored}
base := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "access", RefreshToken: "new-refresh-token"})
ts := newPersistingTokenSource(base, store, config.DefaultClientName, "A@B.COM", stored)

if _, err := ts.Token(); err != nil {
t.Fatalf("Token: %v", err)
}

if store.setCalls != 1 {
t.Fatalf("expected 1 SetToken call, got %d", store.setCalls)
}

if store.setClient != config.DefaultClientName {
t.Fatalf("unexpected client: %q", store.setClient)
}

if store.setEmail != "A@B.COM" {
t.Fatalf("unexpected email: %q", store.setEmail)
}

if store.lastSet.RefreshToken != "new-refresh-token" {
t.Fatalf("expected rotated refresh token to persist, got %q", store.lastSet.RefreshToken)
}

if !reflect.DeepEqual(store.lastSet.Services, stored.Services) {
t.Fatalf("services changed unexpectedly: %#v", store.lastSet.Services)
}

if !reflect.DeepEqual(store.lastSet.Scopes, stored.Scopes) {
t.Fatalf("scopes changed unexpectedly: %#v", store.lastSet.Scopes)
}

if !store.lastSet.CreatedAt.Equal(stored.CreatedAt) {
t.Fatalf("createdAt changed unexpectedly: %v", store.lastSet.CreatedAt)
}
}

func TestPersistingTokenSource_NoRotationDoesNotPersist(t *testing.T) {
stored := secrets.Token{Email: "a@b.com", RefreshToken: "same-token"}
store := &stubStore{tok: stored}
base := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "access", RefreshToken: "same-token"})
ts := newPersistingTokenSource(base, store, config.DefaultClientName, "a@b.com", stored)

if _, err := ts.Token(); err != nil {
t.Fatalf("Token: %v", err)
}

if store.setCalls != 0 {
t.Fatalf("expected no SetToken calls, got %d", store.setCalls)
}
}

func TestPersistingTokenSource_PersistFailureIsNonFatal(t *testing.T) {
stored := secrets.Token{Email: "a@b.com", RefreshToken: "old-token"}
store := &stubStore{tok: stored, setErr: errBoom}
base := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "access", RefreshToken: "new-token"})
ts := newPersistingTokenSource(base, store, config.DefaultClientName, "a@b.com", stored)

tok, err := ts.Token()
if err != nil {
t.Fatalf("Token: %v", err)
}

if tok.AccessToken != "access" {
t.Fatalf("unexpected access token: %q", tok.AccessToken)
}

if store.setCalls != 1 {
t.Fatalf("expected 1 SetToken attempt, got %d", store.setCalls)
}

if store.tok.RefreshToken != "old-token" {
t.Fatalf("store should keep old token on persist error, got %q", store.tok.RefreshToken)
}
}

func TestTokenSourceForAccount_ReadCredsError(t *testing.T) {
origRead := readClientCredentials

Expand Down