Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions auth/tokenprovider/authenticator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package tokenprovider

import (
"context"
"fmt"
"net/http"

"github.com/databricks/databricks-sql-go/auth"
"github.com/rs/zerolog/log"
)

// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider.
//
// Authentication Flow:
// 1. On each Authenticate() call, retrieves a token from the configured TokenProvider
// 2. The provider may implement its own caching and refresh logic
// 3. Validates the returned token is non-empty
// 4. Sets the Authorization header with the token type and value
//
// The authenticator delegates all token management (caching, refresh, expiry)
// to the underlying TokenProvider implementation.
type TokenProviderAuthenticator struct {
provider TokenProvider
}

// NewAuthenticator creates an authenticator from a token provider
func NewAuthenticator(provider TokenProvider) auth.Authenticator {
return &TokenProviderAuthenticator{
provider: provider,
}
}

// Authenticate implements auth.Authenticator
func (a *TokenProviderAuthenticator) Authenticate(r *http.Request) error {
ctx := r.Context()
if ctx == nil {
ctx = context.Background()
}

token, err := a.provider.GetToken(ctx)
if err != nil {
return fmt.Errorf("token provider authenticator: failed to get token: %w", err)
}

if token.AccessToken == "" {
return fmt.Errorf("token provider authenticator: empty access token")
}

token.SetAuthHeader(r)
log.Debug().Msgf("token provider authenticator: authenticated using provider %s", a.provider.Name())

return nil
}
107 changes: 107 additions & 0 deletions auth/tokenprovider/authenticator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package tokenprovider

import (
"context"
"errors"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestTokenProviderAuthenticator(t *testing.T) {
t.Run("successful_authentication", func(t *testing.T) {
provider := NewStaticTokenProvider("test-token-123")
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

require.NoError(t, err)
assert.Equal(t, "Bearer test-token-123", req.Header.Get("Authorization"))
})

t.Run("authentication_with_custom_token_type", func(t *testing.T) {
provider := NewStaticTokenProviderWithType("test-token", "MAC")
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

require.NoError(t, err)
assert.Equal(t, "MAC test-token", req.Header.Get("Authorization"))
})

t.Run("authentication_error_propagation", func(t *testing.T) {
provider := &mockProvider{
tokenFunc: func() (*Token, error) {
return nil, errors.New("provider failed")
},
}
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

assert.Error(t, err)
assert.Contains(t, err.Error(), "provider failed")
assert.Empty(t, req.Header.Get("Authorization"))
})

t.Run("empty_token_error", func(t *testing.T) {
provider := &mockProvider{
tokenFunc: func() (*Token, error) {
return &Token{
AccessToken: "",
TokenType: "Bearer",
}, nil
},
}
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

assert.Error(t, err)
assert.Contains(t, err.Error(), "empty access token")
assert.Empty(t, req.Header.Get("Authorization"))
})

t.Run("uses_request_context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately

provider := &mockProvider{
tokenFunc: func() (*Token, error) {
// This would normally check context cancellation
return &Token{
AccessToken: "test-token",
TokenType: "Bearer",
}, nil
},
}
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

// Even with cancelled context, this should work as our mock doesn't check it
require.NoError(t, err)
assert.Equal(t, "Bearer test-token", req.Header.Get("Authorization"))
})

t.Run("external_token_integration", func(t *testing.T) {
tokenFunc := func() (string, error) {
return "external-token-456", nil
}
provider := NewExternalTokenProvider(tokenFunc)
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("POST", "http://example.com/api", nil)
err := authenticator.Authenticate(req)

require.NoError(t, err)
assert.Equal(t, "Bearer external-token-456", req.Header.Get("Authorization"))
})
}
58 changes: 58 additions & 0 deletions auth/tokenprovider/external.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package tokenprovider

import (
"context"
"fmt"
"time"
)

// ExternalTokenProvider provides tokens from an external source (passthrough).
// This provider calls a user-supplied function to retrieve tokens on-demand.
type ExternalTokenProvider struct {
tokenSource func() (string, error)
tokenType string
}

// NewExternalTokenProvider creates a provider that gets tokens from an external function
func NewExternalTokenProvider(tokenSource func() (string, error)) *ExternalTokenProvider {
return &ExternalTokenProvider{
tokenSource: tokenSource,
tokenType: "Bearer",
}
}

// NewExternalTokenProviderWithType creates a provider with a custom token type
func NewExternalTokenProviderWithType(tokenSource func() (string, error), tokenType string) *ExternalTokenProvider {
return &ExternalTokenProvider{
tokenSource: tokenSource,
tokenType: tokenType,
}
}

// GetToken retrieves the token from the external source
func (p *ExternalTokenProvider) GetToken(ctx context.Context) (*Token, error) {
// Check for cancellation first
if err := ctx.Err(); err != nil {
return nil, fmt.Errorf("external token provider: context cancelled: %w", err)
}

if p.tokenSource == nil {
return nil, fmt.Errorf("external token provider: token source is nil")
}

accessToken, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("external token provider: failed to get token: %w", err)
}

return &Token{
AccessToken: accessToken,
TokenType: p.tokenType,
ExpiresAt: time.Time{}, // External tokens don't provide expiry info
}, nil
}

// Name returns the provider name
func (p *ExternalTokenProvider) Name() string {
return "external"
}
44 changes: 44 additions & 0 deletions auth/tokenprovider/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package tokenprovider

import (
"context"
"net/http"
"time"
)

// TokenProvider is the interface for providing tokens from various sources
type TokenProvider interface {
// GetToken retrieves a valid access token
GetToken(ctx context.Context) (*Token, error)

// Name returns the provider name for logging/debugging
Name() string
}

// Token represents an access token with metadata
type Token struct {
AccessToken string
TokenType string
ExpiresAt time.Time
RefreshToken string
Scopes []string
}

// IsExpired checks if the token has expired
func (t *Token) IsExpired() bool {
if t.ExpiresAt.IsZero() {
return false // No expiry means token doesn't expire
}
// Consider token expired 30 seconds before actual expiry for safety
// This matches the standard buffer used by other Databricks SDKs
return time.Now().Add(30 * time.Second).After(t.ExpiresAt)
}

// SetAuthHeader sets the Authorization header on an HTTP request
func (t *Token) SetAuthHeader(r *http.Request) {
tokenType := t.TokenType
if tokenType == "" {
tokenType = "Bearer"
}
r.Header.Set("Authorization", tokenType+" "+t.AccessToken)
}
Loading
Loading