diff --git a/internal/auth/domain.go b/internal/auth/domain.go index 228e84b04..76964854f 100644 --- a/internal/auth/domain.go +++ b/internal/auth/domain.go @@ -39,6 +39,8 @@ type UserInfo struct { ID model.UserID GroupID user.GroupID Permission Permission + Scope Scope + Legacy bool } // Auth is the basic authorization represent a user. @@ -48,6 +50,18 @@ type Auth struct { ID model.UserID // user id GroupID user.GroupID Permission Permission + Scope Scope + Legacy bool +} + +type Scope map[string]bool + +func (u Auth) HasScope(s string) bool { + if u.Legacy || u.Scope == nil { + return true + } + + return u.Scope[s] } const nsfwThreshold = gtime.OneDay * 60 diff --git a/internal/auth/domain_test.go b/internal/auth/domain_test.go index 257d27eda..89c44531b 100644 --- a/internal/auth/domain_test.go +++ b/internal/auth/domain_test.go @@ -47,3 +47,31 @@ func TestNotAllowNsfw(t *testing.T) { require.False(t, u.AllowNSFW()) } + +func TestAuthHasScope(t *testing.T) { + t.Parallel() + + u := auth.Auth{ + Scope: auth.Scope{ + "write:collection": true, + }, + } + + require.True(t, u.HasScope("write:collection")) + require.False(t, u.HasScope("write:indices")) +} + +func TestAuthHasScopeLegacy(t *testing.T) { + t.Parallel() + + u := auth.Auth{Legacy: true} + require.True(t, u.HasScope("write:collection")) + require.True(t, u.HasScope("any:scope")) +} + +func TestAuthHasScopeNilScopeCompatible(t *testing.T) { + t.Parallel() + + u := auth.Auth{} + require.True(t, u.HasScope("write:collection")) +} diff --git a/internal/auth/mysql_repository.go b/internal/auth/mysql_repository.go index 9e5ede854..55892f36a 100644 --- a/internal/auth/mysql_repository.go +++ b/internal/auth/mysql_repository.go @@ -17,6 +17,7 @@ package auth import ( "context" "database/sql" + "encoding/json" "errors" "time" @@ -47,10 +48,11 @@ type mysqlRepo struct { func (m mysqlRepo) GetByToken(ctx context.Context, token string) (UserInfo, error) { var access struct { - UserID string `db:"user_id"` + UserID string `db:"user_id"` + Scope sql.NullString `db:"scope"` } err := m.db.GetContext(ctx, &access, - `select user_id from chii_oauth_access_tokens + `select user_id, scope from chii_oauth_access_tokens where access_token = BINARY ? and expires > ? limit 1`, token, time.Now()) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -87,14 +89,31 @@ func (m mysqlRepo) GetByToken(ctx context.Context, token string) (UserInfo, erro return UserInfo{}, errgo.Wrap(err, "parsing permission") } + scope, legacy := parseTokenScope(access.Scope) + return UserInfo{ RegTime: time.Unix(u.Regdate, 0), ID: id, GroupID: u.GroupID, Permission: perm, + Scope: scope, + Legacy: legacy, }, nil } +func parseTokenScope(scope sql.NullString) (Scope, bool) { + if !scope.Valid || scope.String == "" { + return nil, true + } + + var parsed map[string]bool + if err := json.Unmarshal([]byte(scope.String), &parsed); err != nil { + return Scope{}, false + } + + return parsed, false +} + func (m mysqlRepo) GetPermission(ctx context.Context, groupID uint8) (Permission, error) { r, err := m.q.UserGroup.WithContext(ctx).Where(m.q.UserGroup.ID.Eq(groupID)).Take() if err != nil { diff --git a/internal/auth/mysql_repository_scope_internal_test.go b/internal/auth/mysql_repository_scope_internal_test.go new file mode 100644 index 000000000..fb6d96c93 --- /dev/null +++ b/internal/auth/mysql_repository_scope_internal_test.go @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: AGPL-3.0-only + +package auth + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseTokenScope_LegacyNull(t *testing.T) { + t.Parallel() + + scope, legacy := parseTokenScope(sql.NullString{}) + require.True(t, legacy) + require.Nil(t, scope) +} + +func TestParseTokenScope_LegacyEmptyString(t *testing.T) { + t.Parallel() + + scope, legacy := parseTokenScope(sql.NullString{Valid: true, String: ""}) + require.True(t, legacy) + require.Nil(t, scope) +} + +func TestParseTokenScope_Object(t *testing.T) { + t.Parallel() + + scope, legacy := parseTokenScope(sql.NullString{Valid: true, String: `{"write:collection":true,"write:indices":false}`}) + require.False(t, legacy) + require.Equal(t, Scope{"write:collection": true, "write:indices": false}, scope) +} + +func TestParseTokenScope_NonObject(t *testing.T) { + t.Parallel() + + scope, legacy := parseTokenScope(sql.NullString{Valid: true, String: `["write:collection"]`}) + require.False(t, legacy) + require.Empty(t, scope) +} diff --git a/internal/auth/service.go b/internal/auth/service.go index 2382ef25a..5479e380d 100644 --- a/internal/auth/service.go +++ b/internal/auth/service.go @@ -76,6 +76,8 @@ func (s service) GetByToken(ctx context.Context, token string) (Auth, error) { ID: a.ID, GroupID: a.GroupID, Permission: permission.Merge(a.Permission), + Scope: a.Scope, + Legacy: a.Legacy, }, nil } diff --git a/openapi/v0.yaml b/openapi/v0.yaml index 8134dd7c2..44250854b 100644 --- a/openapi/v0.yaml +++ b/openapi/v0.yaml @@ -742,13 +742,15 @@ paths: schema: "$ref": "#/components/schemas/ErrorDetail" security: - - HTTPBearer: [] + - HTTPBearer: + - write:collection delete: tags: - 角色 summary: Uncollect character for current user operationId: uncollectCharacterByCharacterIdAndUserId - description: 为当前用户取消收藏角色 + description: | + 为当前用户取消收藏角色 parameters: - $ref: "#/components/parameters/path_character_id" responses: @@ -773,7 +775,8 @@ paths: schema: "$ref": "#/components/schemas/ErrorDetail" security: - - HTTPBearer: [] + - HTTPBearer: + - write:collection "/v0/persons/{person_id}": get: @@ -931,7 +934,8 @@ paths: schema: "$ref": "#/components/schemas/ErrorDetail" security: - - OptionalHTTPBearer: [] + - HTTPBearer: + - write:collection delete: tags: - 人物 @@ -962,7 +966,8 @@ paths: schema: "$ref": "#/components/schemas/ErrorDetail" security: - - OptionalHTTPBearer: [] + - HTTPBearer: + - write:collection "/v0/users/{username}": get: @@ -1198,7 +1203,8 @@ paths: schema: "$ref": "#/components/schemas/ErrorDetail" security: - - OptionalHTTPBearer: [] + - HTTPBearer: + - write:collection patch: tags: - 收藏 @@ -1239,7 +1245,8 @@ paths: schema: "$ref": "#/components/schemas/ErrorDetail" security: - - OptionalHTTPBearer: [] + - HTTPBearer: + - write:collection "/v0/users/-/collections/{subject_id}/episodes": get: @@ -1348,7 +1355,8 @@ paths: schema: "$ref": "#/components/schemas/ErrorDetail" security: - - HTTPBearer: [] + - HTTPBearer: + - write:collection "/v0/users/-/collections/-/episodes/{episode_id}": get: @@ -1424,7 +1432,8 @@ paths: schema: "$ref": "#/components/schemas/ErrorDetail" security: - - HTTPBearer: [] + - HTTPBearer: + - write:collection "/v0/users/{username}/collections/-/characters": get: @@ -1782,7 +1791,8 @@ paths: schema: "$ref": "#/components/schemas/ErrorDetail" security: - - HTTPBearer: [] + - HTTPBearer: + - write:indices "/v0/indices/{index_id}": get: tags: @@ -1828,7 +1838,8 @@ paths: "404": "$ref": "#/components/responses/404" security: - - HTTPBearer: [] + - HTTPBearer: + - write:indices "/v0/indices/{index_id}/subjects": get: tags: @@ -1876,7 +1887,8 @@ paths: "404": "$ref": "#/components/responses/404" security: - - HTTPBearer: [] + - HTTPBearer: + - write:indices "/v0/indices/{index_id}/subjects/{subject_id}": put: tags: @@ -1902,7 +1914,8 @@ paths: "400": "$ref": "#/components/responses/400" security: - - HTTPBearer: [] + - HTTPBearer: + - write:indices delete: tags: - 目录 @@ -1919,7 +1932,8 @@ paths: "401": "$ref": "#/components/responses/401" security: - - HTTPBearer: [] + - HTTPBearer: + - write:indices "/v0/indices/{index_id}/collect": post: tags: @@ -1939,13 +1953,15 @@ paths: "500": "$ref": "#/components/responses/500" security: - - HTTPBearer: [] + - HTTPBearer: + - write:collection delete: tags: - 目录 summary: Uncollect index for current user operationId: uncollectIndexByIndexIdAndUserId - description: 为当前用户取消收藏一条目录 + description: | + 为当前用户取消收藏一条目录 parameters: - $ref: "#/components/parameters/path_index_id" responses: @@ -1958,7 +1974,8 @@ paths: "500": "$ref": "#/components/responses/500" security: - - HTTPBearer: [] + - HTTPBearer: + - write:collection components: parameters: path_subject_id: @@ -3316,9 +3333,15 @@ components: description: 不强制要求用户认证,但是可能看不到某些敏感内容内容(如 NSFW 或者仅用户自己可见的收藏) scheme: Bearer HTTPBearer: - type: http - description: 需要使用 access token 进行认证 - scheme: Bearer + type: oauth2 + description: OAuth2 access token(写操作会校验 scope) + flows: + authorizationCode: + authorizationUrl: /oauth/authorize + tokenUrl: /oauth/access_token + scopes: + write:collection: 修改收藏相关数据 + write:indices: 修改目录及目录条目 responses: 200-no-content: description: Successful Response diff --git a/web/mw/middleware.go b/web/mw/middleware.go index e29d9f746..a5c6d5cba 100644 --- a/web/mw/middleware.go +++ b/web/mw/middleware.go @@ -22,6 +22,12 @@ import ( ) var errNeedLogin = res.Unauthorized("this API need authorization") +var errInsufficientScope = res.Forbidden("insufficient token scope") + +const ( + ScopeWriteCollection = "write:collection" + ScopeWriteIndices = "write:indices" +) func NeedLogin(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -32,3 +38,20 @@ func NeedLogin(next echo.HandlerFunc) echo.HandlerFunc { return next(c) } } + +func NeedScope(scope string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + u := accessor.GetFromCtx(c) + if !u.Login { + return errNeedLogin + } + + if !u.HasScope(scope) { + return errInsufficientScope + } + + return next(c) + } + } +} diff --git a/web/mw/middleware_test.go b/web/mw/middleware_test.go new file mode 100644 index 000000000..fb0d7da32 --- /dev/null +++ b/web/mw/middleware_test.go @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: AGPL-3.0-only + +package mw + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/require" + + "github.com/bangumi/server/internal/auth" + "github.com/bangumi/server/web/accessor" + "github.com/bangumi/server/web/internal/ctxkey" +) + +func TestNeedScope_NeedLogin(t *testing.T) { + t.Parallel() + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := NeedScope(ScopeWriteCollection)(func(c echo.Context) error { + return nil + }) + + err := h(c) + require.ErrorIs(t, err, errNeedLogin) +} + +func TestNeedScope_Insufficient(t *testing.T) { + t.Parallel() + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + a := &accessor.Accessor{} + a.SetAuth(auth.Auth{Scope: auth.Scope{}}) + c.Set(ctxkey.User, a) + + h := NeedScope(ScopeWriteCollection)(func(c echo.Context) error { + return nil + }) + + err := h(c) + require.ErrorIs(t, err, errInsufficientScope) +} + +func TestNeedScope_Match(t *testing.T) { + t.Parallel() + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + a := &accessor.Accessor{} + a.SetAuth(auth.Auth{Scope: auth.Scope{ScopeWriteCollection: true}}) + c.Set(ctxkey.User, a) + + reached := false + h := NeedScope(ScopeWriteCollection)(func(c echo.Context) error { + reached = true + return nil + }) + + err := h(c) + require.NoError(t, err) + require.True(t, reached) +} + +func TestNeedScope_Legacy(t *testing.T) { + t.Parallel() + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + a := &accessor.Accessor{} + a.SetAuth(auth.Auth{Legacy: true}) + c.Set(ctxkey.User, a) + + reached := false + h := NeedScope(ScopeWriteCollection)(func(c echo.Context) error { + reached = true + return nil + }) + + err := h(c) + require.NoError(t, err) + require.True(t, reached) +} diff --git a/web/routes.go b/web/routes.go index f1af85e29..284541325 100644 --- a/web/routes.go +++ b/web/routes.go @@ -62,7 +62,7 @@ func AddRouters( v0.GET("/persons/:id/image", personHandler.GetImage) v0.GET("/persons/:id/subjects", personHandler.GetRelatedSubjects) v0.GET("/persons/:id/characters", personHandler.GetRelatedCharacters) - v0.POST("/persons/:id/collect", personHandler.CollectPerson, mw.NeedLogin) + v0.POST("/persons/:id/collect", personHandler.CollectPerson, mw.NeedLogin, mw.NeedScope(mw.ScopeWriteCollection)) // TODO: wait for soft delete // v0.DELETE("/persons/:id/collect", personHandler.UncollectPerson, mw.NeedLogin) @@ -70,7 +70,7 @@ func AddRouters( v0.GET("/characters/:id/image", characterHandler.GetImage) v0.GET("/characters/:id/subjects", characterHandler.GetRelatedSubjects) v0.GET("/characters/:id/persons", characterHandler.GetRelatedPersons) - v0.POST("/characters/:id/collect", characterHandler.CollectCharacter, mw.NeedLogin) + v0.POST("/characters/:id/collect", characterHandler.CollectCharacter, mw.NeedLogin, mw.NeedScope(mw.ScopeWriteCollection)) // TODO: wait for soft delete // v0.DELETE("/characters/:id/collect", characterHandler.UncollectCharacter, mw.NeedLogin) @@ -85,12 +85,15 @@ func AddRouters( v0.GET("/users/:username/collections/:subject_id", userHandler.GetSubjectCollection) v0.GET("/users/-/collections/-/episodes/:episode_id", userHandler.GetEpisodeCollection, mw.NeedLogin) - v0.PUT("/users/-/collections/-/episodes/:episode_id", userHandler.PutEpisodeCollection, req.JSON, mw.NeedLogin) + v0.PUT("/users/-/collections/-/episodes/:episode_id", userHandler.PutEpisodeCollection, + req.JSON, mw.NeedLogin, mw.NeedScope(mw.ScopeWriteCollection)) v0.GET("/users/-/collections/:subject_id/episodes", userHandler.GetSubjectEpisodeCollection, mw.NeedLogin) - v0.PATCH("/users/-/collections/:subject_id", userHandler.PatchSubjectCollection, req.JSON, mw.NeedLogin) - v0.POST("/users/-/collections/:subject_id", userHandler.PostSubjectCollection, req.JSON, mw.NeedLogin) + v0.PATCH("/users/-/collections/:subject_id", userHandler.PatchSubjectCollection, + req.JSON, mw.NeedLogin, mw.NeedScope(mw.ScopeWriteCollection)) + v0.POST("/users/-/collections/:subject_id", userHandler.PostSubjectCollection, + req.JSON, mw.NeedLogin, mw.NeedScope(mw.ScopeWriteCollection)) v0.PATCH("/users/-/collections/:subject_id/episodes", - userHandler.PatchEpisodeCollectionBatch, req.JSON, mw.NeedLogin) + userHandler.PatchEpisodeCollectionBatch, req.JSON, mw.NeedLogin, mw.NeedScope(mw.ScopeWriteCollection)) v0.GET("/users/:username/collections/-/characters", userHandler.ListCharacterCollection) v0.GET("/users/:username/collections/-/characters/:character_id", userHandler.GetCharacterCollection) @@ -102,15 +105,18 @@ func AddRouters( v0.GET("/indices/:id", i.GetIndex) v0.GET("/indices/:id/subjects", i.GetIndexSubjects) // indices - v0.POST("/indices", i.NewIndex, req.JSON, mw.NeedLogin) - v0.PUT("/indices/:id", i.UpdateIndex, req.JSON, mw.NeedLogin) + v0.POST("/indices", i.NewIndex, req.JSON, mw.NeedLogin, mw.NeedScope(mw.ScopeWriteIndices)) + v0.PUT("/indices/:id", i.UpdateIndex, req.JSON, mw.NeedLogin, mw.NeedScope(mw.ScopeWriteIndices)) // indices subjects - v0.POST("/indices/:id/subjects", i.AddIndexSubject, req.JSON, mw.NeedLogin) - v0.PUT("/indices/:id/subjects/:subject_id", i.UpdateIndexSubject, req.JSON, mw.NeedLogin) - v0.DELETE("/indices/:id/subjects/:subject_id", i.RemoveIndexSubject, mw.NeedLogin) + v0.POST("/indices/:id/subjects", i.AddIndexSubject, + req.JSON, mw.NeedLogin, mw.NeedScope(mw.ScopeWriteIndices)) + v0.PUT("/indices/:id/subjects/:subject_id", i.UpdateIndexSubject, + req.JSON, mw.NeedLogin, mw.NeedScope(mw.ScopeWriteIndices)) + v0.DELETE("/indices/:id/subjects/:subject_id", + i.RemoveIndexSubject, mw.NeedLogin, mw.NeedScope(mw.ScopeWriteIndices)) // collect - v0.POST("/indices/:id/collect", i.CollectIndex, mw.NeedLogin) - v0.DELETE("/indices/:id/collect", i.UncollectIndex, mw.NeedLogin) + v0.POST("/indices/:id/collect", i.CollectIndex, mw.NeedLogin, mw.NeedScope(mw.ScopeWriteCollection)) + v0.DELETE("/indices/:id/collect", i.UncollectIndex, mw.NeedLogin, mw.NeedScope(mw.ScopeWriteCollection)) } v0.GET("/revisions/persons/:id", h.GetPersonRevision)