Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions providers/password/confirm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/mail"
"path"
"reflect"
"strings"
"time"

"github.com/qor/auth"
Expand Down Expand Up @@ -61,6 +62,36 @@ var DefaultConfirmationMailer = func(email string, context *auth.Context, claims
}))
}

// DefaultConfirmHandler default send confirm handler
var DefaultSendConfirmHandler = func(context *auth.Context) error {
var (
currentUser interface{}
authInfo auth_identity.Basic
provider, _ = context.Provider.(*Provider)
req = context.Request
tx = context.Auth.GetDB(req)
err error
)
req.ParseForm()
authInfo.Provider = provider.GetName()
authInfo.UID = strings.TrimSpace(req.Form.Get("email"))
if tx.Model(context.Auth.AuthIdentityModel).Where(authInfo).Scan(&authInfo).RecordNotFound() {
return auth.ErrInvalidAccount
}

if currentUser, err = context.Auth.UserStorer.Get(authInfo.ToClaims(), context); err != nil {
return err
}

if err = provider.Config.ConfirmMailer(authInfo.UID, context, authInfo.ToClaims(), currentUser); err != nil {
return err
}

context.SessionStorer.Flash(context.Writer, req, session.Message{Message: ConfirmFlashMessage, Type: "success"})
context.Auth.Redirector.Redirect(context.Writer, context.Request, "send_confirmation")
return nil
}

// DefaultConfirmHandler default confirm handler
var DefaultConfirmHandler = func(context *auth.Context) error {
var (
Expand Down
44 changes: 11 additions & 33 deletions providers/password/password.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"strings"

"github.com/qor/auth"
"github.com/qor/auth/auth_identity"
"github.com/qor/auth/claims"
"github.com/qor/auth/providers/password/encryptor"
"github.com/qor/auth/providers/password/encryptor/bcrypt_encryptor"
Expand All @@ -15,9 +14,10 @@ import (

// Config password config
type Config struct {
Confirmable bool
ConfirmMailer func(email string, context *auth.Context, claims *claims.Claims, currentUser interface{}) error
ConfirmHandler func(*auth.Context) error
Confirmable bool
ConfirmMailer func(email string, context *auth.Context, claims *claims.Claims, currentUser interface{}) error
SendConfirmHandler func(*auth.Context) error
ConfirmHandler func(*auth.Context) error

ResetPasswordMailer func(email string, context *auth.Context, claims *claims.Claims, currentUser interface{}) error
ResetPasswordHandler func(*auth.Context) error
Expand All @@ -44,6 +44,10 @@ func New(config *Config) *Provider {
config.ConfirmMailer = DefaultConfirmationMailer
}

if config.SendConfirmHandler == nil {
config.SendConfirmHandler = DefaultSendConfirmHandler
}

if config.ConfirmHandler == nil {
config.ConfirmHandler = DefaultConfirmHandler
}
Expand Down Expand Up @@ -121,43 +125,17 @@ func (provider Provider) ServeHTTP(context *auth.Context) {
switch paths[1] {
case "confirmation":
var err error

if len(paths) >= 3 {
switch paths[2] {
case "new":
// render new confirmation page
context.Auth.Config.Render.Execute("auth/confirmation/new", context, context.Request, context.Writer)
case "send":
var (
currentUser interface{}
authInfo auth_identity.Basic
tx = context.Auth.GetDB(req)
)

authInfo.Provider = provider.GetName()
authInfo.UID = strings.TrimSpace(req.Form.Get("email"))
if tx.Model(context.Auth.AuthIdentityModel).Where(authInfo).Scan(&authInfo).RecordNotFound() {
err = auth.ErrInvalidAccount
}

if err == nil {
if currentUser, err = context.Auth.UserStorer.Get(authInfo.ToClaims(), context); err == nil {
err = provider.Config.ConfirmMailer(authInfo.UID, context, authInfo.ToClaims(), currentUser)
}
}

if err == nil {
context.SessionStorer.Flash(context.Writer, req, session.Message{Message: ConfirmFlashMessage, Type: "success"})
context.Auth.Redirector.Redirect(context.Writer, context.Request, "send_confirmation")
}
err = provider.SendConfirmHandler(context)
default:
err = context.Auth.Config.Render.Execute("auth/confirmation/new", context, context.Request, context.Writer)
}
}

if err != nil {
context.SessionStorer.Flash(context.Writer, req, session.Message{Message: template.HTML(err.Error()), Type: "error"})
}
// render new confirmation page
context.Auth.Config.Render.Execute("auth/confirmation/new", context, context.Request, context.Writer)
case "confirm":
// confirm user
err := provider.ConfirmHandler(context)
Expand Down