From 1deb2ae1144129ebad28b5510c503159cffaeb33 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 10 Sep 2025 16:33:00 +0100 Subject: [PATCH 1/2] implement flexible issuer validation for OIDC discovery --- cmd/thv/app/proxy.go | 34 +++--- pkg/auth/discovery/discovery.go | 21 ++-- pkg/auth/oauth/dynamic_registration_test.go | 2 +- pkg/auth/oauth/oidc.go | 17 +-- pkg/auth/oauth/oidc_test.go | 116 ++++++++++++++------ pkg/runner/remote_auth.go | 20 ++-- 6 files changed, 133 insertions(+), 77 deletions(-) diff --git a/cmd/thv/app/proxy.go b/cmd/thv/app/proxy.go index 291f1fe83..50d252512 100644 --- a/cmd/thv/app/proxy.go +++ b/cmd/thv/app/proxy.go @@ -295,14 +295,15 @@ func handleOutgoingAuthentication(ctx context.Context) (*oauth2.TokenSource, *oa } flowConfig := &discovery.OAuthFlowConfig{ - ClientID: remoteAuthFlags.RemoteAuthClientID, - ClientSecret: clientSecret, - AuthorizeURL: remoteAuthFlags.RemoteAuthAuthorizeURL, - TokenURL: remoteAuthFlags.RemoteAuthTokenURL, - Scopes: remoteAuthFlags.RemoteAuthScopes, - CallbackPort: remoteAuthFlags.RemoteAuthCallbackPort, - Timeout: remoteAuthFlags.RemoteAuthTimeout, - SkipBrowser: remoteAuthFlags.RemoteAuthSkipBrowser, + ClientID: remoteAuthFlags.RemoteAuthClientID, + ClientSecret: clientSecret, + AuthorizeURL: remoteAuthFlags.RemoteAuthAuthorizeURL, + TokenURL: remoteAuthFlags.RemoteAuthTokenURL, + Scopes: remoteAuthFlags.RemoteAuthScopes, + CallbackPort: remoteAuthFlags.RemoteAuthCallbackPort, + Timeout: remoteAuthFlags.RemoteAuthTimeout, + SkipBrowser: remoteAuthFlags.RemoteAuthSkipBrowser, + IssuerProvided: remoteAuthFlags.RemoteAuthIssuer != "", // Issuer was explicitly provided } result, err := discovery.PerformOAuthFlow(ctx, remoteAuthFlags.RemoteAuthIssuer, flowConfig) @@ -325,14 +326,15 @@ func handleOutgoingAuthentication(ctx context.Context) (*oauth2.TokenSource, *oa // Perform OAuth flow with discovered configuration flowConfig := &discovery.OAuthFlowConfig{ - ClientID: remoteAuthFlags.RemoteAuthClientID, - ClientSecret: clientSecret, - AuthorizeURL: remoteAuthFlags.RemoteAuthAuthorizeURL, - TokenURL: remoteAuthFlags.RemoteAuthTokenURL, - Scopes: remoteAuthFlags.RemoteAuthScopes, - CallbackPort: remoteAuthFlags.RemoteAuthCallbackPort, - Timeout: remoteAuthFlags.RemoteAuthTimeout, - SkipBrowser: remoteAuthFlags.RemoteAuthSkipBrowser, + ClientID: remoteAuthFlags.RemoteAuthClientID, + ClientSecret: clientSecret, + AuthorizeURL: remoteAuthFlags.RemoteAuthAuthorizeURL, + TokenURL: remoteAuthFlags.RemoteAuthTokenURL, + Scopes: remoteAuthFlags.RemoteAuthScopes, + CallbackPort: remoteAuthFlags.RemoteAuthCallbackPort, + Timeout: remoteAuthFlags.RemoteAuthTimeout, + SkipBrowser: remoteAuthFlags.RemoteAuthSkipBrowser, + IssuerProvided: false, // Issuer was derived from WWW-Authenticate header } result, err := discovery.PerformOAuthFlow(ctx, authInfo.Realm, flowConfig) diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index 48020ce1a..486ca9e37 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -240,15 +240,16 @@ func DeriveIssuerFromURL(remoteURL string) string { // OAuthFlowConfig contains configuration for performing OAuth flows type OAuthFlowConfig struct { - ClientID string - ClientSecret string - AuthorizeURL string // Manual OAuth endpoint (optional) - TokenURL string // Manual OAuth endpoint (optional) - Scopes []string - CallbackPort int - Timeout time.Duration - SkipBrowser bool - OAuthParams map[string]string + ClientID string + ClientSecret string + AuthorizeURL string // Manual OAuth endpoint (optional) + TokenURL string // Manual OAuth endpoint (optional) + Scopes []string + CallbackPort int + Timeout time.Duration + SkipBrowser bool + OAuthParams map[string]string + IssuerProvided bool // Whether the issuer was explicitly provided (not derived from URL) } // OAuthFlowResult contains the result of an OAuth flow @@ -272,7 +273,7 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi var oauthConfig *oauth.Config var err error if shouldDynamicallyRegisterClient(config) { - discoveredDoc, err := oauth.DiscoverOIDCEndpoints(ctx, issuer) + discoveredDoc, err := oauth.DiscoverOIDCEndpoints(ctx, issuer, config.IssuerProvided) if err != nil { return nil, fmt.Errorf("failed to discover registration endpoint: %w", err) } diff --git a/pkg/auth/oauth/dynamic_registration_test.go b/pkg/auth/oauth/dynamic_registration_test.go index 91b69204f..3b2507bdc 100644 --- a/pkg/auth/oauth/dynamic_registration_test.go +++ b/pkg/auth/oauth/dynamic_registration_test.go @@ -125,7 +125,7 @@ func TestDiscoverOIDCEndpointsWithRegistration(t *testing.T) { issuer = server.URL } - result, err := DiscoverOIDCEndpoints(context.Background(), issuer) + result, err := DiscoverOIDCEndpoints(context.Background(), issuer, false) if tt.expectedError { assert.Error(t, err) diff --git a/pkg/auth/oauth/oidc.go b/pkg/auth/oauth/oidc.go index 028f2274c..902bc92ef 100644 --- a/pkg/auth/oauth/oidc.go +++ b/pkg/auth/oauth/oidc.go @@ -37,12 +37,13 @@ type httpClient interface { } // DiscoverOIDCEndpoints discovers OAuth endpoints from an OIDC issuer -func DiscoverOIDCEndpoints(ctx context.Context, issuer string) (*OIDCDiscoveryDocument, error) { - return discoverOIDCEndpointsWithClient(ctx, issuer, nil) +// Uses flexible issuer validation to support cases where issuer is derived from URL +func DiscoverOIDCEndpoints(ctx context.Context, issuer string, validateIssuerMatch bool) (*OIDCDiscoveryDocument, error) { + return discoverOIDCEndpointsWithClient(ctx, issuer, nil, validateIssuerMatch) } // discoverOIDCEndpointsWithClient discovers OAuth endpoints from an OIDC issuer with a custom HTTP client (private for testing) -func discoverOIDCEndpointsWithClient(ctx context.Context, issuer string, client httpClient) (*OIDCDiscoveryDocument, error) { +func discoverOIDCEndpointsWithClient(ctx context.Context, issuer string, client httpClient, validateIssuerMatch bool) (*OIDCDiscoveryDocument, error) { // Validate issuer URL issuerURL, err := url.Parse(issuer) if err != nil { @@ -98,7 +99,7 @@ func discoverOIDCEndpointsWithClient(ctx context.Context, issuer string, client if err := json.NewDecoder(io.LimitReader(resp.Body, maxResponseSize)).Decode(&doc); err != nil { return nil, fmt.Errorf("%s: unexpected response: %w", urlStr, err) } - if err := validateOIDCDocument(&doc, issuer, oidc); err != nil { + if err := validateOIDCDocument(&doc, issuer, validateIssuerMatch, oidc); err != nil { return nil, fmt.Errorf("%s: invalid metadata: %w", urlStr, err) } return &doc, nil @@ -120,12 +121,14 @@ func discoverOIDCEndpointsWithClient(ctx context.Context, issuer string, client } // validateOIDCDocument validates the OIDC discovery document -func validateOIDCDocument(doc *OIDCDiscoveryDocument, expectedIssuer string, oidc bool) error { +func validateOIDCDocument(doc *OIDCDiscoveryDocument, expectedIssuer string, validateIssuerMatch bool, oidc bool) error { if doc.Issuer == "" { return fmt.Errorf("missing issuer") } - if doc.Issuer != expectedIssuer { + // Only validate issuer match if explicitly requested + // This allows for cases where issuer is derived from URL and might not match exactly + if validateIssuerMatch && doc.Issuer != expectedIssuer { return fmt.Errorf("issuer mismatch: expected %s, got %s", expectedIssuer, doc.Issuer) } @@ -184,7 +187,7 @@ func createOAuthConfigFromOIDCWithClient( client httpClient, ) (*Config, error) { // Discover OIDC endpoints - doc, err := discoverOIDCEndpointsWithClient(ctx, issuer, client) + doc, err := discoverOIDCEndpointsWithClient(ctx, issuer, client, true) if err != nil { return nil, fmt.Errorf("failed to discover OIDC endpoints: %w", err) } diff --git a/pkg/auth/oauth/oidc_test.go b/pkg/auth/oauth/oidc_test.go index 898517007..a3671b50a 100644 --- a/pkg/auth/oauth/oidc_test.go +++ b/pkg/auth/oauth/oidc_test.go @@ -96,7 +96,7 @@ func testDiscoverOIDCEndpoints( } // Validate that we got the required fields - if err := validateOIDCDocument(&doc, issuer, true); err != nil { + if err := validateOIDCDocument(&doc, issuer, true, true); err != nil { return nil, fmt.Errorf("invalid OIDC configuration: %w", err) } @@ -319,11 +319,12 @@ func TestDiscoverOIDCEndpoints(t *testing.T) { func TestValidateOIDCDocument(t *testing.T) { t.Parallel() tests := []struct { - name string - doc *OIDCDiscoveryDocument - expectedIssuer string - expectError bool - errorMsg string + name string + doc *OIDCDiscoveryDocument + expectedIssuer string + validateIssuerMatch bool + expectError bool + errorMsg string }{ { name: "missing issuer", @@ -332,21 +333,23 @@ func TestValidateOIDCDocument(t *testing.T) { TokenEndpoint: "https://example.com/token", JWKSURI: "https://example.com/jwks", }, - expectedIssuer: "https://example.com", - expectError: true, - errorMsg: "missing issuer", + expectedIssuer: "https://example.com", + validateIssuerMatch: true, + expectError: true, + errorMsg: "missing issuer", }, { - name: "issuer mismatch", + name: "issuer mismatch (strict validation)", doc: &OIDCDiscoveryDocument{ Issuer: "https://malicious.com", AuthorizationEndpoint: "https://example.com/auth", TokenEndpoint: "https://example.com/token", JWKSURI: "https://example.com/jwks", }, - expectedIssuer: "https://example.com", - expectError: true, - errorMsg: "issuer mismatch", + expectedIssuer: "https://example.com", + validateIssuerMatch: true, + expectError: true, + errorMsg: "issuer mismatch", }, { name: "missing authorization endpoint", @@ -355,9 +358,10 @@ func TestValidateOIDCDocument(t *testing.T) { TokenEndpoint: "https://example.com/token", JWKSURI: "https://example.com/jwks", }, - expectedIssuer: "https://example.com", - expectError: true, - errorMsg: "missing authorization_endpoint", + expectedIssuer: "https://example.com", + validateIssuerMatch: true, + expectError: true, + errorMsg: "missing authorization_endpoint", }, { name: "missing token endpoint", @@ -366,9 +370,10 @@ func TestValidateOIDCDocument(t *testing.T) { AuthorizationEndpoint: "https://example.com/auth", JWKSURI: "https://example.com/jwks", }, - expectedIssuer: "https://example.com", - expectError: true, - errorMsg: "missing token_endpoint", + expectedIssuer: "https://example.com", + validateIssuerMatch: true, + expectError: true, + errorMsg: "missing token_endpoint", }, { name: "missing JWKS URI", @@ -377,9 +382,10 @@ func TestValidateOIDCDocument(t *testing.T) { AuthorizationEndpoint: "https://example.com/auth", TokenEndpoint: "https://example.com/token", }, - expectedIssuer: "https://example.com", - expectError: true, - errorMsg: "missing jwks_uri", + expectedIssuer: "https://example.com", + validateIssuerMatch: true, + expectError: true, + errorMsg: "missing jwks_uri", }, { name: "invalid authorization endpoint URL", @@ -389,9 +395,10 @@ func TestValidateOIDCDocument(t *testing.T) { TokenEndpoint: "https://example.com/token", JWKSURI: "https://example.com/jwks", }, - expectedIssuer: "https://example.com", - expectError: true, - errorMsg: "invalid authorization_endpoint", + expectedIssuer: "https://example.com", + validateIssuerMatch: true, + expectError: true, + errorMsg: "invalid authorization_endpoint", }, { name: "non-HTTPS endpoint (security check)", @@ -401,9 +408,10 @@ func TestValidateOIDCDocument(t *testing.T) { TokenEndpoint: "https://example.com/token", JWKSURI: "https://example.com/jwks", }, - expectedIssuer: "https://example.com", - expectError: true, - errorMsg: "invalid authorization_endpoint", + expectedIssuer: "https://example.com", + validateIssuerMatch: true, + expectError: true, + errorMsg: "invalid authorization_endpoint", }, { name: "valid document", @@ -414,8 +422,9 @@ func TestValidateOIDCDocument(t *testing.T) { JWKSURI: "https://example.com/jwks", UserinfoEndpoint: "https://example.com/userinfo", }, - expectedIssuer: "https://example.com", - expectError: false, + expectedIssuer: "https://example.com", + validateIssuerMatch: true, + expectError: false, }, { name: "localhost endpoints allowed", @@ -425,15 +434,54 @@ func TestValidateOIDCDocument(t *testing.T) { TokenEndpoint: "http://localhost:8080/token", JWKSURI: "http://localhost:8080/jwks", }, - expectedIssuer: "http://localhost:8080", - expectError: false, + expectedIssuer: "http://localhost:8080", + validateIssuerMatch: true, + expectError: false, + }, + // Flexible validation test cases + { + name: "flexible validation allows issuer mismatch", + doc: &OIDCDiscoveryDocument{ + Issuer: "https://auth.example.com", // Different from expected + AuthorizationEndpoint: "https://auth.example.com/auth", + TokenEndpoint: "https://auth.example.com/token", + JWKSURI: "https://auth.example.com/jwks", + }, + expectedIssuer: "https://example.com", // Expected issuer + validateIssuerMatch: false, // Flexible validation + expectError: false, // Should NOT error with flexible validation + }, + { + name: "flexible validation allows derived issuer mismatch (Neon scenario)", + doc: &OIDCDiscoveryDocument{ + Issuer: "https://auth.neon.com", // Different from derived issuer + AuthorizationEndpoint: "https://auth.neon.com/oauth/authorize", + TokenEndpoint: "https://auth.neon.com/oauth/token", + JWKSURI: "https://auth.neon.com/.well-known/jwks.json", + }, + expectedIssuer: "https://api.neon.com", // Derived from URL + validateIssuerMatch: false, // Flexible validation + expectError: false, // Should NOT error with flexible validation + }, + { + name: "flexible validation still requires issuer field", + doc: &OIDCDiscoveryDocument{ + // Missing issuer field + AuthorizationEndpoint: "https://example.com/auth", + TokenEndpoint: "https://example.com/token", + JWKSURI: "https://example.com/jwks", + }, + expectedIssuer: "https://example.com", + validateIssuerMatch: false, // Flexible validation + expectError: true, + errorMsg: "missing issuer", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - err := validateOIDCDocument(tt.doc, tt.expectedIssuer, true) + err := validateOIDCDocument(tt.doc, tt.expectedIssuer, tt.validateIssuerMatch, true) if tt.expectError { require.Error(t, err) @@ -1054,7 +1102,7 @@ func TestDiscoverOIDCEndpoints_Production(t *testing.T) { }, } } - doc, err := discoverOIDCEndpointsWithClient(ctx, issuer, client) + doc, err := discoverOIDCEndpointsWithClient(ctx, issuer, client, true) if tt.expectError { require.Error(t, err) diff --git a/pkg/runner/remote_auth.go b/pkg/runner/remote_auth.go index 847e661f6..c8e7995bb 100644 --- a/pkg/runner/remote_auth.go +++ b/pkg/runner/remote_auth.go @@ -40,6 +40,7 @@ func (h *RemoteAuthHandler) Authenticate(ctx context.Context, remoteURL string) // Handle OAuth authentication if authInfo.Type == "OAuth" { issuer := h.config.Issuer + issuerProvided := issuer != "" if issuer == "" { issuer = discovery.DeriveIssuerFromURL(remoteURL) } @@ -51,15 +52,16 @@ func (h *RemoteAuthHandler) Authenticate(ctx context.Context, remoteURL string) // Create OAuth flow config from RemoteAuthConfig flowConfig := &discovery.OAuthFlowConfig{ - ClientID: h.config.ClientID, - ClientSecret: h.config.ClientSecret, - AuthorizeURL: h.config.AuthorizeURL, - TokenURL: h.config.TokenURL, - Scopes: h.config.Scopes, - CallbackPort: h.config.CallbackPort, - Timeout: h.config.Timeout, - SkipBrowser: h.config.SkipBrowser, - OAuthParams: h.config.OAuthParams, + ClientID: h.config.ClientID, + ClientSecret: h.config.ClientSecret, + AuthorizeURL: h.config.AuthorizeURL, + TokenURL: h.config.TokenURL, + Scopes: h.config.Scopes, + CallbackPort: h.config.CallbackPort, + Timeout: h.config.Timeout, + SkipBrowser: h.config.SkipBrowser, + OAuthParams: h.config.OAuthParams, + IssuerProvided: issuerProvided, } result, err := discovery.PerformOAuthFlow(ctx, issuer, flowConfig) From 46b2813762c1a8391f007b6dac77b7d3e616d64f Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 10 Sep 2025 17:03:04 +0100 Subject: [PATCH 2/2] fix linting --- pkg/auth/oauth/oidc.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pkg/auth/oauth/oidc.go b/pkg/auth/oauth/oidc.go index 902bc92ef..3d0b8d640 100644 --- a/pkg/auth/oauth/oidc.go +++ b/pkg/auth/oauth/oidc.go @@ -43,7 +43,12 @@ func DiscoverOIDCEndpoints(ctx context.Context, issuer string, validateIssuerMat } // discoverOIDCEndpointsWithClient discovers OAuth endpoints from an OIDC issuer with a custom HTTP client (private for testing) -func discoverOIDCEndpointsWithClient(ctx context.Context, issuer string, client httpClient, validateIssuerMatch bool) (*OIDCDiscoveryDocument, error) { +func discoverOIDCEndpointsWithClient( + ctx context.Context, + issuer string, + client httpClient, + validateIssuerMatch bool, +) (*OIDCDiscoveryDocument, error) { // Validate issuer URL issuerURL, err := url.Parse(issuer) if err != nil {