diff --git a/auth/tokenprovider/authenticator.go b/auth/tokenprovider/authenticator.go new file mode 100644 index 00000000..b4ef57ff --- /dev/null +++ b/auth/tokenprovider/authenticator.go @@ -0,0 +1,53 @@ +package tokenprovider + +import ( + "context" + "fmt" + "net/http" + + "github.com/databricks/databricks-sql-go/auth" + "github.com/rs/zerolog/log" +) + +// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider. +// +// Authentication Flow: +// 1. On each Authenticate() call, retrieves a token from the configured TokenProvider +// 2. The provider may implement its own caching and refresh logic +// 3. Validates the returned token is non-empty +// 4. Sets the Authorization header with the token type and value +// +// The authenticator delegates all token management (caching, refresh, expiry) +// to the underlying TokenProvider implementation. +type TokenProviderAuthenticator struct { + provider TokenProvider +} + +// NewAuthenticator creates an authenticator from a token provider +func NewAuthenticator(provider TokenProvider) auth.Authenticator { + return &TokenProviderAuthenticator{ + provider: provider, + } +} + +// Authenticate implements auth.Authenticator +func (a *TokenProviderAuthenticator) Authenticate(r *http.Request) error { + ctx := r.Context() + if ctx == nil { + ctx = context.Background() + } + + token, err := a.provider.GetToken(ctx) + if err != nil { + return fmt.Errorf("token provider authenticator: failed to get token: %w", err) + } + + if token.AccessToken == "" { + return fmt.Errorf("token provider authenticator: empty access token") + } + + token.SetAuthHeader(r) + log.Debug().Msgf("token provider authenticator: authenticated using provider %s", a.provider.Name()) + + return nil +} diff --git a/auth/tokenprovider/authenticator_test.go b/auth/tokenprovider/authenticator_test.go new file mode 100644 index 00000000..7b703c1a --- /dev/null +++ b/auth/tokenprovider/authenticator_test.go @@ -0,0 +1,107 @@ +package tokenprovider + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTokenProviderAuthenticator(t *testing.T) { + t.Run("successful_authentication", func(t *testing.T) { + provider := NewStaticTokenProvider("test-token-123") + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + require.NoError(t, err) + assert.Equal(t, "Bearer test-token-123", req.Header.Get("Authorization")) + }) + + t.Run("authentication_with_custom_token_type", func(t *testing.T) { + provider := NewStaticTokenProviderWithType("test-token", "MAC") + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + require.NoError(t, err) + assert.Equal(t, "MAC test-token", req.Header.Get("Authorization")) + }) + + t.Run("authentication_error_propagation", func(t *testing.T) { + provider := &mockProvider{ + tokenFunc: func() (*Token, error) { + return nil, errors.New("provider failed") + }, + } + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "provider failed") + assert.Empty(t, req.Header.Get("Authorization")) + }) + + t.Run("empty_token_error", func(t *testing.T) { + provider := &mockProvider{ + tokenFunc: func() (*Token, error) { + return &Token{ + AccessToken: "", + TokenType: "Bearer", + }, nil + }, + } + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty access token") + assert.Empty(t, req.Header.Get("Authorization")) + }) + + t.Run("uses_request_context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + provider := &mockProvider{ + tokenFunc: func() (*Token, error) { + // This would normally check context cancellation + return &Token{ + AccessToken: "test-token", + TokenType: "Bearer", + }, nil + }, + } + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + // Even with cancelled context, this should work as our mock doesn't check it + require.NoError(t, err) + assert.Equal(t, "Bearer test-token", req.Header.Get("Authorization")) + }) + + t.Run("external_token_integration", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "external-token-456", nil + } + provider := NewExternalTokenProvider(tokenFunc) + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("POST", "http://example.com/api", nil) + err := authenticator.Authenticate(req) + + require.NoError(t, err) + assert.Equal(t, "Bearer external-token-456", req.Header.Get("Authorization")) + }) +} diff --git a/auth/tokenprovider/external.go b/auth/tokenprovider/external.go new file mode 100644 index 00000000..c2b6a9c5 --- /dev/null +++ b/auth/tokenprovider/external.go @@ -0,0 +1,58 @@ +package tokenprovider + +import ( + "context" + "fmt" + "time" +) + +// ExternalTokenProvider provides tokens from an external source (passthrough). +// This provider calls a user-supplied function to retrieve tokens on-demand. +type ExternalTokenProvider struct { + tokenSource func() (string, error) + tokenType string +} + +// NewExternalTokenProvider creates a provider that gets tokens from an external function +func NewExternalTokenProvider(tokenSource func() (string, error)) *ExternalTokenProvider { + return &ExternalTokenProvider{ + tokenSource: tokenSource, + tokenType: "Bearer", + } +} + +// NewExternalTokenProviderWithType creates a provider with a custom token type +func NewExternalTokenProviderWithType(tokenSource func() (string, error), tokenType string) *ExternalTokenProvider { + return &ExternalTokenProvider{ + tokenSource: tokenSource, + tokenType: tokenType, + } +} + +// GetToken retrieves the token from the external source +func (p *ExternalTokenProvider) GetToken(ctx context.Context) (*Token, error) { + // Check for cancellation first + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("external token provider: context cancelled: %w", err) + } + + if p.tokenSource == nil { + return nil, fmt.Errorf("external token provider: token source is nil") + } + + accessToken, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("external token provider: failed to get token: %w", err) + } + + return &Token{ + AccessToken: accessToken, + TokenType: p.tokenType, + ExpiresAt: time.Time{}, // External tokens don't provide expiry info + }, nil +} + +// Name returns the provider name +func (p *ExternalTokenProvider) Name() string { + return "external" +} diff --git a/auth/tokenprovider/provider.go b/auth/tokenprovider/provider.go new file mode 100644 index 00000000..6faf0a9d --- /dev/null +++ b/auth/tokenprovider/provider.go @@ -0,0 +1,44 @@ +package tokenprovider + +import ( + "context" + "net/http" + "time" +) + +// TokenProvider is the interface for providing tokens from various sources +type TokenProvider interface { + // GetToken retrieves a valid access token + GetToken(ctx context.Context) (*Token, error) + + // Name returns the provider name for logging/debugging + Name() string +} + +// Token represents an access token with metadata +type Token struct { + AccessToken string + TokenType string + ExpiresAt time.Time + RefreshToken string + Scopes []string +} + +// IsExpired checks if the token has expired +func (t *Token) IsExpired() bool { + if t.ExpiresAt.IsZero() { + return false // No expiry means token doesn't expire + } + // Consider token expired 30 seconds before actual expiry for safety + // This matches the standard buffer used by other Databricks SDKs + return time.Now().Add(30 * time.Second).After(t.ExpiresAt) +} + +// SetAuthHeader sets the Authorization header on an HTTP request +func (t *Token) SetAuthHeader(r *http.Request) { + tokenType := t.TokenType + if tokenType == "" { + tokenType = "Bearer" + } + r.Header.Set("Authorization", tokenType+" "+t.AccessToken) +} diff --git a/auth/tokenprovider/provider_test.go b/auth/tokenprovider/provider_test.go new file mode 100644 index 00000000..e3df4753 --- /dev/null +++ b/auth/tokenprovider/provider_test.go @@ -0,0 +1,245 @@ +package tokenprovider + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestToken_IsExpired(t *testing.T) { + tests := []struct { + name string + token *Token + expected bool + }{ + { + name: "token_without_expiry", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Time{}, + }, + expected: false, + }, + { + name: "token_expired", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Now().Add(-10 * time.Minute), + }, + expected: true, + }, + { + name: "token_not_expired", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Now().Add(10 * time.Minute), + }, + expected: false, + }, + { + name: "token_expires_within_30_seconds", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Now().Add(15 * time.Second), + }, + expected: true, // Should be considered expired due to 30-second buffer + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.token.IsExpired()) + }) + } +} + +func TestToken_SetAuthHeader(t *testing.T) { + tests := []struct { + name string + token *Token + expectedHeader string + }{ + { + name: "bearer_token", + token: &Token{ + AccessToken: "test-access-token", + TokenType: "Bearer", + }, + expectedHeader: "Bearer test-access-token", + }, + { + name: "default_to_bearer", + token: &Token{ + AccessToken: "test-access-token", + TokenType: "", + }, + expectedHeader: "Bearer test-access-token", + }, + { + name: "custom_token_type", + token: &Token{ + AccessToken: "test-access-token", + TokenType: "CustomAuth", + }, + expectedHeader: "CustomAuth test-access-token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com", nil) + tt.token.SetAuthHeader(req) + assert.Equal(t, tt.expectedHeader, req.Header.Get("Authorization")) + }) + } +} + +func TestStaticTokenProvider(t *testing.T) { + t.Run("valid_token", func(t *testing.T) { + provider := NewStaticTokenProvider("static-token-123") + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "static-token-123", token.AccessToken) + assert.Equal(t, "Bearer", token.TokenType) + assert.True(t, token.ExpiresAt.IsZero()) + assert.Equal(t, "static", provider.Name()) + }) + + t.Run("empty_token_error", func(t *testing.T) { + provider := NewStaticTokenProvider("") + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "token is empty") + }) + + t.Run("custom_token_type", func(t *testing.T) { + provider := NewStaticTokenProviderWithType("static-token", "CustomAuth") + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "static-token", token.AccessToken) + assert.Equal(t, "CustomAuth", token.TokenType) + }) + + t.Run("multiple_calls_same_token", func(t *testing.T) { + provider := NewStaticTokenProvider("static-token") + + token1, err1 := provider.GetToken(context.Background()) + token2, err2 := provider.GetToken(context.Background()) + + require.NoError(t, err1) + require.NoError(t, err2) + assert.Equal(t, token1.AccessToken, token2.AccessToken) + }) +} + +func TestExternalTokenProvider(t *testing.T) { + t.Run("successful_token_retrieval", func(t *testing.T) { + callCount := 0 + tokenFunc := func() (string, error) { + callCount++ + return "external-token-" + string(rune(callCount)), nil + } + + provider := NewExternalTokenProvider(tokenFunc) + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "external-token-\x01", token.AccessToken) + assert.Equal(t, "Bearer", token.TokenType) + assert.Equal(t, "external", provider.Name()) + }) + + t.Run("token_function_error", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "", errors.New("failed to retrieve token") + } + + provider := NewExternalTokenProvider(tokenFunc) + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "failed to get token") + }) + + t.Run("empty_token_allowed", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "", nil + } + + provider := NewExternalTokenProvider(tokenFunc) + token, err := provider.GetToken(context.Background()) + + assert.NoError(t, err) + assert.NotNil(t, token) + assert.Empty(t, token.AccessToken) + // Empty tokens are validated at the authenticator level, not provider level + }) + + t.Run("nil_function_error", func(t *testing.T) { + provider := NewExternalTokenProvider(nil) + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "token source is nil") + }) + + t.Run("custom_token_type", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "external-token", nil + } + + provider := NewExternalTokenProviderWithType(tokenFunc, "MAC") + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "external-token", token.AccessToken) + assert.Equal(t, "MAC", token.TokenType) + }) + + t.Run("different_token_each_call", func(t *testing.T) { + counter := 0 + tokenFunc := func() (string, error) { + counter++ + return "token-" + string(rune(counter)), nil + } + + provider := NewExternalTokenProvider(tokenFunc) + + token1, err1 := provider.GetToken(context.Background()) + token2, err2 := provider.GetToken(context.Background()) + + require.NoError(t, err1) + require.NoError(t, err2) + assert.NotEqual(t, token1.AccessToken, token2.AccessToken) + assert.Equal(t, "token-\x01", token1.AccessToken) + assert.Equal(t, "token-\x02", token2.AccessToken) + }) +} + +// Mock provider for testing +type mockProvider struct { + tokenFunc func() (*Token, error) + name string +} + +func (m *mockProvider) GetToken(ctx context.Context) (*Token, error) { + if m.tokenFunc != nil { + return m.tokenFunc() + } + return nil, errors.New("not implemented") +} + +func (m *mockProvider) Name() string { + return m.name +} diff --git a/auth/tokenprovider/static.go b/auth/tokenprovider/static.go new file mode 100644 index 00000000..46079ba0 --- /dev/null +++ b/auth/tokenprovider/static.go @@ -0,0 +1,47 @@ +package tokenprovider + +import ( + "context" + "fmt" + "time" +) + +// StaticTokenProvider provides a static token that never changes +type StaticTokenProvider struct { + token string + tokenType string +} + +// NewStaticTokenProvider creates a provider with a static token +func NewStaticTokenProvider(token string) *StaticTokenProvider { + return &StaticTokenProvider{ + token: token, + tokenType: "Bearer", + } +} + +// NewStaticTokenProviderWithType creates a provider with a static token and custom type +func NewStaticTokenProviderWithType(token string, tokenType string) *StaticTokenProvider { + return &StaticTokenProvider{ + token: token, + tokenType: tokenType, + } +} + +// GetToken returns the static token +func (p *StaticTokenProvider) GetToken(ctx context.Context) (*Token, error) { + if p.token == "" { + return nil, fmt.Errorf("static token provider: token is empty") + } + + return &Token{ + AccessToken: p.token, + TokenType: p.tokenType, + ExpiresAt: time.Time{}, // Static tokens don't expire + }, nil +} + +// Name returns the provider name +func (p *StaticTokenProvider) Name() string { + return "static" +} diff --git a/connector.go b/connector.go index 53908b4c..fce77970 100644 --- a/connector.go +++ b/connector.go @@ -12,6 +12,7 @@ import ( "github.com/databricks/databricks-sql-go/auth" "github.com/databricks/databricks-sql-go/auth/oauth/m2m" "github.com/databricks/databricks-sql-go/auth/pat" + "github.com/databricks/databricks-sql-go/auth/tokenprovider" "github.com/databricks/databricks-sql-go/driverctx" dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" @@ -293,3 +294,32 @@ func WithClientCredentials(clientID, clientSecret string) ConnOption { } } } + +// WithTokenProvider sets up authentication using a custom token provider +func WithTokenProvider(provider tokenprovider.TokenProvider) ConnOption { + return func(c *config.Config) { + if provider != nil { + c.Authenticator = tokenprovider.NewAuthenticator(provider) + } + } +} + +// WithExternalToken sets up authentication using an external token function (passthrough) +func WithExternalToken(tokenFunc func() (string, error)) ConnOption { + return func(c *config.Config) { + if tokenFunc != nil { + provider := tokenprovider.NewExternalTokenProvider(tokenFunc) + c.Authenticator = tokenprovider.NewAuthenticator(provider) + } + } +} + +// WithStaticToken sets up authentication using a static token +func WithStaticToken(token string) ConnOption { + return func(c *config.Config) { + if token != "" { + provider := tokenprovider.NewStaticTokenProvider(token) + c.Authenticator = tokenprovider.NewAuthenticator(provider) + } + } +} diff --git a/go.mod b/go.mod index d9a517c5..1d1fdc78 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/apache/arrow/go/v12 v12.0.1 github.com/apache/thrift v0.17.0 github.com/coreos/go-oidc/v3 v3.5.0 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/joho/godotenv v1.4.0 github.com/mattn/go-isatty v0.0.20 github.com/pierrec/lz4/v4 v4.1.15 diff --git a/go.sum b/go.sum index edeb89ee..670487a8 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,8 @@ github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQr github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk= github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=