From e857292512d409e5fcf8a7a033173deb20535139 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 30 Oct 2025 05:17:18 +0000 Subject: [PATCH 1/3] Add token provider infrastructure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This introduces a flexible TokenProvider interface that allows custom authentication implementations: - TokenProvider interface with static, external function support - Token struct with expiration handling - Authenticator wrapper for integration with existing auth system - Connector functions: WithTokenProvider, WithExternalToken, WithStaticToken This foundation enables custom token management strategies without requiring changes to the core driver. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- auth/tokenprovider/authenticator.go | 44 +++ auth/tokenprovider/authenticator_test.go | 135 ++++++++ auth/tokenprovider/external.go | 56 +++ auth/tokenprovider/provider.go | 43 +++ auth/tokenprovider/provider_test.go | 423 +++++++++++++++++++++++ auth/tokenprovider/static.go | 47 +++ connector.go | 30 ++ go.mod | 1 + go.sum | 2 + 9 files changed, 781 insertions(+) create mode 100644 auth/tokenprovider/authenticator.go create mode 100644 auth/tokenprovider/authenticator_test.go create mode 100644 auth/tokenprovider/external.go create mode 100644 auth/tokenprovider/provider.go create mode 100644 auth/tokenprovider/provider_test.go create mode 100644 auth/tokenprovider/static.go diff --git a/auth/tokenprovider/authenticator.go b/auth/tokenprovider/authenticator.go new file mode 100644 index 00000000..3955a4c9 --- /dev/null +++ b/auth/tokenprovider/authenticator.go @@ -0,0 +1,44 @@ +package tokenprovider + +import ( + "context" + "fmt" + "net/http" + + "github.com/databricks/databricks-sql-go/auth" + "github.com/rs/zerolog/log" +) + +// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider +type TokenProviderAuthenticator struct { + provider TokenProvider +} + +// NewAuthenticator creates an authenticator from a token provider +func NewAuthenticator(provider TokenProvider) auth.Authenticator { + return &TokenProviderAuthenticator{ + provider: provider, + } +} + +// Authenticate implements auth.Authenticator +func (a *TokenProviderAuthenticator) Authenticate(r *http.Request) error { + ctx := r.Context() + if ctx == nil { + ctx = context.Background() + } + + token, err := a.provider.GetToken(ctx) + if err != nil { + return fmt.Errorf("token provider authenticator: failed to get token: %w", err) + } + + if token.AccessToken == "" { + return fmt.Errorf("token provider authenticator: empty access token") + } + + token.SetAuthHeader(r) + log.Debug().Msgf("token provider authenticator: authenticated using provider %s", a.provider.Name()) + + return nil +} diff --git a/auth/tokenprovider/authenticator_test.go b/auth/tokenprovider/authenticator_test.go new file mode 100644 index 00000000..a47dd6bc --- /dev/null +++ b/auth/tokenprovider/authenticator_test.go @@ -0,0 +1,135 @@ +package tokenprovider + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTokenProviderAuthenticator(t *testing.T) { + t.Run("successful_authentication", func(t *testing.T) { + provider := NewStaticTokenProvider("test-token-123") + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + require.NoError(t, err) + assert.Equal(t, "Bearer test-token-123", req.Header.Get("Authorization")) + }) + + t.Run("authentication_with_custom_token_type", func(t *testing.T) { + provider := NewStaticTokenProviderWithType("test-token", "MAC") + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + require.NoError(t, err) + assert.Equal(t, "MAC test-token", req.Header.Get("Authorization")) + }) + + t.Run("authentication_error_propagation", func(t *testing.T) { + provider := &mockProvider{ + tokenFunc: func() (*Token, error) { + return nil, errors.New("provider failed") + }, + } + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "provider failed") + assert.Empty(t, req.Header.Get("Authorization")) + }) + + t.Run("empty_token_error", func(t *testing.T) { + provider := &mockProvider{ + tokenFunc: func() (*Token, error) { + return &Token{ + AccessToken: "", + TokenType: "Bearer", + }, nil + }, + } + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty access token") + assert.Empty(t, req.Header.Get("Authorization")) + }) + + t.Run("uses_request_context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + provider := &mockProvider{ + tokenFunc: func() (*Token, error) { + // This would normally check context cancellation + return &Token{ + AccessToken: "test-token", + TokenType: "Bearer", + }, nil + }, + } + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + // Even with cancelled context, this should work as our mock doesn't check it + require.NoError(t, err) + assert.Equal(t, "Bearer test-token", req.Header.Get("Authorization")) + }) + + t.Run("external_token_integration", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "external-token-456", nil + } + provider := NewExternalTokenProvider(tokenFunc) + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("POST", "http://example.com/api", nil) + err := authenticator.Authenticate(req) + + require.NoError(t, err) + assert.Equal(t, "Bearer external-token-456", req.Header.Get("Authorization")) + }) + + t.Run("cached_provider_integration", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + return &Token{ + AccessToken: "cached-token", + TokenType: "Bearer", + }, nil + }, + name: "test", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + authenticator := NewAuthenticator(cachedProvider) + + // Multiple authentication attempts + for i := 0; i < 3; i++ { + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + require.NoError(t, err) + assert.Equal(t, "Bearer cached-token", req.Header.Get("Authorization")) + } + + // Should only call base provider once due to caching + assert.Equal(t, 1, callCount) + }) +} diff --git a/auth/tokenprovider/external.go b/auth/tokenprovider/external.go new file mode 100644 index 00000000..0e511234 --- /dev/null +++ b/auth/tokenprovider/external.go @@ -0,0 +1,56 @@ +package tokenprovider + +import ( + "context" + "fmt" + "time" +) + +// ExternalTokenProvider provides tokens from an external source (passthrough) +type ExternalTokenProvider struct { + tokenFunc func() (string, error) + tokenType string +} + +// NewExternalTokenProvider creates a provider that gets tokens from an external function +func NewExternalTokenProvider(tokenFunc func() (string, error)) *ExternalTokenProvider { + return &ExternalTokenProvider{ + tokenFunc: tokenFunc, + tokenType: "Bearer", + } +} + +// NewExternalTokenProviderWithType creates a provider with a custom token type +func NewExternalTokenProviderWithType(tokenFunc func() (string, error), tokenType string) *ExternalTokenProvider { + return &ExternalTokenProvider{ + tokenFunc: tokenFunc, + tokenType: tokenType, + } +} + +// GetToken retrieves the token from the external source +func (p *ExternalTokenProvider) GetToken(ctx context.Context) (*Token, error) { + if p.tokenFunc == nil { + return nil, fmt.Errorf("external token provider: token function is nil") + } + + accessToken, err := p.tokenFunc() + if err != nil { + return nil, fmt.Errorf("external token provider: failed to get token: %w", err) + } + + if accessToken == "" { + return nil, fmt.Errorf("external token provider: empty token returned") + } + + return &Token{ + AccessToken: accessToken, + TokenType: p.tokenType, + ExpiresAt: time.Time{}, // External tokens don't provide expiry info + }, nil +} + +// Name returns the provider name +func (p *ExternalTokenProvider) Name() string { + return "external" +} diff --git a/auth/tokenprovider/provider.go b/auth/tokenprovider/provider.go new file mode 100644 index 00000000..3e94d6ef --- /dev/null +++ b/auth/tokenprovider/provider.go @@ -0,0 +1,43 @@ +package tokenprovider + +import ( + "context" + "net/http" + "time" +) + +// TokenProvider is the interface for providing tokens from various sources +type TokenProvider interface { + // GetToken retrieves a valid access token + GetToken(ctx context.Context) (*Token, error) + + // Name returns the provider name for logging/debugging + Name() string +} + +// Token represents an access token with metadata +type Token struct { + AccessToken string + TokenType string + ExpiresAt time.Time + RefreshToken string + Scopes []string +} + +// IsExpired checks if the token has expired +func (t *Token) IsExpired() bool { + if t.ExpiresAt.IsZero() { + return false // No expiry means token doesn't expire + } + // Consider token expired 5 minutes before actual expiry for safety + return time.Now().Add(5 * time.Minute).After(t.ExpiresAt) +} + +// SetAuthHeader sets the Authorization header on an HTTP request +func (t *Token) SetAuthHeader(r *http.Request) { + tokenType := t.TokenType + if tokenType == "" { + tokenType = "Bearer" + } + r.Header.Set("Authorization", tokenType+" "+t.AccessToken) +} diff --git a/auth/tokenprovider/provider_test.go b/auth/tokenprovider/provider_test.go new file mode 100644 index 00000000..5acb5538 --- /dev/null +++ b/auth/tokenprovider/provider_test.go @@ -0,0 +1,423 @@ +package tokenprovider + +import ( + "context" + "errors" + "net/http" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestToken_IsExpired(t *testing.T) { + tests := []struct { + name string + token *Token + expected bool + }{ + { + name: "token_without_expiry", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Time{}, + }, + expected: false, + }, + { + name: "token_expired", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Now().Add(-10 * time.Minute), + }, + expected: true, + }, + { + name: "token_not_expired", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Now().Add(10 * time.Minute), + }, + expected: false, + }, + { + name: "token_expires_within_5_minutes", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Now().Add(3 * time.Minute), + }, + expected: true, // Should be considered expired due to 5-minute buffer + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.token.IsExpired()) + }) + } +} + +func TestToken_SetAuthHeader(t *testing.T) { + tests := []struct { + name string + token *Token + expectedHeader string + }{ + { + name: "bearer_token", + token: &Token{ + AccessToken: "test-access-token", + TokenType: "Bearer", + }, + expectedHeader: "Bearer test-access-token", + }, + { + name: "default_to_bearer", + token: &Token{ + AccessToken: "test-access-token", + TokenType: "", + }, + expectedHeader: "Bearer test-access-token", + }, + { + name: "custom_token_type", + token: &Token{ + AccessToken: "test-access-token", + TokenType: "CustomAuth", + }, + expectedHeader: "CustomAuth test-access-token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com", nil) + tt.token.SetAuthHeader(req) + assert.Equal(t, tt.expectedHeader, req.Header.Get("Authorization")) + }) + } +} + +func TestStaticTokenProvider(t *testing.T) { + t.Run("valid_token", func(t *testing.T) { + provider := NewStaticTokenProvider("static-token-123") + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "static-token-123", token.AccessToken) + assert.Equal(t, "Bearer", token.TokenType) + assert.True(t, token.ExpiresAt.IsZero()) + assert.Equal(t, "static", provider.Name()) + }) + + t.Run("empty_token_error", func(t *testing.T) { + provider := NewStaticTokenProvider("") + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "token is empty") + }) + + t.Run("custom_token_type", func(t *testing.T) { + provider := NewStaticTokenProviderWithType("static-token", "CustomAuth") + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "static-token", token.AccessToken) + assert.Equal(t, "CustomAuth", token.TokenType) + }) + + t.Run("multiple_calls_same_token", func(t *testing.T) { + provider := NewStaticTokenProvider("static-token") + + token1, err1 := provider.GetToken(context.Background()) + token2, err2 := provider.GetToken(context.Background()) + + require.NoError(t, err1) + require.NoError(t, err2) + assert.Equal(t, token1.AccessToken, token2.AccessToken) + }) +} + +func TestExternalTokenProvider(t *testing.T) { + t.Run("successful_token_retrieval", func(t *testing.T) { + callCount := 0 + tokenFunc := func() (string, error) { + callCount++ + return "external-token-" + string(rune(callCount)), nil + } + + provider := NewExternalTokenProvider(tokenFunc) + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "external-token-\x01", token.AccessToken) + assert.Equal(t, "Bearer", token.TokenType) + assert.Equal(t, "external", provider.Name()) + }) + + t.Run("token_function_error", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "", errors.New("failed to retrieve token") + } + + provider := NewExternalTokenProvider(tokenFunc) + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "failed to get token") + }) + + t.Run("empty_token_error", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "", nil + } + + provider := NewExternalTokenProvider(tokenFunc) + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "empty token returned") + }) + + t.Run("nil_function_error", func(t *testing.T) { + provider := NewExternalTokenProvider(nil) + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "token function is nil") + }) + + t.Run("custom_token_type", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "external-token", nil + } + + provider := NewExternalTokenProviderWithType(tokenFunc, "MAC") + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "external-token", token.AccessToken) + assert.Equal(t, "MAC", token.TokenType) + }) + + t.Run("different_token_each_call", func(t *testing.T) { + counter := 0 + tokenFunc := func() (string, error) { + counter++ + return "token-" + string(rune(counter)), nil + } + + provider := NewExternalTokenProvider(tokenFunc) + + token1, err1 := provider.GetToken(context.Background()) + token2, err2 := provider.GetToken(context.Background()) + + require.NoError(t, err1) + require.NoError(t, err2) + assert.NotEqual(t, token1.AccessToken, token2.AccessToken) + assert.Equal(t, "token-\x01", token1.AccessToken) + assert.Equal(t, "token-\x02", token2.AccessToken) + }) +} + +func TestCachedTokenProvider(t *testing.T) { + t.Run("caches_valid_token", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + return &Token{ + AccessToken: "cached-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + + // First call - should fetch from base provider + token1, err1 := cachedProvider.GetToken(context.Background()) + require.NoError(t, err1) + assert.Equal(t, "cached-token", token1.AccessToken) + assert.Equal(t, 1, callCount) + + // Second call - should use cache + token2, err2 := cachedProvider.GetToken(context.Background()) + require.NoError(t, err2) + assert.Equal(t, "cached-token", token2.AccessToken) + assert.Equal(t, 1, callCount) // Should still be 1 + }) + + t.Run("refreshes_expired_token", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + // Return token that expires soon + return &Token{ + AccessToken: "token-" + string(rune(callCount)), + TokenType: "Bearer", + ExpiresAt: time.Now().Add(2 * time.Minute), // Within refresh threshold + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + cachedProvider.RefreshThreshold = 5 * time.Minute + + // First call + token1, err1 := cachedProvider.GetToken(context.Background()) + require.NoError(t, err1) + assert.Equal(t, "token-\x01", token1.AccessToken) + assert.Equal(t, 1, callCount) + + // Second call - should refresh because token expires within threshold + token2, err2 := cachedProvider.GetToken(context.Background()) + require.NoError(t, err2) + assert.Equal(t, "token-\x02", token2.AccessToken) + assert.Equal(t, 2, callCount) + }) + + t.Run("handles_provider_error", func(t *testing.T) { + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + return nil, errors.New("provider error") + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + token, err := cachedProvider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "provider error") + }) + + t.Run("no_expiry_token_not_refreshed", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + return &Token{ + AccessToken: "permanent-token", + TokenType: "Bearer", + ExpiresAt: time.Time{}, // No expiry + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + + // Multiple calls should all use cache + for i := 0; i < 5; i++ { + token, err := cachedProvider.GetToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, "permanent-token", token.AccessToken) + } + + assert.Equal(t, 1, callCount) // Should only be called once + }) + + t.Run("clear_cache", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + return &Token{ + AccessToken: "token-" + string(rune(callCount)), + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + + // First call + token1, _ := cachedProvider.GetToken(context.Background()) + assert.Equal(t, "token-\x01", token1.AccessToken) + assert.Equal(t, 1, callCount) + + // Clear cache + cachedProvider.ClearCache() + + // Next call should fetch new token + token2, _ := cachedProvider.GetToken(context.Background()) + assert.Equal(t, "token-\x02", token2.AccessToken) + assert.Equal(t, 2, callCount) + }) + + t.Run("concurrent_access", func(t *testing.T) { + var callCount atomic.Int32 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + // Simulate slow token fetch + time.Sleep(100 * time.Millisecond) + callCount.Add(1) + return &Token{ + AccessToken: "concurrent-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + + // Launch multiple goroutines + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + token, err := cachedProvider.GetToken(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "concurrent-token", token.AccessToken) + }() + } + + wg.Wait() + + // Should only fetch token once despite concurrent access + assert.Equal(t, int32(1), callCount.Load()) + }) + + t.Run("provider_name", func(t *testing.T) { + baseProvider := &mockProvider{name: "test-provider"} + cachedProvider := NewCachedTokenProvider(baseProvider) + + assert.Equal(t, "cached[test-provider]", cachedProvider.Name()) + }) +} + +// Mock provider for testing +type mockProvider struct { + tokenFunc func() (*Token, error) + name string +} + +func (m *mockProvider) GetToken(ctx context.Context) (*Token, error) { + if m.tokenFunc != nil { + return m.tokenFunc() + } + return nil, errors.New("not implemented") +} + +func (m *mockProvider) Name() string { + return m.name +} diff --git a/auth/tokenprovider/static.go b/auth/tokenprovider/static.go new file mode 100644 index 00000000..46079ba0 --- /dev/null +++ b/auth/tokenprovider/static.go @@ -0,0 +1,47 @@ +package tokenprovider + +import ( + "context" + "fmt" + "time" +) + +// StaticTokenProvider provides a static token that never changes +type StaticTokenProvider struct { + token string + tokenType string +} + +// NewStaticTokenProvider creates a provider with a static token +func NewStaticTokenProvider(token string) *StaticTokenProvider { + return &StaticTokenProvider{ + token: token, + tokenType: "Bearer", + } +} + +// NewStaticTokenProviderWithType creates a provider with a static token and custom type +func NewStaticTokenProviderWithType(token string, tokenType string) *StaticTokenProvider { + return &StaticTokenProvider{ + token: token, + tokenType: tokenType, + } +} + +// GetToken returns the static token +func (p *StaticTokenProvider) GetToken(ctx context.Context) (*Token, error) { + if p.token == "" { + return nil, fmt.Errorf("static token provider: token is empty") + } + + return &Token{ + AccessToken: p.token, + TokenType: p.tokenType, + ExpiresAt: time.Time{}, // Static tokens don't expire + }, nil +} + +// Name returns the provider name +func (p *StaticTokenProvider) Name() string { + return "static" +} diff --git a/connector.go b/connector.go index 53908b4c..fce77970 100644 --- a/connector.go +++ b/connector.go @@ -12,6 +12,7 @@ import ( "github.com/databricks/databricks-sql-go/auth" "github.com/databricks/databricks-sql-go/auth/oauth/m2m" "github.com/databricks/databricks-sql-go/auth/pat" + "github.com/databricks/databricks-sql-go/auth/tokenprovider" "github.com/databricks/databricks-sql-go/driverctx" dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" @@ -293,3 +294,32 @@ func WithClientCredentials(clientID, clientSecret string) ConnOption { } } } + +// WithTokenProvider sets up authentication using a custom token provider +func WithTokenProvider(provider tokenprovider.TokenProvider) ConnOption { + return func(c *config.Config) { + if provider != nil { + c.Authenticator = tokenprovider.NewAuthenticator(provider) + } + } +} + +// WithExternalToken sets up authentication using an external token function (passthrough) +func WithExternalToken(tokenFunc func() (string, error)) ConnOption { + return func(c *config.Config) { + if tokenFunc != nil { + provider := tokenprovider.NewExternalTokenProvider(tokenFunc) + c.Authenticator = tokenprovider.NewAuthenticator(provider) + } + } +} + +// WithStaticToken sets up authentication using a static token +func WithStaticToken(token string) ConnOption { + return func(c *config.Config) { + if token != "" { + provider := tokenprovider.NewStaticTokenProvider(token) + c.Authenticator = tokenprovider.NewAuthenticator(provider) + } + } +} diff --git a/go.mod b/go.mod index d9a517c5..1d1fdc78 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/apache/arrow/go/v12 v12.0.1 github.com/apache/thrift v0.17.0 github.com/coreos/go-oidc/v3 v3.5.0 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/joho/godotenv v1.4.0 github.com/mattn/go-isatty v0.0.20 github.com/pierrec/lz4/v4 v4.1.15 diff --git a/go.sum b/go.sum index edeb89ee..670487a8 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,8 @@ github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQr github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk= github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= From c7bd4b20044d72268727d159350fd6e1521f7bb3 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 30 Oct 2025 05:36:18 +0000 Subject: [PATCH 2/3] Remove cached provider tests from foundation PR --- auth/tokenprovider/authenticator_test.go | 28 ---- auth/tokenprovider/provider_test.go | 179 ----------------------- 2 files changed, 207 deletions(-) diff --git a/auth/tokenprovider/authenticator_test.go b/auth/tokenprovider/authenticator_test.go index a47dd6bc..7b703c1a 100644 --- a/auth/tokenprovider/authenticator_test.go +++ b/auth/tokenprovider/authenticator_test.go @@ -104,32 +104,4 @@ func TestTokenProviderAuthenticator(t *testing.T) { require.NoError(t, err) assert.Equal(t, "Bearer external-token-456", req.Header.Get("Authorization")) }) - - t.Run("cached_provider_integration", func(t *testing.T) { - callCount := 0 - baseProvider := &mockProvider{ - tokenFunc: func() (*Token, error) { - callCount++ - return &Token{ - AccessToken: "cached-token", - TokenType: "Bearer", - }, nil - }, - name: "test", - } - - cachedProvider := NewCachedTokenProvider(baseProvider) - authenticator := NewAuthenticator(cachedProvider) - - // Multiple authentication attempts - for i := 0; i < 3; i++ { - req, _ := http.NewRequest("GET", "http://example.com", nil) - err := authenticator.Authenticate(req) - require.NoError(t, err) - assert.Equal(t, "Bearer cached-token", req.Header.Get("Authorization")) - } - - // Should only call base provider once due to caching - assert.Equal(t, 1, callCount) - }) } diff --git a/auth/tokenprovider/provider_test.go b/auth/tokenprovider/provider_test.go index 5acb5538..82182eb7 100644 --- a/auth/tokenprovider/provider_test.go +++ b/auth/tokenprovider/provider_test.go @@ -4,8 +4,6 @@ import ( "context" "errors" "net/http" - "sync" - "sync/atomic" "testing" "time" @@ -228,183 +226,6 @@ func TestExternalTokenProvider(t *testing.T) { }) } -func TestCachedTokenProvider(t *testing.T) { - t.Run("caches_valid_token", func(t *testing.T) { - callCount := 0 - baseProvider := &mockProvider{ - tokenFunc: func() (*Token, error) { - callCount++ - return &Token{ - AccessToken: "cached-token", - TokenType: "Bearer", - ExpiresAt: time.Now().Add(1 * time.Hour), - }, nil - }, - name: "mock", - } - - cachedProvider := NewCachedTokenProvider(baseProvider) - - // First call - should fetch from base provider - token1, err1 := cachedProvider.GetToken(context.Background()) - require.NoError(t, err1) - assert.Equal(t, "cached-token", token1.AccessToken) - assert.Equal(t, 1, callCount) - - // Second call - should use cache - token2, err2 := cachedProvider.GetToken(context.Background()) - require.NoError(t, err2) - assert.Equal(t, "cached-token", token2.AccessToken) - assert.Equal(t, 1, callCount) // Should still be 1 - }) - - t.Run("refreshes_expired_token", func(t *testing.T) { - callCount := 0 - baseProvider := &mockProvider{ - tokenFunc: func() (*Token, error) { - callCount++ - // Return token that expires soon - return &Token{ - AccessToken: "token-" + string(rune(callCount)), - TokenType: "Bearer", - ExpiresAt: time.Now().Add(2 * time.Minute), // Within refresh threshold - }, nil - }, - name: "mock", - } - - cachedProvider := NewCachedTokenProvider(baseProvider) - cachedProvider.RefreshThreshold = 5 * time.Minute - - // First call - token1, err1 := cachedProvider.GetToken(context.Background()) - require.NoError(t, err1) - assert.Equal(t, "token-\x01", token1.AccessToken) - assert.Equal(t, 1, callCount) - - // Second call - should refresh because token expires within threshold - token2, err2 := cachedProvider.GetToken(context.Background()) - require.NoError(t, err2) - assert.Equal(t, "token-\x02", token2.AccessToken) - assert.Equal(t, 2, callCount) - }) - - t.Run("handles_provider_error", func(t *testing.T) { - baseProvider := &mockProvider{ - tokenFunc: func() (*Token, error) { - return nil, errors.New("provider error") - }, - name: "mock", - } - - cachedProvider := NewCachedTokenProvider(baseProvider) - token, err := cachedProvider.GetToken(context.Background()) - - assert.Error(t, err) - assert.Nil(t, token) - assert.Contains(t, err.Error(), "provider error") - }) - - t.Run("no_expiry_token_not_refreshed", func(t *testing.T) { - callCount := 0 - baseProvider := &mockProvider{ - tokenFunc: func() (*Token, error) { - callCount++ - return &Token{ - AccessToken: "permanent-token", - TokenType: "Bearer", - ExpiresAt: time.Time{}, // No expiry - }, nil - }, - name: "mock", - } - - cachedProvider := NewCachedTokenProvider(baseProvider) - - // Multiple calls should all use cache - for i := 0; i < 5; i++ { - token, err := cachedProvider.GetToken(context.Background()) - require.NoError(t, err) - assert.Equal(t, "permanent-token", token.AccessToken) - } - - assert.Equal(t, 1, callCount) // Should only be called once - }) - - t.Run("clear_cache", func(t *testing.T) { - callCount := 0 - baseProvider := &mockProvider{ - tokenFunc: func() (*Token, error) { - callCount++ - return &Token{ - AccessToken: "token-" + string(rune(callCount)), - TokenType: "Bearer", - ExpiresAt: time.Now().Add(1 * time.Hour), - }, nil - }, - name: "mock", - } - - cachedProvider := NewCachedTokenProvider(baseProvider) - - // First call - token1, _ := cachedProvider.GetToken(context.Background()) - assert.Equal(t, "token-\x01", token1.AccessToken) - assert.Equal(t, 1, callCount) - - // Clear cache - cachedProvider.ClearCache() - - // Next call should fetch new token - token2, _ := cachedProvider.GetToken(context.Background()) - assert.Equal(t, "token-\x02", token2.AccessToken) - assert.Equal(t, 2, callCount) - }) - - t.Run("concurrent_access", func(t *testing.T) { - var callCount atomic.Int32 - baseProvider := &mockProvider{ - tokenFunc: func() (*Token, error) { - // Simulate slow token fetch - time.Sleep(100 * time.Millisecond) - callCount.Add(1) - return &Token{ - AccessToken: "concurrent-token", - TokenType: "Bearer", - ExpiresAt: time.Now().Add(1 * time.Hour), - }, nil - }, - name: "mock", - } - - cachedProvider := NewCachedTokenProvider(baseProvider) - - // Launch multiple goroutines - var wg sync.WaitGroup - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - token, err := cachedProvider.GetToken(context.Background()) - assert.NoError(t, err) - assert.Equal(t, "concurrent-token", token.AccessToken) - }() - } - - wg.Wait() - - // Should only fetch token once despite concurrent access - assert.Equal(t, int32(1), callCount.Load()) - }) - - t.Run("provider_name", func(t *testing.T) { - baseProvider := &mockProvider{name: "test-provider"} - cachedProvider := NewCachedTokenProvider(baseProvider) - - assert.Equal(t, "cached[test-provider]", cachedProvider.Name()) - }) -} - // Mock provider for testing type mockProvider struct { tokenFunc func() (*Token, error) From 08d163dbc25262e1c9d1e771a82a4b13f8db1a02 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Tue, 18 Nov 2025 09:02:51 +0000 Subject: [PATCH 3/3] Address PR review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Reduce token expiry buffer from 5 minutes to 30 seconds (matches SDK standard) - Add detailed documentation to TokenProviderAuthenticator explaining flow - Add ctx.Err() check in ExternalTokenProvider for cancellation support - Rename tokenFunc to tokenSource for better clarity - Remove duplicate empty token validation from ExternalTokenProvider - Update tests to reflect changes 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- auth/tokenprovider/authenticator.go | 11 +++++++++- auth/tokenprovider/external.go | 34 +++++++++++++++-------------- auth/tokenprovider/provider.go | 5 +++-- auth/tokenprovider/provider_test.go | 17 ++++++++------- 4 files changed, 40 insertions(+), 27 deletions(-) diff --git a/auth/tokenprovider/authenticator.go b/auth/tokenprovider/authenticator.go index 3955a4c9..b4ef57ff 100644 --- a/auth/tokenprovider/authenticator.go +++ b/auth/tokenprovider/authenticator.go @@ -9,7 +9,16 @@ import ( "github.com/rs/zerolog/log" ) -// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider +// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider. +// +// Authentication Flow: +// 1. On each Authenticate() call, retrieves a token from the configured TokenProvider +// 2. The provider may implement its own caching and refresh logic +// 3. Validates the returned token is non-empty +// 4. Sets the Authorization header with the token type and value +// +// The authenticator delegates all token management (caching, refresh, expiry) +// to the underlying TokenProvider implementation. type TokenProviderAuthenticator struct { provider TokenProvider } diff --git a/auth/tokenprovider/external.go b/auth/tokenprovider/external.go index 0e511234..c2b6a9c5 100644 --- a/auth/tokenprovider/external.go +++ b/auth/tokenprovider/external.go @@ -6,41 +6,43 @@ import ( "time" ) -// ExternalTokenProvider provides tokens from an external source (passthrough) +// ExternalTokenProvider provides tokens from an external source (passthrough). +// This provider calls a user-supplied function to retrieve tokens on-demand. type ExternalTokenProvider struct { - tokenFunc func() (string, error) - tokenType string + tokenSource func() (string, error) + tokenType string } // NewExternalTokenProvider creates a provider that gets tokens from an external function -func NewExternalTokenProvider(tokenFunc func() (string, error)) *ExternalTokenProvider { +func NewExternalTokenProvider(tokenSource func() (string, error)) *ExternalTokenProvider { return &ExternalTokenProvider{ - tokenFunc: tokenFunc, - tokenType: "Bearer", + tokenSource: tokenSource, + tokenType: "Bearer", } } // NewExternalTokenProviderWithType creates a provider with a custom token type -func NewExternalTokenProviderWithType(tokenFunc func() (string, error), tokenType string) *ExternalTokenProvider { +func NewExternalTokenProviderWithType(tokenSource func() (string, error), tokenType string) *ExternalTokenProvider { return &ExternalTokenProvider{ - tokenFunc: tokenFunc, - tokenType: tokenType, + tokenSource: tokenSource, + tokenType: tokenType, } } // GetToken retrieves the token from the external source func (p *ExternalTokenProvider) GetToken(ctx context.Context) (*Token, error) { - if p.tokenFunc == nil { - return nil, fmt.Errorf("external token provider: token function is nil") + // Check for cancellation first + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("external token provider: context cancelled: %w", err) } - accessToken, err := p.tokenFunc() - if err != nil { - return nil, fmt.Errorf("external token provider: failed to get token: %w", err) + if p.tokenSource == nil { + return nil, fmt.Errorf("external token provider: token source is nil") } - if accessToken == "" { - return nil, fmt.Errorf("external token provider: empty token returned") + accessToken, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("external token provider: failed to get token: %w", err) } return &Token{ diff --git a/auth/tokenprovider/provider.go b/auth/tokenprovider/provider.go index 3e94d6ef..6faf0a9d 100644 --- a/auth/tokenprovider/provider.go +++ b/auth/tokenprovider/provider.go @@ -29,8 +29,9 @@ func (t *Token) IsExpired() bool { if t.ExpiresAt.IsZero() { return false // No expiry means token doesn't expire } - // Consider token expired 5 minutes before actual expiry for safety - return time.Now().Add(5 * time.Minute).After(t.ExpiresAt) + // Consider token expired 30 seconds before actual expiry for safety + // This matches the standard buffer used by other Databricks SDKs + return time.Now().Add(30 * time.Second).After(t.ExpiresAt) } // SetAuthHeader sets the Authorization header on an HTTP request diff --git a/auth/tokenprovider/provider_test.go b/auth/tokenprovider/provider_test.go index 82182eb7..e3df4753 100644 --- a/auth/tokenprovider/provider_test.go +++ b/auth/tokenprovider/provider_test.go @@ -42,12 +42,12 @@ func TestToken_IsExpired(t *testing.T) { expected: false, }, { - name: "token_expires_within_5_minutes", + name: "token_expires_within_30_seconds", token: &Token{ AccessToken: "test-token", - ExpiresAt: time.Now().Add(3 * time.Minute), + ExpiresAt: time.Now().Add(15 * time.Second), }, - expected: true, // Should be considered expired due to 5-minute buffer + expected: true, // Should be considered expired due to 30-second buffer }, } @@ -171,7 +171,7 @@ func TestExternalTokenProvider(t *testing.T) { assert.Contains(t, err.Error(), "failed to get token") }) - t.Run("empty_token_error", func(t *testing.T) { + t.Run("empty_token_allowed", func(t *testing.T) { tokenFunc := func() (string, error) { return "", nil } @@ -179,9 +179,10 @@ func TestExternalTokenProvider(t *testing.T) { provider := NewExternalTokenProvider(tokenFunc) token, err := provider.GetToken(context.Background()) - assert.Error(t, err) - assert.Nil(t, token) - assert.Contains(t, err.Error(), "empty token returned") + assert.NoError(t, err) + assert.NotNil(t, token) + assert.Empty(t, token.AccessToken) + // Empty tokens are validated at the authenticator level, not provider level }) t.Run("nil_function_error", func(t *testing.T) { @@ -190,7 +191,7 @@ func TestExternalTokenProvider(t *testing.T) { assert.Error(t, err) assert.Nil(t, token) - assert.Contains(t, err.Error(), "token function is nil") + assert.Contains(t, err.Error(), "token source is nil") }) t.Run("custom_token_type", func(t *testing.T) {