diff --git a/README.md b/README.md index b48a3469e..08ac96247 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ $ go get github.com/markbates/goth * OpenID Connect (auto discovery) * Oura * Patreon +* Password Grant * Paypal * SalesForce * Shopify diff --git a/examples/main.go b/examples/main.go index 6f421d81a..966bb4f5c 100644 --- a/examples/main.go +++ b/examples/main.go @@ -1,6 +1,7 @@ package main import ( + "errors" "fmt" "html/template" "log" @@ -21,6 +22,7 @@ import ( "github.com/markbates/goth/providers/dailymotion" "github.com/markbates/goth/providers/deezer" "github.com/markbates/goth/providers/digitalocean" + "github.com/markbates/goth/providers/direct" "github.com/markbates/goth/providers/discord" "github.com/markbates/goth/providers/dropbox" "github.com/markbates/goth/providers/eveonline" @@ -159,6 +161,30 @@ func main() { goth.UseProviders(openidConnect) } + var userFetcher = func(email string) (goth.User, error) { + if email != "john@doe.com" { + return goth.User{}, errors.New("user not found") + } + + // Possible to associate the token with the user + return goth.User{ + Email: "john@doe.com", + FirstName: "John", + LastName: "Doe", + NickName: "JD", + UserID: "123456789", + Provider: "direct", + }, nil + } + var credChecker = func(email, password string) error { + if email == "john@doe.com" && password == "password" { + return nil + } + return errors.New("invalid username or password") + } + directProvider := direct.New("/login", userFetcher, credChecker) + goth.UseProviders(directProvider) + m := map[string]string{ "amazon": "Amazon", "apple": "Apple", @@ -170,6 +196,7 @@ func main() { "dailymotion": "Dailymotion", "deezer": "Deezer", "digitalocean": "Digital Ocean", + "direct": "Password Grant Flow", "discord": "Discord", "dropbox": "Dropbox", "eveonline": "Eve Online", @@ -231,12 +258,14 @@ func main() { p := pat.New() p.Get("/auth/{provider}/callback", func(res http.ResponseWriter, req *http.Request) { - user, err := gothic.CompleteUserAuth(res, req) if err != nil { fmt.Fprintln(res, err) return } + + // Here you can persist the user in your session store of choice + t, _ := template.New("foo").Parse(userTemplate) t.Execute(res, user) }) @@ -250,6 +279,9 @@ func main() { p.Get("/auth/{provider}", func(res http.ResponseWriter, req *http.Request) { // try to get the user without re-authenticating if gothUser, err := gothic.CompleteUserAuth(res, req); err == nil { + + // Here you can persist the user in your session store of choice + t, _ := template.New("foo").Parse(userTemplate) t.Execute(res, gothUser) } else { @@ -257,6 +289,22 @@ func main() { } }) + p.Post("/auth/direct", func(res http.ResponseWriter, req *http.Request) { + if gothUser, err := gothic.PasswordGrantAuth(res, req); err == nil { + t, _ := template.New("foo").Parse(userTemplate) + t.Execute(res, gothUser) + } else { + log.Println("error:", err) + res.Header().Set("Location", "/") + res.WriteHeader(http.StatusFound) + } + }) + + p.Get("/login", func(res http.ResponseWriter, req *http.Request) { + t, _ := template.New("foo").Parse(loginTemplate) + t.Execute(res, providerIndex) + }) + p.Get("/", func(res http.ResponseWriter, req *http.Request) { t, _ := template.New("foo").Parse(indexTemplate) t.Execute(res, providerIndex) @@ -271,9 +319,27 @@ type ProviderIndex struct { ProvidersMap map[string]string } -var indexTemplate = `{{range $key,$value:=.Providers}} +var loginTemplate = ` + + +
+ + + +
+ + +` + +var indexTemplate = ` + + +{{range $key,$value:=.Providers}}

Log in with {{index $.ProvidersMap $value}}

-{{end}}` +{{end}} + + +` var userTemplate = `

logout

diff --git a/gothic/gothic.go b/gothic/gothic.go index 3a814d9a6..8143f83b3 100644 --- a/gothic/gothic.go +++ b/gothic/gothic.go @@ -216,6 +216,16 @@ var CompleteUserAuth = func(res http.ResponseWriter, req *http.Request) (goth.Us return gu, err } +// PasswordGrantAuth is a helper function that make sure +// authentication happens via the "direct" provider +func PasswordGrantAuth(res http.ResponseWriter, req *http.Request) (goth.User, error) { + ctx := req.Context() + ctx = context.WithValue(ctx, ProviderParamKey, "direct") + req = req.WithContext(ctx) + + return CompleteUserAuth(res, req) +} + // validateState ensures that the state token param from the original // AuthURL matches the one included in the current (callback) request. func validateState(req *http.Request, sess goth.Session) error { diff --git a/providers/direct/direct.go b/providers/direct/direct.go new file mode 100644 index 000000000..e97af255d --- /dev/null +++ b/providers/direct/direct.go @@ -0,0 +1,90 @@ +package direct + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/markbates/goth" + "golang.org/x/oauth2" +) + +type UserFetcher func(email string) (goth.User, error) + +type CredChecker func(email, password string) error + +type Provider struct { + name string + debug bool + AuthURL string + UserFetcher + CredChecker +} + +func New(authUrl string, userFetcher UserFetcher, credChecker CredChecker) *Provider { + return &Provider{ + name: "direct", + AuthURL: authUrl, + UserFetcher: userFetcher, + CredChecker: credChecker, + } +} + +func (p *Provider) Name() string { + return p.name +} + +func (p *Provider) SetName(name string) { + p.name = name +} + +func (p *Provider) BeginAuth(state string) (goth.Session, error) { + return &Session{ + AuthURL: p.AuthURL, + }, nil +} + +func (p *Provider) UnmarshalSession(data string) (goth.Session, error) { + sess := &Session{} + err := json.NewDecoder(strings.NewReader(data)).Decode(sess) + return sess, err +} + +func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { + directSession := session.(*Session) + + if directSession.Email == "" { + // data is not yet retrieved since email is still empty + return goth.User{}, fmt.Errorf("%s cannot get user information without accessToken", p.name) + } + + user, err := p.UserFetcher(directSession.Email) + if err != nil { + return goth.User{}, err + } + + return user, nil +} + +func (p *Provider) Debug(debug bool) { + p.debug = debug +} + +func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) { + return nil, errors.New("refreshToken not supported for the password grant") +} + +func (p *Provider) RefreshTokenAvailable() bool { + return false +} + +func (p *Provider) IssueSession(email, password string) (goth.Session, error) { + if p.CredChecker(email, password) != nil { + return nil, errors.New("invalid username or password") + } + + return &Session{ + Email: email, + }, nil +} diff --git a/providers/direct/direct_test.go b/providers/direct/direct_test.go new file mode 100644 index 000000000..37dae8228 --- /dev/null +++ b/providers/direct/direct_test.go @@ -0,0 +1,87 @@ +package direct_test + +import ( + "errors" + "testing" + + "github.com/markbates/goth" + "github.com/markbates/goth/providers/direct" +) + +func TestDirectProvider(t *testing.T) { + users := map[string]goth.User{ + "test@example.com": { + Email: "test@example.com", + }, + } + + var userFetcher = func(email string) (goth.User, error) { + if user, ok := users[email]; ok { + return user, nil + } + return goth.User{}, errors.New("user not found") + } + var credChecker = func(email, password string) error { + if email == "test@example.com" && password == "password" { + return nil + } + return errors.New("invalid email or password") + } + p := direct.New("/login", userFetcher, credChecker) + + t.Run("Name", func(t *testing.T) { + if p.Name() != "direct" { + t.Errorf("expected provider name to be 'direct', got %s", p.Name()) + } + }) + + t.Run("SetName", func(t *testing.T) { + p.SetName("direct_custom") + if p.Name() != "direct_custom" { + t.Errorf("expected provider name to be 'direct_custom', got %s", p.Name()) + } + }) + + t.Run("IssueSession", func(t *testing.T) { + _, err := p.IssueSession("test@example.com", "password") + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + _, err = p.IssueSession("test@example.com", "wrong_password") + if err == nil { + t.Error("expected error for invalid password, got nil") + } + + _, err = p.IssueSession("nonexistent@example.com", "password") + if err == nil { + t.Error("expected error for non-existent user, got nil") + } + }) + + t.Run("FetchUser", func(t *testing.T) { + session, _ := p.IssueSession("test@example.com", "password") + + user, err := p.FetchUser(session) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if user.Email != "test@example.com" { + t.Errorf("expected email to be 'test@example.com', got %s", user.Email) + } + }) + + t.Run("UnmarshalSession", func(t *testing.T) { + session, _ := p.IssueSession("test@example.com", "password") + data := session.Marshal() + + unmarshalledSession, err := p.UnmarshalSession(data) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + if session.Marshal() != unmarshalledSession.Marshal() { + t.Error("unmarshalled session data does not match the original session data") + } + }) +} diff --git a/providers/direct/session.go b/providers/direct/session.go new file mode 100644 index 000000000..3b51a4752 --- /dev/null +++ b/providers/direct/session.go @@ -0,0 +1,46 @@ +package direct + +import ( + "encoding/json" + "errors" + + "github.com/markbates/goth" +) + +type Session struct { + AuthURL string + Email string +} + +func (s *Session) GetAuthURL() (string, error) { + return s.AuthURL, nil +} + +func (s *Session) Marshal() string { + b, _ := json.Marshal(s) + return string(b) +} + +func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) { + email := params.Get("email") + password := params.Get("password") + + directProvider, ok := provider.(*Provider) + if !ok { + return "", errors.New("invalid provider type") + } + + session, err := directProvider.IssueSession(email, password) + if err != nil { + return "", err + } + + sess, ok := session.(*Session) + if !ok { + return "", errors.New("invalid session type") + } + + s.Email = sess.Email + // Result of Authorize is not used by gothic package + return "", nil +} diff --git a/providers/direct/session_test.go b/providers/direct/session_test.go new file mode 100644 index 000000000..905684e0f --- /dev/null +++ b/providers/direct/session_test.go @@ -0,0 +1,48 @@ +package direct_test + +import ( + "encoding/json" + "testing" + + "github.com/markbates/goth/providers/direct" +) + +func TestDirectSession(t *testing.T) { + t.Run("Marshal", func(t *testing.T) { + session := &direct.Session{ + Email: "test@mail.com", + AuthURL: "/login", + } + marshaled := session.Marshal() + + var unmarshaled direct.Session + err := json.Unmarshal([]byte(marshaled), &unmarshaled) + + if err != nil { + t.Errorf("unexpected error when unmarshaling session data: %v", err) + } + + if unmarshaled.Email != session.Email { + t.Errorf("expected email to be '%s', got '%s'", session.Email, unmarshaled.Email) + } + + if unmarshaled.AuthURL != session.AuthURL { + t.Errorf("expected auth url to be '%s', got '%s'", session.AuthURL, unmarshaled.AuthURL) + } + }) + + t.Run("GetAuthURL", func(t *testing.T) { + session := &direct.Session{ + AuthURL: "/", + } + + url, err := session.GetAuthURL() + if err != nil { + t.Error("unexpected error when calling GetAuthURL") + } + + if url != "/" { + t.Errorf("expected auth url to be '/', got '%s'", url) + } + }) +}