Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions providers/feishu/feishu.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package feishu

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/markbates/goth"
"golang.org/x/oauth2"
)

// See: https://open.feishu.cn/document/sso/web-application-sso/login-overview
var (
AuthURL = "https://accounts.feishu.cn/open-apis/authen/v1/authorize"
TokenURL = "https://open.feishu.cn/open-apis/authen/v2/oauth/token"
ProfileURL = "https://open.feishu.cn/open-apis/authen/v1/user_info"
)

// Provider is the implementation of `goth.Provider` for accessing Feishu.
type Provider struct {
ClientKey string
Secret string
CallbackURL string
HTTPClient *http.Client
config *oauth2.Config
providerName string
AuthURL string
TokenURL string
ProfileURL string
}

// New creates a new Feishu provider, and sets up important connection details.
// You should always call `feishu.New` to get a new Provider. Never try to create
// one manually.
func New(clientKey, secret, callbackURL string, scopes ...string) *Provider {
return NewCustomisedURL(clientKey, secret, callbackURL, AuthURL, TokenURL, ProfileURL, scopes...)
}

// NewCustomisedURL is similar to New(...) but can be used to set custom URLs to connect to
func NewCustomisedURL(clientKey, secret, callbackURL, AuthURL, TokenURL, ProfileURL string, scopes ...string) *Provider {
p := &Provider{
ClientKey: clientKey,
Secret: secret,
CallbackURL: callbackURL,
providerName: "feishu",
AuthURL: AuthURL,
TokenURL: TokenURL,
ProfileURL: ProfileURL,
}
p.config = newConfig(p, scopes)
return p
}

func newConfig(provider *Provider, scopes []string) *oauth2.Config {
c := &oauth2.Config{
ClientID: provider.ClientKey,
ClientSecret: provider.Secret,
RedirectURL: provider.CallbackURL,
Endpoint: oauth2.Endpoint{
AuthURL: provider.AuthURL,
TokenURL: provider.TokenURL,
},
Scopes: []string{},
}

if len(scopes) > 0 {
c.Scopes = append(c.Scopes, scopes...)
} else {
// If no scope is provided, add the default "auth:user.id:read"
c.Scopes = []string{"auth:user.id:read"}
}

return c
}

func (p *Provider) Client() *http.Client {
return goth.HTTPClientWithFallBack(p.HTTPClient)
}

func (p *Provider) Name() string {
return p.providerName
}

// SetName is to update the name of the provider (needed in case of multiple providers of 1 type)
func (p *Provider) SetName(name string) {
p.providerName = name
}

// BeginAuth asks Feishu for an authentication end-point.
func (p *Provider) BeginAuth(state string) (goth.Session, error) {
url := p.config.AuthCodeURL(state)
session := &Session{
AuthURL: url,
}
return session, nil
}

// Debug is a no-op for the amazon package.
func (p *Provider) Debug(debug bool) {}

// RefreshToken get new access token based on the refresh token
func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
token := &oauth2.Token{RefreshToken: refreshToken}
ts := p.config.TokenSource(goth.ContextForClient(p.Client()), token)
newToken, err := ts.Token()
if err != nil {
return nil, err
}
return newToken, err
}

// RefreshTokenAvailable refresh token is provided by Feishu
func (p *Provider) RefreshTokenAvailable() bool {
return true
}

type feishuUser struct {
Name string `json:"name"`
EnName string `json:"en_name"`
AvatarURL string `json:"avatar_url"`
AvatarThumb string `json:"avatar_thumb"`
AvatarMiddle string `json:"avatar_middle"`
AvatarBig string `json:"avatar_big"`
OpenID string `json:"open_id"`
UnionID string `json:"union_id"`
Email string `json:"email,omitempty"`
EnterpriseEmail string `json:"enterprise_email,omitempty"`
UserID string `json:"user_id,omitempty"`
Mobile string `json:"mobile,omitempty"`
TenantKey string `json:"tenant_key"`
EmployeeNo string `json:"employee_no,omitempty"`
}

// FetchUser will go to Feishu and access basic information about the user.
func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
sess := session.(*Session)
user := goth.User{
AccessToken: sess.AccessToken,
Provider: p.Name(),
RefreshToken: sess.RefreshToken,
ExpiresAt: sess.ExpiresAt,
}

if user.AccessToken == "" {
// data is not yet retrieved since accessToken is still empty
return user, fmt.Errorf("%s cannot get user information without accessToken", p.providerName)
}

// Get user information
reqProfile, err := http.NewRequest("GET", p.ProfileURL, nil)
if err != nil {
return user, err
}

reqProfile.Header.Add("Authorization", fmt.Sprintf("Bearer %s", user.AccessToken))
reqProfile.Header.Add("Content-Type", "application/json")

response, err := p.Client().Do(reqProfile)
if err != nil {
return user, err
}
defer response.Body.Close()

if response.StatusCode != http.StatusOK {
return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode)
}

bits, err := io.ReadAll(response.Body)
if err != nil {
return user, err
}

resBody := struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data map[string]interface{} `json:"data"`
}{}
err = json.Unmarshal(bits, &resBody)
if err != nil {
return user, err
}
if resBody.Code != 0 {
return user, fmt.Errorf("%s", resBody.Msg)
}

dataBits, err := json.Marshal(resBody.Data)
if err != nil {
return user, err
}

err = userFromReader(bytes.NewReader(dataBits), &user)
return user, err
}

func userFromReader(r io.Reader, user *goth.User) error {
// Extract user fields directly
u := feishuUser{}
err := json.NewDecoder(r).Decode(&u)
if err != nil {
return err
}
bits, _ := json.Marshal(u)
json.NewDecoder(bytes.NewReader(bits)).Decode(&user.RawData)

// Populate user struct
user.Email = u.EnterpriseEmail
user.Name = u.Name
user.NickName = u.Name
user.UserID = u.OpenID
user.AvatarURL = u.AvatarURL

return nil
}
53 changes: 53 additions & 0 deletions providers/feishu/feishu_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package feishu_test

import (
"os"
"testing"

"github.com/markbates/goth"
"github.com/markbates/goth/providers/feishu"
"github.com/stretchr/testify/assert"
)

func Test_New(t *testing.T) {
t.Parallel()
a := assert.New(t)
p := provider()

a.Equal(p.ClientKey, os.Getenv("FEISHU_KEY"))
a.Equal(p.Secret, os.Getenv("FEISHU_SECRET"))
a.Equal(p.CallbackURL, "/foo")
}

func Test_Implements_Provider(t *testing.T) {
t.Parallel()
a := assert.New(t)
a.Implements((*goth.Provider)(nil), provider())
}

func Test_BeginAuth(t *testing.T) {
t.Parallel()
a := assert.New(t)
p := provider()
session, err := p.BeginAuth("test_state")
s := session.(*feishu.Session)
a.NoError(err)
a.Contains(s.AuthURL, "accounts.feishu.cn/open-apis/authen/v1/authorize")
}

func Test_SessionFromJSON(t *testing.T) {
t.Parallel()
a := assert.New(t)

p := provider()
session, err := p.UnmarshalSession(`{"AuthURL":"https://open.larksuite.cn/open-apis/authen/v2/oauth/authorize","AccessToken":"1234567890"}`)
a.NoError(err)

s := session.(*feishu.Session)
a.Equal(s.AuthURL, "https://open.larksuite.cn/open-apis/authen/v2/oauth/authorize")
a.Equal(s.AccessToken, "1234567890")
}

func provider() *feishu.Provider {
return feishu.New(os.Getenv("FEISHU_KEY"), os.Getenv("FEISHU_SECRET"), "/foo")
}
61 changes: 61 additions & 0 deletions providers/feishu/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package feishu

import (
"encoding/json"
"errors"
"strings"
"time"

"github.com/markbates/goth"
)

type Session struct {
AuthURL string
AccessToken string
RefreshToken string
ExpiresAt time.Time
RefreshTokenExpiresAt time.Time
}

func (s Session) GetAuthURL() (string, error) {
if s.AuthURL == "" {
return "", errors.New(goth.NoAuthUrlErrorMessage)
}
return s.AuthURL, nil
}

// Marshal the session into a string
func (s Session) Marshal() string {
b, _ := json.Marshal(s)
return string(b)
}

// UnmarshalSession will unmarshal a JSON string into a session.
func (p *Provider) UnmarshalSession(data string) (goth.Session, error) {
sess := &Session{}
err := json.NewDecoder(strings.NewReader(data)).Decode(sess)
return sess, err
}

func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) {
p := provider.(*Provider)
token, err := p.config.Exchange(goth.ContextForClient(p.Client()), params.Get("code"))
if err != nil {
return "", err
}

if !token.Valid() {
return "", errors.New("Invalid token received from provider")
}

s.AccessToken = token.AccessToken
s.RefreshToken = token.RefreshToken
s.ExpiresAt = token.Expiry

refreshTokenExpiresAt := token.Extra("refresh_token_expires_in")
if refreshTokenExpiresAt2, ok := refreshTokenExpiresAt.(int); ok {
s.RefreshTokenExpiresAt = time.Now().Add(time.Second * time.Duration(refreshTokenExpiresAt2))
}

return token.AccessToken, err
}
49 changes: 49 additions & 0 deletions providers/feishu/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package feishu_test

import (
"testing"
"time"

"github.com/markbates/goth"
"github.com/markbates/goth/providers/feishu"
"github.com/stretchr/testify/assert"
)

func Test_Implements_Session(t *testing.T) {
t.Parallel()
a := assert.New(t)
s := &feishu.Session{}

a.Implements((*goth.Session)(nil), s)
}

func Test_GetAuthURL(t *testing.T) {
t.Parallel()
a := assert.New(t)
s := &feishu.Session{}

_, err := s.GetAuthURL()
a.Error(err)

s.AuthURL = "/foo"

url, _ := s.GetAuthURL()
a.Equal(url, "/foo")
}

func Test_ToJSON(t *testing.T) {
t.Parallel()
a := assert.New(t)
s := &feishu.Session{}

data := s.Marshal()
a.Equal(data, `{"AuthURL":"","AccessToken":"","RefreshToken":"","ExpiresAt":"0001-01-01T00:00:00Z","RefreshTokenExpiresAt":"0001-01-01T00:00:00Z"}`)
}

func Test_GetExpiresAt(t *testing.T) {
t.Parallel()
a := assert.New(t)
s := &feishu.Session{}

a.Equal(s.ExpiresAt, time.Time{})
}
Loading