diff --git a/providers/feishu/feishu.go b/providers/feishu/feishu.go new file mode 100644 index 00000000..40fc6fcc --- /dev/null +++ b/providers/feishu/feishu.go @@ -0,0 +1,215 @@ +package feishu + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/markbates/goth" + "golang.org/x/oauth2" +) + +// See: https://open.feishu.cn/document/sso/web-application-sso/login-overview +var ( + AuthURL = "https://accounts.feishu.cn/open-apis/authen/v1/authorize" + TokenURL = "https://open.feishu.cn/open-apis/authen/v2/oauth/token" + ProfileURL = "https://open.feishu.cn/open-apis/authen/v1/user_info" +) + +// Provider is the implementation of `goth.Provider` for accessing Feishu. +type Provider struct { + ClientKey string + Secret string + CallbackURL string + HTTPClient *http.Client + config *oauth2.Config + providerName string + AuthURL string + TokenURL string + ProfileURL string +} + +// New creates a new Feishu provider, and sets up important connection details. +// You should always call `feishu.New` to get a new Provider. Never try to create +// one manually. +func New(clientKey, secret, callbackURL string, scopes ...string) *Provider { + return NewCustomisedURL(clientKey, secret, callbackURL, AuthURL, TokenURL, ProfileURL, scopes...) +} + +// NewCustomisedURL is similar to New(...) but can be used to set custom URLs to connect to +func NewCustomisedURL(clientKey, secret, callbackURL, AuthURL, TokenURL, ProfileURL string, scopes ...string) *Provider { + p := &Provider{ + ClientKey: clientKey, + Secret: secret, + CallbackURL: callbackURL, + providerName: "feishu", + AuthURL: AuthURL, + TokenURL: TokenURL, + ProfileURL: ProfileURL, + } + p.config = newConfig(p, scopes) + return p +} + +func newConfig(provider *Provider, scopes []string) *oauth2.Config { + c := &oauth2.Config{ + ClientID: provider.ClientKey, + ClientSecret: provider.Secret, + RedirectURL: provider.CallbackURL, + Endpoint: oauth2.Endpoint{ + AuthURL: provider.AuthURL, + TokenURL: provider.TokenURL, + }, + Scopes: []string{}, + } + + if len(scopes) > 0 { + c.Scopes = append(c.Scopes, scopes...) + } else { + // If no scope is provided, add the default "auth:user.id:read" + c.Scopes = []string{"auth:user.id:read"} + } + + return c +} + +func (p *Provider) Client() *http.Client { + return goth.HTTPClientWithFallBack(p.HTTPClient) +} + +func (p *Provider) Name() string { + return p.providerName +} + +// SetName is to update the name of the provider (needed in case of multiple providers of 1 type) +func (p *Provider) SetName(name string) { + p.providerName = name +} + +// BeginAuth asks Feishu for an authentication end-point. +func (p *Provider) BeginAuth(state string) (goth.Session, error) { + url := p.config.AuthCodeURL(state) + session := &Session{ + AuthURL: url, + } + return session, nil +} + +// Debug is a no-op for the amazon package. +func (p *Provider) Debug(debug bool) {} + +// RefreshToken get new access token based on the refresh token +func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) { + token := &oauth2.Token{RefreshToken: refreshToken} + ts := p.config.TokenSource(goth.ContextForClient(p.Client()), token) + newToken, err := ts.Token() + if err != nil { + return nil, err + } + return newToken, err +} + +// RefreshTokenAvailable refresh token is provided by Feishu +func (p *Provider) RefreshTokenAvailable() bool { + return true +} + +type feishuUser struct { + Name string `json:"name"` + EnName string `json:"en_name"` + AvatarURL string `json:"avatar_url"` + AvatarThumb string `json:"avatar_thumb"` + AvatarMiddle string `json:"avatar_middle"` + AvatarBig string `json:"avatar_big"` + OpenID string `json:"open_id"` + UnionID string `json:"union_id"` + Email string `json:"email,omitempty"` + EnterpriseEmail string `json:"enterprise_email,omitempty"` + UserID string `json:"user_id,omitempty"` + Mobile string `json:"mobile,omitempty"` + TenantKey string `json:"tenant_key"` + EmployeeNo string `json:"employee_no,omitempty"` +} + +// FetchUser will go to Feishu and access basic information about the user. +func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { + sess := session.(*Session) + user := goth.User{ + AccessToken: sess.AccessToken, + Provider: p.Name(), + RefreshToken: sess.RefreshToken, + ExpiresAt: sess.ExpiresAt, + } + + if user.AccessToken == "" { + // data is not yet retrieved since accessToken is still empty + return user, fmt.Errorf("%s cannot get user information without accessToken", p.providerName) + } + + // Get user information + reqProfile, err := http.NewRequest("GET", p.ProfileURL, nil) + if err != nil { + return user, err + } + + reqProfile.Header.Add("Authorization", fmt.Sprintf("Bearer %s", user.AccessToken)) + reqProfile.Header.Add("Content-Type", "application/json") + + response, err := p.Client().Do(reqProfile) + if err != nil { + return user, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode) + } + + bits, err := io.ReadAll(response.Body) + if err != nil { + return user, err + } + + resBody := struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data map[string]interface{} `json:"data"` + }{} + err = json.Unmarshal(bits, &resBody) + if err != nil { + return user, err + } + if resBody.Code != 0 { + return user, fmt.Errorf("%s", resBody.Msg) + } + + dataBits, err := json.Marshal(resBody.Data) + if err != nil { + return user, err + } + + err = userFromReader(bytes.NewReader(dataBits), &user) + return user, err +} + +func userFromReader(r io.Reader, user *goth.User) error { + // Extract user fields directly + u := feishuUser{} + err := json.NewDecoder(r).Decode(&u) + if err != nil { + return err + } + bits, _ := json.Marshal(u) + json.NewDecoder(bytes.NewReader(bits)).Decode(&user.RawData) + + // Populate user struct + user.Email = u.EnterpriseEmail + user.Name = u.Name + user.NickName = u.Name + user.UserID = u.OpenID + user.AvatarURL = u.AvatarURL + + return nil +} diff --git a/providers/feishu/feishu_test.go b/providers/feishu/feishu_test.go new file mode 100644 index 00000000..4d5dd0e9 --- /dev/null +++ b/providers/feishu/feishu_test.go @@ -0,0 +1,53 @@ +package feishu_test + +import ( + "os" + "testing" + + "github.com/markbates/goth" + "github.com/markbates/goth/providers/feishu" + "github.com/stretchr/testify/assert" +) + +func Test_New(t *testing.T) { + t.Parallel() + a := assert.New(t) + p := provider() + + a.Equal(p.ClientKey, os.Getenv("FEISHU_KEY")) + a.Equal(p.Secret, os.Getenv("FEISHU_SECRET")) + a.Equal(p.CallbackURL, "/foo") +} + +func Test_Implements_Provider(t *testing.T) { + t.Parallel() + a := assert.New(t) + a.Implements((*goth.Provider)(nil), provider()) +} + +func Test_BeginAuth(t *testing.T) { + t.Parallel() + a := assert.New(t) + p := provider() + session, err := p.BeginAuth("test_state") + s := session.(*feishu.Session) + a.NoError(err) + a.Contains(s.AuthURL, "accounts.feishu.cn/open-apis/authen/v1/authorize") +} + +func Test_SessionFromJSON(t *testing.T) { + t.Parallel() + a := assert.New(t) + + p := provider() + session, err := p.UnmarshalSession(`{"AuthURL":"https://open.larksuite.cn/open-apis/authen/v2/oauth/authorize","AccessToken":"1234567890"}`) + a.NoError(err) + + s := session.(*feishu.Session) + a.Equal(s.AuthURL, "https://open.larksuite.cn/open-apis/authen/v2/oauth/authorize") + a.Equal(s.AccessToken, "1234567890") +} + +func provider() *feishu.Provider { + return feishu.New(os.Getenv("FEISHU_KEY"), os.Getenv("FEISHU_SECRET"), "/foo") +} diff --git a/providers/feishu/session.go b/providers/feishu/session.go new file mode 100644 index 00000000..e8624a07 --- /dev/null +++ b/providers/feishu/session.go @@ -0,0 +1,61 @@ +package feishu + +import ( + "encoding/json" + "errors" + "strings" + "time" + + "github.com/markbates/goth" +) + +type Session struct { + AuthURL string + AccessToken string + RefreshToken string + ExpiresAt time.Time + RefreshTokenExpiresAt time.Time +} + +func (s Session) GetAuthURL() (string, error) { + if s.AuthURL == "" { + return "", errors.New(goth.NoAuthUrlErrorMessage) + } + return s.AuthURL, nil +} + +// Marshal the session into a string +func (s Session) Marshal() string { + b, _ := json.Marshal(s) + return string(b) +} + +// UnmarshalSession will unmarshal a JSON string into a session. +func (p *Provider) UnmarshalSession(data string) (goth.Session, error) { + sess := &Session{} + err := json.NewDecoder(strings.NewReader(data)).Decode(sess) + return sess, err +} + +func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) { + p := provider.(*Provider) + token, err := p.config.Exchange(goth.ContextForClient(p.Client()), params.Get("code")) + if err != nil { + return "", err + } + + if !token.Valid() { + return "", errors.New("Invalid token received from provider") + } + + s.AccessToken = token.AccessToken + s.RefreshToken = token.RefreshToken + s.ExpiresAt = token.Expiry + + refreshTokenExpiresAt := token.Extra("refresh_token_expires_in") + if refreshTokenExpiresAt2, ok := refreshTokenExpiresAt.(int); ok { + s.RefreshTokenExpiresAt = time.Now().Add(time.Second * time.Duration(refreshTokenExpiresAt2)) + } + + return token.AccessToken, err +} diff --git a/providers/feishu/session_test.go b/providers/feishu/session_test.go new file mode 100644 index 00000000..59947741 --- /dev/null +++ b/providers/feishu/session_test.go @@ -0,0 +1,49 @@ +package feishu_test + +import ( + "testing" + "time" + + "github.com/markbates/goth" + "github.com/markbates/goth/providers/feishu" + "github.com/stretchr/testify/assert" +) + +func Test_Implements_Session(t *testing.T) { + t.Parallel() + a := assert.New(t) + s := &feishu.Session{} + + a.Implements((*goth.Session)(nil), s) +} + +func Test_GetAuthURL(t *testing.T) { + t.Parallel() + a := assert.New(t) + s := &feishu.Session{} + + _, err := s.GetAuthURL() + a.Error(err) + + s.AuthURL = "/foo" + + url, _ := s.GetAuthURL() + a.Equal(url, "/foo") +} + +func Test_ToJSON(t *testing.T) { + t.Parallel() + a := assert.New(t) + s := &feishu.Session{} + + data := s.Marshal() + a.Equal(data, `{"AuthURL":"","AccessToken":"","RefreshToken":"","ExpiresAt":"0001-01-01T00:00:00Z","RefreshTokenExpiresAt":"0001-01-01T00:00:00Z"}`) +} + +func Test_GetExpiresAt(t *testing.T) { + t.Parallel() + a := assert.New(t) + s := &feishu.Session{} + + a.Equal(s.ExpiresAt, time.Time{}) +} diff --git a/providers/lark/lark.go b/providers/lark/lark.go deleted file mode 100644 index d9900b9c..00000000 --- a/providers/lark/lark.go +++ /dev/null @@ -1,307 +0,0 @@ -package lark - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "sync" - "time" - - "github.com/markbates/goth" - "golang.org/x/oauth2" -) - -const ( - appAccessTokenURL string = "https://open.feishu.cn/open-apis/auth/v3/app_access_token/internal/" // get app_access_token - - authURL string = "https://open.feishu.cn/open-apis/authen/v1/authorize" // obtain authorization code - tokenURL string = "https://open.feishu.cn/open-apis/authen/v1/oidc/access_token" // get user_access_token - refreshTokenURL string = "https://open.feishu.cn/open-apis/authen/v1/oidc/refresh_access_token" // refresh user_access_token - endpointProfile string = "https://open.feishu.cn/open-apis/authen/v1/user_info" // get user info -) - -// Lark is the implementation of `goth.Provider` for accessing Lark -type Lark interface { - GetAppAccessToken() error // get app access token -} - -// Provider is the implementation of `goth.Provider` for accessing Lark -type Provider struct { - ClientKey string - Secret string - CallbackURL string - HTTPClient *http.Client - config *oauth2.Config - providerName string - - appAccessToken *appAccessToken -} - -// New creates a new Lark provider and sets up important connection details. -func New(clientKey, secret, callbackURL string, scopes ...string) *Provider { - p := &Provider{ - ClientKey: clientKey, - Secret: secret, - CallbackURL: callbackURL, - providerName: "lark", - appAccessToken: &appAccessToken{}, - } - p.config = newConfig(p, authURL, tokenURL, scopes) - return p -} - -func newConfig(provider *Provider, authURL, tokenURL string, scopes []string) *oauth2.Config { - c := &oauth2.Config{ - ClientID: provider.ClientKey, - ClientSecret: provider.Secret, - RedirectURL: provider.CallbackURL, - Endpoint: oauth2.Endpoint{ - AuthURL: authURL, - TokenURL: tokenURL, - }, - Scopes: []string{}, - } - - if len(scopes) > 0 { - c.Scopes = append(c.Scopes, scopes...) - } - return c -} - -func (p *Provider) Client() *http.Client { - return goth.HTTPClientWithFallBack(p.HTTPClient) -} - -func (p *Provider) Name() string { - return p.providerName -} - -func (p *Provider) SetName(name string) { - p.providerName = name -} - -type appAccessToken struct { - Token string - ExpiresAt time.Time - rMutex sync.RWMutex -} - -type appAccessTokenReq struct { - AppID string `json:"app_id"` // 自建应用的 app_id - AppSecret string `json:"app_secret"` // 自建应用的 app_secret -} - -type appAccessTokenResp struct { - Code int `json:"code"` // 错误码 - Msg string `json:"msg"` // 错误信息 - AppAccessToken string `json:"app_access_token"` // 用于调用应用级接口的 app_access_token - Expire int64 `json:"expire"` // app_access_token 的过期时间 -} - -// GetAppAccessToken get lark app access token -func (p *Provider) GetAppAccessToken() error { - // get from cache app access token - p.appAccessToken.rMutex.RLock() - if time.Now().Before(p.appAccessToken.ExpiresAt) { - p.appAccessToken.rMutex.RUnlock() - return nil - } - p.appAccessToken.rMutex.RUnlock() - - reqBody, err := json.Marshal(&appAccessTokenReq{ - AppID: p.ClientKey, - AppSecret: p.Secret, - }) - if err != nil { - return fmt.Errorf("failed to marshal request body: %w", err) - } - - req, err := http.NewRequest(http.MethodPost, appAccessTokenURL, bytes.NewBuffer(reqBody)) - if err != nil { - return fmt.Errorf("failed to create app access token request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := p.Client().Do(req) - if err != nil { - return fmt.Errorf("failed to send app access token request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected status code while fetching app access token: %d", resp.StatusCode) - } - - tokenResp := new(appAccessTokenResp) - if err = json.NewDecoder(resp.Body).Decode(tokenResp); err != nil { - return fmt.Errorf("failed to decode app access token response: %w", err) - } - - if tokenResp.Code != 0 { - return fmt.Errorf("failed to get app access token: code:%v msg: %s", tokenResp.Code, tokenResp.Msg) - } - - // update local cache - expirationDuration := time.Duration(tokenResp.Expire) * time.Second - p.appAccessToken.rMutex.Lock() - p.appAccessToken.Token = tokenResp.AppAccessToken - p.appAccessToken.ExpiresAt = time.Now().Add(expirationDuration) - p.appAccessToken.rMutex.Unlock() - - return nil -} - -func (p *Provider) BeginAuth(state string) (goth.Session, error) { - // build lark auth url - u, err := url.Parse(p.config.AuthCodeURL(state)) - if err != nil { - panic(err) - } - query := u.Query() - query.Del("response_type") - query.Del("client_id") - query.Add("app_id", p.ClientKey) - u.RawQuery = query.Encode() - - return &Session{ - AuthURL: u.String(), - }, nil -} - -func (p *Provider) UnmarshalSession(data string) (goth.Session, error) { - s := &Session{} - err := json.NewDecoder(strings.NewReader(data)).Decode(s) - return s, err -} - -func (p *Provider) Debug(b bool) { -} - -type getUserAccessTokenResp struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` - RefreshExpiresIn int `json:"refresh_expires_in"` - Scope string `json:"scope"` -} - -func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) { - if err := p.GetAppAccessToken(); err != nil { - return nil, fmt.Errorf("failed to get app access token: %w", err) - } - reqBody := strings.NewReader(`{"grant_type":"refresh_token","refresh_token":"` + refreshToken + `"}`) - - req, err := http.NewRequest(http.MethodPost, refreshTokenURL, reqBody) - if err != nil { - return nil, fmt.Errorf("failed to create refresh token request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.appAccessToken.Token)) - - resp, err := p.Client().Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send refresh token request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code while refreshing token: %d", resp.StatusCode) - } - - var oauthResp commResponse[getUserAccessTokenResp] - err = json.NewDecoder(resp.Body).Decode(&oauthResp) - if err != nil { - return nil, fmt.Errorf("failed to decode refreshed token: %w", err) - } - if oauthResp.Code != 0 { - return nil, fmt.Errorf("failed to refresh token: code:%v msg: %s", oauthResp.Code, oauthResp.Msg) - } - - token := oauth2.Token{ - AccessToken: oauthResp.Data.AccessToken, - RefreshToken: oauthResp.Data.RefreshToken, - Expiry: time.Now().Add(time.Duration(oauthResp.Data.ExpiresIn) * time.Second), - } - - return &token, nil -} - -func (p *Provider) RefreshTokenAvailable() bool { - return true -} - -type commResponse[T any] struct { - Code int `json:"code"` - Msg string `json:"msg"` - Data T `json:"data"` -} - -type larkUser struct { - OpenID string `json:"open_id"` - UnionID string `json:"union_id"` - UserID string `json:"user_id"` - Name string `json:"name"` - Email string `json:"enterprise_email"` - AvatarURL string `json:"avatar_url"` - Mobile string `json:"mobile,omitempty"` -} - -// FetchUser will go to Lark and access basic information about the user. -func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { - sess := session.(*Session) - user := goth.User{ - AccessToken: sess.AccessToken, - Provider: p.Name(), - RefreshToken: sess.RefreshToken, - ExpiresAt: sess.ExpiresAt, - } - if user.AccessToken == "" { - return user, fmt.Errorf("%s cannot get user information without accessToken", p.providerName) - } - - req, err := http.NewRequest("GET", endpointProfile, nil) - if err != nil { - return user, fmt.Errorf("%s failed to create request: %w", p.providerName, err) - } - req.Header.Set("Authorization", "Bearer "+user.AccessToken) - - resp, err := p.Client().Do(req) - if err != nil { - return user, fmt.Errorf("%s failed to get user information: %w", p.providerName, err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, resp.StatusCode) - } - - responseBytes, err := io.ReadAll(resp.Body) - if err != nil { - return user, fmt.Errorf("failed to read response body: %w", err) - } - - var oauthResp commResponse[larkUser] - if err = json.Unmarshal(responseBytes, &oauthResp); err != nil { - return user, fmt.Errorf("failed to decode user info: %w", err) - } - if oauthResp.Code != 0 { - return user, fmt.Errorf("failed to get user info: code:%v msg: %s", oauthResp.Code, oauthResp.Msg) - } - - u := oauthResp.Data - user.UserID = u.UserID - user.Name = u.Name - user.Email = u.Email - user.AvatarURL = u.AvatarURL - user.NickName = u.Name - - if err = json.Unmarshal(responseBytes, &user.RawData); err != nil { - return user, err - } - return user, nil -} diff --git a/providers/lark/lark_test.go b/providers/lark/lark_test.go deleted file mode 100644 index cda49e52..00000000 --- a/providers/lark/lark_test.go +++ /dev/null @@ -1,185 +0,0 @@ -package lark_test - -import ( - "errors" - "fmt" - "io/ioutil" - "net/http" - "net/url" - "os" - "strings" - "testing" - - "github.com/markbates/goth/providers/lark" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -type MockedHTTPClient struct { - mock.Mock -} - -func (m *MockedHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) { - args := m.Mock.Called(req) - return args.Get(0).(*http.Response), args.Error(1) -} - -func Test_New(t *testing.T) { - t.Parallel() - a := assert.New(t) - p := larkProvider() - - a.Equal(p.ClientKey, os.Getenv("LARK_APP_ID")) - a.Equal(p.Secret, os.Getenv("LARK_APP_SECRET")) - a.Equal(p.CallbackURL, "/foo") -} - -func Test_BeginAuth(t *testing.T) { - t.Parallel() - a := assert.New(t) - p := larkProvider() - session, err := p.BeginAuth("test_state") - s := session.(*lark.Session) - a.NoError(err) - a.Contains(s.AuthURL, "https://open.feishu.cn/open-apis/authen/v1/authorize") - a.Contains(s.AuthURL, "app_id="+os.Getenv("LARK_APP_ID")) - a.Contains(s.AuthURL, "state=test_state") - a.Contains(s.AuthURL, fmt.Sprintf("redirect_uri=%s", url.QueryEscape("/foo"))) -} - -func Test_GetAppAccessToken(t *testing.T) { - t.Run("happy path", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0,"msg":"ok","app_access_token":"test_token","expire":3600}`)), - }, nil) - - err := p.GetAppAccessToken() - assert.NoError(t, err) - }) - - t.Run("error on request", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{}, errors.New("request error")) - - err := p.GetAppAccessToken() - assert.Error(t, err) - }) - - t.Run("non-200 status code", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ - StatusCode: http.StatusForbidden, - Body: ioutil.NopCloser(strings.NewReader(``)), - }, nil) - - err := p.GetAppAccessToken() - assert.Error(t, err) - }) - - t.Run("error on response decode", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(strings.NewReader(`not a json`)), - }, nil) - - err := p.GetAppAccessToken() - assert.Error(t, err) - }) - - t.Run("error code in response", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(strings.NewReader(`{"code":1,"msg":"error message"}`)), - }, nil) - - err := p.GetAppAccessToken() - assert.Error(t, err) - }) -} - -func Test_FetchUser(t *testing.T) { - session := &lark.Session{ - AccessToken: "user_access_token", - } - - t.Run("happy path", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0,"msg":"ok","data":{"user_id":"test_user_id","name":"test_name","avatar_url":"test_avatar_url","enterprise_email":"test_email"}}`)), - }, nil) - user, err := p.FetchUser(session) - require.NoError(t, err) - assert.Equal(t, user.UserID, "test_user_id") - assert.Equal(t, user.Name, "test_name") - assert.Equal(t, user.AvatarURL, "test_avatar_url") - assert.Equal(t, user.Email, "test_email") - }) - t.Run("error on request", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{}, errors.New("request error")) - _, err := p.FetchUser(session) - require.Error(t, err) - }) - t.Run("non-200 status code", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ - StatusCode: http.StatusForbidden, - Body: ioutil.NopCloser(strings.NewReader(``)), - }, nil) - _, err := p.FetchUser(session) - require.Error(t, err) - }) - t.Run("error on response decode", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(strings.NewReader(`not a json`)), - }, nil) - _, err := p.FetchUser(session) - require.Error(t, err) - }) - t.Run("error code in response", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(strings.NewReader(`{"code":1,"msg":"error message"}`)), - }, nil) - _, err := p.FetchUser(session) - require.Error(t, err) - }) -} - -func larkProvider() *lark.Provider { - return lark.New(os.Getenv("LARK_APP_ID"), os.Getenv("LARK_APP_SECRET"), "/foo") -} diff --git a/providers/lark/session.go b/providers/lark/session.go deleted file mode 100644 index 2fdf260c..00000000 --- a/providers/lark/session.go +++ /dev/null @@ -1,71 +0,0 @@ -package lark - -import ( - "encoding/json" - "errors" - "fmt" - "net/http" - "strings" - "time" - - "github.com/markbates/goth" -) - -type Session struct { - AuthURL string - AccessToken string - RefreshToken string - ExpiresAt time.Time - RefreshTokenExpiresAt time.Time -} - -func (s *Session) GetAuthURL() (string, error) { - if s.AuthURL == "" { - return "", errors.New("lark: missing AuthURL") - } - return s.AuthURL, nil -} - -func (s *Session) Marshal() string { - b, _ := json.Marshal(s) - return string(b) -} - -func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) { - p := provider.(*Provider) - reqBody := strings.NewReader(`{"grant_type":"authorization_code","code":"` + params.Get("code") + `"}`) - req, err := http.NewRequest(http.MethodPost, tokenURL, reqBody) - if err != nil { - return "", fmt.Errorf("failed to create refresh token request: %w", err) - } - if err = p.GetAppAccessToken(); err != nil { - return "", fmt.Errorf("failed to get app access token: %w", err) - } - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", p.appAccessToken.Token)) - req.Header.Add("Content-Type", "application/json; charset=utf-8") - - resp, err := p.Client().Do(req) - if err != nil { - return "", fmt.Errorf("failed to send refresh token request: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("unexpected status code while authorizing: %d", resp.StatusCode) - } - - var larkCommResp commResponse[getUserAccessTokenResp] - err = json.NewDecoder(resp.Body).Decode(&larkCommResp) - if err != nil { - return "", fmt.Errorf("failed to decode commResponse: %w", err) - } - if larkCommResp.Code != 0 { - return "", fmt.Errorf("failed to get accessToken: code:%v msg: %s", larkCommResp.Code, larkCommResp.Msg) - } - - s.AccessToken = larkCommResp.Data.AccessToken - s.RefreshToken = larkCommResp.Data.RefreshToken - s.ExpiresAt = time.Now().Add(time.Duration(larkCommResp.Data.ExpiresIn) * time.Second) - s.RefreshTokenExpiresAt = time.Now().Add(time.Duration(larkCommResp.Data.RefreshExpiresIn) * time.Second) - return s.AccessToken, nil -} diff --git a/providers/lark/session_test.go b/providers/lark/session_test.go deleted file mode 100644 index 59dc53f2..00000000 --- a/providers/lark/session_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package lark_test - -import ( - "errors" - "io/ioutil" - "net/http" - "strings" - "testing" - - "github.com/markbates/goth" - "github.com/markbates/goth/providers/lark" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -type MockParams struct { - params map[string]string -} - -func (m *MockParams) Get(key string) string { - return m.params[key] -} - -func Test_Implements_Session(t *testing.T) { - t.Parallel() - a := assert.New(t) - s := &lark.Session{} - - a.Implements((*goth.Session)(nil), s) -} - -func Test_GetAuthURL(t *testing.T) { - t.Run("happy path", func(t *testing.T) { - session := &lark.Session{ - AuthURL: "https://auth.url", - } - url, err := session.GetAuthURL() - assert.NoError(t, err) - assert.Equal(t, "https://auth.url", url) - }) - - t.Run("missing AuthURL", func(t *testing.T) { - session := &lark.Session{} - _, err := session.GetAuthURL() - assert.Error(t, err) - }) -} - -func Test_Marshal(t *testing.T) { - session := &lark.Session{ - AuthURL: "https://auth.url", - AccessToken: "access_token", - } - marshaled := session.Marshal() - assert.Contains(t, marshaled, "https://auth.url") - assert.Contains(t, marshaled, "access_token") -} - -func Test_Authorize(t *testing.T) { - session := &lark.Session{} - params := &MockParams{ - params: map[string]string{ - "code": "authorization_code", - }, - } - - t.Run("error on request", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{}, errors.New("request error")) - _, err := session.Authorize(p, params) - require.Error(t, err) - }) - - t.Run("non-200 status code", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ - StatusCode: http.StatusForbidden, - Body: ioutil.NopCloser(strings.NewReader(``)), - }, nil) - _, err := session.Authorize(p, params) - require.Error(t, err) - }) - - t.Run("error on response decode", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(strings.NewReader(`not a json`)), - }, nil) - _, err := session.Authorize(p, params) - require.Error(t, err) - }) - - t.Run("error code in response", func(t *testing.T) { - mockClient := new(MockedHTTPClient) - p := larkProvider() - p.HTTPClient = &http.Client{Transport: mockClient} - mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: ioutil.NopCloser(strings.NewReader(`{"code":1,"msg":"error message"}`)), - }, nil) - _, err := session.Authorize(p, params) - require.Error(t, err) - }) -}