From e857292512d409e5fcf8a7a033173deb20535139 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 30 Oct 2025 05:17:18 +0000 Subject: [PATCH 1/2] Add token provider infrastructure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This introduces a flexible TokenProvider interface that allows custom authentication implementations: - TokenProvider interface with static, external function support - Token struct with expiration handling - Authenticator wrapper for integration with existing auth system - Connector functions: WithTokenProvider, WithExternalToken, WithStaticToken This foundation enables custom token management strategies without requiring changes to the core driver. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- auth/tokenprovider/authenticator.go | 44 +++ auth/tokenprovider/authenticator_test.go | 135 ++++++++ auth/tokenprovider/external.go | 56 +++ auth/tokenprovider/provider.go | 43 +++ auth/tokenprovider/provider_test.go | 423 +++++++++++++++++++++++ auth/tokenprovider/static.go | 47 +++ connector.go | 30 ++ go.mod | 1 + go.sum | 2 + 9 files changed, 781 insertions(+) create mode 100644 auth/tokenprovider/authenticator.go create mode 100644 auth/tokenprovider/authenticator_test.go create mode 100644 auth/tokenprovider/external.go create mode 100644 auth/tokenprovider/provider.go create mode 100644 auth/tokenprovider/provider_test.go create mode 100644 auth/tokenprovider/static.go diff --git a/auth/tokenprovider/authenticator.go b/auth/tokenprovider/authenticator.go new file mode 100644 index 00000000..3955a4c9 --- /dev/null +++ b/auth/tokenprovider/authenticator.go @@ -0,0 +1,44 @@ +package tokenprovider + +import ( + "context" + "fmt" + "net/http" + + "github.com/databricks/databricks-sql-go/auth" + "github.com/rs/zerolog/log" +) + +// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider +type TokenProviderAuthenticator struct { + provider TokenProvider +} + +// NewAuthenticator creates an authenticator from a token provider +func NewAuthenticator(provider TokenProvider) auth.Authenticator { + return &TokenProviderAuthenticator{ + provider: provider, + } +} + +// Authenticate implements auth.Authenticator +func (a *TokenProviderAuthenticator) Authenticate(r *http.Request) error { + ctx := r.Context() + if ctx == nil { + ctx = context.Background() + } + + token, err := a.provider.GetToken(ctx) + if err != nil { + return fmt.Errorf("token provider authenticator: failed to get token: %w", err) + } + + if token.AccessToken == "" { + return fmt.Errorf("token provider authenticator: empty access token") + } + + token.SetAuthHeader(r) + log.Debug().Msgf("token provider authenticator: authenticated using provider %s", a.provider.Name()) + + return nil +} diff --git a/auth/tokenprovider/authenticator_test.go b/auth/tokenprovider/authenticator_test.go new file mode 100644 index 00000000..a47dd6bc --- /dev/null +++ b/auth/tokenprovider/authenticator_test.go @@ -0,0 +1,135 @@ +package tokenprovider + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTokenProviderAuthenticator(t *testing.T) { + t.Run("successful_authentication", func(t *testing.T) { + provider := NewStaticTokenProvider("test-token-123") + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + require.NoError(t, err) + assert.Equal(t, "Bearer test-token-123", req.Header.Get("Authorization")) + }) + + t.Run("authentication_with_custom_token_type", func(t *testing.T) { + provider := NewStaticTokenProviderWithType("test-token", "MAC") + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + require.NoError(t, err) + assert.Equal(t, "MAC test-token", req.Header.Get("Authorization")) + }) + + t.Run("authentication_error_propagation", func(t *testing.T) { + provider := &mockProvider{ + tokenFunc: func() (*Token, error) { + return nil, errors.New("provider failed") + }, + } + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "provider failed") + assert.Empty(t, req.Header.Get("Authorization")) + }) + + t.Run("empty_token_error", func(t *testing.T) { + provider := &mockProvider{ + tokenFunc: func() (*Token, error) { + return &Token{ + AccessToken: "", + TokenType: "Bearer", + }, nil + }, + } + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty access token") + assert.Empty(t, req.Header.Get("Authorization")) + }) + + t.Run("uses_request_context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + provider := &mockProvider{ + tokenFunc: func() (*Token, error) { + // This would normally check context cancellation + return &Token{ + AccessToken: "test-token", + TokenType: "Bearer", + }, nil + }, + } + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + // Even with cancelled context, this should work as our mock doesn't check it + require.NoError(t, err) + assert.Equal(t, "Bearer test-token", req.Header.Get("Authorization")) + }) + + t.Run("external_token_integration", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "external-token-456", nil + } + provider := NewExternalTokenProvider(tokenFunc) + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("POST", "http://example.com/api", nil) + err := authenticator.Authenticate(req) + + require.NoError(t, err) + assert.Equal(t, "Bearer external-token-456", req.Header.Get("Authorization")) + }) + + t.Run("cached_provider_integration", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + return &Token{ + AccessToken: "cached-token", + TokenType: "Bearer", + }, nil + }, + name: "test", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + authenticator := NewAuthenticator(cachedProvider) + + // Multiple authentication attempts + for i := 0; i < 3; i++ { + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + require.NoError(t, err) + assert.Equal(t, "Bearer cached-token", req.Header.Get("Authorization")) + } + + // Should only call base provider once due to caching + assert.Equal(t, 1, callCount) + }) +} diff --git a/auth/tokenprovider/external.go b/auth/tokenprovider/external.go new file mode 100644 index 00000000..0e511234 --- /dev/null +++ b/auth/tokenprovider/external.go @@ -0,0 +1,56 @@ +package tokenprovider + +import ( + "context" + "fmt" + "time" +) + +// ExternalTokenProvider provides tokens from an external source (passthrough) +type ExternalTokenProvider struct { + tokenFunc func() (string, error) + tokenType string +} + +// NewExternalTokenProvider creates a provider that gets tokens from an external function +func NewExternalTokenProvider(tokenFunc func() (string, error)) *ExternalTokenProvider { + return &ExternalTokenProvider{ + tokenFunc: tokenFunc, + tokenType: "Bearer", + } +} + +// NewExternalTokenProviderWithType creates a provider with a custom token type +func NewExternalTokenProviderWithType(tokenFunc func() (string, error), tokenType string) *ExternalTokenProvider { + return &ExternalTokenProvider{ + tokenFunc: tokenFunc, + tokenType: tokenType, + } +} + +// GetToken retrieves the token from the external source +func (p *ExternalTokenProvider) GetToken(ctx context.Context) (*Token, error) { + if p.tokenFunc == nil { + return nil, fmt.Errorf("external token provider: token function is nil") + } + + accessToken, err := p.tokenFunc() + if err != nil { + return nil, fmt.Errorf("external token provider: failed to get token: %w", err) + } + + if accessToken == "" { + return nil, fmt.Errorf("external token provider: empty token returned") + } + + return &Token{ + AccessToken: accessToken, + TokenType: p.tokenType, + ExpiresAt: time.Time{}, // External tokens don't provide expiry info + }, nil +} + +// Name returns the provider name +func (p *ExternalTokenProvider) Name() string { + return "external" +} diff --git a/auth/tokenprovider/provider.go b/auth/tokenprovider/provider.go new file mode 100644 index 00000000..3e94d6ef --- /dev/null +++ b/auth/tokenprovider/provider.go @@ -0,0 +1,43 @@ +package tokenprovider + +import ( + "context" + "net/http" + "time" +) + +// TokenProvider is the interface for providing tokens from various sources +type TokenProvider interface { + // GetToken retrieves a valid access token + GetToken(ctx context.Context) (*Token, error) + + // Name returns the provider name for logging/debugging + Name() string +} + +// Token represents an access token with metadata +type Token struct { + AccessToken string + TokenType string + ExpiresAt time.Time + RefreshToken string + Scopes []string +} + +// IsExpired checks if the token has expired +func (t *Token) IsExpired() bool { + if t.ExpiresAt.IsZero() { + return false // No expiry means token doesn't expire + } + // Consider token expired 5 minutes before actual expiry for safety + return time.Now().Add(5 * time.Minute).After(t.ExpiresAt) +} + +// SetAuthHeader sets the Authorization header on an HTTP request +func (t *Token) SetAuthHeader(r *http.Request) { + tokenType := t.TokenType + if tokenType == "" { + tokenType = "Bearer" + } + r.Header.Set("Authorization", tokenType+" "+t.AccessToken) +} diff --git a/auth/tokenprovider/provider_test.go b/auth/tokenprovider/provider_test.go new file mode 100644 index 00000000..5acb5538 --- /dev/null +++ b/auth/tokenprovider/provider_test.go @@ -0,0 +1,423 @@ +package tokenprovider + +import ( + "context" + "errors" + "net/http" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestToken_IsExpired(t *testing.T) { + tests := []struct { + name string + token *Token + expected bool + }{ + { + name: "token_without_expiry", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Time{}, + }, + expected: false, + }, + { + name: "token_expired", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Now().Add(-10 * time.Minute), + }, + expected: true, + }, + { + name: "token_not_expired", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Now().Add(10 * time.Minute), + }, + expected: false, + }, + { + name: "token_expires_within_5_minutes", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Now().Add(3 * time.Minute), + }, + expected: true, // Should be considered expired due to 5-minute buffer + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.token.IsExpired()) + }) + } +} + +func TestToken_SetAuthHeader(t *testing.T) { + tests := []struct { + name string + token *Token + expectedHeader string + }{ + { + name: "bearer_token", + token: &Token{ + AccessToken: "test-access-token", + TokenType: "Bearer", + }, + expectedHeader: "Bearer test-access-token", + }, + { + name: "default_to_bearer", + token: &Token{ + AccessToken: "test-access-token", + TokenType: "", + }, + expectedHeader: "Bearer test-access-token", + }, + { + name: "custom_token_type", + token: &Token{ + AccessToken: "test-access-token", + TokenType: "CustomAuth", + }, + expectedHeader: "CustomAuth test-access-token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com", nil) + tt.token.SetAuthHeader(req) + assert.Equal(t, tt.expectedHeader, req.Header.Get("Authorization")) + }) + } +} + +func TestStaticTokenProvider(t *testing.T) { + t.Run("valid_token", func(t *testing.T) { + provider := NewStaticTokenProvider("static-token-123") + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "static-token-123", token.AccessToken) + assert.Equal(t, "Bearer", token.TokenType) + assert.True(t, token.ExpiresAt.IsZero()) + assert.Equal(t, "static", provider.Name()) + }) + + t.Run("empty_token_error", func(t *testing.T) { + provider := NewStaticTokenProvider("") + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "token is empty") + }) + + t.Run("custom_token_type", func(t *testing.T) { + provider := NewStaticTokenProviderWithType("static-token", "CustomAuth") + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "static-token", token.AccessToken) + assert.Equal(t, "CustomAuth", token.TokenType) + }) + + t.Run("multiple_calls_same_token", func(t *testing.T) { + provider := NewStaticTokenProvider("static-token") + + token1, err1 := provider.GetToken(context.Background()) + token2, err2 := provider.GetToken(context.Background()) + + require.NoError(t, err1) + require.NoError(t, err2) + assert.Equal(t, token1.AccessToken, token2.AccessToken) + }) +} + +func TestExternalTokenProvider(t *testing.T) { + t.Run("successful_token_retrieval", func(t *testing.T) { + callCount := 0 + tokenFunc := func() (string, error) { + callCount++ + return "external-token-" + string(rune(callCount)), nil + } + + provider := NewExternalTokenProvider(tokenFunc) + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "external-token-\x01", token.AccessToken) + assert.Equal(t, "Bearer", token.TokenType) + assert.Equal(t, "external", provider.Name()) + }) + + t.Run("token_function_error", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "", errors.New("failed to retrieve token") + } + + provider := NewExternalTokenProvider(tokenFunc) + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "failed to get token") + }) + + t.Run("empty_token_error", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "", nil + } + + provider := NewExternalTokenProvider(tokenFunc) + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "empty token returned") + }) + + t.Run("nil_function_error", func(t *testing.T) { + provider := NewExternalTokenProvider(nil) + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "token function is nil") + }) + + t.Run("custom_token_type", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "external-token", nil + } + + provider := NewExternalTokenProviderWithType(tokenFunc, "MAC") + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "external-token", token.AccessToken) + assert.Equal(t, "MAC", token.TokenType) + }) + + t.Run("different_token_each_call", func(t *testing.T) { + counter := 0 + tokenFunc := func() (string, error) { + counter++ + return "token-" + string(rune(counter)), nil + } + + provider := NewExternalTokenProvider(tokenFunc) + + token1, err1 := provider.GetToken(context.Background()) + token2, err2 := provider.GetToken(context.Background()) + + require.NoError(t, err1) + require.NoError(t, err2) + assert.NotEqual(t, token1.AccessToken, token2.AccessToken) + assert.Equal(t, "token-\x01", token1.AccessToken) + assert.Equal(t, "token-\x02", token2.AccessToken) + }) +} + +func TestCachedTokenProvider(t *testing.T) { + t.Run("caches_valid_token", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + return &Token{ + AccessToken: "cached-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + + // First call - should fetch from base provider + token1, err1 := cachedProvider.GetToken(context.Background()) + require.NoError(t, err1) + assert.Equal(t, "cached-token", token1.AccessToken) + assert.Equal(t, 1, callCount) + + // Second call - should use cache + token2, err2 := cachedProvider.GetToken(context.Background()) + require.NoError(t, err2) + assert.Equal(t, "cached-token", token2.AccessToken) + assert.Equal(t, 1, callCount) // Should still be 1 + }) + + t.Run("refreshes_expired_token", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + // Return token that expires soon + return &Token{ + AccessToken: "token-" + string(rune(callCount)), + TokenType: "Bearer", + ExpiresAt: time.Now().Add(2 * time.Minute), // Within refresh threshold + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + cachedProvider.RefreshThreshold = 5 * time.Minute + + // First call + token1, err1 := cachedProvider.GetToken(context.Background()) + require.NoError(t, err1) + assert.Equal(t, "token-\x01", token1.AccessToken) + assert.Equal(t, 1, callCount) + + // Second call - should refresh because token expires within threshold + token2, err2 := cachedProvider.GetToken(context.Background()) + require.NoError(t, err2) + assert.Equal(t, "token-\x02", token2.AccessToken) + assert.Equal(t, 2, callCount) + }) + + t.Run("handles_provider_error", func(t *testing.T) { + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + return nil, errors.New("provider error") + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + token, err := cachedProvider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "provider error") + }) + + t.Run("no_expiry_token_not_refreshed", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + return &Token{ + AccessToken: "permanent-token", + TokenType: "Bearer", + ExpiresAt: time.Time{}, // No expiry + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + + // Multiple calls should all use cache + for i := 0; i < 5; i++ { + token, err := cachedProvider.GetToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, "permanent-token", token.AccessToken) + } + + assert.Equal(t, 1, callCount) // Should only be called once + }) + + t.Run("clear_cache", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + return &Token{ + AccessToken: "token-" + string(rune(callCount)), + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + + // First call + token1, _ := cachedProvider.GetToken(context.Background()) + assert.Equal(t, "token-\x01", token1.AccessToken) + assert.Equal(t, 1, callCount) + + // Clear cache + cachedProvider.ClearCache() + + // Next call should fetch new token + token2, _ := cachedProvider.GetToken(context.Background()) + assert.Equal(t, "token-\x02", token2.AccessToken) + assert.Equal(t, 2, callCount) + }) + + t.Run("concurrent_access", func(t *testing.T) { + var callCount atomic.Int32 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + // Simulate slow token fetch + time.Sleep(100 * time.Millisecond) + callCount.Add(1) + return &Token{ + AccessToken: "concurrent-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + + // Launch multiple goroutines + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + token, err := cachedProvider.GetToken(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "concurrent-token", token.AccessToken) + }() + } + + wg.Wait() + + // Should only fetch token once despite concurrent access + assert.Equal(t, int32(1), callCount.Load()) + }) + + t.Run("provider_name", func(t *testing.T) { + baseProvider := &mockProvider{name: "test-provider"} + cachedProvider := NewCachedTokenProvider(baseProvider) + + assert.Equal(t, "cached[test-provider]", cachedProvider.Name()) + }) +} + +// Mock provider for testing +type mockProvider struct { + tokenFunc func() (*Token, error) + name string +} + +func (m *mockProvider) GetToken(ctx context.Context) (*Token, error) { + if m.tokenFunc != nil { + return m.tokenFunc() + } + return nil, errors.New("not implemented") +} + +func (m *mockProvider) Name() string { + return m.name +} diff --git a/auth/tokenprovider/static.go b/auth/tokenprovider/static.go new file mode 100644 index 00000000..46079ba0 --- /dev/null +++ b/auth/tokenprovider/static.go @@ -0,0 +1,47 @@ +package tokenprovider + +import ( + "context" + "fmt" + "time" +) + +// StaticTokenProvider provides a static token that never changes +type StaticTokenProvider struct { + token string + tokenType string +} + +// NewStaticTokenProvider creates a provider with a static token +func NewStaticTokenProvider(token string) *StaticTokenProvider { + return &StaticTokenProvider{ + token: token, + tokenType: "Bearer", + } +} + +// NewStaticTokenProviderWithType creates a provider with a static token and custom type +func NewStaticTokenProviderWithType(token string, tokenType string) *StaticTokenProvider { + return &StaticTokenProvider{ + token: token, + tokenType: tokenType, + } +} + +// GetToken returns the static token +func (p *StaticTokenProvider) GetToken(ctx context.Context) (*Token, error) { + if p.token == "" { + return nil, fmt.Errorf("static token provider: token is empty") + } + + return &Token{ + AccessToken: p.token, + TokenType: p.tokenType, + ExpiresAt: time.Time{}, // Static tokens don't expire + }, nil +} + +// Name returns the provider name +func (p *StaticTokenProvider) Name() string { + return "static" +} diff --git a/connector.go b/connector.go index 53908b4c..fce77970 100644 --- a/connector.go +++ b/connector.go @@ -12,6 +12,7 @@ import ( "github.com/databricks/databricks-sql-go/auth" "github.com/databricks/databricks-sql-go/auth/oauth/m2m" "github.com/databricks/databricks-sql-go/auth/pat" + "github.com/databricks/databricks-sql-go/auth/tokenprovider" "github.com/databricks/databricks-sql-go/driverctx" dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" @@ -293,3 +294,32 @@ func WithClientCredentials(clientID, clientSecret string) ConnOption { } } } + +// WithTokenProvider sets up authentication using a custom token provider +func WithTokenProvider(provider tokenprovider.TokenProvider) ConnOption { + return func(c *config.Config) { + if provider != nil { + c.Authenticator = tokenprovider.NewAuthenticator(provider) + } + } +} + +// WithExternalToken sets up authentication using an external token function (passthrough) +func WithExternalToken(tokenFunc func() (string, error)) ConnOption { + return func(c *config.Config) { + if tokenFunc != nil { + provider := tokenprovider.NewExternalTokenProvider(tokenFunc) + c.Authenticator = tokenprovider.NewAuthenticator(provider) + } + } +} + +// WithStaticToken sets up authentication using a static token +func WithStaticToken(token string) ConnOption { + return func(c *config.Config) { + if token != "" { + provider := tokenprovider.NewStaticTokenProvider(token) + c.Authenticator = tokenprovider.NewAuthenticator(provider) + } + } +} diff --git a/go.mod b/go.mod index d9a517c5..1d1fdc78 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/apache/arrow/go/v12 v12.0.1 github.com/apache/thrift v0.17.0 github.com/coreos/go-oidc/v3 v3.5.0 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/joho/godotenv v1.4.0 github.com/mattn/go-isatty v0.0.20 github.com/pierrec/lz4/v4 v4.1.15 diff --git a/go.sum b/go.sum index edeb89ee..670487a8 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,8 @@ github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQr github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk= github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= From 1fb8c1a0e05adf609656094c23b29cff1562caab Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 30 Oct 2025 05:18:39 +0000 Subject: [PATCH 2/2] Add token caching and federation support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds automatic token exchange (federation) and caching capabilities: - CachedTokenProvider: Automatic token refresh with 5min buffer - FederationProvider: Auto-detects and exchanges external JWT tokens - Supports both user federation and SP-wide (M2M) federation - Graceful fallback if token exchange unavailable - Connector functions: WithFederatedTokenProvider, WithFederatedTokenProviderAndClientID - Azure domain list updates for staging/dev environments Token exchange follows RFC 8693 standard. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- auth/oauth/oauth.go | 2 + auth/tokenprovider/cached.go | 86 +++++++ auth/tokenprovider/exchange.go | 204 +++++++++++++++ auth/tokenprovider/federation_test.go | 348 ++++++++++++++++++++++++++ connector.go | 23 ++ 5 files changed, 663 insertions(+) create mode 100644 auth/tokenprovider/cached.go create mode 100644 auth/tokenprovider/exchange.go create mode 100644 auth/tokenprovider/federation_test.go diff --git a/auth/oauth/oauth.go b/auth/oauth/oauth.go index 0df9d5c4..80313fa2 100644 --- a/auth/oauth/oauth.go +++ b/auth/oauth/oauth.go @@ -85,6 +85,8 @@ var databricksAWSDomains []string = []string{ } var databricksAzureDomains []string = []string{ + ".staging.azuredatabricks.net", + ".dev.azuredatabricks.net", ".azuredatabricks.net", ".databricks.azure.cn", ".databricks.azure.us", diff --git a/auth/tokenprovider/cached.go b/auth/tokenprovider/cached.go new file mode 100644 index 00000000..b59e883e --- /dev/null +++ b/auth/tokenprovider/cached.go @@ -0,0 +1,86 @@ +package tokenprovider + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/rs/zerolog/log" +) + +// CachedTokenProvider wraps another provider and caches tokens +type CachedTokenProvider struct { + provider TokenProvider + cache *Token + mutex sync.RWMutex + // RefreshThreshold determines when to refresh (default 5 minutes before expiry) + RefreshThreshold time.Duration +} + +// NewCachedTokenProvider creates a caching wrapper around any token provider +func NewCachedTokenProvider(provider TokenProvider) *CachedTokenProvider { + return &CachedTokenProvider{ + provider: provider, + RefreshThreshold: 5 * time.Minute, + } +} + +// GetToken retrieves a token, using cache if available and valid +func (p *CachedTokenProvider) GetToken(ctx context.Context) (*Token, error) { + // Try to get from cache first + p.mutex.RLock() + cached := p.cache + p.mutex.RUnlock() + + if cached != nil && !p.shouldRefresh(cached) { + log.Debug().Msgf("cached token provider: using cached token for provider %s", p.provider.Name()) + return cached, nil + } + + // Need to refresh + p.mutex.Lock() + defer p.mutex.Unlock() + + // Double-check after acquiring write lock + if p.cache != nil && !p.shouldRefresh(p.cache) { + return p.cache, nil + } + + log.Debug().Msgf("cached token provider: fetching new token from provider %s", p.provider.Name()) + token, err := p.provider.GetToken(ctx) + if err != nil { + return nil, fmt.Errorf("cached token provider: failed to get token: %w", err) + } + + p.cache = token + return token, nil +} + +// shouldRefresh determines if a token should be refreshed +func (p *CachedTokenProvider) shouldRefresh(token *Token) bool { + if token == nil { + return true + } + + // If no expiry time, assume token doesn't expire + if token.ExpiresAt.IsZero() { + return false + } + + // Refresh if within threshold of expiry + refreshAt := token.ExpiresAt.Add(-p.RefreshThreshold) + return time.Now().After(refreshAt) +} + +// Name returns the provider name +func (p *CachedTokenProvider) Name() string { + return fmt.Sprintf("cached[%s]", p.provider.Name()) +} + +// ClearCache clears the cached token +func (p *CachedTokenProvider) ClearCache() { + p.mutex.Lock() + p.cache = nil + p.mutex.Unlock() +} diff --git a/auth/tokenprovider/exchange.go b/auth/tokenprovider/exchange.go new file mode 100644 index 00000000..8c0bba60 --- /dev/null +++ b/auth/tokenprovider/exchange.go @@ -0,0 +1,204 @@ +package tokenprovider + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/rs/zerolog/log" +) + +// FederationProvider wraps another token provider and automatically handles token exchange +type FederationProvider struct { + baseProvider TokenProvider + databricksHost string + clientID string // For SP-wide federation + httpClient *http.Client + // Settings for token exchange + returnOriginalTokenIfAuthenticated bool +} + +// NewFederationProvider creates a federation provider that wraps another provider +// It automatically detects when token exchange is needed and falls back gracefully +func NewFederationProvider(baseProvider TokenProvider, databricksHost string) *FederationProvider { + return &FederationProvider{ + baseProvider: baseProvider, + databricksHost: databricksHost, + httpClient: &http.Client{Timeout: 30 * time.Second}, + returnOriginalTokenIfAuthenticated: true, + } +} + +// NewFederationProviderWithClientID creates a provider for SP-wide federation (M2M) +func NewFederationProviderWithClientID(baseProvider TokenProvider, databricksHost, clientID string) *FederationProvider { + return &FederationProvider{ + baseProvider: baseProvider, + databricksHost: databricksHost, + clientID: clientID, + httpClient: &http.Client{Timeout: 30 * time.Second}, + returnOriginalTokenIfAuthenticated: true, + } +} + +// GetToken gets token from base provider and exchanges if needed +func (p *FederationProvider) GetToken(ctx context.Context) (*Token, error) { + // Get token from base provider + baseToken, err := p.baseProvider.GetToken(ctx) + if err != nil { + return nil, fmt.Errorf("federation provider: failed to get base token: %w", err) + } + + // Check if token is a JWT and needs exchange + if p.needsTokenExchange(baseToken.AccessToken) { + log.Debug().Msgf("federation provider: attempting token exchange for %s", p.baseProvider.Name()) + + // Try token exchange + exchangedToken, err := p.tryTokenExchange(ctx, baseToken.AccessToken) + if err != nil { + log.Warn().Err(err).Msg("federation provider: token exchange failed, using original token") + return baseToken, nil // Fall back to original token + } + + log.Debug().Msg("federation provider: token exchange successful") + return exchangedToken, nil + } + + // Use original token + return baseToken, nil +} + +// needsTokenExchange determines if a token needs exchange by checking if it's from a different issuer +func (p *FederationProvider) needsTokenExchange(tokenString string) bool { + // Try to parse as JWT + token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{}) + if err != nil { + log.Debug().Err(err).Msg("federation provider: not a JWT token, skipping exchange") + return false + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return false + } + + issuer, ok := claims["iss"].(string) + if !ok { + return false + } + + // Check if issuer is different from Databricks host + return !p.isSameHost(issuer, p.databricksHost) +} + +// tryTokenExchange attempts to exchange the token with Databricks +func (p *FederationProvider) tryTokenExchange(ctx context.Context, subjectToken string) (*Token, error) { + // Build exchange URL - add scheme if not present + exchangeURL := p.databricksHost + if !strings.HasPrefix(exchangeURL, "http://") && !strings.HasPrefix(exchangeURL, "https://") { + exchangeURL = "https://" + exchangeURL + } + if !strings.HasSuffix(exchangeURL, "/") { + exchangeURL += "/" + } + exchangeURL += "oidc/v1/token" + + // Prepare form data for token exchange + data := url.Values{} + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") + data.Set("scope", "sql") + data.Set("subject_token_type", "urn:ietf:params:oauth:token-type:jwt") + data.Set("subject_token", subjectToken) + + if p.returnOriginalTokenIfAuthenticated { + data.Set("return_original_token_if_authenticated", "true") + } + + // Add client_id for SP-wide federation + if p.clientID != "" { + data.Set("client_id", p.clientID) + } + + // Create request + req, err := http.NewRequestWithContext(ctx, "POST", exchangeURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "*/*") + + // Make request + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("exchange failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var tokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + } + + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + token := &Token{ + AccessToken: tokenResp.AccessToken, + TokenType: tokenResp.TokenType, + Scopes: strings.Fields(tokenResp.Scope), + } + + if tokenResp.ExpiresIn > 0 { + token.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + return token, nil +} + +// isSameHost compares two URLs to see if they have the same host +func (p *FederationProvider) isSameHost(url1, url2 string) bool { + // Add scheme to url2 if it doesn't have one (databricksHost may not have scheme) + parsedURL2 := url2 + if !strings.HasPrefix(url2, "http://") && !strings.HasPrefix(url2, "https://") { + parsedURL2 = "https://" + url2 + } + + u1, err1 := url.Parse(url1) + u2, err2 := url.Parse(parsedURL2) + + if err1 != nil || err2 != nil { + return false + } + + // Use Hostname() instead of Host to ignore port differences + // This handles cases like "host.com:443" == "host.com" for HTTPS + return u1.Hostname() == u2.Hostname() +} + +// Name returns the provider name +func (p *FederationProvider) Name() string { + baseName := p.baseProvider.Name() + if p.clientID != "" { + return fmt.Sprintf("federation[%s,sp:%s]", baseName, p.clientID[:8]) // Truncate client ID for readability + } + return fmt.Sprintf("federation[%s]", baseName) +} diff --git a/auth/tokenprovider/federation_test.go b/auth/tokenprovider/federation_test.go new file mode 100644 index 00000000..554b7333 --- /dev/null +++ b/auth/tokenprovider/federation_test.go @@ -0,0 +1,348 @@ +package tokenprovider + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper function to create JWT tokens for testing +func createTestJWT(issuer, audience string, expiryHours int) string { + claims := jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "exp": time.Now().Add(time.Duration(expiryHours) * time.Hour).Unix(), + "sub": "test-user", + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, _ := token.SignedString([]byte("test-secret")) + return tokenString +} + +func TestFederationProvider_HostComparison(t *testing.T) { + tests := []struct { + name string + issuer string + databricksHost string + shouldExchange bool + }{ + { + name: "same_host_no_port", + issuer: "https://test.databricks.com", + databricksHost: "test.databricks.com", + shouldExchange: false, + }, + { + name: "same_host_with_port_443", + issuer: "https://test.databricks.com:443", + databricksHost: "test.databricks.com", + shouldExchange: false, + }, + { + name: "same_host_both_with_port", + issuer: "https://test.databricks.com:443", + databricksHost: "test.databricks.com:443", + shouldExchange: false, + }, + { + name: "different_host_azure", + issuer: "https://login.microsoftonline.com/tenant-id/", + databricksHost: "test.databricks.com", + shouldExchange: true, + }, + { + name: "different_host_google", + issuer: "https://accounts.google.com", + databricksHost: "test.databricks.com", + shouldExchange: true, + }, + { + name: "different_host_aws", + issuer: "https://cognito-identity.amazonaws.com", + databricksHost: "test.databricks.com", + shouldExchange: true, + }, + { + name: "different_databricks_host", + issuer: "https://test1.databricks.com", + databricksHost: "test2.databricks.com", + shouldExchange: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a JWT token with the specified issuer + jwtToken := createTestJWT(tt.issuer, "databricks", 1) + + // Create a mock base provider + baseProvider := NewStaticTokenProvider(jwtToken) + + // Create federation provider + fedProvider := NewFederationProvider(baseProvider, tt.databricksHost) + + // Check if token needs exchange + needsExchange := fedProvider.needsTokenExchange(jwtToken) + assert.Equal(t, tt.shouldExchange, needsExchange, + "issuer=%s, host=%s, expected shouldExchange=%v, got=%v", + tt.issuer, tt.databricksHost, tt.shouldExchange, needsExchange) + }) + } +} + +func TestFederationProvider_TokenExchangeSuccess(t *testing.T) { + // Create mock token exchange server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method and path + assert.Equal(t, "POST", r.Method) + assert.Contains(t, r.URL.Path, "/oidc/v1/token") + + // Verify headers + assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + assert.Equal(t, "*/*", r.Header.Get("Accept")) + + // Parse form data + err := r.ParseForm() + require.NoError(t, err) + + // Verify form parameters + assert.Equal(t, "urn:ietf:params:oauth:grant-type:token-exchange", r.FormValue("grant_type")) + assert.Equal(t, "sql", r.FormValue("scope")) + assert.Equal(t, "urn:ietf:params:oauth:token-type:jwt", r.FormValue("subject_token_type")) + assert.NotEmpty(t, r.FormValue("subject_token")) + assert.Equal(t, "true", r.FormValue("return_original_token_if_authenticated")) + + // Return successful token response + response := map[string]interface{}{ + "access_token": "exchanged-databricks-token", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "sql", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + // Create external token with different issuer + externalToken := createTestJWT("https://login.microsoftonline.com/tenant-id/", "databricks", 1) + baseProvider := NewStaticTokenProvider(externalToken) + + // Create federation provider pointing to mock server + // Use full URL including http:// scheme for test server + fedProvider := NewFederationProvider(baseProvider, server.URL) + + // Get token - should trigger exchange + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + require.NoError(t, err) + assert.Equal(t, "exchanged-databricks-token", token.AccessToken) + assert.Equal(t, "Bearer", token.TokenType) + assert.False(t, token.ExpiresAt.IsZero()) + assert.Contains(t, token.Scopes, "sql") +} + +func TestFederationProvider_TokenExchangeWithClientID(t *testing.T) { + clientID := "test-client-id-12345" + + // Create mock server that checks for client_id + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + require.NoError(t, err) + + // Verify client_id is present + assert.Equal(t, clientID, r.FormValue("client_id")) + + response := map[string]interface{}{ + "access_token": "sp-wide-federation-token", + "token_type": "Bearer", + "expires_in": 3600, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + externalToken := createTestJWT("https://login.microsoftonline.com/tenant-id/", "databricks", 1) + baseProvider := NewStaticTokenProvider(externalToken) + + fedProvider := NewFederationProviderWithClientID(baseProvider, server.URL, clientID) + + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + require.NoError(t, err) + assert.Equal(t, "sp-wide-federation-token", token.AccessToken) +} + +func TestFederationProvider_TokenExchangeFailureFallback(t *testing.T) { + // Create mock server that returns error + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error": "invalid_request"}`)) + })) + defer server.Close() + + externalToken := createTestJWT("https://login.microsoftonline.com/tenant-id/", "databricks", 1) + baseProvider := NewStaticTokenProvider(externalToken) + + fedProvider := NewFederationProvider(baseProvider, server.URL) + + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + // Should not error - falls back to external token + require.NoError(t, err) + assert.Equal(t, externalToken, token.AccessToken, "Should fall back to original token on exchange failure") + assert.Equal(t, "Bearer", token.TokenType) +} + +func TestFederationProvider_NoExchangeWhenSameIssuer(t *testing.T) { + // Create token with Databricks as issuer + databricksHost := "test.databricks.com" + databricksToken := createTestJWT("https://"+databricksHost, "databricks", 1) + baseProvider := NewStaticTokenProvider(databricksToken) + + fedProvider := NewFederationProvider(baseProvider, databricksHost) + + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + // Should not exchange - just return original token + require.NoError(t, err) + assert.Equal(t, databricksToken, token.AccessToken, "Should use original token when issuer matches") +} + +func TestFederationProvider_NonJWTToken(t *testing.T) { + // Use a non-JWT token (e.g., opaque PAT) + opaqueToken := "dapi1234567890abcdef" + baseProvider := NewStaticTokenProvider(opaqueToken) + + fedProvider := NewFederationProvider(baseProvider, "test.databricks.com") + + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + // Should not error - just pass through non-JWT token + require.NoError(t, err) + assert.Equal(t, opaqueToken, token.AccessToken, "Should pass through non-JWT tokens") +} + +func TestFederationProvider_ProviderName(t *testing.T) { + baseProvider := NewStaticTokenProvider("test-token") + + t.Run("without_client_id", func(t *testing.T) { + fedProvider := NewFederationProvider(baseProvider, "test.databricks.com") + assert.Equal(t, "federation[static]", fedProvider.Name()) + }) + + t.Run("with_client_id", func(t *testing.T) { + fedProvider := NewFederationProviderWithClientID(baseProvider, "test.databricks.com", "client-12345678-more") + // Should truncate client ID to first 8 chars + assert.Equal(t, "federation[static,sp:client-1]", fedProvider.Name()) + }) +} + +func TestFederationProvider_CachedIntegration(t *testing.T) { + callCount := 0 + exchangeCount := 0 + + // Mock server that counts exchanges + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + exchangeCount++ + response := map[string]interface{}{ + "access_token": "databricks-token", + "token_type": "Bearer", + "expires_in": 3600, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + // External provider that counts calls + externalProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + externalToken := createTestJWT("https://login.microsoftonline.com/tenant/", "databricks", 1) + return &Token{ + AccessToken: externalToken, + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil + }, + name: "external", + } + + fedProvider := NewFederationProvider(externalProvider, server.URL) + cachedProvider := NewCachedTokenProvider(fedProvider) + + ctx := context.Background() + + // First call - should call external provider and exchange + token1, err1 := cachedProvider.GetToken(ctx) + require.NoError(t, err1) + assert.Equal(t, "databricks-token", token1.AccessToken) + assert.Equal(t, 1, callCount, "External provider should be called once") + assert.Equal(t, 1, exchangeCount, "Token should be exchanged once") + + // Second call - should use cache + token2, err2 := cachedProvider.GetToken(ctx) + require.NoError(t, err2) + assert.Equal(t, "databricks-token", token2.AccessToken) + assert.Equal(t, 1, callCount, "External provider should still be called only once (cached)") + assert.Equal(t, 1, exchangeCount, "Token should still be exchanged only once (cached)") +} + +func TestFederationProvider_InvalidJWT(t *testing.T) { + // Test with various invalid JWT formats + testCases := []string{ + "not.a.jwt", + "invalid-token-format", + "", + } + + for _, invalidToken := range testCases { + t.Run("invalid_jwt_"+invalidToken, func(t *testing.T) { + baseProvider := NewStaticTokenProvider(invalidToken) + fedProvider := NewFederationProvider(baseProvider, "test.databricks.com") + + // Should not need exchange for invalid JWT + needsExchange := fedProvider.needsTokenExchange(invalidToken) + assert.False(t, needsExchange, "Invalid JWT should not require exchange") + }) + } +} + +func TestFederationProvider_RealWorldIssuers(t *testing.T) { + // Test with real-world identity provider issuers + issuers := map[string]string{ + "azure_ad": "https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/v2.0", + "google": "https://accounts.google.com", + "aws_cognito": "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_example", + "okta": "https://dev-12345.okta.com/oauth2/default", + "auth0": "https://dev-12345.auth0.com/", + "github": "https://token.actions.githubusercontent.com", + } + + databricksHost := "test.databricks.com" + + for name, issuer := range issuers { + t.Run(name, func(t *testing.T) { + jwtToken := createTestJWT(issuer, "databricks", 1) + baseProvider := NewStaticTokenProvider(jwtToken) + fedProvider := NewFederationProvider(baseProvider, databricksHost) + + needsExchange := fedProvider.needsTokenExchange(jwtToken) + assert.True(t, needsExchange, "Token from %s should require exchange", name) + }) + } +} diff --git a/connector.go b/connector.go index fce77970..da5f7cbf 100644 --- a/connector.go +++ b/connector.go @@ -323,3 +323,26 @@ func WithStaticToken(token string) ConnOption { } } } + +// WithFederatedTokenProvider sets up authentication using token federation +// It wraps the base provider and automatically handles token exchange if needed +func WithFederatedTokenProvider(baseProvider tokenprovider.TokenProvider) ConnOption { + return func(c *config.Config) { + if baseProvider != nil { + // Wrap with federation provider that auto-detects need for token exchange + federationProvider := tokenprovider.NewFederationProvider(baseProvider, c.Host) + c.Authenticator = tokenprovider.NewAuthenticator(federationProvider) + } + } +} + +// WithFederatedTokenProviderAndClientID sets up SP-wide token federation +func WithFederatedTokenProviderAndClientID(baseProvider tokenprovider.TokenProvider, clientID string) ConnOption { + return func(c *config.Config) { + if baseProvider != nil { + // Wrap with federation provider for SP-wide federation + federationProvider := tokenprovider.NewFederationProviderWithClientID(baseProvider, c.Host, clientID) + c.Authenticator = tokenprovider.NewAuthenticator(federationProvider) + } + } +}