diff --git a/client/transport/oauth.go b/client/transport/oauth.go index aebbd316e..451a1feae 100644 --- a/client/transport/oauth.go +++ b/client/transport/oauth.go @@ -118,16 +118,34 @@ type OAuthHandler struct { expectedState string // Expected state value for CSRF protection } -// NewOAuthHandler creates a new OAuth handler -func NewOAuthHandler(config OAuthConfig) *OAuthHandler { +type OAuthHandlerOption func(*OAuthHandler) + +// WithOAuthHTTPClient allows setting a custom http.Client for the OAuthHandler. +func WithOAuthHTTPClient(client *http.Client) OAuthHandlerOption { + return func(h *OAuthHandler) { + if client != nil { + h.httpClient = client + } + } +} + +// NewOAuthHandler creates a new OAuth handler. +// Optionally accepts functional options such as WithOAuthHTTPClient. +func NewOAuthHandler(config OAuthConfig, opts ...OAuthHandlerOption) *OAuthHandler { if config.TokenStore == nil { config.TokenStore = NewMemoryTokenStore() } - return &OAuthHandler{ + handler := &OAuthHandler{ config: config, httpClient: &http.Client{Timeout: 30 * time.Second}, } + + for _, opt := range opts { + opt(handler) + } + + return handler } // GetAuthorizationHeader returns the Authorization header value for a request diff --git a/client/transport/oauth_test.go b/client/transport/oauth_test.go index 24dec6eff..a7172af33 100644 --- a/client/transport/oauth_test.go +++ b/client/transport/oauth_test.go @@ -3,6 +3,7 @@ package transport import ( "context" "errors" + "net/http" "strings" "testing" "time" @@ -300,3 +301,26 @@ func TestOAuthHandler_ProcessAuthorizationResponse_StateValidation(t *testing.T) t.Errorf("Got ErrInvalidState when expected a different error for empty expected state") } } + +func TestNewOAuthHandler_WithOAuthHTTPClient(t *testing.T) { + // Custom client with unique timeout + customClient := &http.Client{Timeout: 123 * time.Second} + + // Handler with custom client + handlerWithCustom := NewOAuthHandler(OAuthConfig{}, WithOAuthHTTPClient(customClient)) + if handlerWithCustom.httpClient != customClient { + t.Errorf("Expected custom http.Client to be set via WithOAuthHTTPClient") + } + if handlerWithCustom.httpClient.Timeout != 123*time.Second { + t.Errorf("Expected custom http.Client timeout to be 123s, got %v", handlerWithCustom.httpClient.Timeout) + } + + // Handler with default client + handlerDefault := NewOAuthHandler(OAuthConfig{}) + if handlerDefault.httpClient == nil { + t.Errorf("Expected default http.Client to be set") + } + if handlerDefault.httpClient.Timeout != 30*time.Second { + t.Errorf("Expected default http.Client timeout to be 30s, got %v", handlerDefault.httpClient.Timeout) + } +}