From 9a0a6e593ff462ea980070dca7f62eb393286969 Mon Sep 17 00:00:00 2001 From: Tom Fenech Date: Tue, 14 May 2024 16:42:21 +0200 Subject: [PATCH 1/2] Pass OIDC claims into post-login flow to include in web hook context The login flow doesn't trigger a refresh of the identity when the OIDC claims have changed. By passing the claims through to the web hook context, this means that an external handler can be configured to update the identity as appropriate, when there are changes. --- selfservice/flow/login/hook.go | 29 +++++-- selfservice/hook/error.go | 3 +- selfservice/hook/require_verified_address.go | 3 +- .../hook/require_verified_address_test.go | 10 +-- selfservice/hook/session_destroyer.go | 11 ++- selfservice/hook/session_destroyer_test.go | 1 + selfservice/hook/show_verification_ui.go | 3 +- selfservice/hook/show_verification_ui_test.go | 12 +-- selfservice/hook/stub/test_body.jsonnet | 10 ++- selfservice/hook/verification.go | 3 +- selfservice/hook/verification_test.go | 4 +- selfservice/hook/web_hook.go | 5 +- selfservice/hook/web_hook_integration_test.go | 32 +++++--- selfservice/strategy/oidc/claims/claims.go | 50 ++++++++++++ .../strategy/oidc/claims/claims_test.go | 18 +++++ selfservice/strategy/oidc/claims/locale.go | 29 +++++++ selfservice/strategy/oidc/provider.go | 80 ++----------------- selfservice/strategy/oidc/provider_apple.go | 10 ++- .../strategy/oidc/provider_apple_test.go | 11 +-- selfservice/strategy/oidc/provider_auth0.go | 5 +- .../strategy/oidc/provider_dingtalk.go | 7 +- selfservice/strategy/oidc/provider_discord.go | 7 +- .../strategy/oidc/provider_facebook.go | 5 +- .../strategy/oidc/provider_generic_oidc.go | 13 +-- selfservice/strategy/oidc/provider_github.go | 5 +- .../strategy/oidc/provider_github_app.go | 5 +- selfservice/strategy/oidc/provider_gitlab.go | 5 +- selfservice/strategy/oidc/provider_google.go | 5 +- selfservice/strategy/oidc/provider_lark.go | 7 +- .../strategy/oidc/provider_linkedin.go | 5 +- .../strategy/oidc/provider_linkedin_test.go | 5 +- .../strategy/oidc/provider_microsoft.go | 5 +- selfservice/strategy/oidc/provider_netid.go | 15 ++-- selfservice/strategy/oidc/provider_patreon.go | 5 +- .../strategy/oidc/provider_salesforce.go | 5 +- selfservice/strategy/oidc/provider_slack.go | 5 +- selfservice/strategy/oidc/provider_spotify.go | 5 +- selfservice/strategy/oidc/provider_test.go | 17 ++-- .../strategy/oidc/provider_test_fedcm.go | 6 +- .../strategy/oidc/provider_userinfo_test.go | 17 ++-- selfservice/strategy/oidc/provider_vk.go | 5 +- selfservice/strategy/oidc/provider_x.go | 14 ++-- selfservice/strategy/oidc/provider_yandex.go | 5 +- selfservice/strategy/oidc/strategy.go | 7 +- .../strategy/oidc/strategy_helper_test.go | 4 +- selfservice/strategy/oidc/strategy_login.go | 7 +- .../strategy/oidc/strategy_registration.go | 11 ++- .../strategy/oidc/strategy_settings.go | 3 +- selfservice/strategy/oidc/strategy_test.go | 6 +- selfservice/strategy/oidc/token_verifier.go | 6 +- 50 files changed, 321 insertions(+), 225 deletions(-) create mode 100644 selfservice/strategy/oidc/claims/claims.go create mode 100644 selfservice/strategy/oidc/claims/claims_test.go create mode 100644 selfservice/strategy/oidc/claims/locale.go diff --git a/selfservice/flow/login/hook.go b/selfservice/flow/login/hook.go index 8faed1ccc7de..76f697d3210d 100644 --- a/selfservice/flow/login/hook.go +++ b/selfservice/flow/login/hook.go @@ -22,6 +22,7 @@ import ( "github.com/ory/kratos/schema" "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/sessiontokenexchange" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/ui/container" "github.com/ory/kratos/ui/node" @@ -36,7 +37,7 @@ type ( } PostHookExecutor interface { - ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *Flow, s *session.Session) error + ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *Flow, s *session.Session, c *claims.Claims) error } HooksProvider interface { @@ -66,6 +67,7 @@ type ( } HookExecutor struct { d executorDependencies + c *claims.Claims } HookExecutorProvider interface { LoginHookExecutor() *HookExecutor @@ -123,6 +125,14 @@ func (e *HookExecutor) handleLoginError(_ http.ResponseWriter, r *http.Request, return flowError } +type PostLoginHookOpt func(*HookExecutor) + +func WithClaims(c *claims.Claims) PostLoginHookOpt { + return func(h *HookExecutor) { + h.c = c + } +} + func (e *HookExecutor) PostLoginHook( w http.ResponseWriter, r *http.Request, @@ -131,6 +141,7 @@ func (e *HookExecutor) PostLoginHook( i *identity.Identity, s *session.Session, provider string, + opts ...PostLoginHookOpt, ) (err error) { ctx := r.Context() ctx, span := e.d.Tracer(ctx).Tracer().Start(ctx, "HookExecutor.PostLoginHook") @@ -149,15 +160,15 @@ func (e *HookExecutor) PostLoginHook( return err } - c := e.d.Config() + cfg := e.d.Config() // Verify the redirect URL before we do any other processing. returnTo, err := redir.SecureRedirectTo(r, - c.SelfServiceBrowserDefaultReturnTo(ctx), + cfg.SelfServiceBrowserDefaultReturnTo(ctx), redir.SecureRedirectReturnTo(f.ReturnTo), redir.SecureRedirectUseSourceURL(f.RequestURL), - redir.SecureRedirectAllowURLs(c.SelfServiceBrowserAllowedReturnToDomains(ctx)), - redir.SecureRedirectAllowSelfServiceURLs(c.SelfPublicURL(ctx)), - redir.SecureRedirectOverrideDefaultReturnTo(c.SelfServiceFlowLoginReturnTo(ctx, f.Active.String())), + redir.SecureRedirectAllowURLs(cfg.SelfServiceBrowserAllowedReturnToDomains(ctx)), + redir.SecureRedirectAllowSelfServiceURLs(cfg.SelfPublicURL(ctx)), + redir.SecureRedirectOverrideDefaultReturnTo(cfg.SelfServiceFlowLoginReturnTo(ctx, f.Active.String())), ) if err != nil { return err @@ -175,6 +186,10 @@ func (e *HookExecutor) PostLoginHook( classified := s s = s.Declassified() + for _, o := range opts { + o(e) + } + e.d.Logger(). WithRequest(r). WithField("identity_id", i.ID). @@ -185,7 +200,7 @@ func (e *HookExecutor) PostLoginHook( return err } for k, executor := range hooks { - if err := executor.ExecuteLoginPostHook(w, r, g, f, s); err != nil { + if err := executor.ExecuteLoginPostHook(w, r, g, f, s, e.c); err != nil { if errors.Is(err, ErrHookAbortFlow) { e.d.Logger(). WithRequest(r). diff --git a/selfservice/hook/error.go b/selfservice/hook/error.go index bb396578e5f8..a82d9e3511fe 100644 --- a/selfservice/hook/error.go +++ b/selfservice/hook/error.go @@ -12,6 +12,7 @@ import ( "github.com/ory/kratos/selfservice/flow/recovery" "github.com/ory/kratos/selfservice/flow/verification" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/ui/node" "github.com/ory/kratos/identity" @@ -64,7 +65,7 @@ func (e Error) ExecuteSettingsPostPersistHook(w http.ResponseWriter, r *http.Req return e.err("ExecuteSettingsPostPersistHook", settings.ErrHookAbortFlow) } -func (e Error) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *login.Flow, s *session.Session) error { +func (e Error) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *login.Flow, s *session.Session, c *claims.Claims) error { return e.err("ExecuteLoginPostHook", login.ErrHookAbortFlow) } diff --git a/selfservice/hook/require_verified_address.go b/selfservice/hook/require_verified_address.go index 85f7f968d33f..ae7ed374ef41 100644 --- a/selfservice/hook/require_verified_address.go +++ b/selfservice/hook/require_verified_address.go @@ -9,6 +9,7 @@ import ( "github.com/ory/kratos/driver/config" "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/verification" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/text" "github.com/ory/kratos/x" "github.com/ory/kratos/x/nosurfx" @@ -50,7 +51,7 @@ func NewAddressVerifier(r addressVerifierDependencies) *AddressVerifier { } } -func (e *AddressVerifier) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, f *login.Flow, s *session.Session) (err error) { +func (e *AddressVerifier) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, f *login.Flow, s *session.Session, _ *claims.Claims) (err error) { ctx, span := e.r.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.hook.Verifier.do") r = r.WithContext(ctx) defer otelx.End(span, &err) diff --git a/selfservice/hook/require_verified_address_test.go b/selfservice/hook/require_verified_address_test.go index e4066594b387..b4fdb0001689 100644 --- a/selfservice/hook/require_verified_address_test.go +++ b/selfservice/hook/require_verified_address_test.go @@ -101,7 +101,7 @@ func TestAddressVerifier(t *testing.T) { ID: x.NewUUID(), Identity: &identity.Identity{ID: x.NewUUID(), VerifiableAddresses: uc.verifiableAddresses}, } - err := verifier.ExecuteLoginPostHook(nil, httptest.NewRequest("GET", "http://example.com", nil), node.DefaultGroup, tc.flow, sessions) + err := verifier.ExecuteLoginPostHook(nil, httptest.NewRequest("GET", "http://example.com", nil), node.DefaultGroup, tc.flow, sessions, nil) if tc.neverError || uc.expectedError == nil { assert.NoError(t, err) } else { @@ -158,7 +158,7 @@ func TestAddressVerifier(t *testing.T) { } // Expect verification flow creation and ErrHookAbortFlow - err := verifier.ExecuteLoginPostHook(mockResponse, mockJSONReq, node.DefaultGroup, loginFlow, sessions) + err := verifier.ExecuteLoginPostHook(mockResponse, mockJSONReq, node.DefaultGroup, loginFlow, sessions, nil) assert.ErrorIs(t, err, login.ErrHookAbortFlow) // Verify response contains continueWith and ErrAddressNotVerified @@ -216,7 +216,7 @@ func TestAddressVerifier(t *testing.T) { } // Expect verification flow creation and redirect - err := verifier.ExecuteLoginPostHook(mockResponse, mockBrowserReq, node.DefaultGroup, browserFlow, sessions) + err := verifier.ExecuteLoginPostHook(mockResponse, mockBrowserReq, node.DefaultGroup, browserFlow, sessions, nil) assert.ErrorIs(t, err, login.ErrHookAbortFlow) // Verify redirect occurred @@ -254,7 +254,7 @@ func TestAddressVerifier(t *testing.T) { Identity: identity, } - err := verifier.ExecuteLoginPostHook(nil, mockRequest, node.DefaultGroup, verifiedFlow, sessions) + err := verifier.ExecuteLoginPostHook(nil, mockRequest, node.DefaultGroup, verifiedFlow, sessions, nil) assert.NoError(t, err) }) @@ -279,7 +279,7 @@ func TestAddressVerifier(t *testing.T) { Identity: identity, } - err := verifier.ExecuteLoginPostHook(nil, mockRequest, node.DefaultGroup, noAddressFlow, sessions) + err := verifier.ExecuteLoginPostHook(nil, mockRequest, node.DefaultGroup, noAddressFlow, sessions, nil) assert.ErrorIs(t, err, herodot.ErrMisconfiguration) }) }) diff --git a/selfservice/hook/session_destroyer.go b/selfservice/hook/session_destroyer.go index 16f7aa11b435..0e19b2d85d9a 100644 --- a/selfservice/hook/session_destroyer.go +++ b/selfservice/hook/session_destroyer.go @@ -11,14 +11,17 @@ import ( "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/recovery" "github.com/ory/kratos/selfservice/flow/settings" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/ui/node" "github.com/ory/x/otelx" ) -var _ login.PostHookExecutor = new(SessionDestroyer) -var _ recovery.PostHookExecutor = new(SessionDestroyer) -var _ settings.PostHookPostPersistExecutor = new(SessionDestroyer) +var ( + _ login.PostHookExecutor = new(SessionDestroyer) + _ recovery.PostHookExecutor = new(SessionDestroyer) + _ settings.PostHookPostPersistExecutor = new(SessionDestroyer) +) type ( sessionDestroyerDependencies interface { @@ -34,7 +37,7 @@ func NewSessionDestroyer(r sessionDestroyerDependencies) *SessionDestroyer { return &SessionDestroyer{r: r} } -func (e *SessionDestroyer) ExecuteLoginPostHook(_ http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, _ *login.Flow, s *session.Session) error { +func (e *SessionDestroyer) ExecuteLoginPostHook(_ http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, _ *login.Flow, s *session.Session, _ *claims.Claims) error { return otelx.WithSpan(r.Context(), "selfservice.hook.SessionDestroyer.ExecuteLoginPostHook", func(ctx context.Context) error { if _, err := e.r.SessionPersister().RevokeSessionsIdentityExcept(ctx, s.Identity.ID, s.ID); err != nil { return err diff --git a/selfservice/hook/session_destroyer_test.go b/selfservice/hook/session_destroyer_test.go index e2d0cc21c2e2..833dd2968a61 100644 --- a/selfservice/hook/session_destroyer_test.go +++ b/selfservice/hook/session_destroyer_test.go @@ -52,6 +52,7 @@ func TestSessionDestroyer(t *testing.T) { node.DefaultGroup, nil, &session.Session{Identity: i}, + nil, ) }, }, diff --git a/selfservice/hook/show_verification_ui.go b/selfservice/hook/show_verification_ui.go index c670f50c6257..92cd8e1251c6 100644 --- a/selfservice/hook/show_verification_ui.go +++ b/selfservice/hook/show_verification_ui.go @@ -17,6 +17,7 @@ import ( "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/registration" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/ui/node" "github.com/ory/kratos/x" @@ -59,7 +60,7 @@ func (e *ShowVerificationUIHook) ExecutePostRegistrationPostPersistHook(_ http.R // ExecuteLoginPostHook adds redirect headers and status code if the request is a browser request. // If the request is not a browser request, this hook does nothing. -func (e *ShowVerificationUIHook) ExecuteLoginPostHook(_ http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, f *login.Flow, _ *session.Session) error { +func (e *ShowVerificationUIHook) ExecuteLoginPostHook(_ http.ResponseWriter, r *http.Request, _ node.UiNodeGroup, f *login.Flow, _ *session.Session, _ *claims.Claims) error { return e.execute(r, f) } diff --git a/selfservice/hook/show_verification_ui_test.go b/selfservice/hook/show_verification_ui_test.go index 75601d488039..f8e8e569d266 100644 --- a/selfservice/hook/show_verification_ui_test.go +++ b/selfservice/hook/show_verification_ui_test.go @@ -85,7 +85,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { browserRequest := httptest.NewRequest("GET", "/", nil) f := &login.Flow{} rec := httptest.NewRecorder() - require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", f, nil)) + require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", f, nil, nil)) require.Equal(t, 200, rec.Code) }) @@ -96,7 +96,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { browserRequest.Header.Add("Accept", "application/json") f := &login.Flow{} rec := httptest.NewRecorder() - require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", f, nil)) + require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", f, nil, nil)) require.Equal(t, 200, rec.Code) }) @@ -113,7 +113,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { flow.NewContinueWithVerificationUI(vf.ID, "some@ory.sh", ""), } rec := httptest.NewRecorder() - require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", rf, nil)) + require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", rf, nil, nil)) assert.Equal(t, 200, rec.Code) assert.Equal(t, "/verification?flow="+vf.ID.String(), rf.ReturnToVerification) }) @@ -128,7 +128,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { flow.NewContinueWithSetToken("token"), } rec := httptest.NewRecorder() - require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", rf, nil)) + require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", rf, nil, nil)) assert.Equal(t, 200, rec.Code) }) }) @@ -201,7 +201,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { lf.InternalContext = internalContext rec := httptest.NewRecorder() - require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", lf, nil)) + require.NoError(t, h.ExecuteLoginPostHook(rec, browserRequest, "", lf, nil, nil)) assert.Equal(t, 200, rec.Code) assert.Equal(t, "/verification?flow="+vfID.String(), lf.ReturnToVerification) }) @@ -220,7 +220,7 @@ func TestExecutePostRegistrationPostPersistHook(t *testing.T) { lf.InternalContext = internalContext rec := httptest.NewRecorder() - err = h.ExecuteLoginPostHook(rec, browserRequest, "", lf, nil) + err = h.ExecuteLoginPostHook(rec, browserRequest, "", lf, nil, nil) require.Error(t, err) }) }) diff --git a/selfservice/hook/stub/test_body.jsonnet b/selfservice/hook/stub/test_body.jsonnet index 117ecc587707..f406ad51076e 100644 --- a/selfservice/hook/stub/test_body.jsonnet +++ b/selfservice/hook/stub/test_body.jsonnet @@ -1,10 +1,14 @@ function(ctx) std.prune({ flow_id: ctx.flow.id, - identity_id: if std.objectHas(ctx, "identity") then ctx.identity.id, - session_id: if std.objectHas(ctx, "session") then ctx.session.id, + identity_id: if std.objectHas(ctx, 'identity') then ctx.identity.id, + session_id: if std.objectHas(ctx, 'session') then ctx.session.id, headers: ctx.request_headers, url: ctx.request_url, method: ctx.request_method, cookies: ctx.request_cookies, - transient_payload: if std.objectHas(ctx.flow, "transient_payload") then ctx.flow.transient_payload, + transient_payload: if std.objectHas(ctx.flow, 'transient_payload') then ctx.flow.transient_payload, + nickname: if std.objectHas(ctx, 'claims') then ctx.claims.nickname, + groups: if std.objectHas(ctx, 'claims') && + std.objectHas(ctx.claims, 'raw_claims') && + std.objectHas(ctx.claims.raw_claims, 'groups') then ctx.claims.raw_claims.groups, }) diff --git a/selfservice/hook/verification.go b/selfservice/hook/verification.go index 25816d95e489..b17cdf0e5a76 100644 --- a/selfservice/hook/verification.go +++ b/selfservice/hook/verification.go @@ -21,6 +21,7 @@ import ( "github.com/ory/kratos/selfservice/flow/registration" "github.com/ory/kratos/selfservice/flow/settings" "github.com/ory/kratos/selfservice/flow/verification" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/text" "github.com/ory/kratos/ui/node" @@ -72,7 +73,7 @@ func (e *Verifier) ExecuteSettingsPostPersistHook(w http.ResponseWriter, r *http }) } -func (e *Verifier) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, f *login.Flow, s *session.Session) (err error) { +func (e *Verifier) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, f *login.Flow, s *session.Session, c *claims.Claims) (err error) { ctx, span := e.r.Tracer(r.Context()).Tracer().Start(r.Context(), "selfservice.hook.Verifier.ExecuteLoginPostHook") r = r.WithContext(ctx) defer otelx.End(span, &err) diff --git a/selfservice/hook/verification_test.go b/selfservice/hook/verification_test.go index 4e78f320491e..181371670c4c 100644 --- a/selfservice/hook/verification_test.go +++ b/selfservice/hook/verification_test.go @@ -53,7 +53,7 @@ func TestVerifier(t *testing.T) { name: "login", execHook: func(h *hook.Verifier, i *identity.Identity, f flow.Flow) error { return h.ExecuteLoginPostHook( - httptest.NewRecorder(), u, node.CodeGroup, f.(*login.Flow), &session.Session{ID: x.NewUUID(), Identity: i}) + httptest.NewRecorder(), u, node.CodeGroup, f.(*login.Flow), &session.Session{ID: x.NewUUID(), Identity: i}, nil) }, originalFlow: func() interface { flow.InternalContexter @@ -158,7 +158,7 @@ func TestVerifier(t *testing.T) { h := hook.NewVerifier(reg) i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) f := &login.Flow{RequestedAAL: "aal2"} - require.NoError(t, h.ExecuteLoginPostHook(httptest.NewRecorder(), u, node.CodeGroup, f, &session.Session{ID: x.NewUUID(), Identity: i})) + require.NoError(t, h.ExecuteLoginPostHook(httptest.NewRecorder(), u, node.CodeGroup, f, &session.Session{ID: x.NewUUID(), Identity: i}, nil)) messages, err := reg.CourierPersister().NextMessages(context.Background(), 12) require.EqualError(t, err, "queue is empty") diff --git a/selfservice/hook/web_hook.go b/selfservice/hook/web_hook.go index e73857317551..cd7f5371ab0f 100644 --- a/selfservice/hook/web_hook.go +++ b/selfservice/hook/web_hook.go @@ -34,6 +34,7 @@ import ( "github.com/ory/kratos/selfservice/flow/registration" "github.com/ory/kratos/selfservice/flow/settings" "github.com/ory/kratos/selfservice/flow/verification" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/text" "github.com/ory/kratos/ui/node" @@ -85,6 +86,7 @@ type ( RequestCookies map[string]string `json:"request_cookies"` Identity *identity.Identity `json:"identity,omitempty"` Session *session.Session `json:"session,omitempty"` + Claims *claims.Claims `json:"claims,omitempty"` } WebHook struct { @@ -135,7 +137,7 @@ func (e *WebHook) ExecuteLoginPreHook(_ http.ResponseWriter, req *http.Request, }) } -func (e *WebHook) ExecuteLoginPostHook(_ http.ResponseWriter, req *http.Request, _ node.UiNodeGroup, flow *login.Flow, session *session.Session) error { +func (e *WebHook) ExecuteLoginPostHook(_ http.ResponseWriter, req *http.Request, _ node.UiNodeGroup, flow *login.Flow, session *session.Session, claims *claims.Claims) error { return otelx.WithSpan(req.Context(), "selfservice.hook.WebHook.ExecuteLoginPostHook", func(ctx context.Context) error { return e.execute(ctx, &templateContext{ Flow: flow, @@ -145,6 +147,7 @@ func (e *WebHook) ExecuteLoginPostHook(_ http.ResponseWriter, req *http.Request, RequestCookies: cookies(req), Identity: session.Identity, Session: session, + Claims: claims, }) }) } diff --git a/selfservice/hook/web_hook_integration_test.go b/selfservice/hook/web_hook_integration_test.go index 03813c97e8d1..78e299122b11 100644 --- a/selfservice/hook/web_hook_integration_test.go +++ b/selfservice/hook/web_hook_integration_test.go @@ -42,6 +42,7 @@ import ( "github.com/ory/kratos/selfservice/flow/settings" "github.com/ory/kratos/selfservice/flow/verification" "github.com/ory/kratos/selfservice/hook" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/text" "github.com/ory/kratos/ui/node" @@ -62,6 +63,13 @@ var transientPayload = json.RawMessage(`{ } }`) +var oidcClaims = claims.Claims{ + Nickname: "nicky", + RawClaims: map[string]interface{}{ + "groups": []string{"first", "second"}, + }, +} + func TestWebHooks(t *testing.T) { ctx := context.Background() conf, reg := internal.NewFastRegistryWithMocks(t) @@ -190,7 +198,7 @@ func TestWebHooks(t *testing.T) { return body } - bodyWithFlowAndIdentityAndSessionAndTransientPayload := func(req *http.Request, f flow.Flow, s *session.Session, tp json.RawMessage) string { + bodyWithFlowAndIdentityAndSessionAndClaimsAndTransientPayload := func(req *http.Request, f flow.Flow, s *session.Session, c *claims.Claims, tp json.RawMessage) string { body := fmt.Sprintf(`{ "flow_id": "%s", "identity_id": "%s", @@ -202,8 +210,10 @@ func TestWebHooks(t *testing.T) { "Some-Cookie-2": "Some-other-Cookie-Value", "Some-Cookie-3": "Third-Cookie-Value" }, - "transient_payload": %s - }`, f.GetID(), s.Identity.ID, s.ID, req.Method, "http://www.ory.sh/some_end_point", string(tp)) + "transient_payload": %s, + "nickname": "%s", + "groups": ["%s", "%s"] + }`, f.GetID(), s.Identity.ID, s.ID, req.Method, "http://www.ory.sh/some_end_point", string(tp), c.Nickname, c.RawClaims["groups"].([]string)[0], c.RawClaims["groups"].([]string)[1]) if len(req.Header) != 0 { if ua := req.Header.Get("User-Agent"); ua != "" { body, _ = sjson.Set(body, "headers.User-Agent", []string{ua}) @@ -233,10 +243,10 @@ func TestWebHooks(t *testing.T) { uc: "Post Login Hook", createFlow: func() flow.Flow { return &login.Flow{ID: x.NewUUID(), TransientPayload: transientPayload} }, callWebHook: func(wh *hook.WebHook, req *http.Request, f flow.Flow, s *session.Session) error { - return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s) + return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s, &oidcClaims) }, expectedBody: func(req *http.Request, f flow.Flow, s *session.Session) string { - return bodyWithFlowAndIdentityAndSessionAndTransientPayload(req, f, s, transientPayload) + return bodyWithFlowAndIdentityAndSessionAndClaimsAndTransientPayload(req, f, s, &oidcClaims, transientPayload) }, }, { @@ -469,7 +479,7 @@ func TestWebHooks(t *testing.T) { uc: "Post Login Hook - no block", createFlow: func() flow.Flow { return &login.Flow{ID: x.NewUUID()} }, callWebHook: func(wh *hook.WebHook, req *http.Request, f flow.Flow, s *session.Session) error { - return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s) + return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s, nil) }, webHookResponse: func() (int, []byte) { return http.StatusOK, []byte{} @@ -480,7 +490,7 @@ func TestWebHooks(t *testing.T) { uc: "Post Login Hook - block", createFlow: func() flow.Flow { return &login.Flow{ID: x.NewUUID()} }, callWebHook: func(wh *hook.WebHook, req *http.Request, f flow.Flow, s *session.Session) error { - return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s) + return wh.ExecuteLoginPostHook(nil, req, node.PasswordGroup, f.(*login.Flow), s, nil) }, webHookResponse: func() (int, []byte) { return http.StatusBadRequest, webHookResponse @@ -1058,7 +1068,7 @@ func TestDisallowPrivateIPRanges(t *testing.T) { Method: "GET", TemplateURI: "file://stub/test_body.jsonnet", }) - err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s) + err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s, nil) require.Error(t, err) require.Contains(t, err.Error(), "is not a permitted destination") }) @@ -1070,7 +1080,7 @@ func TestDisallowPrivateIPRanges(t *testing.T) { Method: "GET", TemplateURI: "file://stub/test_body.jsonnet", }) - err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s) + err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s, nil) require.Error(t, err, "the target does not exist and we still receive an error") require.NotContains(t, err.Error(), "is not a permitted destination", "but the error is not related to the IP range.") }) @@ -1091,7 +1101,7 @@ func TestDisallowPrivateIPRanges(t *testing.T) { Method: "GET", TemplateURI: "http://192.168.178.0/test_body.jsonnet", }) - err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s) + err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s, nil) require.Error(t, err) require.Contains(t, err.Error(), "is not a permitted destination") }) @@ -1149,7 +1159,7 @@ func TestAsyncWebhook(t *testing.T) { Ignore: true, }, }) - err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s) + err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s, nil) require.NoError(t, err) // execution returns immediately for async webhook select { case <-time.After(5 * time.Second): diff --git a/selfservice/strategy/oidc/claims/claims.go b/selfservice/strategy/oidc/claims/claims.go new file mode 100644 index 000000000000..924b8eee6738 --- /dev/null +++ b/selfservice/strategy/oidc/claims/claims.go @@ -0,0 +1,50 @@ +package claims + +import ( + "github.com/pkg/errors" + + "github.com/ory/herodot" + "github.com/ory/kratos/x" +) + +type Claims struct { + Issuer string `json:"iss,omitempty"` + Subject string `json:"sub,omitempty"` + Object string `json:"oid,omitempty"` + Name string `json:"name,omitempty"` + GivenName string `json:"given_name,omitempty"` + FamilyName string `json:"family_name,omitempty"` + LastName string `json:"last_name,omitempty"` + MiddleName string `json:"middle_name,omitempty"` + Nickname string `json:"nickname,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty"` + Profile string `json:"profile,omitempty"` + Picture string `json:"picture,omitempty"` + Website string `json:"website,omitempty"` + Email string `json:"email,omitempty"` + // ConvertibleBoolean is used as Apple casually sends the email_verified field as a string. + EmailVerified x.ConvertibleBoolean `json:"email_verified,omitempty"` + Gender string `json:"gender,omitempty"` + Birthdate string `json:"birthdate,omitempty"` + Zoneinfo string `json:"zoneinfo,omitempty"` + Locale Locale `json:"locale,omitempty"` + PhoneNumber string `json:"phone_number,omitempty"` + PhoneNumberVerified bool `json:"phone_number_verified,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + HD string `json:"hd,omitempty"` + Team string `json:"team,omitempty"` + Nonce string `json:"nonce,omitempty"` + NonceSupported bool `json:"nonce_supported,omitempty"` + RawClaims map[string]interface{} `json:"raw_claims,omitempty"` +} + +// Validate checks if the claims are valid. +func (c *Claims) Validate() error { + if c.Subject == "" { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("provider did not return a subject")) + } + if c.Issuer == "" { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("issuer not set in claims")) + } + return nil +} diff --git a/selfservice/strategy/oidc/claims/claims_test.go b/selfservice/strategy/oidc/claims/claims_test.go new file mode 100644 index 000000000000..47ada4a4695c --- /dev/null +++ b/selfservice/strategy/oidc/claims/claims_test.go @@ -0,0 +1,18 @@ +package claims_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/selfservice/strategy/oidc/claims" +) + +func TestClaimsValidate(t *testing.T) { + require.Error(t, new(claims.Claims).Validate()) + require.Error(t, (&claims.Claims{Issuer: "not-empty"}).Validate()) + require.Error(t, (&claims.Claims{Issuer: "not-empty"}).Validate()) + require.Error(t, (&claims.Claims{Subject: "not-empty"}).Validate()) + require.Error(t, (&claims.Claims{Subject: "not-empty"}).Validate()) + require.NoError(t, (&claims.Claims{Issuer: "not-empty", Subject: "not-empty"}).Validate()) +} diff --git a/selfservice/strategy/oidc/claims/locale.go b/selfservice/strategy/oidc/claims/locale.go new file mode 100644 index 000000000000..07fb8576becc --- /dev/null +++ b/selfservice/strategy/oidc/claims/locale.go @@ -0,0 +1,29 @@ +package claims + +import ( + "encoding/json" + "strings" +) + +type Locale string + +func (l *Locale) UnmarshalJSON(data []byte) error { + var linkedInLocale struct { + Language string `json:"language"` + Country string `json:"country"` + } + if err := json.Unmarshal(data, &linkedInLocale); err == nil { + switch { + case linkedInLocale.Language == "": + *l = Locale(linkedInLocale.Country) + case linkedInLocale.Country == "": + *l = Locale(linkedInLocale.Language) + default: + *l = Locale(strings.Join([]string{linkedInLocale.Language, linkedInLocale.Country}, "-")) + } + + return nil + } + + return json.Unmarshal(data, (*string)(l)) +} diff --git a/selfservice/strategy/oidc/provider.go b/selfservice/strategy/oidc/provider.go index 8d2e5edad189..dd7a7433ffe8 100644 --- a/selfservice/strategy/oidc/provider.go +++ b/selfservice/strategy/oidc/provider.go @@ -5,19 +5,14 @@ package oidc import ( "context" - "encoding/json" "net/http" "net/url" - "strings" "github.com/dghubble/oauth1" - "github.com/pkg/errors" - "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "golang.org/x/oauth2" - - "github.com/ory/kratos/x" ) type ( @@ -28,13 +23,13 @@ type ( Provider AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption OAuth2(ctx context.Context) (*oauth2.Config, error) - Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) + Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) } OAuth1Provider interface { Provider OAuth1(ctx context.Context) *oauth1.Config AuthURL(ctx context.Context, state string) (string, error) - Claims(ctx context.Context, token *oauth1.Token) (*Claims, error) + Claims(ctx context.Context, token *oauth1.Token) (*claims.Claims, error) ExchangeToken(ctx context.Context, req *http.Request) (*oauth1.Token, error) } ) @@ -44,76 +39,11 @@ type OAuth2TokenExchanger interface { } type IDTokenVerifier interface { - Verify(ctx context.Context, rawIDToken string) (*Claims, error) + Verify(ctx context.Context, rawIDToken string) (*claims.Claims, error) } type NonceValidationSkipper interface { - CanSkipNonce(*Claims) bool -} - -type Claims struct { - Issuer string `json:"iss,omitempty"` - Subject string `json:"sub,omitempty"` - Object string `json:"oid,omitempty"` - Name string `json:"name,omitempty"` - GivenName string `json:"given_name,omitempty"` - FamilyName string `json:"family_name,omitempty"` - LastName string `json:"last_name,omitempty"` - MiddleName string `json:"middle_name,omitempty"` - Nickname string `json:"nickname,omitempty"` - PreferredUsername string `json:"preferred_username,omitempty"` - Profile string `json:"profile,omitempty"` - Picture string `json:"picture,omitempty"` - Website string `json:"website,omitempty"` - Email string `json:"email,omitempty"` - // ConvertibleBoolean is used as Apple casually sends the email_verified field as a string. - EmailVerified x.ConvertibleBoolean `json:"email_verified,omitempty"` - Gender string `json:"gender,omitempty"` - Birthdate string `json:"birthdate,omitempty"` - Zoneinfo string `json:"zoneinfo,omitempty"` - Locale Locale `json:"locale,omitempty"` - PhoneNumber string `json:"phone_number,omitempty"` - PhoneNumberVerified bool `json:"phone_number_verified,omitempty"` - UpdatedAt int64 `json:"updated_at,omitempty"` - HD string `json:"hd,omitempty"` - Team string `json:"team,omitempty"` - Nonce string `json:"nonce,omitempty"` - NonceSupported bool `json:"nonce_supported,omitempty"` - RawClaims map[string]interface{} `json:"raw_claims,omitempty"` -} - -type Locale string - -func (l *Locale) UnmarshalJSON(data []byte) error { - var linkedInLocale struct { - Language string `json:"language"` - Country string `json:"country"` - } - if err := json.Unmarshal(data, &linkedInLocale); err == nil { - switch { - case linkedInLocale.Language == "": - *l = Locale(linkedInLocale.Country) - case linkedInLocale.Country == "": - *l = Locale(linkedInLocale.Language) - default: - *l = Locale(strings.Join([]string{linkedInLocale.Language, linkedInLocale.Country}, "-")) - } - - return nil - } - - return json.Unmarshal(data, (*string)(l)) -} - -// Validate checks if the claims are valid. -func (c *Claims) Validate() error { - if c.Subject == "" { - return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("provider did not return a subject")) - } - if c.Issuer == "" { - return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("issuer not set in claims")) - } - return nil + CanSkipNonce(*claims.Claims) bool } // UpstreamParameters returns a list of oauth2.AuthCodeOption based on the upstream parameters. diff --git a/selfservice/strategy/oidc/provider_apple.go b/selfservice/strategy/oidc/provider_apple.go index bc5523b22bbd..a2ae14c3ec45 100644 --- a/selfservice/strategy/oidc/provider_apple.go +++ b/selfservice/strategy/oidc/provider_apple.go @@ -15,6 +15,8 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/golang-jwt/jwt/v4" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" + "github.com/pkg/errors" "golang.org/x/oauth2" @@ -114,7 +116,7 @@ func (a *ProviderApple) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return options } -func (a *ProviderApple) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (a *ProviderApple) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { claims, err := a.ProviderGenericOIDC.Claims(ctx, exchange, query) if err != nil { return claims, err @@ -128,7 +130,7 @@ func (a *ProviderApple) Claims(ctx context.Context, exchange *oauth2.Token, quer // The info is sent as an extra query parameter to the redirect URL. // See https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/configuring_your_webpage_for_sign_in_with_apple#3331292 // Note that there's no way to make sure the info hasn't been tampered with. -func (a *ProviderApple) DecodeQuery(query url.Values, claims *Claims) { +func (a *ProviderApple) DecodeQuery(query url.Values, claims *claims.Claims) { var user struct { Name *struct { FirstName *string `json:"firstName"` @@ -158,7 +160,7 @@ var _ IDTokenVerifier = new(ProviderApple) const issuerURLApple = "https://appleid.apple.com" -func (a *ProviderApple) Verify(ctx context.Context, rawIDToken string) (*Claims, error) { +func (a *ProviderApple) Verify(ctx context.Context, rawIDToken string) (*claims.Claims, error) { keySet := oidc.NewRemoteKeySet(ctx, a.JWKSUrl) ctx = oidc.ClientContext(ctx, a.reg.HTTPClient(ctx).HTTPClient) @@ -167,6 +169,6 @@ func (a *ProviderApple) Verify(ctx context.Context, rawIDToken string) (*Claims, var _ NonceValidationSkipper = new(ProviderApple) -func (a *ProviderApple) CanSkipNonce(c *Claims) bool { +func (a *ProviderApple) CanSkipNonce(c *claims.Claims) bool { return c.NonceSupported } diff --git a/selfservice/strategy/oidc/provider_apple_test.go b/selfservice/strategy/oidc/provider_apple_test.go index 422ae643708a..39c6c95462f4 100644 --- a/selfservice/strategy/oidc/provider_apple_test.go +++ b/selfservice/strategy/oidc/provider_apple_test.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/internal" "github.com/ory/kratos/selfservice/strategy/oidc" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" ) func TestDecodeQuery(t *testing.T) { @@ -28,15 +29,15 @@ func TestDecodeQuery(t *testing.T) { } for k, tc := range []struct { - claims *oidc.Claims + claims *claims.Claims familyName string givenName string lastName string }{ - {claims: &oidc.Claims{}, familyName: "last", givenName: "first", lastName: "last"}, - {claims: &oidc.Claims{FamilyName: "fam"}, familyName: "fam", givenName: "first", lastName: "last"}, - {claims: &oidc.Claims{FamilyName: "fam", GivenName: "giv"}, familyName: "fam", givenName: "giv", lastName: "last"}, - {claims: &oidc.Claims{FamilyName: "fam", GivenName: "giv", LastName: "las"}, familyName: "fam", givenName: "giv", lastName: "las"}, + {claims: &claims.Claims{}, familyName: "last", givenName: "first", lastName: "last"}, + {claims: &claims.Claims{FamilyName: "fam"}, familyName: "fam", givenName: "first", lastName: "last"}, + {claims: &claims.Claims{FamilyName: "fam", GivenName: "giv"}, familyName: "fam", givenName: "giv", lastName: "last"}, + {claims: &claims.Claims{FamilyName: "fam", GivenName: "giv", LastName: "las"}, familyName: "fam", givenName: "giv", lastName: "las"}, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { a := oidc.NewProviderApple(&oidc.Configuration{}, nil).(*oidc.ProviderApple) diff --git a/selfservice/strategy/oidc/provider_auth0.go b/selfservice/strategy/oidc/provider_auth0.go index 50f4c03fc45b..5d71ab7f9a65 100644 --- a/selfservice/strategy/oidc/provider_auth0.go +++ b/selfservice/strategy/oidc/provider_auth0.go @@ -11,6 +11,7 @@ import ( "path" "time" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/x/stringsx" @@ -73,7 +74,7 @@ func (g *ProviderAuth0) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2(ctx) } -func (g *ProviderAuth0) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderAuth0) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -115,7 +116,7 @@ func (g *ProviderAuth0) Claims(ctx context.Context, exchange *oauth2.Token, quer } // Once we get here, we know that if there is an updated_at field in the json, it is the correct type. - var claims Claims + var claims claims.Claims if err := json.Unmarshal(b, &claims); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } diff --git a/selfservice/strategy/oidc/provider_dingtalk.go b/selfservice/strategy/oidc/provider_dingtalk.go index 466c7d76406d..40c33985a249 100644 --- a/selfservice/strategy/oidc/provider_dingtalk.go +++ b/selfservice/strategy/oidc/provider_dingtalk.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "golang.org/x/oauth2" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/hashicorp/go-retryablehttp" @@ -42,7 +43,7 @@ func (g *ProviderDingTalk) Config() *Configuration { } func (g *ProviderDingTalk) oauth2(ctx context.Context) *oauth2.Config { - var endpoint = oauth2.Endpoint{ + endpoint := oauth2.Endpoint{ AuthURL: "https://login.dingtalk.com/oauth2/auth", TokenURL: "https://api.dingtalk.com/v1.0/oauth2/userAccessToken", } @@ -124,7 +125,7 @@ func (g *ProviderDingTalk) ExchangeOAuth2Token(ctx context.Context, code string, return token, nil } -func (g *ProviderDingTalk) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*Claims, error) { +func (g *ProviderDingTalk) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*claims.Claims, error) { userInfoURL := "https://api.dingtalk.com/v1.0/contact/users/me" accessToken := exchange.AccessToken @@ -162,7 +163,7 @@ func (g *ProviderDingTalk) Claims(ctx context.Context, exchange *oauth2.Token, _ return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("userResp.ErrCode = %s, userResp.ErrMsg = %s", user.ErrCode, user.ErrMsg)) } - return &Claims{ + return &claims.Claims{ Issuer: userInfoURL, Subject: user.OpenId, Nickname: user.Nick, diff --git a/selfservice/strategy/oidc/provider_discord.go b/selfservice/strategy/oidc/provider_discord.go index 97c64a4b414e..55322269e191 100644 --- a/selfservice/strategy/oidc/provider_discord.go +++ b/selfservice/strategy/oidc/provider_discord.go @@ -8,6 +8,7 @@ import ( "fmt" "net/url" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/x" "github.com/bwmarrin/discordgo" @@ -68,7 +69,7 @@ func (d *ProviderDiscord) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { } } -func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), " ") for _, check := range d.Config().Scope { if !stringslice.Has(grantedScopes, check) { @@ -86,7 +87,7 @@ func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token, qu return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Issuer: discordgo.EndpointOauth2, Subject: user.ID, Name: fmt.Sprintf("%s#%s", user.Username, user.Discriminator), @@ -95,7 +96,7 @@ func (d *ProviderDiscord) Claims(ctx context.Context, exchange *oauth2.Token, qu Picture: user.AvatarURL(""), Email: user.Email, EmailVerified: x.ConvertibleBoolean(user.Verified), - Locale: Locale(user.Locale), + Locale: claims.Locale(user.Locale), } return claims, nil diff --git a/selfservice/strategy/oidc/provider_facebook.go b/selfservice/strategy/oidc/provider_facebook.go index a7d2ec689eaf..96844c08ab94 100644 --- a/selfservice/strategy/oidc/provider_facebook.go +++ b/selfservice/strategy/oidc/provider_facebook.go @@ -16,6 +16,7 @@ import ( "github.com/ory/x/httpx" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/x" "github.com/pkg/errors" @@ -64,7 +65,7 @@ func (g *ProviderFacebook) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2ConfigFromEndpoint(ctx, endpoint), nil } -func (g *ProviderFacebook) Claims(ctx context.Context, token *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderFacebook) Claims(ctx context.Context, token *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -121,7 +122,7 @@ func (g *ProviderFacebook) Claims(ctx context.Context, token *oauth2.Token, quer user.EmailVerified = true } - return &Claims{ + return &claims.Claims{ Issuer: u.String(), Subject: user.Id, Name: user.Name, diff --git a/selfservice/strategy/oidc/provider_generic_oidc.go b/selfservice/strategy/oidc/provider_generic_oidc.go index 3bdb8d24ec31..8cb9cf1426af 100644 --- a/selfservice/strategy/oidc/provider_generic_oidc.go +++ b/selfservice/strategy/oidc/provider_generic_oidc.go @@ -13,6 +13,7 @@ import ( "golang.org/x/oauth2" "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" ) var _ OAuth2Provider = (*ProviderGenericOIDC)(nil) @@ -95,13 +96,13 @@ func (g *ProviderGenericOIDC) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption return options } -func (g *ProviderGenericOIDC) verifyAndDecodeClaimsWithProvider(ctx context.Context, provider *gooidc.Provider, raw string) (*Claims, error) { +func (g *ProviderGenericOIDC) verifyAndDecodeClaimsWithProvider(ctx context.Context, provider *gooidc.Provider, raw string) (*claims.Claims, error) { token, err := provider.VerifierContext(g.withHTTPClientContext(ctx), &gooidc.Config{ClientID: g.config.ClientID}).Verify(ctx, raw) if err != nil { return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf("%s", err)) } - var claims Claims + var claims claims.Claims if err := token.Claims(&claims); err != nil { return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf("%s", err)) } @@ -115,7 +116,7 @@ func (g *ProviderGenericOIDC) verifyAndDecodeClaimsWithProvider(ctx context.Cont return &claims, nil } -func (g *ProviderGenericOIDC) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*Claims, error) { +func (g *ProviderGenericOIDC) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*claims.Claims, error) { switch g.config.ClaimsSource { case ClaimsSourceIDToken, "": return g.claimsFromIDToken(ctx, exchange) @@ -127,7 +128,7 @@ func (g *ProviderGenericOIDC) Claims(ctx context.Context, exchange *oauth2.Token WithReasonf("Unknown claims source: %q", g.config.ClaimsSource)) } -func (g *ProviderGenericOIDC) claimsFromUserInfo(ctx context.Context, exchange *oauth2.Token) (*Claims, error) { +func (g *ProviderGenericOIDC) claimsFromUserInfo(ctx context.Context, exchange *oauth2.Token) (*claims.Claims, error) { p, err := g.provider(ctx) if err != nil { return nil, err @@ -138,7 +139,7 @@ func (g *ProviderGenericOIDC) claimsFromUserInfo(ctx context.Context, exchange * return nil, err } - var claims Claims + var claims claims.Claims if err = userInfo.Claims(&claims); err != nil { return nil, err } @@ -177,7 +178,7 @@ func (g *ProviderGenericOIDC) claimsFromUserInfo(ctx context.Context, exchange * return &claims, nil } -func (g *ProviderGenericOIDC) claimsFromIDToken(ctx context.Context, exchange *oauth2.Token) (*Claims, error) { +func (g *ProviderGenericOIDC) claimsFromIDToken(ctx context.Context, exchange *oauth2.Token) (*claims.Claims, error) { p, raw, err := g.idTokenAndProvider(ctx, exchange) if err != nil { return nil, err diff --git a/selfservice/strategy/oidc/provider_github.go b/selfservice/strategy/oidc/provider_github.go index 650778cd1506..d86591e68358 100644 --- a/selfservice/strategy/oidc/provider_github.go +++ b/selfservice/strategy/oidc/provider_github.go @@ -8,6 +8,7 @@ import ( "fmt" "net/url" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/x" "github.com/pkg/errors" @@ -62,7 +63,7 @@ func (g *ProviderGitHub) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return []oauth2.AuthCodeOption{} } -func (g *ProviderGitHub) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderGitHub) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), ",") for _, check := range g.Config().Scope { if !stringslice.Has(grantedScopes, check) { @@ -78,7 +79,7 @@ func (g *ProviderGitHub) Claims(ctx context.Context, exchange *oauth2.Token, que return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Subject: fmt.Sprintf("%d", user.GetID()), Issuer: github.Endpoint.TokenURL, Name: user.GetName(), diff --git a/selfservice/strategy/oidc/provider_github_app.go b/selfservice/strategy/oidc/provider_github_app.go index 83cfd9bdb882..dfa55d225849 100644 --- a/selfservice/strategy/oidc/provider_github_app.go +++ b/selfservice/strategy/oidc/provider_github_app.go @@ -8,6 +8,7 @@ import ( "fmt" "net/url" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/x" "github.com/ory/x/httpx" @@ -57,7 +58,7 @@ func (g *ProviderGitHubApp) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return []oauth2.AuthCodeOption{} } -func (g *ProviderGitHubApp) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderGitHubApp) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { ctx, client := httpx.SetOAuth2(ctx, g.reg.HTTPClient(ctx), g.oauth2(ctx), exchange) gh := ghapi.NewClient(client.HTTPClient) @@ -66,7 +67,7 @@ func (g *ProviderGitHubApp) Claims(ctx context.Context, exchange *oauth2.Token, return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Subject: fmt.Sprintf("%d", user.GetID()), Issuer: github.Endpoint.TokenURL, Name: user.GetName(), diff --git a/selfservice/strategy/oidc/provider_gitlab.go b/selfservice/strategy/oidc/provider_gitlab.go index a0cf7508c944..32fd90d0c4c9 100644 --- a/selfservice/strategy/oidc/provider_gitlab.go +++ b/selfservice/strategy/oidc/provider_gitlab.go @@ -9,6 +9,7 @@ import ( "net/url" "path" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/stringsx" "github.com/hashicorp/go-retryablehttp" @@ -71,7 +72,7 @@ func (g *ProviderGitLab) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2(ctx) } -func (g *ProviderGitLab) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderGitLab) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -100,7 +101,7 @@ func (g *ProviderGitLab) Claims(ctx context.Context, exchange *oauth2.Token, que return nil, err } - var claims Claims + var claims claims.Claims if err := json.NewDecoder(resp.Body).Decode(&claims); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } diff --git a/selfservice/strategy/oidc/provider_google.go b/selfservice/strategy/oidc/provider_google.go index b1f758bd726b..59178d9256b7 100644 --- a/selfservice/strategy/oidc/provider_google.go +++ b/selfservice/strategy/oidc/provider_google.go @@ -9,6 +9,7 @@ import ( gooidc "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/stringslice" ) @@ -75,7 +76,7 @@ var _ IDTokenVerifier = new(ProviderGoogle) const issuerUrlGoogle = "https://accounts.google.com" -func (p *ProviderGoogle) Verify(ctx context.Context, rawIDToken string) (*Claims, error) { +func (p *ProviderGoogle) Verify(ctx context.Context, rawIDToken string) (*claims.Claims, error) { keySet := gooidc.NewRemoteKeySet(ctx, p.JWKSUrl) ctx = gooidc.ClientContext(ctx, p.reg.HTTPClient(ctx).HTTPClient) @@ -84,7 +85,7 @@ func (p *ProviderGoogle) Verify(ctx context.Context, rawIDToken string) (*Claims var _ NonceValidationSkipper = new(ProviderGoogle) -func (a *ProviderGoogle) CanSkipNonce(c *Claims) bool { +func (a *ProviderGoogle) CanSkipNonce(c *claims.Claims) bool { // Not all SDKs support nonce validation, so we skip it if no nonce is present in the claims of the ID Token. return c.Nonce == "" } diff --git a/selfservice/strategy/oidc/provider_lark.go b/selfservice/strategy/oidc/provider_lark.go index d66d5c0b2230..8ba95f15c692 100644 --- a/selfservice/strategy/oidc/provider_lark.go +++ b/selfservice/strategy/oidc/provider_lark.go @@ -13,6 +13,7 @@ import ( "golang.org/x/oauth2" "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" ) @@ -48,7 +49,6 @@ func (g *ProviderLark) Config() *Configuration { } func (g *ProviderLark) OAuth2(ctx context.Context) (*oauth2.Config, error) { - return &oauth2.Config{ ClientID: g.config.ClientID, ClientSecret: g.config.ClientSecret, @@ -57,10 +57,9 @@ func (g *ProviderLark) OAuth2(ctx context.Context) (*oauth2.Config, error) { Scopes: g.config.Scope, RedirectURL: g.config.Redir(g.reg.Config().OIDCRedirectURIBase(ctx)), }, nil - } -func (g *ProviderLark) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderLark) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { // larkClaim is defined in the https://open.feishu.cn/document/common-capabilities/sso/api/get-user-info type larkClaim struct { Sub string `json:"sub"` @@ -103,7 +102,7 @@ func (g *ProviderLark) Claims(ctx context.Context, exchange *oauth2.Token, query return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - return &Claims{ + return &claims.Claims{ Issuer: larkUserEndpoint, Subject: user.OpenID, Name: user.Name, diff --git a/selfservice/strategy/oidc/provider_linkedin.go b/selfservice/strategy/oidc/provider_linkedin.go index 475dd738b29f..166808bcea9f 100644 --- a/selfservice/strategy/oidc/provider_linkedin.go +++ b/selfservice/strategy/oidc/provider_linkedin.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/otelx" "github.com/hashicorp/go-retryablehttp" @@ -167,7 +168,7 @@ func (l *ProviderLinkedIn) ProfilePicture(profile *LinkedInProfile) string { return identifiers[0].Identifier } -func (l *ProviderLinkedIn) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (_ *Claims, err error) { +func (l *ProviderLinkedIn) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (_ *claims.Claims, err error) { ctx, span := l.reg.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.oidc.ProviderLinkedIn.Claims") defer otelx.End(span, &err) @@ -187,7 +188,7 @@ func (l *ProviderLinkedIn) Claims(ctx context.Context, exchange *oauth2.Token, q return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Subject: profile.ID, Issuer: "https://login.linkedin.com/", Email: email.Elements[0].Handle.EmailAddress, diff --git a/selfservice/strategy/oidc/provider_linkedin_test.go b/selfservice/strategy/oidc/provider_linkedin_test.go index d5b9df86d25a..cff4f0edb9f7 100644 --- a/selfservice/strategy/oidc/provider_linkedin_test.go +++ b/selfservice/strategy/oidc/provider_linkedin_test.go @@ -19,6 +19,7 @@ import ( "github.com/ory/kratos/internal" "github.com/ory/kratos/selfservice/strategy/oidc" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" ) func TestProviderLinkedin_Claims(t *testing.T) { @@ -122,7 +123,7 @@ func TestProviderLinkedin_Claims(t *testing.T) { ) require.NoError(t, err) - assert.Equal(t, &oidc.Claims{ + assert.Equal(t, &claims.Claims{ Issuer: "https://login.linkedin.com/", Subject: "5foOWOiYXD", GivenName: "John", @@ -198,7 +199,7 @@ func TestProviderLinkedin_No_Picture(t *testing.T) { ) require.NoError(t, err) - assert.Equal(t, &oidc.Claims{ + assert.Equal(t, &claims.Claims{ Issuer: "https://login.linkedin.com/", Subject: "5foOWOiYXD", GivenName: "John", diff --git a/selfservice/strategy/oidc/provider_microsoft.go b/selfservice/strategy/oidc/provider_microsoft.go index 408a11096573..e22864e02c99 100644 --- a/selfservice/strategy/oidc/provider_microsoft.go +++ b/selfservice/strategy/oidc/provider_microsoft.go @@ -17,6 +17,7 @@ import ( "golang.org/x/oauth2" "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" ) @@ -52,7 +53,7 @@ func (m *ProviderMicrosoft) OAuth2(ctx context.Context) (*oauth2.Config, error) return m.oauth2ConfigFromEndpoint(ctx, endpoint), nil } -func (m *ProviderMicrosoft) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*Claims, error) { +func (m *ProviderMicrosoft) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*claims.Claims, error) { raw, ok := exchange.Extra("id_token").(string) if !ok || len(raw) == 0 { return nil, errors.WithStack(ErrIDTokenMissing) @@ -83,7 +84,7 @@ func (m *ProviderMicrosoft) Claims(ctx context.Context, exchange *oauth2.Token, return m.updateSubject(ctx, claims, exchange) } -func (m *ProviderMicrosoft) updateSubject(ctx context.Context, claims *Claims, exchange *oauth2.Token) (*Claims, error) { +func (m *ProviderMicrosoft) updateSubject(ctx context.Context, claims *claims.Claims, exchange *oauth2.Token) (*claims.Claims, error) { if m.config.SubjectSource == "me" { o, err := m.OAuth2(ctx) if err != nil { diff --git a/selfservice/strategy/oidc/provider_netid.go b/selfservice/strategy/oidc/provider_netid.go index 9e4a79aba581..390893ab27fc 100644 --- a/selfservice/strategy/oidc/provider_netid.go +++ b/selfservice/strategy/oidc/provider_netid.go @@ -17,6 +17,7 @@ import ( "golang.org/x/oauth2" "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/x/urlx" ) @@ -71,7 +72,7 @@ func (n *ProviderNetID) oAuth2(ctx context.Context) (*oauth2.Config, error) { }, nil } -func (n *ProviderNetID) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*Claims, error) { +func (n *ProviderNetID) Claims(ctx context.Context, exchange *oauth2.Token, _ url.Values) (*claims.Claims, error) { o, err := n.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -103,21 +104,21 @@ func (n *ProviderNetID) Claims(ctx context.Context, exchange *oauth2.Token, _ ur return nil, errors.WithStack(ErrIDTokenMissing) } - claims, err := n.verifyAndDecodeClaimsWithProvider(ctx, p, raw) + dec, err := n.verifyAndDecodeClaimsWithProvider(ctx, p, raw) if err != nil { return nil, err } - var userinfo Claims + var userinfo claims.Claims if err := json.NewDecoder(resp.Body).Decode(&userinfo); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - userinfo.Issuer = claims.Issuer - userinfo.Subject = claims.Subject + userinfo.Issuer = dec.Issuer + userinfo.Subject = dec.Subject return &userinfo, nil } -func (n *ProviderNetID) Verify(ctx context.Context, rawIDToken string) (*Claims, error) { +func (n *ProviderNetID) Verify(ctx context.Context, rawIDToken string) (*claims.Claims, error) { provider, err := n.provider(ctx) if err != nil { return nil, err @@ -154,7 +155,7 @@ func (n *ProviderNetID) Verify(ctx context.Context, rawIDToken string) (*Claims, } var ( - claims Claims + claims claims.Claims rawClaims map[string]any ) diff --git a/selfservice/strategy/oidc/provider_patreon.go b/selfservice/strategy/oidc/provider_patreon.go index d89e1e2a3ebc..b73eee901297 100644 --- a/selfservice/strategy/oidc/provider_patreon.go +++ b/selfservice/strategy/oidc/provider_patreon.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/go-retryablehttp" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/pkg/errors" @@ -81,7 +82,7 @@ func (d *ProviderPatreon) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { } } -func (d *ProviderPatreon) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (d *ProviderPatreon) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { identityUrl := "https://www.patreon.com/api/oauth2/v2/identity?fields%5Buser%5D=first_name,last_name,url,full_name,email,image_url" o := d.oauth2(ctx) @@ -109,7 +110,7 @@ func (d *ProviderPatreon) Claims(ctx context.Context, exchange *oauth2.Token, qu return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", jsonErr)) } - claims := &Claims{ + claims := &claims.Claims{ Issuer: "https://www.patreon.com/", Subject: data.Data.Id, Name: data.Data.Attributes.FullName, diff --git a/selfservice/strategy/oidc/provider_salesforce.go b/selfservice/strategy/oidc/provider_salesforce.go index 04d514ccdf22..09c9df965a24 100644 --- a/selfservice/strategy/oidc/provider_salesforce.go +++ b/selfservice/strategy/oidc/provider_salesforce.go @@ -11,6 +11,7 @@ import ( "path" "time" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/x/stringsx" @@ -73,7 +74,7 @@ func (g *ProviderSalesforce) OAuth2(ctx context.Context) (*oauth2.Config, error) return g.oauth2(ctx) } -func (g *ProviderSalesforce) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderSalesforce) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -115,7 +116,7 @@ func (g *ProviderSalesforce) Claims(ctx context.Context, exchange *oauth2.Token, } // Once we get here, we know that if there is an updated_at field in the json, it is the correct type. - var claims Claims + var claims claims.Claims if err := json.Unmarshal(b, &claims); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } diff --git a/selfservice/strategy/oidc/provider_slack.go b/selfservice/strategy/oidc/provider_slack.go index 0faed2220ae5..c4234ff3d297 100644 --- a/selfservice/strategy/oidc/provider_slack.go +++ b/selfservice/strategy/oidc/provider_slack.go @@ -9,6 +9,7 @@ import ( "net/url" "github.com/ory/herodot" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/pkg/errors" "golang.org/x/oauth2" @@ -63,7 +64,7 @@ func (d *ProviderSlack) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return []oauth2.AuthCodeOption{} } -func (d *ProviderSlack) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (d *ProviderSlack) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), ",") for _, check := range d.Config().Scope { if !stringslice.Has(grantedScopes, check) { @@ -77,7 +78,7 @@ func (d *ProviderSlack) Claims(ctx context.Context, exchange *oauth2.Token, quer return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) } - claims := &Claims{ + claims := &claims.Claims{ Issuer: "https://slack.com/oauth/", Subject: identity.User.ID, Name: identity.User.Name, diff --git a/selfservice/strategy/oidc/provider_spotify.go b/selfservice/strategy/oidc/provider_spotify.go index 2c01d0764b3c..f4618e3d8bac 100644 --- a/selfservice/strategy/oidc/provider_spotify.go +++ b/selfservice/strategy/oidc/provider_spotify.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "golang.org/x/oauth2" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/x/stringslice" "github.com/ory/x/stringsx" @@ -62,7 +63,7 @@ func (g *ProviderSpotify) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption { return []oauth2.AuthCodeOption{} } -func (g *ProviderSpotify) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderSpotify) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { grantedScopes := stringsx.Splitx(fmt.Sprintf("%s", exchange.Extra("scope")), " ") for _, check := range g.Config().Scope { if !stringslice.Has(grantedScopes, check) { @@ -87,7 +88,7 @@ func (g *ProviderSpotify) Claims(ctx context.Context, exchange *oauth2.Token, qu userPicture = user.Images[0].URL } - claims := &Claims{ + claims := &claims.Claims{ Subject: user.ID, Issuer: spotify.Endpoint.TokenURL, Name: user.DisplayName, diff --git a/selfservice/strategy/oidc/provider_test.go b/selfservice/strategy/oidc/provider_test.go index 208421ad2ab0..ba18644034d4 100644 --- a/selfservice/strategy/oidc/provider_test.go +++ b/selfservice/strategy/oidc/provider_test.go @@ -11,16 +11,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" -) -func TestClaimsValidate(t *testing.T) { - require.Error(t, new(Claims).Validate()) - require.Error(t, (&Claims{Issuer: "not-empty"}).Validate()) - require.Error(t, (&Claims{Issuer: "not-empty"}).Validate()) - require.Error(t, (&Claims{Subject: "not-empty"}).Validate()) - require.Error(t, (&Claims{Subject: "not-empty"}).Validate()) - require.NoError(t, (&Claims{Issuer: "not-empty", Subject: "not-empty"}).Validate()) -} + "github.com/ory/kratos/selfservice/strategy/oidc/claims" +) type TestProvider struct { *ProviderGenericOIDC @@ -43,11 +36,11 @@ func RegisterTestProvider(t *testing.T, id string) { var _ IDTokenVerifier = new(TestProvider) -func (t *TestProvider) Verify(_ context.Context, token string) (*Claims, error) { +func (t *TestProvider) Verify(_ context.Context, token string) (*claims.Claims, error) { if token == "error" { return nil, fmt.Errorf("stub error") } - c := Claims{} + c := claims.Claims{} if err := json.Unmarshal([]byte(token), &c); err != nil { return nil, err } @@ -95,7 +88,7 @@ func TestLocale(t *testing.T) { expected: "", }} { t.Run(tc.name, func(t *testing.T) { - var c Claims + var c claims.Claims err := json.Unmarshal([]byte(tc.json), &c) if tc.assertErr != nil { tc.assertErr(t, err) diff --git a/selfservice/strategy/oidc/provider_test_fedcm.go b/selfservice/strategy/oidc/provider_test_fedcm.go index 5ea002faa74b..aef7912fc6ad 100644 --- a/selfservice/strategy/oidc/provider_test_fedcm.go +++ b/selfservice/strategy/oidc/provider_test_fedcm.go @@ -7,6 +7,8 @@ import ( "context" "github.com/golang-jwt/jwt/v5" + + "github.com/ory/kratos/selfservice/strategy/oidc/claims" ) // ProviderTestFedcm is a mock provider to test FedCM. @@ -28,9 +30,9 @@ func NewProviderTestFedcm( } } -func (g *ProviderTestFedcm) Verify(_ context.Context, rawIDToken string) (claims *Claims, err error) { +func (g *ProviderTestFedcm) Verify(_ context.Context, rawIDToken string) (c *claims.Claims, err error) { rawClaims := &struct { - Claims + claims.Claims jwt.MapClaims }{} _, err = jwt.ParseWithClaims(rawIDToken, rawClaims, func(token *jwt.Token) (interface{}, error) { diff --git a/selfservice/strategy/oidc/provider_userinfo_test.go b/selfservice/strategy/oidc/provider_userinfo_test.go index 9eb27914541e..4454fea5cc01 100644 --- a/selfservice/strategy/oidc/provider_userinfo_test.go +++ b/selfservice/strategy/oidc/provider_userinfo_test.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/driver/config" "github.com/ory/kratos/internal" "github.com/ory/kratos/selfservice/strategy/oidc" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/x/otelx" @@ -45,7 +46,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { ctx := context.Background() token := &oauth2.Token{AccessToken: "foo", Expiry: time.Now().Add(time.Hour)} - expectedClaims := &oidc.Claims{ + expectedClaims := &claims.Claims{ Issuer: "ignore-me", Subject: "123456789012345", Name: "John Doe", @@ -75,7 +76,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { config *oidc.Configuration provider oidc.Provider userInfoHandler func(req *http.Request) (*http.Response, error) - expectedClaims *oidc.Claims + expectedClaims *claims.Claims useToken *oauth2.Token hook func(t *testing.T) }{ @@ -135,7 +136,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { }, ) }, - expectedClaims: &oidc.Claims{Issuer: "https://broker.netid.de/", Subject: "1234567890", Name: "John Doe", GivenName: "John", FamilyName: "Doe", LastName: "", MiddleName: "", Nickname: "John Doe", PreferredUsername: "John Doe", Profile: "", Picture: "", Website: "", Email: "john.doe@example.com", EmailVerified: true, Gender: "", Birthdate: "01/01/1990", Zoneinfo: "", Locale: "", PhoneNumber: "", PhoneNumberVerified: false, UpdatedAt: 0, HD: "", Team: ""}, + expectedClaims: &claims.Claims{Issuer: "https://broker.netid.de/", Subject: "1234567890", Name: "John Doe", GivenName: "John", FamilyName: "Doe", LastName: "", MiddleName: "", Nickname: "John Doe", PreferredUsername: "John Doe", Profile: "", Picture: "", Website: "", Email: "john.doe@example.com", EmailVerified: true, Gender: "", Birthdate: "01/01/1990", Zoneinfo: "", Locale: "", PhoneNumber: "", PhoneNumberVerified: false, UpdatedAt: 0, HD: "", Team: ""}, }, { name: "vk", @@ -158,7 +159,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { return resp, err }, - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://api.vk.com/method/users.get", Subject: "123456789012345", Email: "john.doe@example.com", @@ -186,7 +187,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { return resp, err }, - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://login.yandex.ru/info", Subject: "123456789012345", Email: "john.doe@example.com", @@ -231,7 +232,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { }) return resp, err }, - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://graph.facebook.com/me?fields=id,name,first_name,last_name,middle_name,email,picture,birthday,gender&appsecret_proof=0c0d98f7e3d9d45e72e8877bc1b104327efb9c07b18f2ffeced76d81307f1fff", Subject: "123456789012345", Name: "John Doe", @@ -302,7 +303,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { }, ) }, - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://login.microsoftonline.com/a9b86385-f32c-4803-afc8-4b2312fbdf24/v2.0", Subject: "new-id", Name: "John Doe", Email: "john.doe@example.com", RawClaims: map[string]interface{}{"aud": []interface{}{"foo"}, "exp": 4.071728504e+09, "iat": 1.516239022e+09, "iss": "https://login.microsoftonline.com/a9b86385-f32c-4803-afc8-4b2312fbdf24/v2.0", "email": "john.doe@example.com", "name": "John Doe", "sub": "1234567890", "tid": "a9b86385-f32c-4803-afc8-4b2312fbdf24"}, }, @@ -327,7 +328,7 @@ func TestProviderClaimsRespectsErrorCodes(t *testing.T) { ID: "dingtalk", Provider: "dingtalk", }, reg), - expectedClaims: &oidc.Claims{ + expectedClaims: &claims.Claims{ Issuer: "https://api.dingtalk.com/v1.0/contact/users/me", Subject: "123456789012345", Email: "john.doe@example.com", diff --git a/selfservice/strategy/oidc/provider_vk.go b/selfservice/strategy/oidc/provider_vk.go index c60711504fd3..4e838d5e58a5 100644 --- a/selfservice/strategy/oidc/provider_vk.go +++ b/selfservice/strategy/oidc/provider_vk.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/go-retryablehttp" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/pkg/errors" @@ -61,7 +62,7 @@ func (g *ProviderVK) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2(ctx), nil } -func (g *ProviderVK) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderVK) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -120,7 +121,7 @@ func (g *ProviderVK) Claims(ctx context.Context, exchange *oauth2.Token, query u gender = "male" } - return &Claims{ + return &claims.Claims{ Issuer: "https://api.vk.com/method/users.get", Subject: strconv.Itoa(user.Id), GivenName: user.FirstName, diff --git a/selfservice/strategy/oidc/provider_x.go b/selfservice/strategy/oidc/provider_x.go index ca2acb6c5e25..cd6012965164 100644 --- a/selfservice/strategy/oidc/provider_x.go +++ b/selfservice/strategy/oidc/provider_x.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/otelx" "github.com/dghubble/oauth1" @@ -20,8 +21,10 @@ import ( var _ OAuth1Provider = (*ProviderX)(nil) -const xUserInfoBase = "https://api.twitter.com/1.1/account/verify_credentials.json" -const xUserInfoWithEmail = xUserInfoBase + "?include_email=true" +const ( + xUserInfoBase = "https://api.twitter.com/1.1/account/verify_credentials.json" + xUserInfoWithEmail = xUserInfoBase + "?include_email=true" +) type ProviderX struct { config *Configuration @@ -34,7 +37,8 @@ func (p *ProviderX) Config() *Configuration { func NewProviderX( config *Configuration, - reg Dependencies) Provider { + reg Dependencies, +) Provider { return &ProviderX{ config: config, reg: reg, @@ -106,7 +110,7 @@ func (p *ProviderX) userInfoEndpoint() string { return xUserInfoBase } -func (p *ProviderX) Claims(ctx context.Context, token *oauth1.Token) (*Claims, error) { +func (p *ProviderX) Claims(ctx context.Context, token *oauth1.Token) (*claims.Claims, error) { ctx = context.WithValue(ctx, oauth1.HTTPClient, p.reg.HTTPClient(ctx).HTTPClient) c := p.OAuth1(ctx) @@ -133,7 +137,7 @@ func (p *ProviderX) Claims(ctx context.Context, token *oauth1.Token) (*Claims, e website = *user.URL } - return &Claims{ + return &claims.Claims{ Issuer: endpoint, Subject: user.IDStr, Name: user.Name, diff --git a/selfservice/strategy/oidc/provider_yandex.go b/selfservice/strategy/oidc/provider_yandex.go index 9b11b8fbcf5e..9f99aba81762 100644 --- a/selfservice/strategy/oidc/provider_yandex.go +++ b/selfservice/strategy/oidc/provider_yandex.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" "golang.org/x/oauth2" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/x/httpx" "github.com/ory/herodot" @@ -59,7 +60,7 @@ func (g *ProviderYandex) OAuth2(ctx context.Context) (*oauth2.Config, error) { return g.oauth2(ctx), nil } -func (g *ProviderYandex) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) { +func (g *ProviderYandex) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*claims.Claims, error) { o, err := g.OAuth2(ctx) if err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err)) @@ -102,7 +103,7 @@ func (g *ProviderYandex) Claims(ctx context.Context, exchange *oauth2.Token, que user.Picture = "" } - return &Claims{ + return &claims.Claims{ Issuer: "https://login.yandex.ru/info", Subject: user.Id, GivenName: user.FirstName, diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index 62867b886a7a..ac70ebb0b3e6 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -41,6 +41,7 @@ import ( "github.com/ory/kratos/selfservice/flow/settings" "github.com/ory/kratos/selfservice/sessiontokenexchange" "github.com/ory/kratos/selfservice/strategy" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/text" "github.com/ory/kratos/ui/container" @@ -148,7 +149,7 @@ type Strategy struct { conflictingIdentityPolicy ConflictingIdentityPolicy } -type ConflictingIdentityPolicy func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, provider Provider, claims *Claims) ConflictingIdentityVerdict +type ConflictingIdentityPolicy func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, provider Provider, claims *claims.Claims) ConflictingIdentityVerdict type AuthCodeContainer struct { FlowID string `json:"flow_id"` @@ -475,7 +476,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt return } - var claims *Claims + var claims *claims.Claims var et *identity.CredentialsOIDCEncryptedTokens switch p := provider.(type) { case OAuth2Provider: @@ -801,7 +802,7 @@ func (s *Strategy) CompletedAuthenticationMethod(context.Context) session.Authen } } -func (s *Strategy) ProcessIDToken(r *http.Request, provider Provider, idToken, idTokenNonce string) (*Claims, error) { +func (s *Strategy) ProcessIDToken(r *http.Request, provider Provider, idToken, idTokenNonce string) (*claims.Claims, error) { verifier, ok := provider.(IDTokenVerifier) if !ok { return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The provider %s does not support id_token verification", provider.Config().Provider)) diff --git a/selfservice/strategy/oidc/strategy_helper_test.go b/selfservice/strategy/oidc/strategy_helper_test.go index 2c25e7435ff0..c354fb653224 100644 --- a/selfservice/strategy/oidc/strategy_helper_test.go +++ b/selfservice/strategy/oidc/strategy_helper_test.go @@ -389,7 +389,7 @@ var publicJWKS []byte //go:embed stub/jwks_public2.json var publicJWKS2 []byte -type claims struct { +type jwtClaims struct { *jwt.RegisteredClaims Email string `json:"email"` } @@ -397,7 +397,7 @@ type claims struct { func createIdToken(t *testing.T, cl jwt.RegisteredClaims) string { key := &jwk.KeySpec{} require.NoError(t, json.Unmarshal(rawKey, key)) - token := jwt.NewWithClaims(jwt.SigningMethodRS256, &claims{ + token := jwt.NewWithClaims(jwt.SigningMethodRS256, &jwtClaims{ RegisteredClaims: &cl, Email: "acme@ory.sh", }) diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 23ec8a2e2e13..88d1f6b1b7cd 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -22,6 +22,7 @@ import ( "github.com/ory/kratos/selfservice/flow/registration" "github.com/ory/kratos/selfservice/flowhelpers" "github.com/ory/kratos/selfservice/strategy/idfirst" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/text" "github.com/ory/kratos/ui/node" @@ -99,7 +100,7 @@ type UpdateLoginFlowWithOidcMethod struct { TransientPayload json.RawMessage `json:"transient_payload,omitempty" form:"transient_payload"` } -func (s *Strategy) handleConflictingIdentity(ctx context.Context, w http.ResponseWriter, r *http.Request, loginFlow *login.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider, container *AuthCodeContainer) (verdict ConflictingIdentityVerdict, id *identity.Identity, credentials *identity.Credentials, err error) { +func (s *Strategy) handleConflictingIdentity(ctx context.Context, w http.ResponseWriter, r *http.Request, loginFlow *login.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider, container *AuthCodeContainer) (verdict ConflictingIdentityVerdict, id *identity.Identity, credentials *identity.Credentials, err error) { if s.conflictingIdentityPolicy == nil { return ConflictingIdentityVerdictReject, nil, nil, nil } @@ -159,7 +160,7 @@ func (s *Strategy) handleConflictingIdentity(ctx context.Context, w http.Respons return verdict, existingIdentity, creds, nil } -func (s *Strategy) ProcessLogin(ctx context.Context, w http.ResponseWriter, r *http.Request, loginFlow *login.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider, container *AuthCodeContainer) (_ *registration.Flow, err error) { +func (s *Strategy) ProcessLogin(ctx context.Context, w http.ResponseWriter, r *http.Request, loginFlow *login.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider, container *AuthCodeContainer) (_ *registration.Flow, err error) { ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.oidc.Strategy.processLogin") defer otelx.End(span, &err) @@ -247,7 +248,7 @@ func (s *Strategy) ProcessLogin(ctx context.Context, w http.ResponseWriter, r *h for _, c := range oidcCredentials.Providers { if c.Subject == claims.Subject && c.Provider == provider.Config().ID { - if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, provider.Config().ID); err != nil { + if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, provider.Config().ID, login.WithClaims(claims)); err != nil { return nil, s.HandleError(ctx, w, r, loginFlow, provider.Config().ID, nil, err) } return nil, nil diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index f6ba6d05af2f..67f57e8e51d9 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -25,6 +25,7 @@ import ( "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/registration" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/text" "github.com/ory/kratos/x" "github.com/ory/kratos/x/events" @@ -34,8 +35,10 @@ import ( "github.com/ory/x/sqlxx" ) -var _ registration.Strategy = new(Strategy) -var _ registration.FormHydrator = new(Strategy) +var ( + _ registration.Strategy = new(Strategy) + _ registration.FormHydrator = new(Strategy) +) var jsonnetCache, _ = ristretto.NewCache(&ristretto.Config[[]byte, []byte]{ MaxCost: 100 << 20, // 100MB, @@ -295,7 +298,7 @@ func (s *Strategy) registrationToLogin(ctx context.Context, w http.ResponseWrite return lf, nil } -func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWriter, r *http.Request, rf *registration.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider, container *AuthCodeContainer) (_ *login.Flow, err error) { +func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWriter, r *http.Request, rf *registration.Flow, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider, container *AuthCodeContainer) (_ *login.Flow, err error) { ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.oidc.Strategy.processRegistration") defer otelx.End(span, &err) @@ -360,7 +363,7 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite return nil, nil } -func (s *Strategy) newIdentityFromClaims(ctx context.Context, claims *Claims, provider Provider, container *AuthCodeContainer) (_ *identity.Identity, _ []VerifiedAddress, err error) { +func (s *Strategy) newIdentityFromClaims(ctx context.Context, claims *claims.Claims, provider Provider, container *AuthCodeContainer) (_ *identity.Identity, _ []VerifiedAddress, err error) { fetch := fetcher.NewFetcher(fetcher.WithClient(s.d.HTTPClient(ctx)), fetcher.WithCache(jsonnetCache, 60*time.Minute)) jsonnetSnippet, err := fetch.FetchContext(ctx, provider.Config().Mapper) if err != nil { diff --git a/selfservice/strategy/oidc/strategy_settings.go b/selfservice/strategy/oidc/strategy_settings.go index 66dd8fb9876c..412e7492cf98 100644 --- a/selfservice/strategy/oidc/strategy_settings.go +++ b/selfservice/strategy/oidc/strategy_settings.go @@ -29,6 +29,7 @@ import ( "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/settings" "github.com/ory/kratos/selfservice/strategy" + "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/session" "github.com/ory/kratos/x" ) @@ -402,7 +403,7 @@ func (s *Strategy) initLinkProvider(ctx context.Context, w http.ResponseWriter, return errors.WithStack(flow.ErrCompletedByStrategy) } -func (s *Strategy) linkProvider(ctx context.Context, w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, token *identity.CredentialsOIDCEncryptedTokens, claims *Claims, provider Provider) error { +func (s *Strategy) linkProvider(ctx context.Context, w http.ResponseWriter, r *http.Request, ctxUpdate *settings.UpdateContext, token *identity.CredentialsOIDCEncryptedTokens, claims *claims.Claims, provider Provider) error { p := &updateSettingsFlowWithOidcMethod{ Link: provider.Config().ID, FlowID: ctxUpdate.Flow.ID.String(), } diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 7672dda900d1..d02271810fe2 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -58,6 +58,7 @@ import ( "github.com/ory/kratos/selfservice/flow/registration" "github.com/ory/kratos/selfservice/strategy/oidc" + oidcclaims "github.com/ory/kratos/selfservice/strategy/oidc/claims" "github.com/ory/kratos/x" ) @@ -1690,13 +1691,12 @@ func TestStrategy(t *testing.T) { }) t.Run("suite=auto link policy", func(t *testing.T) { - t.Run("case=should automatically link credential if policy says so", func(t *testing.T) { subject = "user-in-org@ory.sh" scope = []string{"openid"} reg.AllLoginStrategies().MustStrategy("oidc").(*oidc.Strategy).SetOnConflictingIdentity(t, - func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, _ oidc.Provider, _ *oidc.Claims) oidc.ConflictingIdentityVerdict { + func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, _ oidc.Provider, _ *oidcclaims.Claims) oidc.ConflictingIdentityVerdict { return oidc.ConflictingIdentityVerdictMerge }) @@ -1730,7 +1730,7 @@ func TestStrategy(t *testing.T) { scope = []string{"openid"} reg.AllLoginStrategies().MustStrategy("oidc").(*oidc.Strategy).SetOnConflictingIdentity(t, - func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, _ oidc.Provider, _ *oidc.Claims) oidc.ConflictingIdentityVerdict { + func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, _ oidc.Provider, _ *oidcclaims.Claims) oidc.ConflictingIdentityVerdict { return oidc.ConflictingIdentityVerdictMerge }) diff --git a/selfservice/strategy/oidc/token_verifier.go b/selfservice/strategy/oidc/token_verifier.go index 42b16767a041..7d79f60104bb 100644 --- a/selfservice/strategy/oidc/token_verifier.go +++ b/selfservice/strategy/oidc/token_verifier.go @@ -9,9 +9,11 @@ import ( "strings" "github.com/coreos/go-oidc/v3/oidc" + + "github.com/ory/kratos/selfservice/strategy/oidc/claims" ) -func verifyToken(ctx context.Context, keySet oidc.KeySet, config *Configuration, rawIDToken, issuerURL string) (*Claims, error) { +func verifyToken(ctx context.Context, keySet oidc.KeySet, config *Configuration, rawIDToken, issuerURL string) (*claims.Claims, error) { tokenAudiences := append([]string{config.ClientID}, config.AdditionalIDTokenAudiences...) var token *oidc.IDToken err := fmt.Errorf("no audience matched the token's audience") @@ -34,7 +36,7 @@ func verifyToken(ctx context.Context, keySet oidc.KeySet, config *Configuration, // None of the allowed audiences matched the audience in the token return nil, fmt.Errorf("token audience didn't match allowed audiences: %+v %w", tokenAudiences, err) } - claims := &Claims{} + claims := &claims.Claims{} var rawClaims map[string]any if token == nil { From c6b85fc3fdbc621def85f40f6c4883bcd9fda84e Mon Sep 17 00:00:00 2001 From: Tom Fenech Date: Tue, 14 May 2024 16:57:10 +0200 Subject: [PATCH 2/2] chore: add license comment to new files --- selfservice/strategy/oidc/claims/claims.go | 3 +++ selfservice/strategy/oidc/claims/claims_test.go | 3 +++ selfservice/strategy/oidc/claims/locale.go | 3 +++ 3 files changed, 9 insertions(+) diff --git a/selfservice/strategy/oidc/claims/claims.go b/selfservice/strategy/oidc/claims/claims.go index 924b8eee6738..e042be95272c 100644 --- a/selfservice/strategy/oidc/claims/claims.go +++ b/selfservice/strategy/oidc/claims/claims.go @@ -1,3 +1,6 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + package claims import ( diff --git a/selfservice/strategy/oidc/claims/claims_test.go b/selfservice/strategy/oidc/claims/claims_test.go index 47ada4a4695c..8c532eab569c 100644 --- a/selfservice/strategy/oidc/claims/claims_test.go +++ b/selfservice/strategy/oidc/claims/claims_test.go @@ -1,3 +1,6 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + package claims_test import ( diff --git a/selfservice/strategy/oidc/claims/locale.go b/selfservice/strategy/oidc/claims/locale.go index 07fb8576becc..226737505170 100644 --- a/selfservice/strategy/oidc/claims/locale.go +++ b/selfservice/strategy/oidc/claims/locale.go @@ -1,3 +1,6 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + package claims import (