From 2df726dd37c1d6be91538e308f8b7a31b2a46142 Mon Sep 17 00:00:00 2001 From: eslizn Date: Sat, 23 Nov 2019 19:57:15 +0800 Subject: [PATCH 1/2] extend send confirm handler --- providers/password/confirm.go | 31 ++++++++++++++++++++++++++++++ providers/password/password.go | 35 +++++++++------------------------- 2 files changed, 40 insertions(+), 26 deletions(-) diff --git a/providers/password/confirm.go b/providers/password/confirm.go index 72ff442..1444456 100644 --- a/providers/password/confirm.go +++ b/providers/password/confirm.go @@ -6,6 +6,7 @@ import ( "net/mail" "path" "reflect" + "strings" "time" "github.com/qor/auth" @@ -61,6 +62,36 @@ var DefaultConfirmationMailer = func(email string, context *auth.Context, claims })) } +// DefaultConfirmHandler default send confirm handler +var DefaultSendConfirmHandler = func(context *auth.Context) error { + var ( + currentUser interface{} + authInfo auth_identity.Basic + provider, _ = context.Provider.(*Provider) + req = context.Request + tx = context.Auth.GetDB(req) + err error + ) + req.ParseForm() + authInfo.Provider = provider.GetName() + authInfo.UID = strings.TrimSpace(req.Form.Get("email")) + if tx.Model(context.Auth.AuthIdentityModel).Where(authInfo).Scan(&authInfo).RecordNotFound() { + return auth.ErrInvalidAccount + } + + if currentUser, err = context.Auth.UserStorer.Get(authInfo.ToClaims(), context); err != nil { + return err + } + + if err = provider.Config.ConfirmMailer(authInfo.UID, context, authInfo.ToClaims(), currentUser); err != nil { + return err + } + + context.SessionStorer.Flash(context.Writer, req, session.Message{Message: ConfirmFlashMessage, Type: "success"}) + context.Auth.Redirector.Redirect(context.Writer, context.Request, "send_confirmation") + return nil +} + // DefaultConfirmHandler default confirm handler var DefaultConfirmHandler = func(context *auth.Context) error { var ( diff --git a/providers/password/password.go b/providers/password/password.go index d47b3d9..09b418a 100644 --- a/providers/password/password.go +++ b/providers/password/password.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/qor/auth" - "github.com/qor/auth/auth_identity" "github.com/qor/auth/claims" "github.com/qor/auth/providers/password/encryptor" "github.com/qor/auth/providers/password/encryptor/bcrypt_encryptor" @@ -15,9 +14,10 @@ import ( // Config password config type Config struct { - Confirmable bool - ConfirmMailer func(email string, context *auth.Context, claims *claims.Claims, currentUser interface{}) error - ConfirmHandler func(*auth.Context) error + Confirmable bool + ConfirmMailer func(email string, context *auth.Context, claims *claims.Claims, currentUser interface{}) error + SendConfirmHandler func(*auth.Context) error + ConfirmHandler func(*auth.Context) error ResetPasswordMailer func(email string, context *auth.Context, claims *claims.Claims, currentUser interface{}) error ResetPasswordHandler func(*auth.Context) error @@ -44,6 +44,10 @@ func New(config *Config) *Provider { config.ConfirmMailer = DefaultConfirmationMailer } + if config.SendConfirmHandler == nil { + config.SendConfirmHandler = DefaultSendConfirmHandler + } + if config.ConfirmHandler == nil { config.ConfirmHandler = DefaultConfirmHandler } @@ -128,28 +132,7 @@ func (provider Provider) ServeHTTP(context *auth.Context) { // render new confirmation page context.Auth.Config.Render.Execute("auth/confirmation/new", context, context.Request, context.Writer) case "send": - var ( - currentUser interface{} - authInfo auth_identity.Basic - tx = context.Auth.GetDB(req) - ) - - authInfo.Provider = provider.GetName() - authInfo.UID = strings.TrimSpace(req.Form.Get("email")) - if tx.Model(context.Auth.AuthIdentityModel).Where(authInfo).Scan(&authInfo).RecordNotFound() { - err = auth.ErrInvalidAccount - } - - if err == nil { - if currentUser, err = context.Auth.UserStorer.Get(authInfo.ToClaims(), context); err == nil { - err = provider.Config.ConfirmMailer(authInfo.UID, context, authInfo.ToClaims(), currentUser) - } - } - - if err == nil { - context.SessionStorer.Flash(context.Writer, req, session.Message{Message: ConfirmFlashMessage, Type: "success"}) - context.Auth.Redirector.Redirect(context.Writer, context.Request, "send_confirmation") - } + err = provider.SendConfirmHandler(context) } } From 1b0f4e5ec42b860f51a535bf4ba6293a2582dd3e Mon Sep 17 00:00:00 2001 From: eslizn Date: Sat, 23 Nov 2019 19:59:28 +0800 Subject: [PATCH 2/2] fix multiple render confirmation/new --- providers/password/password.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/providers/password/password.go b/providers/password/password.go index 09b418a..4eeedd3 100644 --- a/providers/password/password.go +++ b/providers/password/password.go @@ -125,22 +125,17 @@ func (provider Provider) ServeHTTP(context *auth.Context) { switch paths[1] { case "confirmation": var err error - if len(paths) >= 3 { switch paths[2] { - case "new": - // render new confirmation page - context.Auth.Config.Render.Execute("auth/confirmation/new", context, context.Request, context.Writer) case "send": err = provider.SendConfirmHandler(context) + default: + err = context.Auth.Config.Render.Execute("auth/confirmation/new", context, context.Request, context.Writer) } } - if err != nil { context.SessionStorer.Flash(context.Writer, req, session.Message{Message: template.HTML(err.Error()), Type: "error"}) } - // render new confirmation page - context.Auth.Config.Render.Execute("auth/confirmation/new", context, context.Request, context.Writer) case "confirm": // confirm user err := provider.ConfirmHandler(context)